You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
448 lines
13 KiB
448 lines
13 KiB
4 years ago
|
#include <SQLiteCpp/SQLiteCpp.h>
|
||
|
#include <libtelegram/libtelegram.h>
|
||
|
#include <boost/algorithm/string.hpp>
|
||
|
#include <vector>
|
||
|
#include <sstream>
|
||
|
#include <algorithm>
|
||
|
#include <iterator>
|
||
|
#include <regex>
|
||
|
|
||
|
std::string get_token() {
|
||
|
using namespace std;
|
||
|
fstream ftoken;
|
||
|
ftoken.open("apikey.txt", ios::in);
|
||
|
if (!ftoken) {
|
||
|
perror("No apikey.txt file provided");
|
||
|
exit(1);
|
||
|
}
|
||
|
|
||
|
string apikey;
|
||
|
ftoken >> apikey;
|
||
|
return apikey;
|
||
|
}
|
||
|
|
||
|
class SQLiteMarkov {
|
||
|
public:
|
||
|
int order;
|
||
|
private:
|
||
|
std::string db_filename;
|
||
|
SQLite::Database data;
|
||
|
|
||
|
const int WORD_SIZE = 400;
|
||
|
const int MAX_REPLY_LENGTH = 500;
|
||
|
const std::string INSERT_STATEMENT;
|
||
|
const std::string UPDATE_STATEMENT;
|
||
|
const std::string QUERY_STATEMENT;
|
||
|
|
||
|
std::string get_filepath(std::string fp) {
|
||
|
std::string filePath(__FILE__);
|
||
|
return filePath.substr( 0, filePath.length() - std::string("bot.cpp").length())
|
||
|
+ fp;
|
||
|
}
|
||
|
|
||
|
SQLite::Database open_db() {
|
||
|
try {
|
||
|
return SQLite::Database (db_filename, SQLite::OPEN_READWRITE);
|
||
|
} catch (std::exception& e)
|
||
|
{
|
||
|
std::cout << "exception: " << e.what() << std::endl;
|
||
|
}
|
||
|
|
||
|
std::cout << "Creating new db: '" << db_filename << "'" << std::endl;
|
||
|
|
||
|
return SQLite::Database (db_filename, SQLite::OPEN_CREATE|SQLite::OPEN_READWRITE);
|
||
|
}
|
||
|
|
||
|
SQLite::Database Initialise_Data() {
|
||
|
SQLite::Database db = open_db();
|
||
|
|
||
|
SQLite::Statement query(db, "SELECT name from sqlite_master WHERE name = 'markov'");
|
||
|
try
|
||
|
{
|
||
|
if (query.executeStep()) {
|
||
|
std::cout << "database ready." << std::endl;
|
||
|
return db;
|
||
|
}
|
||
|
}
|
||
|
catch (std::exception& e)
|
||
|
{
|
||
|
std::cout << "exception: " << e.what() << std::endl;
|
||
|
}
|
||
|
|
||
|
// create table sql statement construction
|
||
|
std::string ins = "CREATE TABLE markov (\n";
|
||
|
for (int i = 0; i < order + 1; i ++) {
|
||
|
ins += "word_" + std::to_string(i) + " VARCHAR(" +
|
||
|
std::to_string(WORD_SIZE) + ") NOT NULL,\n";
|
||
|
}
|
||
|
ins += "count INTEGER NOT NULL, \n PRIMARY KEY (";
|
||
|
for (int i = 0; i < order + 1; i ++) {
|
||
|
ins += "word_" + std::to_string(i);
|
||
|
if (i != order) {
|
||
|
ins += ",";
|
||
|
}
|
||
|
}
|
||
|
ins += " )\n);";
|
||
|
|
||
|
// run create table instruction
|
||
|
SQLite::Transaction transaction(db);
|
||
|
db.exec(ins);
|
||
|
transaction.commit();
|
||
|
|
||
|
return db;
|
||
|
}
|
||
|
|
||
|
std::string get_update_statement_string() {
|
||
|
std::string update_template = "UPDATE markov SET count = ? WHERE\n";
|
||
|
|
||
|
for (int i = 0; i < order + 1; i++) {
|
||
|
update_template += "word_" + std::to_string(i);
|
||
|
update_template += " = ? ";
|
||
|
if (i != order) {
|
||
|
update_template += "AND ";
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return update_template;
|
||
|
}
|
||
|
|
||
|
std::string get_select_statement_string() {
|
||
|
std::string update_template = "SELECT * FROM markov WHERE " ;
|
||
|
|
||
|
for (int i = 0; i < order; i++) {
|
||
|
update_template += "word_" + std::to_string(i);
|
||
|
update_template += " = ? ";
|
||
|
if (i != order - 1) {
|
||
|
update_template += "AND ";
|
||
|
}
|
||
|
}
|
||
|
|
||
|
update_template += ";";
|
||
|
|
||
|
return update_template;
|
||
|
}
|
||
|
|
||
|
std::string get_insert_statement_string() {
|
||
|
std::string update_template = "INSERT INTO markov VALUES (" ;
|
||
|
|
||
|
for (int i = 0; i < order + 1; i++) {
|
||
|
update_template += "?" ;
|
||
|
if (i != order) {
|
||
|
update_template += ", ";
|
||
|
} else {
|
||
|
update_template += ", ?)";
|
||
|
}
|
||
|
}
|
||
|
return update_template;
|
||
|
}
|
||
|
|
||
|
bool update_count(std::vector<std::string> words, int count) {
|
||
|
std::string update_template = UPDATE_STATEMENT;
|
||
|
|
||
|
|
||
|
SQLite::Statement increment(data, update_template);
|
||
|
|
||
|
increment.bind(1, count);
|
||
|
|
||
|
for (int i = 0; i < order + 1; i++) {
|
||
|
std::cout << i << ": " << words[i];
|
||
|
increment.bind(i + 2, words[i]);
|
||
|
}
|
||
|
|
||
|
return increment.executeStep();
|
||
|
}
|
||
|
|
||
|
bool insert_record(std::vector<std::string> words, int count) {
|
||
|
|
||
|
std::string update_template = INSERT_STATEMENT;
|
||
|
|
||
|
|
||
|
SQLite::Statement increment(data, update_template);
|
||
|
for (int i = 0; i < order + 1; i++) {
|
||
|
std::cout << i << ": " << words[i];
|
||
|
increment.bind(i + 1, words[i]);
|
||
|
}
|
||
|
|
||
|
increment.bind(order + 2, count);
|
||
|
|
||
|
return increment.executeStep();
|
||
|
}
|
||
|
|
||
|
|
||
|
bool add_to_db(std::vector<std::string> words) {
|
||
|
|
||
|
// UPDATE COMPANY SET ADDRESS = 'Texas' WHERE ID = 6;
|
||
|
|
||
|
std::string query_template = "SELECT count FROM markov WHERE ";
|
||
|
for (int i = 0; i < order + 1; i++) {
|
||
|
query_template += "word_" + std::to_string(i);
|
||
|
query_template += " = ? ";
|
||
|
if (i != order) {
|
||
|
query_template += "AND ";
|
||
|
}
|
||
|
}
|
||
|
|
||
|
|
||
|
SQLite::Statement update(data, query_template);
|
||
|
|
||
|
for (int i = 1; i < order + 2; i++) {
|
||
|
update.bind(i, words[i-1]);
|
||
|
}
|
||
|
|
||
|
int count = 0;
|
||
|
if (update.executeStep()) {
|
||
|
count = update.getColumn(0);
|
||
|
count += 1;
|
||
|
update_count(words, count);
|
||
|
} else {
|
||
|
count = 1;
|
||
|
insert_record(words, count);
|
||
|
}
|
||
|
|
||
|
return false;
|
||
|
}
|
||
|
|
||
|
std::string get_next_word(std::vector<std::string> words) {
|
||
|
if (words.size() != order) {
|
||
|
throw "Invalid prompt vector size";
|
||
|
return "";
|
||
|
}
|
||
|
|
||
|
SQLite::Statement query(data, QUERY_STATEMENT);
|
||
|
|
||
|
for (int i = 0; i < order; i++) {
|
||
|
query.bind(i+1, words[i]);
|
||
|
}
|
||
|
|
||
|
int total = 0;
|
||
|
while (query.executeStep()) {
|
||
|
int count = query.getColumn(order + 1);
|
||
|
total += count;
|
||
|
}
|
||
|
|
||
|
total += 1;
|
||
|
int threshold = rand() % total;
|
||
|
|
||
|
// count up to threshold
|
||
|
// pretty sure the order doesn't matter statistically?
|
||
|
// Obviously it will have an impact but I think the distribution should
|
||
|
// be the same whether the list is ordered by frequency or not
|
||
|
|
||
|
query.reset();
|
||
|
total = 0;
|
||
|
|
||
|
std::vector<std::string> message;
|
||
|
|
||
|
while (query.executeStep()) {
|
||
|
int count = query.getColumn(order + 1);
|
||
|
total += count;
|
||
|
if (total >= threshold) {
|
||
|
std::string next_word = query.getColumn(order);
|
||
|
return next_word;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
std::cout << "No matches found" << std::endl;
|
||
|
throw "No next word found";
|
||
|
return "";
|
||
|
}
|
||
|
|
||
|
|
||
|
public:
|
||
|
|
||
|
SQLiteMarkov(int ord, std::string dbname): order(ord),
|
||
|
db_filename(get_filepath(dbname)),
|
||
|
data(Initialise_Data()),
|
||
|
INSERT_STATEMENT(get_insert_statement_string()),
|
||
|
QUERY_STATEMENT(get_select_statement_string()),
|
||
|
UPDATE_STATEMENT(get_update_statement_string())
|
||
|
{;};
|
||
|
|
||
|
SQLiteMarkov(int ord): order(ord),
|
||
|
db_filename(get_filepath("markov.sqlite")),
|
||
|
data(Initialise_Data()),
|
||
|
INSERT_STATEMENT(get_insert_statement_string()),
|
||
|
QUERY_STATEMENT(get_select_statement_string()),
|
||
|
UPDATE_STATEMENT(get_update_statement_string())
|
||
|
{;};
|
||
|
|
||
|
SQLiteMarkov(): order(1),
|
||
|
db_filename(get_filepath("markov.sqlite")),
|
||
|
data(Initialise_Data()),
|
||
|
INSERT_STATEMENT(get_insert_statement_string()),
|
||
|
QUERY_STATEMENT(get_select_statement_string()),
|
||
|
UPDATE_STATEMENT(get_update_statement_string())
|
||
|
{;};
|
||
|
|
||
|
SQLiteMarkov(std::string dbname): order(1),
|
||
|
db_filename(get_filepath(dbname)),
|
||
|
data(Initialise_Data()),
|
||
|
INSERT_STATEMENT(get_insert_statement_string()),
|
||
|
QUERY_STATEMENT(get_select_statement_string()),
|
||
|
UPDATE_STATEMENT(get_update_statement_string())
|
||
|
{;};
|
||
|
|
||
|
bool add_ngrams(std::vector<std::string> words) {
|
||
|
// assertions?
|
||
|
if (words.size() > order + 1) {
|
||
|
return true;
|
||
|
}
|
||
|
|
||
|
for (int i = 0; i < order + 1; i++) {
|
||
|
if (words[i].length() > SQLiteMarkov::WORD_SIZE) {
|
||
|
return true;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return add_to_db(words);
|
||
|
|
||
|
}
|
||
|
|
||
|
std::string get_continuation(std::vector<std::string> prompt) {
|
||
|
std::string new_words = "";
|
||
|
std::vector<std::string> words(prompt);
|
||
|
int count = 0;
|
||
|
|
||
|
do {
|
||
|
std::string next_word = "invalid";
|
||
|
|
||
|
try {
|
||
|
next_word = get_next_word(words);
|
||
|
} catch (std::exception &e) {
|
||
|
std::cout << "exception: " << e.what() << std::endl;
|
||
|
std::cout << "Failed Get Next Word" << std::endl;
|
||
|
break;
|
||
|
}
|
||
|
|
||
|
if (next_word == "MESSAGE_END") {
|
||
|
break;
|
||
|
}
|
||
|
|
||
|
new_words += " " + next_word;
|
||
|
std::regex end_punc("[\\.\\?\\!]");
|
||
|
|
||
|
if (std::regex_search(std::to_string(next_word.back()), end_punc)) {
|
||
|
break;
|
||
|
}
|
||
|
|
||
|
count++;
|
||
|
words.push_back(next_word);
|
||
|
words.erase(words.begin());
|
||
|
|
||
|
} while (count < MAX_REPLY_LENGTH);
|
||
|
|
||
|
return new_words;
|
||
|
}
|
||
|
};
|
||
|
|
||
|
|
||
|
class MarkovHandler {
|
||
|
|
||
|
private:
|
||
|
SQLiteMarkov markov;
|
||
|
|
||
|
const bool MAKE_LOWER = true;
|
||
|
|
||
|
public:
|
||
|
|
||
|
bool test() {
|
||
|
std::vector<std::string> words {"so", "long"};
|
||
|
add_ngrams("so long and thanks\nfor\tall the-fishes bro man");
|
||
|
//markov.add_ngrams(words);
|
||
|
return false;
|
||
|
}
|
||
|
|
||
|
bool process_string(std::string& message) {
|
||
|
if (MAKE_LOWER) {
|
||
|
for (int i = 0; i < message.length(); i++) {
|
||
|
if (std::isupper(message[i])) {
|
||
|
message[i] = std::tolower(message[i], std::locale());
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
message += " MESSAGE_END";
|
||
|
|
||
|
return false;
|
||
|
}
|
||
|
|
||
|
std::vector<std::string> split_string(std::string message) {
|
||
|
|
||
|
std::istringstream ss(message);
|
||
|
|
||
|
std::vector<std::string> words{
|
||
|
std::istream_iterator<std::string>{ss},
|
||
|
std::istream_iterator<std::string>{}
|
||
|
};
|
||
|
|
||
|
return words;
|
||
|
}
|
||
|
|
||
|
bool add_ngrams(std::string message) {
|
||
|
process_string(message);
|
||
|
std::vector<std::string> words (split_string(message));
|
||
|
|
||
|
for (int i = 0; i < words.size(); i++) {
|
||
|
if (i >= markov.order) {
|
||
|
std::vector<std::string> add_words;
|
||
|
|
||
|
int count = 0;
|
||
|
for (int j = i - markov.order; j <= i; j++, count++) {
|
||
|
add_words.push_back(words[j]);
|
||
|
std::cout << j << ": " << add_words[count] << std::endl;
|
||
|
}
|
||
|
|
||
|
markov.add_ngrams(add_words);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return false;
|
||
|
}
|
||
|
|
||
|
std::string continue_message(std::string message) {
|
||
|
std::vector<std::string> words {split_string(message)};
|
||
|
words.erase(words.begin(), words.begin() + words.size() - markov.order);
|
||
|
|
||
|
return markov.get_continuation(words);
|
||
|
}
|
||
|
};
|
||
|
|
||
|
class Bot {
|
||
|
private:
|
||
|
std::string apikey;
|
||
|
telegram::sender sender;
|
||
|
MarkovHandler markov;
|
||
|
|
||
|
public:
|
||
|
telegram::listener::poll listener;
|
||
|
|
||
|
Bot() : apikey(get_token()), sender(apikey), listener(sender) {
|
||
|
echo();
|
||
|
};
|
||
|
|
||
|
Bot(std::string key) : apikey(key), sender(apikey), listener(sender) {
|
||
|
echo();
|
||
|
};
|
||
|
|
||
|
void echo() {
|
||
|
listener.set_callback_message([&](telegram::types::message const &message){
|
||
|
|
||
|
std::string message_text = message.text.value();
|
||
|
markov.add_ngrams(message_text);
|
||
|
|
||
|
std::string reply = markov.continue_message(message_text);
|
||
|
sender.send_message(message.chat.id, reply);
|
||
|
});
|
||
|
}
|
||
|
|
||
|
void run() {
|
||
|
listener.run();
|
||
|
}
|
||
|
};
|
||
|
|
||
|
auto main()->int {
|
||
|
Bot bot {};
|
||
|
bot.run();
|
||
|
return EXIT_SUCCESS;
|
||
|
};
|