diff --git a/bot.cpp b/bot.cpp index f4bd08c..499eed9 100644 --- a/bot.cpp +++ b/bot.cpp @@ -27,11 +27,10 @@ bool MarkovChain::process_string(std::string& message) { } } - message += " MESSAGE_END"; - return false; } + std::vector MarkovChain::split_string(std::string message) { std::istringstream ss(message); @@ -43,11 +42,12 @@ std::vector MarkovChain::split_string(std::string message) { return words; } -bool MarkovChain::add_ngrams(std::string message) { +std::vector> MarkovChain::get_words(std::string message) { this->process_string(message); std::vector words (split_string(message)); + std::vector> import; for (int i = 0; i < words.size(); i++) { if (i >= order) { @@ -58,11 +58,11 @@ bool MarkovChain::add_ngrams(std::string message) { add_words.push_back(words[j]); } - markov.add_ngrams(add_words); + import.push_back(add_words); } } - return false; + return import; } @@ -88,6 +88,7 @@ std::string MarkovChain::continue_message(std::string message, int &score) { std::cout << "Using seed: "; for (auto x: newwords) std::cout << " " << x; + std::cout << "\n"; std::string ret = markov.get_continuation(newwords, score); @@ -95,6 +96,7 @@ std::string MarkovChain::continue_message(std::string message, int &score) { return ret; } + std::string TelegramBotM::get_token() { using namespace std; fstream ftoken; @@ -117,8 +119,9 @@ void TelegramBotM::add_echo() { try { markov.add_ngrams(message_text); - } catch (std::exception &e) { - std::cout << "Error adding ngram: " << e.what() << std::endl; + } catch (char const * e) { + std::cout << "Error adding ngram: " << e << std::endl; + return; } //if (replies && rand() % replies != 1) { @@ -129,7 +132,6 @@ void TelegramBotM::add_echo() { std::string reply; reply = markov.continue_message(message_text, score); - std::cout << "Reply generated: " << reply << std::endl; if (reply == "") { return; @@ -148,24 +150,56 @@ void TelegramBotM::add_echo() { }); } -int import_from_file(std::string filename, MarkovChain &m) { + +bool MarkovChain::add_ngrams(std::string message) { + + process_string(message); + std::vector> add_words(get_words(message)); + + if (markov.batch_add_ngrams(add_words)) + throw "Failed to add ngrams"; + + return 0; +} + + +int MarkovChain::import_from_file(std::string filename) { std::ifstream infile(filename); int i = 0; std::string line; + std::vector> imports(0); + while(std::getline(infile, line)) { - m.add_ngrams(line); - std::cout << "\radded " << ++i << " lines."; + + this->process_string(line); + std::vector> add_words(get_words(line)); + + imports.insert(imports.end(), + std::make_move_iterator(add_words.begin()), + std::make_move_iterator(add_words.end())); + + i++; + if (i % 21 == 0) { + std::cout << "\rread " << i << " lines." << std::flush; + } } + std::cout << "\nImporting...\n" << std::flush; + + if (markov.batch_add_ngrams(imports)) + return 1; + + std::cout << "imported " << imports.size() << " ngrams.\n" << std::flush; + return 0; } struct options { std::string filename = "markov.sqlite"; std::string tablename = "markov"; - int order = 2; + int order = 1; std::string importname = ""; bool run_bot = true; }; @@ -203,7 +237,6 @@ struct options parse_args(int argc, char ** argv) { } int main(int argc, char **argv) { - struct options opts = parse_args(argc, argv); SQLiteMarkov db = SQLiteMarkov(opts.order, opts.filename, opts.tablename); @@ -223,8 +256,7 @@ int main(int argc, char **argv) { while(std::getline(std::cin, line)) { if (line.size() > 0) { if (std::tolower(line[0]) == 'y') { - std::cout << "Importing...\n"; - if (import_from_file(opts.importname, m)) { + if (m.import_from_file(opts.importname)) { std::cerr << "Failed import." << std::endl; } break; diff --git a/bot.h b/bot.h index 8986151..a1dcc4c 100644 --- a/bot.h +++ b/bot.h @@ -15,6 +15,7 @@ class MarkovDB { virtual bool add_ngrams(std::vector words) = 0; virtual std::string get_continuation(std::vector prompt, int &score) = 0; virtual ~MarkovDB() {} + virtual bool batch_add_ngrams(std::vector> batch) = 0; }; class MarkovChain { @@ -36,10 +37,15 @@ class MarkovChain { MarkovChain (int ord, bool lower, MarkovDB &db) : markov(db), MAKE_LOWER(lower), order(ord) {; }; + + int import_from_file(std::string filename); + + std::vector> get_words(std::string message); virtual ~MarkovChain() {}; }; + class Bot { protected: MarkovChain markov; @@ -52,6 +58,7 @@ class Bot { ~Bot() {}; }; + class TelegramBotM: public Bot { protected: std::string api_key = ""; diff --git a/sqlite_markov.h b/sqlite_markov.h index 4e94312..3759af1 100644 --- a/sqlite_markov.h +++ b/sqlite_markov.h @@ -1,6 +1,7 @@ #include "SQLiteCpp/Database.h" #include #include +#include "SQLiteCpp/Transaction.h" #include "bot.h" class SQLiteMarkov: public MarkovDB { @@ -13,8 +14,7 @@ class SQLiteMarkov: public MarkovDB { const std::string TABLE_NAME = "markov"; const std::string INSERT_STATEMENT = get_insert_statement_string(); - const std::string UPDATE_STATEMENT = get_update_statement_string(); - const std::string QUERY_STATEMENT = get_select_statement_string(); + const std::string UPDATE_STATEMENT = get_update_statement_string(); const std::string QUERY_STATEMENT = get_select_statement_string(); std::shared_ptr data; @@ -24,6 +24,7 @@ class SQLiteMarkov: public MarkovDB { + fp; } + std::unique_ptr open_db() { try { return std::unique_ptr(new @@ -72,8 +73,7 @@ class SQLiteMarkov: public MarkovDB { } ins += " )\n);"; - // run create table instruction - SQLite::Transaction transaction(*data); + SQLite::Transaction transaction (*data); data->exec(ins); transaction.commit(); @@ -228,7 +228,6 @@ class SQLiteMarkov: public MarkovDB { } std::cout << "No matches found" << std::endl; - throw "No next word found"; return ""; } @@ -238,7 +237,6 @@ class SQLiteMarkov: public MarkovDB { } public: - SQLiteMarkov(int ord, std::string dbname, std::string table): order(ord), db_filename(dbname), @@ -263,7 +261,10 @@ class SQLiteMarkov: public MarkovDB { Initialise_Data(); } + virtual ~SQLiteMarkov() {}; + bool add_ngrams(std::vector words) override { + // assertions? if (words.size() > order + 1) { return true; @@ -278,15 +279,36 @@ class SQLiteMarkov: public MarkovDB { return add_to_db(words); } + bool batch_add_ngrams(std::vector> batch) + override { + + SQLite::Transaction transaction(*data); + + for (auto ngram: batch) { + add_ngrams(ngram); + } + + transaction.commit(); + + return 0; + } + std::string get_continuation(std::vector prompt, int &score) override { - std::string new_words = ""; + std::stringstream ss; + for (int i=0; i < prompt.size(); i++) { + ss << " " << prompt[i]; + } + + std::string new_words = ss.str(); std::vector words(prompt); + int count = 0; score = 0; do { - std::string next_word = "invalid"; + std::string next_word = "throwaway"; + try { int total; next_word = get_next_word(words, total); @@ -297,14 +319,12 @@ class SQLiteMarkov: public MarkovDB { break; } - if (next_word == "MESSAGE_END") { - break; - } - new_words += " " + next_word; std::regex end_punc("[\\.\\?\\!]"); - if (std::regex_search(std::to_string(next_word.back()), end_punc)) { + if (next_word.size() == 0 + || std::regex_search(std::to_string(next_word.back()), + end_punc)) { break; } @@ -314,6 +334,9 @@ class SQLiteMarkov: public MarkovDB { } while (count < MAX_REPLY_LENGTH); - return new_words; + if (count > 0) + return new_words; + else + return ""; } };