alistair
4 years ago
6 changed files with 487 additions and 419 deletions
@ -1,3 +1,6 @@
@@ -1,3 +1,6 @@
|
||||
[submodule "libtelegram"] |
||||
path = libtelegram |
||||
url = https://github.com/slowriot/libtelegram.git |
||||
[submodule "SQLiteCpp"] |
||||
path = SQLiteCpp |
||||
url = https://github.com/SRombauts/SQLiteCpp.git |
||||
|
@ -0,0 +1,78 @@
@@ -0,0 +1,78 @@
|
||||
#ifndef BOT_H |
||||
#define BOT_H |
||||
|
||||
#include <boost/algorithm/string.hpp> |
||||
#include <sstream> |
||||
#include <algorithm> |
||||
#include <iterator> |
||||
#include <regex> |
||||
#include <string> |
||||
#include <vector> |
||||
#include <libtelegram/libtelegram.h> |
||||
|
||||
class MarkovDB { |
||||
public: |
||||
virtual bool add_ngrams(std::vector<std::string> words) = 0;
|
||||
virtual std::string get_continuation(std::vector<std::string> prompt, int &score) = 0;
|
||||
virtual ~MarkovDB() {} |
||||
}; |
||||
|
||||
class MarkovChain { |
||||
protected: |
||||
MarkovDB &markov; |
||||
|
||||
int order = 1; |
||||
bool MAKE_LOWER = false; |
||||
|
||||
bool process_string(std::string& message);
|
||||
std::vector<std::string> split_string(std::string message);
|
||||
bool test();
|
||||
|
||||
public: |
||||
bool add_ngrams(std::string message);
|
||||
|
||||
std::string continue_message(std::string message, int &score);
|
||||
|
||||
MarkovChain (int ord, bool lower, MarkovDB &db)
|
||||
: markov(db), MAKE_LOWER(lower), order(ord) {; |
||||
}; |
||||
|
||||
virtual ~MarkovChain() {}; |
||||
}; |
||||
|
||||
class Bot { |
||||
protected: |
||||
MarkovChain markov; |
||||
int replies = 200; |
||||
int SCORE_THRESHOLD = 100; |
||||
public: |
||||
virtual void run() = 0; |
||||
Bot(int replies, int score_threshold, MarkovChain m) |
||||
: markov(m), replies(replies), SCORE_THRESHOLD(score_threshold) {}; |
||||
~Bot() {}; |
||||
}; |
||||
|
||||
class TelegramBotM: public Bot { |
||||
protected: |
||||
std::string api_key = ""; |
||||
telegram::sender sender; |
||||
telegram::listener::poll listener; |
||||
std::string chain_name = "markov"; |
||||
std::string get_token();
|
||||
void add_echo();
|
||||
|
||||
public:
|
||||
TelegramBotM(int order, MarkovChain &m, std::string chain_name) :
|
||||
api_key(get_token()), sender(get_token()), listener(sender),
|
||||
chain_name(chain_name), Bot(200, 100, m) { |
||||
add_echo(); |
||||
};
|
||||
|
||||
void run() override { |
||||
listener.run(); |
||||
} |
||||
|
||||
~TelegramBotM() {}; |
||||
}; |
||||
|
||||
#endif |
@ -0,0 +1,317 @@
@@ -0,0 +1,317 @@
|
||||
#include "bot.h" |
||||
#include <SQLiteCpp/SQLiteCpp.h> |
||||
|
||||
class SQLiteMarkov: public MarkovDB { |
||||
protected: |
||||
const std::string db_filename; |
||||
const int WORD_SIZE = 400; // chars
|
||||
const int MAX_REPLY_LENGTH = 500; // words
|
||||
const std::string INSERT_STATEMENT; |
||||
const std::string UPDATE_STATEMENT; |
||||
const std::string QUERY_STATEMENT; |
||||
const std::string TABLE_NAME; |
||||
|
||||
int order; |
||||
SQLite::Database data; |
||||
|
||||
std::string get_filepath(std::string fp) { |
||||
std::string filePath(__FILE__); |
||||
return filePath.substr( 0, filePath.length() - std::string("bot.cpp").length())
|
||||
+ fp; |
||||
} |
||||
|
||||
SQLite::Database open_db() { |
||||
try { |
||||
return SQLite::Database (db_filename, SQLite::OPEN_READWRITE); |
||||
} catch (std::exception& e)
|
||||
{ |
||||
std::cout << "exception: " << e.what() << std::endl << std::flush; |
||||
} |
||||
|
||||
std::cout << "Creating new db: '" << db_filename << "'" << std::endl; |
||||
|
||||
return SQLite::Database (db_filename, SQLite::OPEN_CREATE|SQLite::OPEN_READWRITE); |
||||
} |
||||
|
||||
SQLite::Database Initialise_Data() { |
||||
SQLite::Database db = open_db(); |
||||
|
||||
SQLite::Statement query(db, "SELECT name from sqlite_master WHERE name = 'markov'"); |
||||
try
|
||||
{ |
||||
if (query.executeStep()) { |
||||
std::cout << "database ready." << std::endl; |
||||
return db; |
||||
} |
||||
} |
||||
catch (std::exception& e) |
||||
{ |
||||
std::cout << "exception: " << e.what() << std::endl; |
||||
} |
||||
|
||||
// create table sql statement construction
|
||||
std::string ins = "CREATE TABLE markov (\n"; |
||||
for (int i = 0; i < order + 1; i ++) { |
||||
ins += "word_" + std::to_string(i) + " VARCHAR(" +
|
||||
std::to_string(WORD_SIZE) + ") NOT NULL,\n"; |
||||
} |
||||
ins += "count INTEGER NOT NULL, \n PRIMARY KEY ("; |
||||
for (int i = 0; i < order + 1; i ++) { |
||||
ins += "word_" + std::to_string(i); |
||||
if (i != order) { |
||||
ins += ","; |
||||
} |
||||
} |
||||
ins += " )\n);"; |
||||
|
||||
// run create table instruction
|
||||
SQLite::Transaction transaction(db); |
||||
db.exec(ins); |
||||
transaction.commit(); |
||||
|
||||
return db; |
||||
} |
||||
|
||||
std::string get_update_statement_string() { |
||||
std::string update_template = "UPDATE markov SET count = ? WHERE\n"; |
||||
|
||||
for (int i = 0; i < order + 1; i++) { |
||||
update_template += "word_" + std::to_string(i); |
||||
update_template += " = ? "; |
||||
if (i != order) { |
||||
update_template += "AND "; |
||||
} |
||||
} |
||||
|
||||
return update_template; |
||||
} |
||||
|
||||
std::string get_select_statement_string() { |
||||
std::string update_template = "SELECT * FROM markov WHERE " ; |
||||
|
||||
for (int i = 0; i < order; i++) { |
||||
update_template += "word_" + std::to_string(i); |
||||
update_template += " = ? "; |
||||
if (i != order - 1) { |
||||
update_template += "AND "; |
||||
} |
||||
} |
||||
|
||||
update_template += ";"; |
||||
|
||||
return update_template; |
||||
} |
||||
|
||||
std::string get_insert_statement_string() { |
||||
std::string update_template = "INSERT INTO markov VALUES (" ; |
||||
|
||||
for (int i = 0; i < order + 1; i++) { |
||||
update_template += "?" ; |
||||
if (i != order) { |
||||
update_template += ", "; |
||||
} else { |
||||
update_template += ", ?)"; |
||||
} |
||||
} |
||||
return update_template; |
||||
} |
||||
|
||||
bool update_count(std::vector<std::string> words, int count) { |
||||
std::string update_template = UPDATE_STATEMENT; |
||||
|
||||
|
||||
SQLite::Statement increment(data, update_template); |
||||
|
||||
increment.bind(1, count); |
||||
|
||||
for (int i = 0; i < order + 1; i++) { |
||||
increment.bind(i + 2, words[i]); |
||||
} |
||||
|
||||
return increment.executeStep(); |
||||
} |
||||
|
||||
bool insert_record(std::vector<std::string> words, int count) { |
||||
|
||||
std::string update_template = INSERT_STATEMENT; |
||||
|
||||
SQLite::Statement increment(data, update_template); |
||||
for (int i = 0; i < order + 1; i++) { |
||||
increment.bind(i + 1, words[i]); |
||||
} |
||||
|
||||
increment.bind(order + 2, count); |
||||
|
||||
return increment.executeStep(); |
||||
} |
||||
|
||||
bool add_to_db(std::vector<std::string> words) { |
||||
|
||||
// UPDATE COMPANY SET ADDRESS = 'Texas' WHERE ID = 6;
|
||||
|
||||
std::string query_template = "SELECT count FROM markov WHERE "; |
||||
for (int i = 0; i < order + 1; i++) { |
||||
query_template += "word_" + std::to_string(i); |
||||
query_template += " = ? "; |
||||
if (i != order) { |
||||
query_template += "AND "; |
||||
} |
||||
} |
||||
|
||||
SQLite::Statement update(data, query_template); |
||||
|
||||
for (int i = 1; i < order + 2; i++) { |
||||
update.bind(i, words[i-1]); |
||||
} |
||||
|
||||
int count = 0; |
||||
if (update.executeStep()) { |
||||
count = update.getColumn(0);
|
||||
count += 1; |
||||
update_count(words, count); |
||||
} else { |
||||
count = 1; |
||||
insert_record(words, count); |
||||
} |
||||
|
||||
return false; |
||||
} |
||||
|
||||
std::string get_next_word(std::vector<std::string> words, int &score) { |
||||
if (words.size() != order) { |
||||
throw "Invalid prompt vector size"; |
||||
return ""; |
||||
} |
||||
|
||||
SQLite::Statement query(data, QUERY_STATEMENT); |
||||
|
||||
for (int i = 0; i < order; i++) { |
||||
query.bind(i+1, words[i]); |
||||
} |
||||
|
||||
int total = 0; |
||||
while (query.executeStep()) { |
||||
int count = query.getColumn(order + 1); |
||||
total += count; |
||||
} |
||||
|
||||
total += 1; |
||||
int threshold = rand() % total; |
||||
|
||||
// count up to threshold
|
||||
// pretty sure the order doesn't matter statistically?
|
||||
// Obviously it will have an impact but I think the distribution should
|
||||
// be the same whether the list is ordered by frequency or not
|
||||
|
||||
query.reset(); |
||||
total = 0; |
||||
|
||||
std::vector<std::string> message; |
||||
|
||||
while (query.executeStep()) { |
||||
int count = query.getColumn(order + 1); |
||||
total += count; |
||||
if (total >= threshold) { |
||||
std::string next_word = query.getColumn(order); |
||||
|
||||
score = total; |
||||
return next_word; |
||||
} |
||||
} |
||||
|
||||
std::cout << "No matches found" << std::endl; |
||||
throw "No next word found"; |
||||
return ""; |
||||
} |
||||
|
||||
std::string get_next_word(std::vector<std::string> words) { |
||||
int count; |
||||
return get_next_word(words, count); |
||||
} |
||||
|
||||
public: |
||||
|
||||
SQLiteMarkov(int ord, std::string dbname, std::string table): order(ord), |
||||
db_filename(get_filepath(dbname)),
|
||||
data(Initialise_Data()), |
||||
TABLE_NAME(table), |
||||
INSERT_STATEMENT(get_insert_statement_string()), |
||||
QUERY_STATEMENT(get_select_statement_string()), |
||||
UPDATE_STATEMENT(get_update_statement_string()) |
||||
{;}; |
||||
|
||||
SQLiteMarkov(int ord): order(ord),
|
||||
db_filename(get_filepath("markov.sqlite")),
|
||||
data(Initialise_Data()), |
||||
TABLE_NAME("markov"), |
||||
INSERT_STATEMENT(get_insert_statement_string()), |
||||
QUERY_STATEMENT(get_select_statement_string()), |
||||
UPDATE_STATEMENT(get_update_statement_string()) |
||||
{;}; |
||||
|
||||
SQLiteMarkov(): order(1),
|
||||
db_filename(get_filepath("markov.sqlite")),
|
||||
data(Initialise_Data()), |
||||
TABLE_NAME("markov"), |
||||
INSERT_STATEMENT(get_insert_statement_string()), |
||||
QUERY_STATEMENT(get_select_statement_string()), |
||||
UPDATE_STATEMENT(get_update_statement_string()) |
||||
{;}; |
||||
|
||||
|
||||
bool add_ngrams(std::vector<std::string> words) override { |
||||
// assertions?
|
||||
if (words.size() > order + 1) { |
||||
return true; |
||||
} |
||||
|
||||
for (int i = 0; i < order + 1; i++) { |
||||
if (words[i].length() > SQLiteMarkov::WORD_SIZE) { |
||||
return true; |
||||
} |
||||
} |
||||
|
||||
return add_to_db(words); |
||||
|
||||
} |
||||
|
||||
std::string get_continuation(std::vector<std::string> prompt, int &score)
|
||||
override { |
||||
std::string new_words = ""; |
||||
std::vector<std::string> words(prompt); |
||||
int count = 0; |
||||
score = 0; |
||||
|
||||
do { |
||||
std::string next_word = "invalid"; |
||||
|
||||
try { |
||||
int total; |
||||
next_word = get_next_word(words, total); |
||||
score += total; |
||||
} catch (std::exception &e) { |
||||
std::cout << "exception: " << e.what() << std::endl;
|
||||
std::cout << "Failed Get Next Word" << std::endl; |
||||
break; |
||||
} |
||||
|
||||
if (next_word == "MESSAGE_END") { |
||||
break; |
||||
} |
||||
|
||||
new_words += " " + next_word; |
||||
std::regex end_punc("[\\.\\?\\!]"); |
||||
|
||||
if (std::regex_search(std::to_string(next_word.back()), end_punc)) { |
||||
break; |
||||
} |
||||
|
||||
count++; |
||||
words.push_back(next_word); |
||||
words.erase(words.begin()); |
||||
} while (count < MAX_REPLY_LENGTH); |
||||
|
||||
return new_words; |
||||
} |
||||
}; |
||||
|
Loading…
Reference in new issue