Browse Source

transaction import

master
alistair 4 years ago
parent
commit
27e9baadfb
  1. 62
      bot.cpp
  2. 7
      bot.h
  3. 51
      sqlite_markov.h

62
bot.cpp

@ -27,11 +27,10 @@ bool MarkovChain::process_string(std::string& message) {
} }
} }
message += " MESSAGE_END";
return false; return false;
} }
std::vector<std::string> MarkovChain::split_string(std::string message) { std::vector<std::string> MarkovChain::split_string(std::string message) {
std::istringstream ss(message); std::istringstream ss(message);
@ -43,11 +42,12 @@ std::vector<std::string> MarkovChain::split_string(std::string message) {
return words; return words;
} }
bool MarkovChain::add_ngrams(std::string message) { std::vector<std::vector<std::string>> MarkovChain::get_words(std::string message) {
this->process_string(message); this->process_string(message);
std::vector<std::string> words (split_string(message)); std::vector<std::string> words (split_string(message));
std::vector<std::vector<std::string>> import;
for (int i = 0; i < words.size(); i++) { for (int i = 0; i < words.size(); i++) {
if (i >= order) { if (i >= order) {
@ -58,11 +58,11 @@ bool MarkovChain::add_ngrams(std::string message) {
add_words.push_back(words[j]); add_words.push_back(words[j]);
} }
markov.add_ngrams(add_words); import.push_back(add_words);
} }
} }
return false; return import;
} }
@ -88,6 +88,7 @@ std::string MarkovChain::continue_message(std::string message, int &score) {
std::cout << "Using seed: "; std::cout << "Using seed: ";
for (auto x: newwords) for (auto x: newwords)
std::cout << " " << x; std::cout << " " << x;
std::cout << "\n"; std::cout << "\n";
std::string ret = markov.get_continuation(newwords, score); std::string ret = markov.get_continuation(newwords, score);
@ -95,6 +96,7 @@ std::string MarkovChain::continue_message(std::string message, int &score) {
return ret; return ret;
} }
std::string TelegramBotM::get_token() { std::string TelegramBotM::get_token() {
using namespace std; using namespace std;
fstream ftoken; fstream ftoken;
@ -117,8 +119,9 @@ void TelegramBotM::add_echo() {
try { try {
markov.add_ngrams(message_text); markov.add_ngrams(message_text);
} catch (std::exception &e) { } catch (char const * e) {
std::cout << "Error adding ngram: " << e.what() << std::endl; std::cout << "Error adding ngram: " << e << std::endl;
return;
} }
//if (replies && rand() % replies != 1) { //if (replies && rand() % replies != 1) {
@ -129,7 +132,6 @@ void TelegramBotM::add_echo() {
std::string reply; std::string reply;
reply = markov.continue_message(message_text, score); reply = markov.continue_message(message_text, score);
std::cout << "Reply generated: " << reply << std::endl;
if (reply == "") { if (reply == "") {
return; return;
@ -148,24 +150,56 @@ void TelegramBotM::add_echo() {
}); });
} }
int import_from_file(std::string filename, MarkovChain &m) {
bool MarkovChain::add_ngrams(std::string message) {
process_string(message);
std::vector<std::vector<std::string>> add_words(get_words(message));
if (markov.batch_add_ngrams(add_words))
throw "Failed to add ngrams";
return 0;
}
int MarkovChain::import_from_file(std::string filename) {
std::ifstream infile(filename); std::ifstream infile(filename);
int i = 0; int i = 0;
std::string line; std::string line;
std::vector<std::vector<std::string>> imports(0);
while(std::getline(infile, line)) { while(std::getline(infile, line)) {
m.add_ngrams(line);
std::cout << "\radded " << ++i << " lines."; this->process_string(line);
std::vector<std::vector<std::string>> add_words(get_words(line));
imports.insert(imports.end(),
std::make_move_iterator(add_words.begin()),
std::make_move_iterator(add_words.end()));
i++;
if (i % 21 == 0) {
std::cout << "\rread " << i << " lines." << std::flush;
}
} }
std::cout << "\nImporting...\n" << std::flush;
if (markov.batch_add_ngrams(imports))
return 1;
std::cout << "imported " << imports.size() << " ngrams.\n" << std::flush;
return 0; return 0;
} }
struct options { struct options {
std::string filename = "markov.sqlite"; std::string filename = "markov.sqlite";
std::string tablename = "markov"; std::string tablename = "markov";
int order = 2; int order = 1;
std::string importname = ""; std::string importname = "";
bool run_bot = true; bool run_bot = true;
}; };
@ -203,7 +237,6 @@ struct options parse_args(int argc, char ** argv) {
} }
int main(int argc, char **argv) { int main(int argc, char **argv) {
struct options opts = parse_args(argc, argv); struct options opts = parse_args(argc, argv);
SQLiteMarkov db = SQLiteMarkov(opts.order, opts.filename, opts.tablename); SQLiteMarkov db = SQLiteMarkov(opts.order, opts.filename, opts.tablename);
@ -223,8 +256,7 @@ int main(int argc, char **argv) {
while(std::getline(std::cin, line)) { while(std::getline(std::cin, line)) {
if (line.size() > 0) { if (line.size() > 0) {
if (std::tolower(line[0]) == 'y') { if (std::tolower(line[0]) == 'y') {
std::cout << "Importing...\n"; if (m.import_from_file(opts.importname)) {
if (import_from_file(opts.importname, m)) {
std::cerr << "Failed import." << std::endl; std::cerr << "Failed import." << std::endl;
} }
break; break;

7
bot.h

@ -15,6 +15,7 @@ class MarkovDB {
virtual bool add_ngrams(std::vector<std::string> words) = 0; virtual bool add_ngrams(std::vector<std::string> words) = 0;
virtual std::string get_continuation(std::vector<std::string> prompt, int &score) = 0; virtual std::string get_continuation(std::vector<std::string> prompt, int &score) = 0;
virtual ~MarkovDB() {} virtual ~MarkovDB() {}
virtual bool batch_add_ngrams(std::vector<std::vector<std::string>> batch) = 0;
}; };
class MarkovChain { class MarkovChain {
@ -36,10 +37,15 @@ class MarkovChain {
MarkovChain (int ord, bool lower, MarkovDB &db) MarkovChain (int ord, bool lower, MarkovDB &db)
: markov(db), MAKE_LOWER(lower), order(ord) {; : markov(db), MAKE_LOWER(lower), order(ord) {;
}; };
int import_from_file(std::string filename);
std::vector<std::vector<std::string>> get_words(std::string message);
virtual ~MarkovChain() {}; virtual ~MarkovChain() {};
}; };
class Bot { class Bot {
protected: protected:
MarkovChain markov; MarkovChain markov;
@ -52,6 +58,7 @@ class Bot {
~Bot() {}; ~Bot() {};
}; };
class TelegramBotM: public Bot { class TelegramBotM: public Bot {
protected: protected:
std::string api_key = ""; std::string api_key = "";

51
sqlite_markov.h

@ -1,6 +1,7 @@
#include "SQLiteCpp/Database.h" #include "SQLiteCpp/Database.h"
#include <SQLiteCpp/SQLiteCpp.h> #include <SQLiteCpp/SQLiteCpp.h>
#include <memory> #include <memory>
#include "SQLiteCpp/Transaction.h"
#include "bot.h" #include "bot.h"
class SQLiteMarkov: public MarkovDB { class SQLiteMarkov: public MarkovDB {
@ -13,8 +14,7 @@ class SQLiteMarkov: public MarkovDB {
const std::string TABLE_NAME = "markov"; const std::string TABLE_NAME = "markov";
const std::string INSERT_STATEMENT = get_insert_statement_string(); const std::string INSERT_STATEMENT = get_insert_statement_string();
const std::string UPDATE_STATEMENT = get_update_statement_string(); const std::string UPDATE_STATEMENT = get_update_statement_string(); const std::string QUERY_STATEMENT = get_select_statement_string();
const std::string QUERY_STATEMENT = get_select_statement_string();
std::shared_ptr<SQLite::Database> data; std::shared_ptr<SQLite::Database> data;
@ -24,6 +24,7 @@ class SQLiteMarkov: public MarkovDB {
+ fp; + fp;
} }
std::unique_ptr<SQLite::Database> open_db() { std::unique_ptr<SQLite::Database> open_db() {
try { try {
return std::unique_ptr<SQLite::Database>(new return std::unique_ptr<SQLite::Database>(new
@ -72,8 +73,7 @@ class SQLiteMarkov: public MarkovDB {
} }
ins += " )\n);"; ins += " )\n);";
// run create table instruction SQLite::Transaction transaction (*data);
SQLite::Transaction transaction(*data);
data->exec(ins); data->exec(ins);
transaction.commit(); transaction.commit();
@ -228,7 +228,6 @@ class SQLiteMarkov: public MarkovDB {
} }
std::cout << "No matches found" << std::endl; std::cout << "No matches found" << std::endl;
throw "No next word found";
return ""; return "";
} }
@ -238,7 +237,6 @@ class SQLiteMarkov: public MarkovDB {
} }
public: public:
SQLiteMarkov(int ord, std::string dbname, std::string table): SQLiteMarkov(int ord, std::string dbname, std::string table):
order(ord), order(ord),
db_filename(dbname), db_filename(dbname),
@ -263,7 +261,10 @@ class SQLiteMarkov: public MarkovDB {
Initialise_Data(); Initialise_Data();
} }
virtual ~SQLiteMarkov() {};
bool add_ngrams(std::vector<std::string> words) override { bool add_ngrams(std::vector<std::string> words) override {
// assertions? // assertions?
if (words.size() > order + 1) { if (words.size() > order + 1) {
return true; return true;
@ -278,15 +279,36 @@ class SQLiteMarkov: public MarkovDB {
return add_to_db(words); return add_to_db(words);
} }
bool batch_add_ngrams(std::vector<std::vector<std::string>> batch)
override {
SQLite::Transaction transaction(*data);
for (auto ngram: batch) {
add_ngrams(ngram);
}
transaction.commit();
return 0;
}
std::string get_continuation(std::vector<std::string> prompt, int &score) std::string get_continuation(std::vector<std::string> prompt, int &score)
override { override {
std::string new_words = ""; std::stringstream ss;
for (int i=0; i < prompt.size(); i++) {
ss << " " << prompt[i];
}
std::string new_words = ss.str();
std::vector<std::string> words(prompt); std::vector<std::string> words(prompt);
int count = 0; int count = 0;
score = 0; score = 0;
do { do {
std::string next_word = "invalid"; std::string next_word = "throwaway";
try { try {
int total; int total;
next_word = get_next_word(words, total); next_word = get_next_word(words, total);
@ -297,14 +319,12 @@ class SQLiteMarkov: public MarkovDB {
break; break;
} }
if (next_word == "MESSAGE_END") {
break;
}
new_words += " " + next_word; new_words += " " + next_word;
std::regex end_punc("[\\.\\?\\!]"); std::regex end_punc("[\\.\\?\\!]");
if (std::regex_search(std::to_string(next_word.back()), end_punc)) { if (next_word.size() == 0
|| std::regex_search(std::to_string(next_word.back()),
end_punc)) {
break; break;
} }
@ -314,6 +334,9 @@ class SQLiteMarkov: public MarkovDB {
} while (count < MAX_REPLY_LENGTH); } while (count < MAX_REPLY_LENGTH);
return new_words; if (count > 0)
return new_words;
else
return "";
} }
}; };

Loading…
Cancel
Save