Browse Source

fixed

master
alistair 4 years ago
parent
commit
25fd993d0a
  1. 168
      bot.cpp
  2. 128
      sqlite_markov.h

168
bot.cpp

@ -5,29 +5,19 @@ @@ -5,29 +5,19 @@
#include <algorithm>
#include <iterator>
#include <regex>
#include <iostream>
#include <string>
#include <unistd.h>
#include "bot.h"
#include "sqlite_markov.h"
std::string TelegramBotM::get_token() {
using namespace std;
fstream ftoken;
ftoken.open("apikey.txt", ios::in);
if (!ftoken) {
perror("No apikey.txt file provided");
exit(1);
}
string apikey;
ftoken >> apikey;
return apikey;
}
#include <cstdlib>
bool MarkovChain::test() {
add_ngrams("so long and thanks\nfor\tall the fishes bro man");
return false;
}
bool MarkovChain::process_string(std::string& message) {
if (this->MAKE_LOWER) {
for (int i = 0; i < message.length(); i++) {
@ -58,7 +48,6 @@ bool MarkovChain::add_ngrams(std::string message) { @@ -58,7 +48,6 @@ bool MarkovChain::add_ngrams(std::string message) {
std::vector<std::string> words (split_string(message));
std::cout << message << std::endl;
for (int i = 0; i < words.size(); i++) {
if (i >= order) {
@ -67,8 +56,8 @@ bool MarkovChain::add_ngrams(std::string message) { @@ -67,8 +56,8 @@ bool MarkovChain::add_ngrams(std::string message) {
int count = 0;
for (int j = i - order; j <= i; j++, count++) {
add_words.push_back(words[j]);
std::cout << j << ": " << add_words[count] << std::endl;
}
markov.add_ngrams(add_words);
}
}
@ -76,13 +65,49 @@ bool MarkovChain::add_ngrams(std::string message) { @@ -76,13 +65,49 @@ bool MarkovChain::add_ngrams(std::string message) {
return false;
}
std::string MarkovChain::continue_message(std::string message, int &score) {
std::vector<std::string> words {split_string(message)};
words.erase(words.begin(), words.begin() + words.size() - order);
std::string ret = markov.get_continuation(words, score);
std::cout << 1 << std::endl;
std::vector<std::string> words = split_string(message);
if (words.size() < order) {
std::cout << "Message too short.\n";
return "";
}
for (int i = 0; i < words.size(); i++) {
std::cout << words[i] << std::endl;
}
std::vector<std::string>::const_iterator first = words.begin()
+ (words.size() - order);
std::vector<std::string>::const_iterator last = words.end();
std::vector<std::string> newwords(first, last);
std::cout << "Using seed: ";
for (auto x: newwords)
std::cout << " " << x;
std::cout << "\n";
std::string ret = markov.get_continuation(newwords, score);
return ret;
}
std::string TelegramBotM::get_token() {
using namespace std;
fstream ftoken;
ftoken.open("apikey.txt", ios::in);
if (!ftoken) {
perror("No apikey.txt file provided");
exit(1);
}
string apikey;
ftoken >> apikey;
return apikey;
}
void TelegramBotM::add_echo() {
@ -102,11 +127,12 @@ void TelegramBotM::add_echo() { @@ -102,11 +127,12 @@ void TelegramBotM::add_echo() {
int score = 0;
std::string reply;
try {
// reply = markov.continue_message(message_text, score);
reply = message_text;
} catch (std::exception &e) {
std::cout << "Error getting message" << std::endl;
reply = markov.continue_message(message_text, score);
std::cout << "Reply generated: " << reply << std::endl;
if (reply == "") {
return;
}
std::cerr << "SCORE: " << score << std::endl;
@ -122,14 +148,92 @@ void TelegramBotM::add_echo() { @@ -122,14 +148,92 @@ void TelegramBotM::add_echo() {
});
}
int import_from_file(std::string filename, MarkovChain &m) {
std::ifstream infile(filename);
int i = 0;
std::string line;
while(std::getline(infile, line)) {
m.add_ngrams(line);
std::cout << "\radded " << ++i << " lines.";
}
return 0;
}
struct options {
std::string filename = "markov.sqlite";
std::string tablename = "markov";
int order = 2;
std::string importname = "";
bool run_bot = true;
};
struct options parse_args(int argc, char ** argv) {
struct options options;
int c;
while ((c = getopt(argc, argv, "d:n:o:i:rs")) != EOF) {
switch (c) {
case 'd':
options.filename = optarg;
break;
case 'n':
options.tablename = optarg;
break;
case 'o':
options.order = std::stoi(optarg);
break;
case 'i':
options.importname = optarg;
options.run_bot = false;
break;
case 'r':
options.run_bot = true;
break;
case 's':
options.run_bot = false;
break;
}
}
int main() {
int order = 1;
SQLiteMarkov db = SQLiteMarkov(order, "markov.sqlite", "markov");
MarkovChain m = MarkovChain(order, true, db);
TelegramBotM b = TelegramBotM(order, m, "markov");
std::cout << "Running...\n";
b.run();
return options;
}
int main(int argc, char **argv) {
struct options opts = parse_args(argc, argv);
SQLiteMarkov db = SQLiteMarkov(opts.order, opts.filename, opts.tablename);
MarkovChain m = MarkovChain(opts.order, true, db);
if (opts.run_bot) {
TelegramBotM b = TelegramBotM(opts.order, m, opts.tablename);
std::cout << "Running...\n";
b.run();
} else if (opts.importname != "") {
std::cout << "Are you sure you want to import from " << opts.importname
<< " to the database " << opts.filename << ", table "
<< opts.tablename << "?" << std::endl << " [Y/n] >";
std::string line;
while(std::getline(std::cin, line)) {
if (line.size() > 0) {
if (std::tolower(line[0]) == 'y') {
std::cout << "Importing...\n";
if (import_from_file(opts.importname, m)) {
std::cerr << "Failed import." << std::endl;
}
break;
} else if (std::tolower(line[0]) == 'n') {
break;
}
}
}
}
return 0;
}

128
sqlite_markov.h

@ -1,18 +1,22 @@ @@ -1,18 +1,22 @@
#include "bot.h"
#include "SQLiteCpp/Database.h"
#include <SQLiteCpp/SQLiteCpp.h>
#include <memory>
#include "bot.h"
class SQLiteMarkov: public MarkovDB {
protected:
const std::string db_filename;
const std::string db_filename = get_filepath("markov.sqlite");
const int WORD_SIZE = 400; // chars
const int MAX_REPLY_LENGTH = 500; // words
const std::string INSERT_STATEMENT;
const std::string UPDATE_STATEMENT;
const std::string QUERY_STATEMENT;
const std::string TABLE_NAME;
int order = 2;
const std::string TABLE_NAME = "markov";
int order;
SQLite::Database data;
const std::string INSERT_STATEMENT = get_insert_statement_string();
const std::string UPDATE_STATEMENT = get_update_statement_string();
const std::string QUERY_STATEMENT = get_select_statement_string();
std::shared_ptr<SQLite::Database> data;
std::string get_filepath(std::string fp) {
std::string filePath(__FILE__);
@ -20,9 +24,10 @@ class SQLiteMarkov: public MarkovDB { @@ -20,9 +24,10 @@ class SQLiteMarkov: public MarkovDB {
+ fp;
}
SQLite::Database open_db() {
std::unique_ptr<SQLite::Database> open_db() {
try {
return SQLite::Database (db_filename, SQLite::OPEN_READWRITE);
return std::unique_ptr<SQLite::Database>(new
SQLite::Database (db_filename, SQLite::OPEN_READWRITE));
} catch (std::exception& e)
{
std::cout << "exception: " << e.what() << std::endl << std::flush;
@ -30,27 +35,30 @@ class SQLiteMarkov: public MarkovDB { @@ -30,27 +35,30 @@ class SQLiteMarkov: public MarkovDB {
std::cout << "Creating new db: '" << db_filename << "'" << std::endl;
return SQLite::Database (db_filename, SQLite::OPEN_CREATE|SQLite::OPEN_READWRITE);
return std::unique_ptr<SQLite::Database>(new SQLite::Database
(db_filename, SQLite::OPEN_CREATE|SQLite::OPEN_READWRITE));
}
SQLite::Database Initialise_Data() {
SQLite::Database db = open_db();
SQLite::Statement query(db, "SELECT name from sqlite_master WHERE name = 'markov'");
int Initialise_Data() {
SQLite::Statement query(*data, "SELECT name from sqlite_master WHERE name = '"
+ TABLE_NAME + "'");
try
{
if (query.executeStep()) {
std::cout << "database ready." << std::endl;
return db;
return 0;
}
}
catch (std::exception& e)
{
std::cout << "exception: " << e.what() << std::endl;
exit(1);
}
// create table sql statement construction
std::string ins = "CREATE TABLE markov (\n";
std::string ins = "CREATE TABLE " + TABLE_NAME + " (\n";
for (int i = 0; i < order + 1; i ++) {
ins += "word_" + std::to_string(i) + " VARCHAR(" +
std::to_string(WORD_SIZE) + ") NOT NULL,\n";
@ -65,15 +73,15 @@ class SQLiteMarkov: public MarkovDB { @@ -65,15 +73,15 @@ class SQLiteMarkov: public MarkovDB {
ins += " )\n);";
// run create table instruction
SQLite::Transaction transaction(db);
db.exec(ins);
SQLite::Transaction transaction(*data);
data->exec(ins);
transaction.commit();
return db;
return 0;
}
std::string get_update_statement_string() {
std::string update_template = "UPDATE markov SET count = ? WHERE\n";
std::string update_template = "UPDATE " + TABLE_NAME + " SET count = ? WHERE\n";
for (int i = 0; i < order + 1; i++) {
update_template += "word_" + std::to_string(i);
@ -87,7 +95,7 @@ class SQLiteMarkov: public MarkovDB { @@ -87,7 +95,7 @@ class SQLiteMarkov: public MarkovDB {
}
std::string get_select_statement_string() {
std::string update_template = "SELECT * FROM markov WHERE " ;
std::string update_template = "SELECT * FROM " + TABLE_NAME + " WHERE " ;
for (int i = 0; i < order; i++) {
update_template += "word_" + std::to_string(i);
@ -103,7 +111,8 @@ class SQLiteMarkov: public MarkovDB { @@ -103,7 +111,8 @@ class SQLiteMarkov: public MarkovDB {
}
std::string get_insert_statement_string() {
std::string update_template = "INSERT INTO markov VALUES (" ;
std::string update_template = "INSERT INTO " + TABLE_NAME +
" VALUES (" ;
for (int i = 0; i < order + 1; i++) {
update_template += "?" ;
@ -119,8 +128,7 @@ class SQLiteMarkov: public MarkovDB { @@ -119,8 +128,7 @@ class SQLiteMarkov: public MarkovDB {
bool update_count(std::vector<std::string> words, int count) {
std::string update_template = UPDATE_STATEMENT;
SQLite::Statement increment(data, update_template);
SQLite::Statement increment(*data, update_template);
increment.bind(1, count);
@ -135,7 +143,7 @@ class SQLiteMarkov: public MarkovDB { @@ -135,7 +143,7 @@ class SQLiteMarkov: public MarkovDB {
std::string update_template = INSERT_STATEMENT;
SQLite::Statement increment(data, update_template);
SQLite::Statement increment(*data, update_template);
for (int i = 0; i < order + 1; i++) {
increment.bind(i + 1, words[i]);
}
@ -149,21 +157,22 @@ class SQLiteMarkov: public MarkovDB { @@ -149,21 +157,22 @@ class SQLiteMarkov: public MarkovDB {
// UPDATE COMPANY SET ADDRESS = 'Texas' WHERE ID = 6;
std::string query_template = "SELECT count FROM markov WHERE ";
for (int i = 0; i < order + 1; i++) {
std::string query_template = "SELECT count FROM " + TABLE_NAME
+ " WHERE ";
for (int i = 0; i <= order; i++) {
query_template += "word_" + std::to_string(i);
query_template += " = ? ";
if (i != order) {
query_template += "AND ";
}
}
SQLite::Statement update(data, query_template);
SQLite::Statement update(*data, query_template);
for (int i = 1; i < order + 2; i++) {
update.bind(i, words[i-1]);
}
int count = 0;
if (update.executeStep()) {
count = update.getColumn(0);
@ -179,11 +188,10 @@ class SQLiteMarkov: public MarkovDB { @@ -179,11 +188,10 @@ class SQLiteMarkov: public MarkovDB {
std::string get_next_word(std::vector<std::string> words, int &score) {
if (words.size() != order) {
throw "Invalid prompt vector size";
return "";
}
SQLite::Statement query(data, QUERY_STATEMENT);
SQLite::Statement query(*data, QUERY_STATEMENT);
for (int i = 0; i < order; i++) {
query.bind(i+1, words[i]);
@ -231,33 +239,29 @@ class SQLiteMarkov: public MarkovDB { @@ -231,33 +239,29 @@ class SQLiteMarkov: public MarkovDB {
public:
SQLiteMarkov(int ord, std::string dbname, std::string table): order(ord),
db_filename(get_filepath(dbname)),
data(Initialise_Data()),
TABLE_NAME(table),
INSERT_STATEMENT(get_insert_statement_string()),
QUERY_STATEMENT(get_select_statement_string()),
UPDATE_STATEMENT(get_update_statement_string())
{;};
SQLiteMarkov(int ord): order(ord),
db_filename(get_filepath("markov.sqlite")),
data(Initialise_Data()),
TABLE_NAME("markov"),
INSERT_STATEMENT(get_insert_statement_string()),
QUERY_STATEMENT(get_select_statement_string()),
UPDATE_STATEMENT(get_update_statement_string())
{;};
SQLiteMarkov(): order(1),
db_filename(get_filepath("markov.sqlite")),
data(Initialise_Data()),
TABLE_NAME("markov"),
INSERT_STATEMENT(get_insert_statement_string()),
QUERY_STATEMENT(get_select_statement_string()),
UPDATE_STATEMENT(get_update_statement_string())
{;};
SQLiteMarkov(int ord, std::string dbname, std::string table):
order(ord),
db_filename(dbname),
data(open_db()),
TABLE_NAME(table)
{
Initialise_Data();
}
SQLiteMarkov(int ord): order(ord), data(open_db())
{
Initialise_Data();
}
SQLiteMarkov(): data(open_db())
{
Initialise_Data();}
SQLiteMarkov(int order, std::string table_name, SQLite::Database *db):
data(db), order(order), TABLE_NAME(table_name)
{
Initialise_Data();
}
bool add_ngrams(std::vector<std::string> words) override {
// assertions?
@ -272,7 +276,6 @@ class SQLiteMarkov: public MarkovDB { @@ -272,7 +276,6 @@ class SQLiteMarkov: public MarkovDB {
}
return add_to_db(words);
}
std::string get_continuation(std::vector<std::string> prompt, int &score)
@ -284,13 +287,12 @@ class SQLiteMarkov: public MarkovDB { @@ -284,13 +287,12 @@ class SQLiteMarkov: public MarkovDB {
do {
std::string next_word = "invalid";
try {
int total;
next_word = get_next_word(words, total);
score += total;
} catch (std::exception &e) {
std::cout << "exception: " << e.what() << std::endl;
} catch (char * &e) {
std::cout << "Exception: " << e << std::endl;
std::cout << "Failed Get Next Word" << std::endl;
break;
}
@ -309,9 +311,9 @@ class SQLiteMarkov: public MarkovDB { @@ -309,9 +311,9 @@ class SQLiteMarkov: public MarkovDB {
count++;
words.push_back(next_word);
words.erase(words.begin());
} while (count < MAX_REPLY_LENGTH);
return new_words;
}
};

Loading…
Cancel
Save