|
|
|
#include "SQLiteCpp/Database.h"
|
|
|
|
#include <SQLiteCpp/SQLiteCpp.h>
|
|
|
|
#include <memory>
|
|
|
|
#include "SQLiteCpp/Transaction.h"
|
|
|
|
#include "bot.h"
|
|
|
|
|
|
|
|
class SQLiteMarkov: public MarkovDB {
|
|
|
|
protected:
|
|
|
|
const std::string db_filename = get_filepath("markov.sqlite");
|
|
|
|
const int WORD_SIZE = 400; // chars
|
|
|
|
const int MAX_REPLY_LENGTH = 500; // words
|
|
|
|
int order = 2;
|
|
|
|
|
|
|
|
const std::string TABLE_NAME = "markov";
|
|
|
|
|
|
|
|
const std::string INSERT_STATEMENT = get_insert_statement_string();
|
|
|
|
const std::string UPDATE_STATEMENT = get_update_statement_string(); const std::string QUERY_STATEMENT = get_select_statement_string();
|
|
|
|
|
|
|
|
std::shared_ptr<SQLite::Database> data;
|
|
|
|
|
|
|
|
std::string get_filepath(std::string fp) {
|
|
|
|
std::string filePath(__FILE__);
|
|
|
|
return filePath.substr( 0, filePath.length() - std::string("bot.cpp").length())
|
|
|
|
+ fp;
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<SQLite::Database> open_db() {
|
|
|
|
try {
|
|
|
|
return std::unique_ptr<SQLite::Database>(new
|
|
|
|
SQLite::Database (db_filename, SQLite::OPEN_READWRITE));
|
|
|
|
} catch (std::exception& e)
|
|
|
|
{
|
|
|
|
std::cout << "exception: " << e.what() << std::endl << std::flush;
|
|
|
|
}
|
|
|
|
|
|
|
|
std::cout << "Creating new db: '" << db_filename << "'" << std::endl;
|
|
|
|
|
|
|
|
return std::unique_ptr<SQLite::Database>(new SQLite::Database
|
|
|
|
(db_filename, SQLite::OPEN_CREATE|SQLite::OPEN_READWRITE));
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
int Initialise_Data() {
|
|
|
|
SQLite::Statement query(*data, "SELECT name from sqlite_master WHERE name = '"
|
|
|
|
+ TABLE_NAME + "'");
|
|
|
|
|
|
|
|
try
|
|
|
|
{
|
|
|
|
if (query.executeStep()) {
|
|
|
|
std::cout << "database ready." << std::endl;
|
|
|
|
return 0;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
catch (std::exception& e)
|
|
|
|
{
|
|
|
|
std::cout << "exception: " << e.what() << std::endl;
|
|
|
|
exit(1);
|
|
|
|
}
|
|
|
|
|
|
|
|
// create table sql statement construction
|
|
|
|
std::string ins = "CREATE TABLE " + TABLE_NAME + " (\n";
|
|
|
|
for (int i = 0; i < order + 1; i ++) {
|
|
|
|
ins += "word_" + std::to_string(i) + " VARCHAR(" +
|
|
|
|
std::to_string(WORD_SIZE) + ") NOT NULL,\n";
|
|
|
|
}
|
|
|
|
ins += "count INTEGER NOT NULL, \n PRIMARY KEY (";
|
|
|
|
for (int i = 0; i < order + 1; i ++) {
|
|
|
|
ins += "word_" + std::to_string(i);
|
|
|
|
if (i != order) {
|
|
|
|
ins += ",";
|
|
|
|
}
|
|
|
|
}
|
|
|
|
ins += " )\n);";
|
|
|
|
|
|
|
|
SQLite::Transaction transaction (*data);
|
|
|
|
data->exec(ins);
|
|
|
|
transaction.commit();
|
|
|
|
|
|
|
|
return 0;
|
|
|
|
}
|
|
|
|
|
|
|
|
std::string get_update_statement_string() {
|
|
|
|
std::string update_template = "UPDATE " + TABLE_NAME + " SET count = ? WHERE\n";
|
|
|
|
|
|
|
|
for (int i = 0; i < order + 1; i++) {
|
|
|
|
update_template += "word_" + std::to_string(i);
|
|
|
|
update_template += " = ? ";
|
|
|
|
if (i != order) {
|
|
|
|
update_template += "AND ";
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return update_template;
|
|
|
|
}
|
|
|
|
|
|
|
|
std::string get_select_statement_string() {
|
|
|
|
std::string update_template = "SELECT * FROM " + TABLE_NAME + " WHERE " ;
|
|
|
|
|
|
|
|
for (int i = 0; i < order; i++) {
|
|
|
|
update_template += "word_" + std::to_string(i);
|
|
|
|
update_template += " = ? ";
|
|
|
|
if (i != order - 1) {
|
|
|
|
update_template += "AND ";
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
update_template += ";";
|
|
|
|
|
|
|
|
return update_template;
|
|
|
|
}
|
|
|
|
|
|
|
|
std::string get_insert_statement_string() {
|
|
|
|
std::string update_template = "INSERT INTO " + TABLE_NAME +
|
|
|
|
" VALUES (" ;
|
|
|
|
|
|
|
|
for (int i = 0; i < order + 1; i++) {
|
|
|
|
update_template += "?" ;
|
|
|
|
if (i != order) {
|
|
|
|
update_template += ", ";
|
|
|
|
} else {
|
|
|
|
update_template += ", ?)";
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return update_template;
|
|
|
|
}
|
|
|
|
|
|
|
|
bool update_count(std::vector<std::string> words, int count) {
|
|
|
|
std::string update_template = UPDATE_STATEMENT;
|
|
|
|
|
|
|
|
SQLite::Statement increment(*data, update_template);
|
|
|
|
|
|
|
|
increment.bind(1, count);
|
|
|
|
|
|
|
|
for (int i = 0; i < order + 1; i++) {
|
|
|
|
increment.bind(i + 2, words[i]);
|
|
|
|
}
|
|
|
|
|
|
|
|
return increment.executeStep();
|
|
|
|
}
|
|
|
|
|
|
|
|
bool insert_record(std::vector<std::string> words, int count) {
|
|
|
|
|
|
|
|
std::string update_template = INSERT_STATEMENT;
|
|
|
|
|
|
|
|
SQLite::Statement increment(*data, update_template);
|
|
|
|
for (int i = 0; i < order + 1; i++) {
|
|
|
|
increment.bind(i + 1, words[i]);
|
|
|
|
}
|
|
|
|
|
|
|
|
increment.bind(order + 2, count);
|
|
|
|
|
|
|
|
return increment.executeStep();
|
|
|
|
}
|
|
|
|
|
|
|
|
bool add_to_db(std::vector<std::string> words) {
|
|
|
|
|
|
|
|
// UPDATE COMPANY SET ADDRESS = 'Texas' WHERE ID = 6;
|
|
|
|
|
|
|
|
std::string query_template = "SELECT count FROM " + TABLE_NAME
|
|
|
|
+ " WHERE ";
|
|
|
|
|
|
|
|
for (int i = 0; i <= order; i++) {
|
|
|
|
query_template += "word_" + std::to_string(i);
|
|
|
|
query_template += " = ? ";
|
|
|
|
if (i != order) {
|
|
|
|
query_template += "AND ";
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
SQLite::Statement update(*data, query_template);
|
|
|
|
|
|
|
|
for (int i = 1; i < order + 2; i++) {
|
|
|
|
update.bind(i, words[i-1]);
|
|
|
|
}
|
|
|
|
int count = 0;
|
|
|
|
if (update.executeStep()) {
|
|
|
|
count = update.getColumn(0);
|
|
|
|
count += 1;
|
|
|
|
update_count(words, count);
|
|
|
|
} else {
|
|
|
|
count = 1;
|
|
|
|
insert_record(words, count);
|
|
|
|
}
|
|
|
|
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
std::string get_next_word(std::vector<std::string> words, int &score) {
|
|
|
|
if (words.size() != order) {
|
|
|
|
return "";
|
|
|
|
}
|
|
|
|
|
|
|
|
SQLite::Statement query(*data, QUERY_STATEMENT);
|
|
|
|
|
|
|
|
for (int i = 0; i < order; i++) {
|
|
|
|
query.bind(i+1, words[i]);
|
|
|
|
}
|
|
|
|
|
|
|
|
int total = 0;
|
|
|
|
while (query.executeStep()) {
|
|
|
|
int count = query.getColumn(order + 1);
|
|
|
|
total += count;
|
|
|
|
}
|
|
|
|
|
|
|
|
total += 1;
|
|
|
|
int threshold = rand() % total;
|
|
|
|
|
|
|
|
// count up to threshold
|
|
|
|
// pretty sure the order doesn't matter statistically?
|
|
|
|
// Obviously it will have an impact but I think the distribution should
|
|
|
|
// be the same whether the list is ordered by frequency or not
|
|
|
|
|
|
|
|
query.reset();
|
|
|
|
total = 0;
|
|
|
|
|
|
|
|
std::vector<std::string> message;
|
|
|
|
|
|
|
|
while (query.executeStep()) {
|
|
|
|
int count = query.getColumn(order + 1);
|
|
|
|
total += count;
|
|
|
|
if (total >= threshold) {
|
|
|
|
std::string next_word = query.getColumn(order);
|
|
|
|
|
|
|
|
score = total;
|
|
|
|
return next_word;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
std::cout << "No matches found" << std::endl;
|
|
|
|
return "";
|
|
|
|
}
|
|
|
|
|
|
|
|
std::string get_next_word(std::vector<std::string> words) {
|
|
|
|
int count;
|
|
|
|
return get_next_word(words, count);
|
|
|
|
}
|
|
|
|
|
|
|
|
public:
|
|
|
|
SQLiteMarkov(int ord, std::string dbname, std::string table):
|
|
|
|
order(ord),
|
|
|
|
db_filename(dbname),
|
|
|
|
data(open_db()),
|
|
|
|
TABLE_NAME(table)
|
|
|
|
{
|
|
|
|
Initialise_Data();
|
|
|
|
}
|
|
|
|
|
|
|
|
SQLiteMarkov(int ord): order(ord), data(open_db())
|
|
|
|
{
|
|
|
|
Initialise_Data();
|
|
|
|
}
|
|
|
|
|
|
|
|
SQLiteMarkov(): data(open_db())
|
|
|
|
{
|
|
|
|
Initialise_Data();}
|
|
|
|
|
|
|
|
SQLiteMarkov(int order, std::string table_name, SQLite::Database *db):
|
|
|
|
data(db), order(order), TABLE_NAME(table_name)
|
|
|
|
{
|
|
|
|
Initialise_Data();
|
|
|
|
}
|
|
|
|
|
|
|
|
virtual ~SQLiteMarkov() {};
|
|
|
|
|
|
|
|
bool add_ngrams(std::vector<std::string> words) override {
|
|
|
|
|
|
|
|
// assertions?
|
|
|
|
if (words.size() > order + 1) {
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
|
|
|
for (int i = 0; i < order + 1; i++) {
|
|
|
|
if (words[i].length() > SQLiteMarkov::WORD_SIZE) {
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return add_to_db(words);
|
|
|
|
}
|
|
|
|
|
|
|
|
bool batch_add_ngrams(std::vector<std::vector<std::string>> batch)
|
|
|
|
override {
|
|
|
|
|
|
|
|
SQLite::Transaction transaction(*data);
|
|
|
|
|
|
|
|
for (auto ngram: batch) {
|
|
|
|
add_ngrams(ngram);
|
|
|
|
}
|
|
|
|
|
|
|
|
transaction.commit();
|
|
|
|
|
|
|
|
return 0;
|
|
|
|
}
|
|
|
|
|
|
|
|
std::string get_continuation(std::vector<std::string> prompt, int &score)
|
|
|
|
override {
|
|
|
|
std::stringstream ss;
|
|
|
|
for (int i=0; i < prompt.size(); i++) {
|
|
|
|
ss << " " << prompt[i];
|
|
|
|
}
|
|
|
|
|
|
|
|
std::string new_words = ss.str();
|
|
|
|
std::vector<std::string> words(prompt);
|
|
|
|
|
|
|
|
int count = 0;
|
|
|
|
score = 0;
|
|
|
|
|
|
|
|
do {
|
|
|
|
std::string next_word = "throwaway";
|
|
|
|
|
|
|
|
try {
|
|
|
|
int total;
|
|
|
|
next_word = get_next_word(words, total);
|
|
|
|
score += total;
|
|
|
|
} catch (char * &e) {
|
|
|
|
std::cout << "Exception: " << e << std::endl;
|
|
|
|
std::cout << "Failed Get Next Word" << std::endl;
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
|
|
|
|
new_words += " " + next_word;
|
|
|
|
std::regex end_punc("[\\.\\?\\!]");
|
|
|
|
|
|
|
|
if (next_word.size() == 0
|
|
|
|
|| std::regex_search(std::to_string(next_word.back()),
|
|
|
|
end_punc)) {
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
|
|
|
|
count++;
|
|
|
|
words.push_back(next_word);
|
|
|
|
words.erase(words.begin());
|
|
|
|
|
|
|
|
} while (count < MAX_REPLY_LENGTH);
|
|
|
|
|
|
|
|
if (count > 0)
|
|
|
|
return new_words;
|
|
|
|
else
|
|
|
|
return "";
|
|
|
|
}
|
|
|
|
};
|