#include "SQLiteCpp/Database.h" #include #include #include "SQLiteCpp/Transaction.h" #include "bot.h" class SQLiteMarkov: public MarkovDB { protected: const std::string db_filename = get_filepath("markov.sqlite"); const int WORD_SIZE = 400; // chars const int MAX_REPLY_LENGTH = 500; // words int order = 2; const std::string TABLE_NAME = "markov"; const std::string INSERT_STATEMENT = get_insert_statement_string(); const std::string UPDATE_STATEMENT = get_update_statement_string(); const std::string QUERY_STATEMENT = get_select_statement_string(); std::shared_ptr data; std::string get_filepath(std::string fp) { std::string filePath(__FILE__); return filePath.substr( 0, filePath.length() - std::string("bot.cpp").length()) + fp; } std::unique_ptr open_db() { try { return std::unique_ptr(new SQLite::Database (db_filename, SQLite::OPEN_READWRITE)); } catch (std::exception& e) { std::cout << "exception: " << e.what() << std::endl << std::flush; } std::cout << "Creating new db: '" << db_filename << "'" << std::endl; return std::unique_ptr(new SQLite::Database (db_filename, SQLite::OPEN_CREATE|SQLite::OPEN_READWRITE)); } int Initialise_Data() { SQLite::Statement query(*data, "SELECT name from sqlite_master WHERE name = '" + TABLE_NAME + "'"); try { if (query.executeStep()) { std::cout << "database ready." << std::endl; return 0; } } catch (std::exception& e) { std::cout << "exception: " << e.what() << std::endl; exit(1); } // create table sql statement construction std::string ins = "CREATE TABLE " + TABLE_NAME + " (\n"; for (int i = 0; i < order + 1; i ++) { ins += "word_" + std::to_string(i) + " VARCHAR(" + std::to_string(WORD_SIZE) + ") NOT NULL,\n"; } ins += "count INTEGER NOT NULL, \n PRIMARY KEY ("; for (int i = 0; i < order + 1; i ++) { ins += "word_" + std::to_string(i); if (i != order) { ins += ","; } } ins += " )\n);"; SQLite::Transaction transaction (*data); data->exec(ins); transaction.commit(); return 0; } std::string get_update_statement_string() { std::string update_template = "UPDATE " + TABLE_NAME + " SET count = ? WHERE\n"; for (int i = 0; i < order + 1; i++) { update_template += "word_" + std::to_string(i); update_template += " = ? "; if (i != order) { update_template += "AND "; } } return update_template; } std::string get_select_statement_string() { std::string update_template = "SELECT * FROM " + TABLE_NAME + " WHERE " ; for (int i = 0; i < order; i++) { update_template += "word_" + std::to_string(i); update_template += " = ? "; if (i != order - 1) { update_template += "AND "; } } update_template += ";"; return update_template; } std::string get_insert_statement_string() { std::string update_template = "INSERT INTO " + TABLE_NAME + " VALUES (" ; for (int i = 0; i < order + 1; i++) { update_template += "?" ; if (i != order) { update_template += ", "; } else { update_template += ", ?)"; } } return update_template; } bool update_count(std::vector words, int count) { std::string update_template = UPDATE_STATEMENT; SQLite::Statement increment(*data, update_template); increment.bind(1, count); for (int i = 0; i < order + 1; i++) { increment.bind(i + 2, words[i]); } return increment.executeStep(); } bool insert_record(std::vector words, int count) { std::string update_template = INSERT_STATEMENT; SQLite::Statement increment(*data, update_template); for (int i = 0; i < order + 1; i++) { increment.bind(i + 1, words[i]); } increment.bind(order + 2, count); return increment.executeStep(); } bool add_to_db(std::vector words) { // UPDATE COMPANY SET ADDRESS = 'Texas' WHERE ID = 6; std::string query_template = "SELECT count FROM " + TABLE_NAME + " WHERE "; for (int i = 0; i <= order; i++) { query_template += "word_" + std::to_string(i); query_template += " = ? "; if (i != order) { query_template += "AND "; } } SQLite::Statement update(*data, query_template); for (int i = 1; i < order + 2; i++) { update.bind(i, words[i-1]); } int count = 0; if (update.executeStep()) { count = update.getColumn(0); count += 1; update_count(words, count); } else { count = 1; insert_record(words, count); } return false; } std::string get_next_word(std::vector words, int &score) { if (words.size() != order) { return ""; } SQLite::Statement query(*data, QUERY_STATEMENT); for (int i = 0; i < order; i++) { query.bind(i+1, words[i]); } int total = 0; while (query.executeStep()) { int count = query.getColumn(order + 1); total += count; } total += 1; int threshold = rand() % total; // count up to threshold // pretty sure the order doesn't matter statistically? // Obviously it will have an impact but I think the distribution should // be the same whether the list is ordered by frequency or not query.reset(); total = 0; std::vector message; while (query.executeStep()) { int count = query.getColumn(order + 1); total += count; if (total >= threshold) { std::string next_word = query.getColumn(order); score = total; return next_word; } } std::cout << "No matches found" << std::endl; return ""; } std::string get_next_word(std::vector words) { int count; return get_next_word(words, count); } public: SQLiteMarkov(int ord, std::string dbname, std::string table): order(ord), db_filename(dbname), data(open_db()), TABLE_NAME(table) { Initialise_Data(); } SQLiteMarkov(int ord): order(ord), data(open_db()) { Initialise_Data(); } SQLiteMarkov(): data(open_db()) { Initialise_Data();} SQLiteMarkov(int order, std::string table_name, SQLite::Database *db): data(db), order(order), TABLE_NAME(table_name) { Initialise_Data(); } virtual ~SQLiteMarkov() {}; bool add_ngrams(std::vector words) override { // assertions? if (words.size() > order + 1) { return true; } for (int i = 0; i < order + 1; i++) { if (words[i].length() > SQLiteMarkov::WORD_SIZE) { return true; } } return add_to_db(words); } bool batch_add_ngrams(std::vector> batch) override { SQLite::Transaction transaction(*data); for (auto ngram: batch) { add_ngrams(ngram); } transaction.commit(); return 0; } std::string get_continuation(std::vector prompt, int &score) override { std::stringstream ss; for (int i=0; i < prompt.size(); i++) { ss << " " << prompt[i]; } std::string new_words = ss.str(); std::vector words(prompt); int count = 0; score = 0; do { std::string next_word = "throwaway"; try { int total; next_word = get_next_word(words, total); score += total; } catch (char * &e) { std::cout << "Exception: " << e << std::endl; std::cout << "Failed Get Next Word" << std::endl; break; } new_words += " " + next_word; std::regex end_punc("[\\.\\?\\!]"); if (next_word.size() == 0 || std::regex_search(std::to_string(next_word.back()), end_punc)) { break; } count++; words.push_back(next_word); words.erase(words.begin()); } while (count < MAX_REPLY_LENGTH); if (count > 0) return new_words; else return ""; } };