You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
271 lines
6.6 KiB
271 lines
6.6 KiB
#include <libtelegram/libtelegram.h> |
|
#include <boost/algorithm/string.hpp> |
|
#include <vector> |
|
#include <sstream> |
|
#include <algorithm> |
|
#include <iterator> |
|
#include <regex> |
|
#include <iostream> |
|
#include <string> |
|
#include <unistd.h> |
|
#include "bot.h" |
|
#include "sqlite_markov.h" |
|
#include <cstdlib> |
|
|
|
bool MarkovChain::test() { |
|
add_ngrams("so long and thanks\nfor\tall the fishes bro man"); |
|
return false; |
|
} |
|
|
|
|
|
bool MarkovChain::process_string(std::string& message) { |
|
if (this->MAKE_LOWER) { |
|
for (int i = 0; i < message.length(); i++) { |
|
if (std::isupper(message[i])) { |
|
message[i] = std::tolower(message[i], std::locale()); |
|
} |
|
} |
|
} |
|
|
|
return false; |
|
} |
|
|
|
|
|
std::vector<std::string> MarkovChain::split_string(std::string message) { |
|
std::istringstream ss(message); |
|
|
|
std::vector<std::string> words{ |
|
std::istream_iterator<std::string>{ss}, |
|
std::istream_iterator<std::string>{} |
|
}; |
|
|
|
return words; |
|
} |
|
|
|
std::vector<std::vector<std::string>> MarkovChain::get_words(std::string message) { |
|
this->process_string(message); |
|
|
|
std::vector<std::string> words (split_string(message)); |
|
|
|
std::vector<std::vector<std::string>> import; |
|
|
|
for (int i = 0; i < words.size(); i++) { |
|
if (i >= order) { |
|
std::vector<std::string> add_words; |
|
|
|
int count = 0; |
|
for (int j = i - order; j <= i; j++, count++) { |
|
add_words.push_back(words[j]); |
|
} |
|
|
|
import.push_back(add_words); |
|
} |
|
} |
|
|
|
return import; |
|
} |
|
|
|
|
|
std::string MarkovChain::continue_message(std::string message, int &score) { |
|
std::cout << 1 << std::endl; |
|
std::vector<std::string> words = split_string(message); |
|
|
|
if (words.size() < order) { |
|
std::cout << "Message too short.\n"; |
|
return ""; |
|
} |
|
|
|
for (int i = 0; i < words.size(); i++) { |
|
std::cout << words[i] << std::endl; |
|
} |
|
|
|
std::vector<std::string>::const_iterator first = words.begin() |
|
+ (words.size() - order); |
|
|
|
std::vector<std::string>::const_iterator last = words.end(); |
|
|
|
std::vector<std::string> newwords(first, last); |
|
std::cout << "Using seed: "; |
|
for (auto x: newwords) |
|
std::cout << " " << x; |
|
|
|
std::cout << "\n"; |
|
|
|
std::string ret = markov.get_continuation(newwords, score); |
|
|
|
return ret; |
|
} |
|
|
|
|
|
std::string TelegramBotM::get_token() { |
|
using namespace std; |
|
fstream ftoken; |
|
ftoken.open("apikey.txt", ios::in); |
|
if (!ftoken) { |
|
perror("No apikey.txt file provided"); |
|
exit(1); |
|
} |
|
|
|
string apikey; |
|
ftoken >> apikey; |
|
return apikey; |
|
} |
|
|
|
|
|
void TelegramBotM::add_echo() { |
|
listener.set_callback_message([&](telegram::types::message const &message) { |
|
// listener.set_callback_json([&](nlohmann::json const &input) { |
|
std::string message_text = message.text.value(); |
|
|
|
try { |
|
markov.add_ngrams(message_text); |
|
} catch (char const * e) { |
|
std::cout << "Error adding ngram: " << e << std::endl; |
|
return; |
|
} |
|
|
|
//if (replies && rand() % replies != 1) { |
|
//return; |
|
//} |
|
|
|
int score = 0; |
|
std::string reply; |
|
|
|
reply = markov.continue_message(message_text, score); |
|
|
|
if (reply == "") { |
|
return; |
|
} |
|
|
|
std::cerr << "SCORE: " << score << std::endl; |
|
|
|
std::cerr << "Send" << std::endl; |
|
|
|
try { |
|
sender.send_message(message.chat.id, reply); |
|
} catch (std::exception &e) { |
|
std::cerr << "Send error: " << e.what() << std::endl; |
|
|
|
} |
|
}); |
|
} |
|
|
|
|
|
bool MarkovChain::add_ngrams(std::string message) { |
|
|
|
process_string(message); |
|
std::vector<std::vector<std::string>> add_words(get_words(message)); |
|
|
|
if (markov.batch_add_ngrams(add_words)) |
|
throw "Failed to add ngrams"; |
|
|
|
return 0; |
|
} |
|
|
|
|
|
int MarkovChain::import_from_file(std::string filename) { |
|
|
|
std::ifstream infile(filename); |
|
|
|
int i = 0; |
|
std::string line; |
|
std::vector<std::vector<std::string>> imports(0); |
|
|
|
while(std::getline(infile, line)) { |
|
|
|
this->process_string(line); |
|
std::vector<std::vector<std::string>> add_words(get_words(line)); |
|
|
|
imports.insert(imports.end(), |
|
std::make_move_iterator(add_words.begin()), |
|
std::make_move_iterator(add_words.end())); |
|
|
|
i++; |
|
if (i % 21 == 0) { |
|
std::cout << "\rread " << i << " lines." << std::flush; |
|
} |
|
} |
|
|
|
std::cout << "\nImporting...\n" << std::flush; |
|
|
|
if (markov.batch_add_ngrams(imports)) |
|
return 1; |
|
|
|
std::cout << "imported " << imports.size() << " ngrams.\n" << std::flush; |
|
|
|
return 0; |
|
} |
|
|
|
struct options { |
|
std::string filename = "markov.sqlite"; |
|
std::string tablename = "markov"; |
|
int order = 1; |
|
std::string importname = ""; |
|
bool run_bot = true; |
|
}; |
|
|
|
struct options parse_args(int argc, char ** argv) { |
|
struct options options; |
|
|
|
int c; |
|
|
|
while ((c = getopt(argc, argv, "d:n:o:i:rs")) != EOF) { |
|
switch (c) { |
|
case 'd': |
|
options.filename = optarg; |
|
break; |
|
case 'n': |
|
options.tablename = optarg; |
|
break; |
|
case 'o': |
|
options.order = std::stoi(optarg); |
|
break; |
|
case 'i': |
|
options.importname = optarg; |
|
options.run_bot = false; |
|
break; |
|
case 'r': |
|
options.run_bot = true; |
|
break; |
|
case 's': |
|
options.run_bot = false; |
|
break; |
|
} |
|
} |
|
|
|
return options; |
|
} |
|
|
|
int main(int argc, char **argv) { |
|
struct options opts = parse_args(argc, argv); |
|
|
|
SQLiteMarkov db = SQLiteMarkov(opts.order, opts.filename, opts.tablename); |
|
MarkovChain m = MarkovChain(opts.order, true, db); |
|
|
|
if (opts.run_bot) { |
|
TelegramBotM b = TelegramBotM(opts.order, m, opts.tablename); |
|
std::cout << "Running...\n"; |
|
b.run(); |
|
} else if (opts.importname != "") { |
|
std::cout << "Are you sure you want to import from " << opts.importname |
|
<< " to the database " << opts.filename << ", table " |
|
<< opts.tablename << "?" << std::endl << " [Y/n] >"; |
|
|
|
std::string line; |
|
|
|
while(std::getline(std::cin, line)) { |
|
if (line.size() > 0) { |
|
if (std::tolower(line[0]) == 'y') { |
|
if (m.import_from_file(opts.importname)) { |
|
std::cerr << "Failed import." << std::endl; |
|
} |
|
break; |
|
} else if (std::tolower(line[0]) == 'n') { |
|
break; |
|
} |
|
} |
|
} |
|
} |
|
|
|
return 0; |
|
}
|
|
|