Browse Source

reorgansie

master
alistair 4 years ago
parent
commit
9f4433b0cf
  1. 3
      .gitmodules
  2. 11
      CMakeLists.txt
  3. 1
      SQLiteCpp
  4. 496
      bot.cpp
  5. 78
      bot.h
  6. 317
      sqlite_markov.h

3
.gitmodules vendored

@ -1,3 +1,6 @@ @@ -1,3 +1,6 @@
[submodule "libtelegram"]
path = libtelegram
url = https://github.com/slowriot/libtelegram.git
[submodule "SQLiteCpp"]
path = SQLiteCpp
url = https://github.com/SRombauts/SQLiteCpp.git

11
CMakeLists.txt

@ -1,4 +1,4 @@ @@ -1,4 +1,4 @@
cmake_minimum_required (VERSION 3.12)
cmake_minimum_required(VERSION 3.15)
set(PROJECT_NAME "markov bot")
project (${PROJECT_NAME})
@ -7,13 +7,14 @@ set_property (GLOBAL PROPERTY USE_FOLDERS ON) @@ -7,13 +7,14 @@ set_property (GLOBAL PROPERTY USE_FOLDERS ON)
set (CMAKE_CXX_STANDARD 17)
set (CMAKE_CXX_STANDARD_REQUIRED ON)
set (CMAKE_BUILD_TYPE Debug)
set (CMAKE_CXX_FLAGS_DEBUG_INIT "-Wall -g -Werror=format-security -Werror=implicit-function-declaration -fsanitize=address")
set (THREADS_PREFER_PTHREAD_FLAG ON)
find_package (Threads REQUIRED)
find_package(Boost 1.66 REQUIRED COMPONENTS system)
IF(Boost_FOUND)
include_directories(${Boost_INCLUDE_DIRS})
ELSE()
@ -22,6 +23,12 @@ ENDIF(Boost_FOUND) @@ -22,6 +23,12 @@ ENDIF(Boost_FOUND)
find_package(OpenSSL REQUIRED)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/SQLiteCpp)
include_directories(
${CMAKE_CURRENT_LIST_DIR}/SQLiteCpp/include
)
set(include_dir "${CMAKE_SOURCE_DIR}/libtelegram/include")
message(STATUS "include_dir: ${include_dir}")
include_directories (${include_dir})

1
SQLiteCpp

@ -0,0 +1 @@ @@ -0,0 +1 @@
Subproject commit 01cf6f2d2adf2a7b71a420f2c91446364a9ada1e

496
bot.cpp

@ -1,4 +1,3 @@ @@ -1,4 +1,3 @@
#include <SQLiteCpp/SQLiteCpp.h>
#include <libtelegram/libtelegram.h>
#include <boost/algorithm/string.hpp>
#include <vector>
@ -6,8 +5,10 @@ @@ -6,8 +5,10 @@
#include <algorithm>
#include <iterator>
#include <regex>
#include "bot.h"
#include "sqlite_markov.h"
std::string get_token() {
std::string TelegramBotM::get_token() {
using namespace std;
fstream ftoken;
ftoken.open("apikey.txt", ios::in);
@ -21,453 +22,114 @@ std::string get_token() { @@ -21,453 +22,114 @@ std::string get_token() {
return apikey;
}
class SQLiteMarkov {
public:
int order;
private:
std::string db_filename;
SQLite::Database data;
const int WORD_SIZE = 400;
const int MAX_REPLY_LENGTH = 500;
const std::string INSERT_STATEMENT;
const std::string UPDATE_STATEMENT;
const std::string QUERY_STATEMENT;
std::string get_filepath(std::string fp) {
std::string filePath(__FILE__);
return filePath.substr( 0, filePath.length() - std::string("bot.cpp").length())
+ fp;
}
SQLite::Database open_db() {
try {
return SQLite::Database (db_filename, SQLite::OPEN_READWRITE);
} catch (std::exception& e)
{
std::cout << "exception: " << e.what() << std::endl;
}
std::cout << "Creating new db: '" << db_filename << "'" << std::endl;
return SQLite::Database (db_filename, SQLite::OPEN_CREATE|SQLite::OPEN_READWRITE);
}
SQLite::Database Initialise_Data() {
SQLite::Database db = open_db();
SQLite::Statement query(db, "SELECT name from sqlite_master WHERE name = 'markov'");
try
{
if (query.executeStep()) {
std::cout << "database ready." << std::endl;
return db;
}
}
catch (std::exception& e)
{
std::cout << "exception: " << e.what() << std::endl;
}
// create table sql statement construction
std::string ins = "CREATE TABLE markov (\n";
for (int i = 0; i < order + 1; i ++) {
ins += "word_" + std::to_string(i) + " VARCHAR(" +
std::to_string(WORD_SIZE) + ") NOT NULL,\n";
}
ins += "count INTEGER NOT NULL, \n PRIMARY KEY (";
for (int i = 0; i < order + 1; i ++) {
ins += "word_" + std::to_string(i);
if (i != order) {
ins += ",";
}
}
ins += " )\n);";
// run create table instruction
SQLite::Transaction transaction(db);
db.exec(ins);
transaction.commit();
return db;
}
std::string get_update_statement_string() {
std::string update_template = "UPDATE markov SET count = ? WHERE\n";
for (int i = 0; i < order + 1; i++) {
update_template += "word_" + std::to_string(i);
update_template += " = ? ";
if (i != order) {
update_template += "AND ";
}
}
return update_template;
}
std::string get_select_statement_string() {
std::string update_template = "SELECT * FROM markov WHERE " ;
for (int i = 0; i < order; i++) {
update_template += "word_" + std::to_string(i);
update_template += " = ? ";
if (i != order - 1) {
update_template += "AND ";
}
}
update_template += ";";
return update_template;
}
std::string get_insert_statement_string() {
std::string update_template = "INSERT INTO markov VALUES (" ;
for (int i = 0; i < order + 1; i++) {
update_template += "?" ;
if (i != order) {
update_template += ", ";
} else {
update_template += ", ?)";
}
}
return update_template;
}
bool update_count(std::vector<std::string> words, int count) {
std::string update_template = UPDATE_STATEMENT;
SQLite::Statement increment(data, update_template);
increment.bind(1, count);
for (int i = 0; i < order + 1; i++) {
increment.bind(i + 2, words[i]);
bool MarkovChain::test() {
add_ngrams("so long and thanks\nfor\tall the fishes bro man");
return false;
}
bool MarkovChain::process_string(std::string& message) {
if (this->MAKE_LOWER) {
for (int i = 0; i < message.length(); i++) {
if (std::isupper(message[i])) {
message[i] = std::tolower(message[i], std::locale());
}
return increment.executeStep();
}
bool insert_record(std::vector<std::string> words, int count) {
std::string update_template = INSERT_STATEMENT;
SQLite::Statement increment(data, update_template);
for (int i = 0; i < order + 1; i++) {
increment.bind(i + 1, words[i]);
}
increment.bind(order + 2, count);
return increment.executeStep();
}
bool add_to_db(std::vector<std::string> words) {
// UPDATE COMPANY SET ADDRESS = 'Texas' WHERE ID = 6;
std::string query_template = "SELECT count FROM markov WHERE ";
for (int i = 0; i < order + 1; i++) {
query_template += "word_" + std::to_string(i);
query_template += " = ? ";
if (i != order) {
query_template += "AND ";
}
}
SQLite::Statement update(data, query_template);
for (int i = 1; i < order + 2; i++) {
update.bind(i, words[i-1]);
}
int count = 0;
if (update.executeStep()) {
count = update.getColumn(0);
count += 1;
update_count(words, count);
} else {
count = 1;
insert_record(words, count);
}
return false;
}
std::string get_next_word(std::vector<std::string> words, int &score) {
if (words.size() != order) {
throw "Invalid prompt vector size";
return "";
}
message += " MESSAGE_END";
SQLite::Statement query(data, QUERY_STATEMENT);
return false;
}
for (int i = 0; i < order; i++) {
query.bind(i+1, words[i]);
}
std::vector<std::string> MarkovChain::split_string(std::string message) {
std::istringstream ss(message);
int total = 0;
while (query.executeStep()) {
int count = query.getColumn(order + 1);
total += count;
}
std::vector<std::string> words{
std::istream_iterator<std::string>{ss},
std::istream_iterator<std::string>{}
};
total += 1;
int threshold = rand() % total;
return words;
}
// count up to threshold
// pretty sure the order doesn't matter statistically?
// Obviously it will have an impact but I think the distribution should
// be the same whether the list is ordered by frequency or not
bool MarkovChain::add_ngrams(std::string message) {
this->process_string(message);
query.reset();
total = 0;
std::vector<std::string> words (split_string(message));
std::vector<std::string> message;
std::cout << message << std::endl;
while (query.executeStep()) {
int count = query.getColumn(order + 1);
total += count;
if (total >= threshold) {
std::string next_word = query.getColumn(order);
for (int i = 0; i < words.size(); i++) {
if (i >= order) {
std::vector<std::string> add_words;
score = total;
return next_word;
int count = 0;
for (int j = i - order; j <= i; j++, count++) {
add_words.push_back(words[j]);
std::cout << j << ": " << add_words[count] << std::endl;
}
markov.add_ngrams(add_words);
}
std::cout << "No matches found" << std::endl;
throw "No next word found";
return "";
}
std::string get_next_word(std::vector<std::string> words) {
int count;
return get_next_word(words, count);
}
public:
SQLiteMarkov(int ord, std::string dbname): order(ord),
db_filename(get_filepath(dbname)),
data(Initialise_Data()),
INSERT_STATEMENT(get_insert_statement_string()),
QUERY_STATEMENT(get_select_statement_string()),
UPDATE_STATEMENT(get_update_statement_string())
{;};
SQLiteMarkov(int ord): order(ord),
db_filename(get_filepath("markov.sqlite")),
data(Initialise_Data()),
INSERT_STATEMENT(get_insert_statement_string()),
QUERY_STATEMENT(get_select_statement_string()),
UPDATE_STATEMENT(get_update_statement_string())
{;};
SQLiteMarkov(): order(1),
db_filename(get_filepath("markov.sqlite")),
data(Initialise_Data()),
INSERT_STATEMENT(get_insert_statement_string()),
QUERY_STATEMENT(get_select_statement_string()),
UPDATE_STATEMENT(get_update_statement_string())
{;};
return false;
}
SQLiteMarkov(std::string dbname): order(1),
db_filename(get_filepath(dbname)),
data(Initialise_Data()),
INSERT_STATEMENT(get_insert_statement_string()),
QUERY_STATEMENT(get_select_statement_string()),
UPDATE_STATEMENT(get_update_statement_string())
{;};
std::string MarkovChain::continue_message(std::string message, int &score) {
std::vector<std::string> words {split_string(message)};
words.erase(words.begin(), words.begin() + words.size() - order);
std::string ret = markov.get_continuation(words, score);
return ret;
}
bool add_ngrams(std::vector<std::string> words) {
// assertions?
if (words.size() > order + 1) {
return true;
}
for (int i = 0; i < order + 1; i++) {
if (words[i].length() > SQLiteMarkov::WORD_SIZE) {
return true;
}
}
return add_to_db(words);
void TelegramBotM::add_echo() {
listener.set_callback_message([&](telegram::types::message const &message) {
// listener.set_callback_json([&](nlohmann::json const &input) {
std::string message_text = message.text.value();
try {
markov.add_ngrams(message_text);
} catch (std::exception &e) {
std::cout << "Error adding ngram: " << e.what() << std::endl;
}
std::string get_continuation(std::vector<std::string> prompt, int &score) {
std::string new_words = "";
std::vector<std::string> words(prompt);
int count = 0;
score = 0;
do {
std::string next_word = "invalid";
try {
int total;
next_word = get_next_word(words, total);
score += total;
} catch (std::exception &e) {
std::cout << "exception: " << e.what() << std::endl;
std::cout << "Failed Get Next Word" << std::endl;
break;
}
if (next_word == "MESSAGE_END") {
break;
}
new_words += " " + next_word;
std::regex end_punc("[\\.\\?\\!]");
if (std::regex_search(std::to_string(next_word.back()), end_punc)) {
break;
}
count++;
words.push_back(next_word);
words.erase(words.begin());
} while (count < MAX_REPLY_LENGTH);
return new_words;
}
};
class MarkovHandler {
private:
SQLiteMarkov markov;
const bool MAKE_LOWER = true;
public:
bool test() {
std::vector<std::string> words {"so", "long"};
add_ngrams("so long and thanks\nfor\tall the-fishes bro man");
//markov.add_ngrams(words);
return false;
}
//if (replies && rand() % replies != 1) {
//return;
//}
bool process_string(std::string& message) {
if (MAKE_LOWER) {
for (int i = 0; i < message.length(); i++) {
if (std::isupper(message[i])) {
message[i] = std::tolower(message[i], std::locale());
}
}
}
message += " MESSAGE_END";
return false;
}
std::vector<std::string> split_string(std::string message) {
std::istringstream ss(message);
std::vector<std::string> words{
std::istream_iterator<std::string>{ss},
std::istream_iterator<std::string>{}
};
return words;
int score = 0;
std::string reply;
try {
// reply = markov.continue_message(message_text, score);
reply = message_text;
} catch (std::exception &e) {
std::cout << "Error getting message" << std::endl;
}
bool add_ngrams(std::string message) {
process_string(message);
std::vector<std::string> words (split_string(message));
for (int i = 0; i < words.size(); i++) {
if (i >= markov.order) {
std::vector<std::string> add_words;
int count = 0;
for (int j = i - markov.order; j <= i; j++, count++) {
add_words.push_back(words[j]);
std::cout << j << ": " << add_words[count] << std::endl;
}
std::cerr << "SCORE: " << score << std::endl;
markov.add_ngrams(add_words);
}
}
std::cerr << "Send" << std::endl;
return false;
}
try {
sender.send_message(message.chat.id, reply);
} catch (std::exception &e) {
std::cerr << "Send error: " << e.what() << std::endl;
std::string continue_message(std::string message, int &score) {
std::vector<std::string> words {split_string(message)};
words.erase(words.begin(), words.begin() + words.size() - markov.order);
std::string ret = markov.get_continuation(words, score);
return ret;
}
};
class Bot {
private:
std::string apikey;
telegram::sender sender;
MarkovHandler markov;
const int replies = 200;
int SCORE_THRESHOLD = 100;
public:
telegram::listener::poll listener;
Bot() : apikey(get_token()), sender(apikey), listener(sender) {
echo();
};
Bot(std::string key) : apikey(key), sender(apikey), listener(sender) {
echo();
};
void echo() {
listener.set_callback_message([&](telegram::types::message const &message){
std::string message_text = message.text.value();
markov.add_ngrams(message_text);
if (replies && rand() % replies != 1) {
return;
}
int score = 0;
std::string reply;
try {
reply = markov.continue_message(message_text, score);
} catch (std::exception &e) {
std::cout << "Error getting message" << std::endl;
return;
}
std::cout << "SCORE: " << score << std::endl;
});
}
if (score > SCORE_THRESHOLD) {
sender.send_message(message.chat.id, reply);
}
});
}
int main() {
int order = 1;
SQLiteMarkov db = SQLiteMarkov(order, "markov.sqlite", "markov");
MarkovChain m = MarkovChain(order, true, db);
TelegramBotM b = TelegramBotM(order, m, "markov");
std::cout << "Running...\n";
b.run();
void run() {
listener.run();
}
};
auto main()->int {
Bot bot {};
bot.run();
return EXIT_SUCCESS;
};
return 0;
}

78
bot.h

@ -0,0 +1,78 @@ @@ -0,0 +1,78 @@
#ifndef BOT_H
#define BOT_H
#include <boost/algorithm/string.hpp>
#include <sstream>
#include <algorithm>
#include <iterator>
#include <regex>
#include <string>
#include <vector>
#include <libtelegram/libtelegram.h>
class MarkovDB {
public:
virtual bool add_ngrams(std::vector<std::string> words) = 0;
virtual std::string get_continuation(std::vector<std::string> prompt, int &score) = 0;
virtual ~MarkovDB() {}
};
class MarkovChain {
protected:
MarkovDB &markov;
int order = 1;
bool MAKE_LOWER = false;
bool process_string(std::string& message);
std::vector<std::string> split_string(std::string message);
bool test();
public:
bool add_ngrams(std::string message);
std::string continue_message(std::string message, int &score);
MarkovChain (int ord, bool lower, MarkovDB &db)
: markov(db), MAKE_LOWER(lower), order(ord) {;
};
virtual ~MarkovChain() {};
};
class Bot {
protected:
MarkovChain markov;
int replies = 200;
int SCORE_THRESHOLD = 100;
public:
virtual void run() = 0;
Bot(int replies, int score_threshold, MarkovChain m)
: markov(m), replies(replies), SCORE_THRESHOLD(score_threshold) {};
~Bot() {};
};
class TelegramBotM: public Bot {
protected:
std::string api_key = "";
telegram::sender sender;
telegram::listener::poll listener;
std::string chain_name = "markov";
std::string get_token();
void add_echo();
public:
TelegramBotM(int order, MarkovChain &m, std::string chain_name) :
api_key(get_token()), sender(get_token()), listener(sender),
chain_name(chain_name), Bot(200, 100, m) {
add_echo();
};
void run() override {
listener.run();
}
~TelegramBotM() {};
};
#endif

317
sqlite_markov.h

@ -0,0 +1,317 @@ @@ -0,0 +1,317 @@
#include "bot.h"
#include <SQLiteCpp/SQLiteCpp.h>
class SQLiteMarkov: public MarkovDB {
protected:
const std::string db_filename;
const int WORD_SIZE = 400; // chars
const int MAX_REPLY_LENGTH = 500; // words
const std::string INSERT_STATEMENT;
const std::string UPDATE_STATEMENT;
const std::string QUERY_STATEMENT;
const std::string TABLE_NAME;
int order;
SQLite::Database data;
std::string get_filepath(std::string fp) {
std::string filePath(__FILE__);
return filePath.substr( 0, filePath.length() - std::string("bot.cpp").length())
+ fp;
}
SQLite::Database open_db() {
try {
return SQLite::Database (db_filename, SQLite::OPEN_READWRITE);
} catch (std::exception& e)
{
std::cout << "exception: " << e.what() << std::endl << std::flush;
}
std::cout << "Creating new db: '" << db_filename << "'" << std::endl;
return SQLite::Database (db_filename, SQLite::OPEN_CREATE|SQLite::OPEN_READWRITE);
}
SQLite::Database Initialise_Data() {
SQLite::Database db = open_db();
SQLite::Statement query(db, "SELECT name from sqlite_master WHERE name = 'markov'");
try
{
if (query.executeStep()) {
std::cout << "database ready." << std::endl;
return db;
}
}
catch (std::exception& e)
{
std::cout << "exception: " << e.what() << std::endl;
}
// create table sql statement construction
std::string ins = "CREATE TABLE markov (\n";
for (int i = 0; i < order + 1; i ++) {
ins += "word_" + std::to_string(i) + " VARCHAR(" +
std::to_string(WORD_SIZE) + ") NOT NULL,\n";
}
ins += "count INTEGER NOT NULL, \n PRIMARY KEY (";
for (int i = 0; i < order + 1; i ++) {
ins += "word_" + std::to_string(i);
if (i != order) {
ins += ",";
}
}
ins += " )\n);";
// run create table instruction
SQLite::Transaction transaction(db);
db.exec(ins);
transaction.commit();
return db;
}
std::string get_update_statement_string() {
std::string update_template = "UPDATE markov SET count = ? WHERE\n";
for (int i = 0; i < order + 1; i++) {
update_template += "word_" + std::to_string(i);
update_template += " = ? ";
if (i != order) {
update_template += "AND ";
}
}
return update_template;
}
std::string get_select_statement_string() {
std::string update_template = "SELECT * FROM markov WHERE " ;
for (int i = 0; i < order; i++) {
update_template += "word_" + std::to_string(i);
update_template += " = ? ";
if (i != order - 1) {
update_template += "AND ";
}
}
update_template += ";";
return update_template;
}
std::string get_insert_statement_string() {
std::string update_template = "INSERT INTO markov VALUES (" ;
for (int i = 0; i < order + 1; i++) {
update_template += "?" ;
if (i != order) {
update_template += ", ";
} else {
update_template += ", ?)";
}
}
return update_template;
}
bool update_count(std::vector<std::string> words, int count) {
std::string update_template = UPDATE_STATEMENT;
SQLite::Statement increment(data, update_template);
increment.bind(1, count);
for (int i = 0; i < order + 1; i++) {
increment.bind(i + 2, words[i]);
}
return increment.executeStep();
}
bool insert_record(std::vector<std::string> words, int count) {
std::string update_template = INSERT_STATEMENT;
SQLite::Statement increment(data, update_template);
for (int i = 0; i < order + 1; i++) {
increment.bind(i + 1, words[i]);
}
increment.bind(order + 2, count);
return increment.executeStep();
}
bool add_to_db(std::vector<std::string> words) {
// UPDATE COMPANY SET ADDRESS = 'Texas' WHERE ID = 6;
std::string query_template = "SELECT count FROM markov WHERE ";
for (int i = 0; i < order + 1; i++) {
query_template += "word_" + std::to_string(i);
query_template += " = ? ";
if (i != order) {
query_template += "AND ";
}
}
SQLite::Statement update(data, query_template);
for (int i = 1; i < order + 2; i++) {
update.bind(i, words[i-1]);
}
int count = 0;
if (update.executeStep()) {
count = update.getColumn(0);
count += 1;
update_count(words, count);
} else {
count = 1;
insert_record(words, count);
}
return false;
}
std::string get_next_word(std::vector<std::string> words, int &score) {
if (words.size() != order) {
throw "Invalid prompt vector size";
return "";
}
SQLite::Statement query(data, QUERY_STATEMENT);
for (int i = 0; i < order; i++) {
query.bind(i+1, words[i]);
}
int total = 0;
while (query.executeStep()) {
int count = query.getColumn(order + 1);
total += count;
}
total += 1;
int threshold = rand() % total;
// count up to threshold
// pretty sure the order doesn't matter statistically?
// Obviously it will have an impact but I think the distribution should
// be the same whether the list is ordered by frequency or not
query.reset();
total = 0;
std::vector<std::string> message;
while (query.executeStep()) {
int count = query.getColumn(order + 1);
total += count;
if (total >= threshold) {
std::string next_word = query.getColumn(order);
score = total;
return next_word;
}
}
std::cout << "No matches found" << std::endl;
throw "No next word found";
return "";
}
std::string get_next_word(std::vector<std::string> words) {
int count;
return get_next_word(words, count);
}
public:
SQLiteMarkov(int ord, std::string dbname, std::string table): order(ord),
db_filename(get_filepath(dbname)),
data(Initialise_Data()),
TABLE_NAME(table),
INSERT_STATEMENT(get_insert_statement_string()),
QUERY_STATEMENT(get_select_statement_string()),
UPDATE_STATEMENT(get_update_statement_string())
{;};
SQLiteMarkov(int ord): order(ord),
db_filename(get_filepath("markov.sqlite")),
data(Initialise_Data()),
TABLE_NAME("markov"),
INSERT_STATEMENT(get_insert_statement_string()),
QUERY_STATEMENT(get_select_statement_string()),
UPDATE_STATEMENT(get_update_statement_string())
{;};
SQLiteMarkov(): order(1),
db_filename(get_filepath("markov.sqlite")),
data(Initialise_Data()),
TABLE_NAME("markov"),
INSERT_STATEMENT(get_insert_statement_string()),
QUERY_STATEMENT(get_select_statement_string()),
UPDATE_STATEMENT(get_update_statement_string())
{;};
bool add_ngrams(std::vector<std::string> words) override {
// assertions?
if (words.size() > order + 1) {
return true;
}
for (int i = 0; i < order + 1; i++) {
if (words[i].length() > SQLiteMarkov::WORD_SIZE) {
return true;
}
}
return add_to_db(words);
}
std::string get_continuation(std::vector<std::string> prompt, int &score)
override {
std::string new_words = "";
std::vector<std::string> words(prompt);
int count = 0;
score = 0;
do {
std::string next_word = "invalid";
try {
int total;
next_word = get_next_word(words, total);
score += total;
} catch (std::exception &e) {
std::cout << "exception: " << e.what() << std::endl;
std::cout << "Failed Get Next Word" << std::endl;
break;
}
if (next_word == "MESSAGE_END") {
break;
}
new_words += " " + next_word;
std::regex end_punc("[\\.\\?\\!]");
if (std::regex_search(std::to_string(next_word.back()), end_punc)) {
break;
}
count++;
words.push_back(next_word);
words.erase(words.begin());
} while (count < MAX_REPLY_LENGTH);
return new_words;
}
};
Loading…
Cancel
Save