#include #include #include #include #include #include #include #include #include #include #include "bot.h" #include "sqlite_markov.h" #include bool MarkovChain::test() { add_ngrams("so long and thanks\nfor\tall the fishes bro man"); return false; } bool MarkovChain::process_string(std::string& message) { if (this->MAKE_LOWER) { for (int i = 0; i < message.length(); i++) { if (std::isupper(message[i])) { message[i] = std::tolower(message[i], std::locale()); } } } return false; } std::vector MarkovChain::split_string(std::string message) { std::istringstream ss(message); std::vector words{ std::istream_iterator{ss}, std::istream_iterator{} }; return words; } std::vector> MarkovChain::get_words(std::string message) { this->process_string(message); std::vector words (split_string(message)); std::vector> import; for (int i = 0; i < words.size(); i++) { if (i >= order) { std::vector add_words; int count = 0; for (int j = i - order; j <= i; j++, count++) { add_words.push_back(words[j]); } import.push_back(add_words); } } return import; } std::string MarkovChain::continue_message(std::string message, int &score) { std::cout << 1 << std::endl; std::vector words = split_string(message); if (words.size() < order) { std::cout << "Message too short.\n"; return ""; } for (int i = 0; i < words.size(); i++) { std::cout << words[i] << std::endl; } std::vector::const_iterator first = words.begin() + (words.size() - order); std::vector::const_iterator last = words.end(); std::vector newwords(first, last); std::cout << "Using seed: "; for (auto x: newwords) std::cout << " " << x; std::cout << "\n"; std::string ret = markov.get_continuation(newwords, score); return ret; } std::string TelegramBotM::get_token() { using namespace std; fstream ftoken; ftoken.open("apikey.txt", ios::in); if (!ftoken) { perror("No apikey.txt file provided"); exit(1); } string apikey; ftoken >> apikey; return apikey; } void TelegramBotM::add_echo() { listener.set_callback_message([&](telegram::types::message const &message) { // listener.set_callback_json([&](nlohmann::json const &input) { std::string message_text = message.text.value(); try { markov.add_ngrams(message_text); } catch (char const * e) { std::cout << "Error adding ngram: " << e << std::endl; return; } //if (replies && rand() % replies != 1) { //return; //} int score = 0; std::string reply; reply = markov.continue_message(message_text, score); if (reply == "") { return; } std::cerr << "SCORE: " << score << std::endl; std::cerr << "Send" << std::endl; try { sender.send_message(message.chat.id, reply); } catch (std::exception &e) { std::cerr << "Send error: " << e.what() << std::endl; } }); } bool MarkovChain::add_ngrams(std::string message) { process_string(message); std::vector> add_words(get_words(message)); if (markov.batch_add_ngrams(add_words)) throw "Failed to add ngrams"; return 0; } int MarkovChain::import_from_file(std::string filename) { std::ifstream infile(filename); int i = 0; std::string line; std::vector> imports(0); while(std::getline(infile, line)) { this->process_string(line); std::vector> add_words(get_words(line)); imports.insert(imports.end(), std::make_move_iterator(add_words.begin()), std::make_move_iterator(add_words.end())); i++; if (i % 21 == 0) { std::cout << "\rread " << i << " lines." << std::flush; } } std::cout << "\nImporting...\n" << std::flush; if (markov.batch_add_ngrams(imports)) return 1; std::cout << "imported " << imports.size() << " ngrams.\n" << std::flush; return 0; } struct options { std::string filename = "markov.sqlite"; std::string tablename = "markov"; int order = 1; std::string importname = ""; bool run_bot = true; }; struct options parse_args(int argc, char ** argv) { struct options options; int c; while ((c = getopt(argc, argv, "d:n:o:i:rs")) != EOF) { switch (c) { case 'd': options.filename = optarg; break; case 'n': options.tablename = optarg; break; case 'o': options.order = std::stoi(optarg); break; case 'i': options.importname = optarg; options.run_bot = false; break; case 'r': options.run_bot = true; break; case 's': options.run_bot = false; break; } } return options; } int main(int argc, char **argv) { struct options opts = parse_args(argc, argv); SQLiteMarkov db = SQLiteMarkov(opts.order, opts.filename, opts.tablename); MarkovChain m = MarkovChain(opts.order, true, db); if (opts.run_bot) { TelegramBotM b = TelegramBotM(opts.order, m, opts.tablename); std::cout << "Running...\n"; b.run(); } else if (opts.importname != "") { std::cout << "Are you sure you want to import from " << opts.importname << " to the database " << opts.filename << ", table " << opts.tablename << "?" << std::endl << " [Y/n] >"; std::string line; while(std::getline(std::cin, line)) { if (line.size() > 0) { if (std::tolower(line[0]) == 'y') { if (m.import_from_file(opts.importname)) { std::cerr << "Failed import." << std::endl; } break; } else if (std::tolower(line[0]) == 'n') { break; } } } } return 0; }