Browse Source

reply confidence threshold

master
user 4 years ago
parent
commit
8daeeb471b
  1. 52
      bot.cpp

52
bot.cpp

@ -145,7 +145,6 @@ class SQLiteMarkov { @@ -145,7 +145,6 @@ class SQLiteMarkov {
increment.bind(1, count);
for (int i = 0; i < order + 1; i++) {
std::cout << i << ": " << words[i];
increment.bind(i + 2, words[i]);
}
@ -159,7 +158,6 @@ class SQLiteMarkov { @@ -159,7 +158,6 @@ class SQLiteMarkov {
SQLite::Statement increment(data, update_template);
for (int i = 0; i < order + 1; i++) {
std::cout << i << ": " << words[i];
increment.bind(i + 1, words[i]);
}
@ -167,7 +165,6 @@ class SQLiteMarkov { @@ -167,7 +165,6 @@ class SQLiteMarkov {
return increment.executeStep();
}
bool add_to_db(std::vector<std::string> words) {
@ -202,7 +199,7 @@ class SQLiteMarkov { @@ -202,7 +199,7 @@ class SQLiteMarkov {
return false;
}
std::string get_next_word(std::vector<std::string> words) {
std::string get_next_word(std::vector<std::string> words, int &score) {
if (words.size() != order) {
throw "Invalid prompt vector size";
return "";
@ -238,6 +235,8 @@ class SQLiteMarkov { @@ -238,6 +235,8 @@ class SQLiteMarkov {
total += count;
if (total >= threshold) {
std::string next_word = query.getColumn(order);
score = total;
return next_word;
}
}
@ -247,6 +246,11 @@ class SQLiteMarkov { @@ -247,6 +246,11 @@ class SQLiteMarkov {
return "";
}
std::string get_next_word(std::vector<std::string> words) {
int count;
return get_next_word(words, count);
}
public:
@ -298,16 +302,19 @@ class SQLiteMarkov { @@ -298,16 +302,19 @@ class SQLiteMarkov {
}
std::string get_continuation(std::vector<std::string> prompt) {
std::string get_continuation(std::vector<std::string> prompt, int &score) {
std::string new_words = "";
std::vector<std::string> words(prompt);
int count = 0;
score = 0;
do {
std::string next_word = "invalid";
try {
next_word = get_next_word(words);
int total;
next_word = get_next_word(words, total);
score += total;
} catch (std::exception &e) {
std::cout << "exception: " << e.what() << std::endl;
std::cout << "Failed Get Next Word" << std::endl;
@ -328,7 +335,6 @@ class SQLiteMarkov { @@ -328,7 +335,6 @@ class SQLiteMarkov {
count++;
words.push_back(next_word);
words.erase(words.begin());
} while (count < MAX_REPLY_LENGTH);
return new_words;
@ -399,12 +405,12 @@ class MarkovHandler { @@ -399,12 +405,12 @@ class MarkovHandler {
return false;
}
std::string continue_message(std::string message) {
std::string continue_message(std::string message, int &score) {
std::vector<std::string> words {split_string(message)};
words.erase(words.begin(), words.begin() + words.size() - markov.order);
return markov.get_continuation(words);
}
std::string ret = markov.get_continuation(words, score);
return ret;
}
};
class Bot {
@ -412,6 +418,8 @@ class Bot { @@ -412,6 +418,8 @@ class Bot {
std::string apikey;
telegram::sender sender;
MarkovHandler markov;
const int replies = 200;
int SCORE_THRESHOLD = 100;
public:
telegram::listener::poll listener;
@ -425,13 +433,31 @@ class Bot { @@ -425,13 +433,31 @@ class Bot {
};
void echo() {
listener.set_callback_message([&](telegram::types::message const &message){
std::string message_text = message.text.value();
markov.add_ngrams(message_text);
if (replies && rand() % replies != 1) {
return;
}
std::string reply = markov.continue_message(message_text);
sender.send_message(message.chat.id, reply);
int score = 0;
std::string reply;
try {
reply = markov.continue_message(message_text, score);
} catch (std::exception &e) {
std::cout << "Error getting message" << std::endl;
return;
}
std::cout << "SCORE: " << score << std::endl;
if (score > SCORE_THRESHOLD) {
sender.send_message(message.chat.id, reply);
}
});
}

Loading…
Cancel
Save