commit a0d097b02535eb57bafd33bd1b7d788e41e4fd35 Author: user Date: Fri Feb 28 02:29:13 2020 +1000 works diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..5886c39 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "libtelegram"] + path = libtelegram + url = https://github.com/slowriot/libtelegram.git diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..bccfa10 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,43 @@ +cmake_minimum_required (VERSION 3.12) + +set(PROJECT_NAME "markov bot") +project (${PROJECT_NAME}) + +set_property (GLOBAL PROPERTY USE_FOLDERS ON) + +set (CMAKE_CXX_STANDARD 17) +set (CMAKE_CXX_STANDARD_REQUIRED ON) + +set (THREADS_PREFER_PTHREAD_FLAG ON) + +find_package (Threads REQUIRED) +find_package(Boost 1.66 REQUIRED COMPONENTS system) + + +IF(Boost_FOUND) + include_directories(${Boost_INCLUDE_DIRS}) +ELSE() + message(FATAL "${CMAKE_SOURCE_DIR}/libtelegram/include") +ENDIF(Boost_FOUND) + +find_package(OpenSSL REQUIRED) + +set(include_dir "${CMAKE_SOURCE_DIR}/libtelegram/include") +message(STATUS "include_dir: ${include_dir}") +include_directories (${include_dir}) + +set(link_libs + Threads::Threads + OpenSSL::SSL + Boost::system + SQLiteCpp + sqlite3 + dl) + +macro(add_tg_example name) + set(target_name ${name}) + add_executable(${target_name} ${name}.cpp) + target_link_libraries(${target_name} ${link_libs}) +endmacro(add_tg_example name) + +add_tg_example(bot) diff --git a/Readme.md b/Readme.md new file mode 100644 index 0000000..68e157f --- /dev/null +++ b/Readme.md @@ -0,0 +1,11 @@ + + +## Building on Fedora + +``` + +sudo dnf install boost-devel sqlitecpp-devel openssl-devel +cmake . +make + +``` diff --git a/bot.cpp b/bot.cpp new file mode 100644 index 0000000..e313677 --- /dev/null +++ b/bot.cpp @@ -0,0 +1,447 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +std::string get_token() { + using namespace std; + fstream ftoken; + ftoken.open("apikey.txt", ios::in); + if (!ftoken) { + perror("No apikey.txt file provided"); + exit(1); + } + + string apikey; + ftoken >> apikey; + return apikey; +} + +class SQLiteMarkov { + public: + int order; + private: + std::string db_filename; + SQLite::Database data; + + const int WORD_SIZE = 400; + const int MAX_REPLY_LENGTH = 500; + const std::string INSERT_STATEMENT; + const std::string UPDATE_STATEMENT; + const std::string QUERY_STATEMENT; + + std::string get_filepath(std::string fp) { + std::string filePath(__FILE__); + return filePath.substr( 0, filePath.length() - std::string("bot.cpp").length()) + + fp; + } + + SQLite::Database open_db() { + try { + return SQLite::Database (db_filename, SQLite::OPEN_READWRITE); + } catch (std::exception& e) + { + std::cout << "exception: " << e.what() << std::endl; + } + + std::cout << "Creating new db: '" << db_filename << "'" << std::endl; + + return SQLite::Database (db_filename, SQLite::OPEN_CREATE|SQLite::OPEN_READWRITE); + } + + SQLite::Database Initialise_Data() { + SQLite::Database db = open_db(); + + SQLite::Statement query(db, "SELECT name from sqlite_master WHERE name = 'markov'"); + try + { + if (query.executeStep()) { + std::cout << "database ready." << std::endl; + return db; + } + } + catch (std::exception& e) + { + std::cout << "exception: " << e.what() << std::endl; + } + + // create table sql statement construction + std::string ins = "CREATE TABLE markov (\n"; + for (int i = 0; i < order + 1; i ++) { + ins += "word_" + std::to_string(i) + " VARCHAR(" + + std::to_string(WORD_SIZE) + ") NOT NULL,\n"; + } + ins += "count INTEGER NOT NULL, \n PRIMARY KEY ("; + for (int i = 0; i < order + 1; i ++) { + ins += "word_" + std::to_string(i); + if (i != order) { + ins += ","; + } + } + ins += " )\n);"; + + // run create table instruction + SQLite::Transaction transaction(db); + db.exec(ins); + transaction.commit(); + + return db; + } + + std::string get_update_statement_string() { + std::string update_template = "UPDATE markov SET count = ? WHERE\n"; + + for (int i = 0; i < order + 1; i++) { + update_template += "word_" + std::to_string(i); + update_template += " = ? "; + if (i != order) { + update_template += "AND "; + } + } + + return update_template; + } + + std::string get_select_statement_string() { + std::string update_template = "SELECT * FROM markov WHERE " ; + + for (int i = 0; i < order; i++) { + update_template += "word_" + std::to_string(i); + update_template += " = ? "; + if (i != order - 1) { + update_template += "AND "; + } + } + + update_template += ";"; + + return update_template; + } + + std::string get_insert_statement_string() { + std::string update_template = "INSERT INTO markov VALUES (" ; + + for (int i = 0; i < order + 1; i++) { + update_template += "?" ; + if (i != order) { + update_template += ", "; + } else { + update_template += ", ?)"; + } + } + return update_template; + } + + bool update_count(std::vector words, int count) { + std::string update_template = UPDATE_STATEMENT; + + + SQLite::Statement increment(data, update_template); + + increment.bind(1, count); + + for (int i = 0; i < order + 1; i++) { + std::cout << i << ": " << words[i]; + increment.bind(i + 2, words[i]); + } + + return increment.executeStep(); + } + + bool insert_record(std::vector words, int count) { + + std::string update_template = INSERT_STATEMENT; + + + SQLite::Statement increment(data, update_template); + for (int i = 0; i < order + 1; i++) { + std::cout << i << ": " << words[i]; + increment.bind(i + 1, words[i]); + } + + increment.bind(order + 2, count); + + return increment.executeStep(); + } + + + bool add_to_db(std::vector words) { + + // UPDATE COMPANY SET ADDRESS = 'Texas' WHERE ID = 6; + + std::string query_template = "SELECT count FROM markov WHERE "; + for (int i = 0; i < order + 1; i++) { + query_template += "word_" + std::to_string(i); + query_template += " = ? "; + if (i != order) { + query_template += "AND "; + } + } + + + SQLite::Statement update(data, query_template); + + for (int i = 1; i < order + 2; i++) { + update.bind(i, words[i-1]); + } + + int count = 0; + if (update.executeStep()) { + count = update.getColumn(0); + count += 1; + update_count(words, count); + } else { + count = 1; + insert_record(words, count); + } + + return false; + } + + std::string get_next_word(std::vector words) { + if (words.size() != order) { + throw "Invalid prompt vector size"; + return ""; + } + + SQLite::Statement query(data, QUERY_STATEMENT); + + for (int i = 0; i < order; i++) { + query.bind(i+1, words[i]); + } + + int total = 0; + while (query.executeStep()) { + int count = query.getColumn(order + 1); + total += count; + } + + total += 1; + int threshold = rand() % total; + + // count up to threshold + // pretty sure the order doesn't matter statistically? + // Obviously it will have an impact but I think the distribution should + // be the same whether the list is ordered by frequency or not + + query.reset(); + total = 0; + + std::vector message; + + while (query.executeStep()) { + int count = query.getColumn(order + 1); + total += count; + if (total >= threshold) { + std::string next_word = query.getColumn(order); + return next_word; + } + } + + std::cout << "No matches found" << std::endl; + throw "No next word found"; + return ""; + } + + + public: + + SQLiteMarkov(int ord, std::string dbname): order(ord), + db_filename(get_filepath(dbname)), + data(Initialise_Data()), + INSERT_STATEMENT(get_insert_statement_string()), + QUERY_STATEMENT(get_select_statement_string()), + UPDATE_STATEMENT(get_update_statement_string()) + {;}; + + SQLiteMarkov(int ord): order(ord), + db_filename(get_filepath("markov.sqlite")), + data(Initialise_Data()), + INSERT_STATEMENT(get_insert_statement_string()), + QUERY_STATEMENT(get_select_statement_string()), + UPDATE_STATEMENT(get_update_statement_string()) + {;}; + + SQLiteMarkov(): order(1), + db_filename(get_filepath("markov.sqlite")), + data(Initialise_Data()), + INSERT_STATEMENT(get_insert_statement_string()), + QUERY_STATEMENT(get_select_statement_string()), + UPDATE_STATEMENT(get_update_statement_string()) + {;}; + + SQLiteMarkov(std::string dbname): order(1), + db_filename(get_filepath(dbname)), + data(Initialise_Data()), + INSERT_STATEMENT(get_insert_statement_string()), + QUERY_STATEMENT(get_select_statement_string()), + UPDATE_STATEMENT(get_update_statement_string()) + {;}; + + bool add_ngrams(std::vector words) { + // assertions? + if (words.size() > order + 1) { + return true; + } + + for (int i = 0; i < order + 1; i++) { + if (words[i].length() > SQLiteMarkov::WORD_SIZE) { + return true; + } + } + + return add_to_db(words); + + } + + std::string get_continuation(std::vector prompt) { + std::string new_words = ""; + std::vector words(prompt); + int count = 0; + + do { + std::string next_word = "invalid"; + + try { + next_word = get_next_word(words); + } catch (std::exception &e) { + std::cout << "exception: " << e.what() << std::endl; + std::cout << "Failed Get Next Word" << std::endl; + break; + } + + if (next_word == "MESSAGE_END") { + break; + } + + new_words += " " + next_word; + std::regex end_punc("[\\.\\?\\!]"); + + if (std::regex_search(std::to_string(next_word.back()), end_punc)) { + break; + } + + count++; + words.push_back(next_word); + words.erase(words.begin()); + + } while (count < MAX_REPLY_LENGTH); + + return new_words; + } +}; + + +class MarkovHandler { + + private: + SQLiteMarkov markov; + + const bool MAKE_LOWER = true; + + public: + + bool test() { + std::vector words {"so", "long"}; + add_ngrams("so long and thanks\nfor\tall the-fishes bro man"); + //markov.add_ngrams(words); + return false; + } + + bool process_string(std::string& message) { + if (MAKE_LOWER) { + for (int i = 0; i < message.length(); i++) { + if (std::isupper(message[i])) { + message[i] = std::tolower(message[i], std::locale()); + } + } + } + + message += " MESSAGE_END"; + + return false; + } + + std::vector split_string(std::string message) { + + std::istringstream ss(message); + + std::vector words{ + std::istream_iterator{ss}, + std::istream_iterator{} + }; + + return words; + } + + bool add_ngrams(std::string message) { + process_string(message); + std::vector words (split_string(message)); + + for (int i = 0; i < words.size(); i++) { + if (i >= markov.order) { + std::vector add_words; + + int count = 0; + for (int j = i - markov.order; j <= i; j++, count++) { + add_words.push_back(words[j]); + std::cout << j << ": " << add_words[count] << std::endl; + } + + markov.add_ngrams(add_words); + } + } + + return false; + } + + std::string continue_message(std::string message) { + std::vector words {split_string(message)}; + words.erase(words.begin(), words.begin() + words.size() - markov.order); + + return markov.get_continuation(words); + } +}; + +class Bot { + private: + std::string apikey; + telegram::sender sender; + MarkovHandler markov; + + public: + telegram::listener::poll listener; + + Bot() : apikey(get_token()), sender(apikey), listener(sender) { + echo(); + }; + + Bot(std::string key) : apikey(key), sender(apikey), listener(sender) { + echo(); + }; + + void echo() { + listener.set_callback_message([&](telegram::types::message const &message){ + + std::string message_text = message.text.value(); + markov.add_ngrams(message_text); + + std::string reply = markov.continue_message(message_text); + sender.send_message(message.chat.id, reply); + }); + } + + void run() { + listener.run(); + } +}; + +auto main()->int { + Bot bot {}; + bot.run(); + return EXIT_SUCCESS; +}; diff --git a/libtelegram b/libtelegram new file mode 160000 index 0000000..a0c52ba --- /dev/null +++ b/libtelegram @@ -0,0 +1 @@ +Subproject commit a0c52ba4aac5519a3031dba506ad4a64cd8fca83