You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

271 lines
6.6 KiB

#include <libtelegram/libtelegram.h>
#include <boost/algorithm/string.hpp>
#include <vector>
#include <sstream>
#include <algorithm>
#include <iterator>
#include <regex>
#include <iostream>
#include <string>
#include <unistd.h>
#include "bot.h"
#include "sqlite_markov.h"
#include <cstdlib>
bool MarkovChain::test() {
add_ngrams("so long and thanks\nfor\tall the fishes bro man");
return false;
}
bool MarkovChain::process_string(std::string& message) {
if (this->MAKE_LOWER) {
for (int i = 0; i < message.length(); i++) {
if (std::isupper(message[i])) {
message[i] = std::tolower(message[i], std::locale());
}
}
}
return false;
}
std::vector<std::string> MarkovChain::split_string(std::string message) {
std::istringstream ss(message);
std::vector<std::string> words{
std::istream_iterator<std::string>{ss},
std::istream_iterator<std::string>{}
};
return words;
}
std::vector<std::vector<std::string>> MarkovChain::get_words(std::string message) {
this->process_string(message);
std::vector<std::string> words (split_string(message));
std::vector<std::vector<std::string>> import;
for (int i = 0; i < words.size(); i++) {
if (i >= order) {
std::vector<std::string> add_words;
int count = 0;
for (int j = i - order; j <= i; j++, count++) {
add_words.push_back(words[j]);
}
import.push_back(add_words);
}
}
return import;
}
std::string MarkovChain::continue_message(std::string message, int &score) {
std::cout << 1 << std::endl;
std::vector<std::string> words = split_string(message);
if (words.size() < order) {
std::cout << "Message too short.\n";
return "";
}
for (int i = 0; i < words.size(); i++) {
std::cout << words[i] << std::endl;
}
std::vector<std::string>::const_iterator first = words.begin()
+ (words.size() - order);
std::vector<std::string>::const_iterator last = words.end();
std::vector<std::string> newwords(first, last);
std::cout << "Using seed: ";
for (auto x: newwords)
std::cout << " " << x;
std::cout << "\n";
std::string ret = markov.get_continuation(newwords, score);
return ret;
}
std::string TelegramBotM::get_token() {
using namespace std;
fstream ftoken;
ftoken.open("apikey.txt", ios::in);
if (!ftoken) {
perror("No apikey.txt file provided");
exit(1);
}
string apikey;
ftoken >> apikey;
return apikey;
}
void TelegramBotM::add_echo() {
listener.set_callback_message([&](telegram::types::message const &message) {
// listener.set_callback_json([&](nlohmann::json const &input) {
std::string message_text = message.text.value();
try {
markov.add_ngrams(message_text);
} catch (char const * e) {
std::cout << "Error adding ngram: " << e << std::endl;
return;
}
//if (replies && rand() % replies != 1) {
//return;
//}
int score = 0;
std::string reply;
reply = markov.continue_message(message_text, score);
if (reply == "") {
return;
}
std::cerr << "SCORE: " << score << std::endl;
std::cerr << "Send" << std::endl;
try {
sender.send_message(message.chat.id, reply);
} catch (std::exception &e) {
std::cerr << "Send error: " << e.what() << std::endl;
}
});
}
bool MarkovChain::add_ngrams(std::string message) {
process_string(message);
std::vector<std::vector<std::string>> add_words(get_words(message));
if (markov.batch_add_ngrams(add_words))
throw "Failed to add ngrams";
return 0;
}
int MarkovChain::import_from_file(std::string filename) {
std::ifstream infile(filename);
int i = 0;
std::string line;
std::vector<std::vector<std::string>> imports(0);
while(std::getline(infile, line)) {
this->process_string(line);
std::vector<std::vector<std::string>> add_words(get_words(line));
imports.insert(imports.end(),
std::make_move_iterator(add_words.begin()),
std::make_move_iterator(add_words.end()));
i++;
if (i % 21 == 0) {
std::cout << "\rread " << i << " lines." << std::flush;
}
}
std::cout << "\nImporting...\n" << std::flush;
if (markov.batch_add_ngrams(imports))
return 1;
std::cout << "imported " << imports.size() << " ngrams.\n" << std::flush;
return 0;
}
struct options {
std::string filename = "markov.sqlite";
std::string tablename = "markov";
int order = 1;
std::string importname = "";
bool run_bot = true;
};
struct options parse_args(int argc, char ** argv) {
struct options options;
int c;
while ((c = getopt(argc, argv, "d:n:o:i:rs")) != EOF) {
switch (c) {
case 'd':
options.filename = optarg;
break;
case 'n':
options.tablename = optarg;
break;
case 'o':
options.order = std::stoi(optarg);
break;
case 'i':
options.importname = optarg;
options.run_bot = false;
break;
case 'r':
options.run_bot = true;
break;
case 's':
options.run_bot = false;
break;
}
}
return options;
}
int main(int argc, char **argv) {
struct options opts = parse_args(argc, argv);
SQLiteMarkov db = SQLiteMarkov(opts.order, opts.filename, opts.tablename);
MarkovChain m = MarkovChain(opts.order, true, db);
if (opts.run_bot) {
TelegramBotM b = TelegramBotM(opts.order, m, opts.tablename);
std::cout << "Running...\n";
b.run();
} else if (opts.importname != "") {
std::cout << "Are you sure you want to import from " << opts.importname
<< " to the database " << opts.filename << ", table "
<< opts.tablename << "?" << std::endl << " [Y/n] >";
std::string line;
while(std::getline(std::cin, line)) {
if (line.size() > 0) {
if (std::tolower(line[0]) == 'y') {
if (m.import_from_file(opts.importname)) {
std::cerr << "Failed import." << std::endl;
}
break;
} else if (std::tolower(line[0]) == 'n') {
break;
}
}
}
}
return 0;
}