Browse Source

transaction import

master
alistair 4 years ago
parent
commit
27e9baadfb
  1. 62
      bot.cpp
  2. 7
      bot.h
  3. 51
      sqlite_markov.h

62
bot.cpp

@ -27,11 +27,10 @@ bool MarkovChain::process_string(std::string& message) { @@ -27,11 +27,10 @@ bool MarkovChain::process_string(std::string& message) {
}
}
message += " MESSAGE_END";
return false;
}
std::vector<std::string> MarkovChain::split_string(std::string message) {
std::istringstream ss(message);
@ -43,11 +42,12 @@ std::vector<std::string> MarkovChain::split_string(std::string message) { @@ -43,11 +42,12 @@ std::vector<std::string> MarkovChain::split_string(std::string message) {
return words;
}
bool MarkovChain::add_ngrams(std::string message) {
std::vector<std::vector<std::string>> MarkovChain::get_words(std::string message) {
this->process_string(message);
std::vector<std::string> words (split_string(message));
std::vector<std::vector<std::string>> import;
for (int i = 0; i < words.size(); i++) {
if (i >= order) {
@ -58,11 +58,11 @@ bool MarkovChain::add_ngrams(std::string message) { @@ -58,11 +58,11 @@ bool MarkovChain::add_ngrams(std::string message) {
add_words.push_back(words[j]);
}
markov.add_ngrams(add_words);
import.push_back(add_words);
}
}
return false;
return import;
}
@ -88,6 +88,7 @@ std::string MarkovChain::continue_message(std::string message, int &score) { @@ -88,6 +88,7 @@ std::string MarkovChain::continue_message(std::string message, int &score) {
std::cout << "Using seed: ";
for (auto x: newwords)
std::cout << " " << x;
std::cout << "\n";
std::string ret = markov.get_continuation(newwords, score);
@ -95,6 +96,7 @@ std::string MarkovChain::continue_message(std::string message, int &score) { @@ -95,6 +96,7 @@ std::string MarkovChain::continue_message(std::string message, int &score) {
return ret;
}
std::string TelegramBotM::get_token() {
using namespace std;
fstream ftoken;
@ -117,8 +119,9 @@ void TelegramBotM::add_echo() { @@ -117,8 +119,9 @@ void TelegramBotM::add_echo() {
try {
markov.add_ngrams(message_text);
} catch (std::exception &e) {
std::cout << "Error adding ngram: " << e.what() << std::endl;
} catch (char const * e) {
std::cout << "Error adding ngram: " << e << std::endl;
return;
}
//if (replies && rand() % replies != 1) {
@ -129,7 +132,6 @@ void TelegramBotM::add_echo() { @@ -129,7 +132,6 @@ void TelegramBotM::add_echo() {
std::string reply;
reply = markov.continue_message(message_text, score);
std::cout << "Reply generated: " << reply << std::endl;
if (reply == "") {
return;
@ -148,24 +150,56 @@ void TelegramBotM::add_echo() { @@ -148,24 +150,56 @@ void TelegramBotM::add_echo() {
});
}
int import_from_file(std::string filename, MarkovChain &m) {
bool MarkovChain::add_ngrams(std::string message) {
process_string(message);
std::vector<std::vector<std::string>> add_words(get_words(message));
if (markov.batch_add_ngrams(add_words))
throw "Failed to add ngrams";
return 0;
}
int MarkovChain::import_from_file(std::string filename) {
std::ifstream infile(filename);
int i = 0;
std::string line;
std::vector<std::vector<std::string>> imports(0);
while(std::getline(infile, line)) {
m.add_ngrams(line);
std::cout << "\radded " << ++i << " lines.";
this->process_string(line);
std::vector<std::vector<std::string>> add_words(get_words(line));
imports.insert(imports.end(),
std::make_move_iterator(add_words.begin()),
std::make_move_iterator(add_words.end()));
i++;
if (i % 21 == 0) {
std::cout << "\rread " << i << " lines." << std::flush;
}
}
std::cout << "\nImporting...\n" << std::flush;
if (markov.batch_add_ngrams(imports))
return 1;
std::cout << "imported " << imports.size() << " ngrams.\n" << std::flush;
return 0;
}
struct options {
std::string filename = "markov.sqlite";
std::string tablename = "markov";
int order = 2;
int order = 1;
std::string importname = "";
bool run_bot = true;
};
@ -203,7 +237,6 @@ struct options parse_args(int argc, char ** argv) { @@ -203,7 +237,6 @@ struct options parse_args(int argc, char ** argv) {
}
int main(int argc, char **argv) {
struct options opts = parse_args(argc, argv);
SQLiteMarkov db = SQLiteMarkov(opts.order, opts.filename, opts.tablename);
@ -223,8 +256,7 @@ int main(int argc, char **argv) { @@ -223,8 +256,7 @@ int main(int argc, char **argv) {
while(std::getline(std::cin, line)) {
if (line.size() > 0) {
if (std::tolower(line[0]) == 'y') {
std::cout << "Importing...\n";
if (import_from_file(opts.importname, m)) {
if (m.import_from_file(opts.importname)) {
std::cerr << "Failed import." << std::endl;
}
break;

7
bot.h

@ -15,6 +15,7 @@ class MarkovDB { @@ -15,6 +15,7 @@ class MarkovDB {
virtual bool add_ngrams(std::vector<std::string> words) = 0;
virtual std::string get_continuation(std::vector<std::string> prompt, int &score) = 0;
virtual ~MarkovDB() {}
virtual bool batch_add_ngrams(std::vector<std::vector<std::string>> batch) = 0;
};
class MarkovChain {
@ -36,10 +37,15 @@ class MarkovChain { @@ -36,10 +37,15 @@ class MarkovChain {
MarkovChain (int ord, bool lower, MarkovDB &db)
: markov(db), MAKE_LOWER(lower), order(ord) {;
};
int import_from_file(std::string filename);
std::vector<std::vector<std::string>> get_words(std::string message);
virtual ~MarkovChain() {};
};
class Bot {
protected:
MarkovChain markov;
@ -52,6 +58,7 @@ class Bot { @@ -52,6 +58,7 @@ class Bot {
~Bot() {};
};
class TelegramBotM: public Bot {
protected:
std::string api_key = "";

51
sqlite_markov.h

@ -1,6 +1,7 @@ @@ -1,6 +1,7 @@
#include "SQLiteCpp/Database.h"
#include <SQLiteCpp/SQLiteCpp.h>
#include <memory>
#include "SQLiteCpp/Transaction.h"
#include "bot.h"
class SQLiteMarkov: public MarkovDB {
@ -13,8 +14,7 @@ class SQLiteMarkov: public MarkovDB { @@ -13,8 +14,7 @@ class SQLiteMarkov: public MarkovDB {
const std::string TABLE_NAME = "markov";
const std::string INSERT_STATEMENT = get_insert_statement_string();
const std::string UPDATE_STATEMENT = get_update_statement_string();
const std::string QUERY_STATEMENT = get_select_statement_string();
const std::string UPDATE_STATEMENT = get_update_statement_string(); const std::string QUERY_STATEMENT = get_select_statement_string();
std::shared_ptr<SQLite::Database> data;
@ -24,6 +24,7 @@ class SQLiteMarkov: public MarkovDB { @@ -24,6 +24,7 @@ class SQLiteMarkov: public MarkovDB {
+ fp;
}
std::unique_ptr<SQLite::Database> open_db() {
try {
return std::unique_ptr<SQLite::Database>(new
@ -72,8 +73,7 @@ class SQLiteMarkov: public MarkovDB { @@ -72,8 +73,7 @@ class SQLiteMarkov: public MarkovDB {
}
ins += " )\n);";
// run create table instruction
SQLite::Transaction transaction(*data);
SQLite::Transaction transaction (*data);
data->exec(ins);
transaction.commit();
@ -228,7 +228,6 @@ class SQLiteMarkov: public MarkovDB { @@ -228,7 +228,6 @@ class SQLiteMarkov: public MarkovDB {
}
std::cout << "No matches found" << std::endl;
throw "No next word found";
return "";
}
@ -238,7 +237,6 @@ class SQLiteMarkov: public MarkovDB { @@ -238,7 +237,6 @@ class SQLiteMarkov: public MarkovDB {
}
public:
SQLiteMarkov(int ord, std::string dbname, std::string table):
order(ord),
db_filename(dbname),
@ -263,7 +261,10 @@ class SQLiteMarkov: public MarkovDB { @@ -263,7 +261,10 @@ class SQLiteMarkov: public MarkovDB {
Initialise_Data();
}
virtual ~SQLiteMarkov() {};
bool add_ngrams(std::vector<std::string> words) override {
// assertions?
if (words.size() > order + 1) {
return true;
@ -278,15 +279,36 @@ class SQLiteMarkov: public MarkovDB { @@ -278,15 +279,36 @@ class SQLiteMarkov: public MarkovDB {
return add_to_db(words);
}
bool batch_add_ngrams(std::vector<std::vector<std::string>> batch)
override {
SQLite::Transaction transaction(*data);
for (auto ngram: batch) {
add_ngrams(ngram);
}
transaction.commit();
return 0;
}
std::string get_continuation(std::vector<std::string> prompt, int &score)
override {
std::string new_words = "";
std::stringstream ss;
for (int i=0; i < prompt.size(); i++) {
ss << " " << prompt[i];
}
std::string new_words = ss.str();
std::vector<std::string> words(prompt);
int count = 0;
score = 0;
do {
std::string next_word = "invalid";
std::string next_word = "throwaway";
try {
int total;
next_word = get_next_word(words, total);
@ -297,14 +319,12 @@ class SQLiteMarkov: public MarkovDB { @@ -297,14 +319,12 @@ class SQLiteMarkov: public MarkovDB {
break;
}
if (next_word == "MESSAGE_END") {
break;
}
new_words += " " + next_word;
std::regex end_punc("[\\.\\?\\!]");
if (std::regex_search(std::to_string(next_word.back()), end_punc)) {
if (next_word.size() == 0
|| std::regex_search(std::to_string(next_word.back()),
end_punc)) {
break;
}
@ -314,6 +334,9 @@ class SQLiteMarkov: public MarkovDB { @@ -314,6 +334,9 @@ class SQLiteMarkov: public MarkovDB {
} while (count < MAX_REPLY_LENGTH);
return new_words;
if (count > 0)
return new_words;
else
return "";
}
};

Loading…
Cancel
Save