user
4 years ago
commit
a0d097b025
5 changed files with 505 additions and 0 deletions
@ -0,0 +1,3 @@ |
|||||||
|
[submodule "libtelegram"] |
||||||
|
path = libtelegram |
||||||
|
url = https://github.com/slowriot/libtelegram.git |
@ -0,0 +1,43 @@ |
|||||||
|
cmake_minimum_required (VERSION 3.12) |
||||||
|
|
||||||
|
set(PROJECT_NAME "markov bot") |
||||||
|
project (${PROJECT_NAME}) |
||||||
|
|
||||||
|
set_property (GLOBAL PROPERTY USE_FOLDERS ON) |
||||||
|
|
||||||
|
set (CMAKE_CXX_STANDARD 17) |
||||||
|
set (CMAKE_CXX_STANDARD_REQUIRED ON) |
||||||
|
|
||||||
|
set (THREADS_PREFER_PTHREAD_FLAG ON) |
||||||
|
|
||||||
|
find_package (Threads REQUIRED) |
||||||
|
find_package(Boost 1.66 REQUIRED COMPONENTS system) |
||||||
|
|
||||||
|
|
||||||
|
IF(Boost_FOUND) |
||||||
|
include_directories(${Boost_INCLUDE_DIRS}) |
||||||
|
ELSE() |
||||||
|
message(FATAL "${CMAKE_SOURCE_DIR}/libtelegram/include") |
||||||
|
ENDIF(Boost_FOUND) |
||||||
|
|
||||||
|
find_package(OpenSSL REQUIRED) |
||||||
|
|
||||||
|
set(include_dir "${CMAKE_SOURCE_DIR}/libtelegram/include") |
||||||
|
message(STATUS "include_dir: ${include_dir}") |
||||||
|
include_directories (${include_dir}) |
||||||
|
|
||||||
|
set(link_libs |
||||||
|
Threads::Threads |
||||||
|
OpenSSL::SSL |
||||||
|
Boost::system |
||||||
|
SQLiteCpp |
||||||
|
sqlite3 |
||||||
|
dl) |
||||||
|
|
||||||
|
macro(add_tg_example name) |
||||||
|
set(target_name ${name}) |
||||||
|
add_executable(${target_name} ${name}.cpp) |
||||||
|
target_link_libraries(${target_name} ${link_libs}) |
||||||
|
endmacro(add_tg_example name) |
||||||
|
|
||||||
|
add_tg_example(bot) |
@ -0,0 +1,11 @@ |
|||||||
|
|
||||||
|
|
||||||
|
## Building on Fedora |
||||||
|
|
||||||
|
``` |
||||||
|
|
||||||
|
sudo dnf install boost-devel sqlitecpp-devel openssl-devel |
||||||
|
cmake . |
||||||
|
make |
||||||
|
|
||||||
|
``` |
@ -0,0 +1,447 @@ |
|||||||
|
#include <SQLiteCpp/SQLiteCpp.h> |
||||||
|
#include <libtelegram/libtelegram.h> |
||||||
|
#include <boost/algorithm/string.hpp> |
||||||
|
#include <vector> |
||||||
|
#include <sstream> |
||||||
|
#include <algorithm> |
||||||
|
#include <iterator> |
||||||
|
#include <regex> |
||||||
|
|
||||||
|
std::string get_token() { |
||||||
|
using namespace std; |
||||||
|
fstream ftoken; |
||||||
|
ftoken.open("apikey.txt", ios::in); |
||||||
|
if (!ftoken) { |
||||||
|
perror("No apikey.txt file provided"); |
||||||
|
exit(1); |
||||||
|
} |
||||||
|
|
||||||
|
string apikey; |
||||||
|
ftoken >> apikey; |
||||||
|
return apikey; |
||||||
|
} |
||||||
|
|
||||||
|
class SQLiteMarkov { |
||||||
|
public:
|
||||||
|
int order; |
||||||
|
private: |
||||||
|
std::string db_filename; |
||||||
|
SQLite::Database data; |
||||||
|
|
||||||
|
const int WORD_SIZE = 400; |
||||||
|
const int MAX_REPLY_LENGTH = 500; |
||||||
|
const std::string INSERT_STATEMENT; |
||||||
|
const std::string UPDATE_STATEMENT; |
||||||
|
const std::string QUERY_STATEMENT; |
||||||
|
|
||||||
|
std::string get_filepath(std::string fp) { |
||||||
|
std::string filePath(__FILE__); |
||||||
|
return filePath.substr( 0, filePath.length() - std::string("bot.cpp").length())
|
||||||
|
+ fp; |
||||||
|
} |
||||||
|
|
||||||
|
SQLite::Database open_db() { |
||||||
|
try { |
||||||
|
return SQLite::Database (db_filename, SQLite::OPEN_READWRITE); |
||||||
|
} catch (std::exception& e)
|
||||||
|
{ |
||||||
|
std::cout << "exception: " << e.what() << std::endl; |
||||||
|
} |
||||||
|
|
||||||
|
std::cout << "Creating new db: '" << db_filename << "'" << std::endl; |
||||||
|
|
||||||
|
return SQLite::Database (db_filename, SQLite::OPEN_CREATE|SQLite::OPEN_READWRITE); |
||||||
|
} |
||||||
|
|
||||||
|
SQLite::Database Initialise_Data() { |
||||||
|
SQLite::Database db = open_db(); |
||||||
|
|
||||||
|
SQLite::Statement query(db, "SELECT name from sqlite_master WHERE name = 'markov'"); |
||||||
|
try
|
||||||
|
{ |
||||||
|
if (query.executeStep()) { |
||||||
|
std::cout << "database ready." << std::endl; |
||||||
|
return db; |
||||||
|
} |
||||||
|
} |
||||||
|
catch (std::exception& e) |
||||||
|
{ |
||||||
|
std::cout << "exception: " << e.what() << std::endl; |
||||||
|
} |
||||||
|
|
||||||
|
// create table sql statement construction
|
||||||
|
std::string ins = "CREATE TABLE markov (\n"; |
||||||
|
for (int i = 0; i < order + 1; i ++) { |
||||||
|
ins += "word_" + std::to_string(i) + " VARCHAR(" +
|
||||||
|
std::to_string(WORD_SIZE) + ") NOT NULL,\n"; |
||||||
|
} |
||||||
|
ins += "count INTEGER NOT NULL, \n PRIMARY KEY ("; |
||||||
|
for (int i = 0; i < order + 1; i ++) { |
||||||
|
ins += "word_" + std::to_string(i); |
||||||
|
if (i != order) { |
||||||
|
ins += ","; |
||||||
|
} |
||||||
|
} |
||||||
|
ins += " )\n);"; |
||||||
|
|
||||||
|
// run create table instruction
|
||||||
|
SQLite::Transaction transaction(db); |
||||||
|
db.exec(ins); |
||||||
|
transaction.commit(); |
||||||
|
|
||||||
|
return db; |
||||||
|
} |
||||||
|
|
||||||
|
std::string get_update_statement_string() { |
||||||
|
std::string update_template = "UPDATE markov SET count = ? WHERE\n"; |
||||||
|
|
||||||
|
for (int i = 0; i < order + 1; i++) { |
||||||
|
update_template += "word_" + std::to_string(i); |
||||||
|
update_template += " = ? "; |
||||||
|
if (i != order) { |
||||||
|
update_template += "AND "; |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
return update_template; |
||||||
|
} |
||||||
|
|
||||||
|
std::string get_select_statement_string() { |
||||||
|
std::string update_template = "SELECT * FROM markov WHERE " ; |
||||||
|
|
||||||
|
for (int i = 0; i < order; i++) { |
||||||
|
update_template += "word_" + std::to_string(i); |
||||||
|
update_template += " = ? "; |
||||||
|
if (i != order - 1) { |
||||||
|
update_template += "AND "; |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
update_template += ";"; |
||||||
|
|
||||||
|
return update_template; |
||||||
|
} |
||||||
|
|
||||||
|
std::string get_insert_statement_string() { |
||||||
|
std::string update_template = "INSERT INTO markov VALUES (" ; |
||||||
|
|
||||||
|
for (int i = 0; i < order + 1; i++) { |
||||||
|
update_template += "?" ; |
||||||
|
if (i != order) { |
||||||
|
update_template += ", "; |
||||||
|
} else { |
||||||
|
update_template += ", ?)"; |
||||||
|
} |
||||||
|
} |
||||||
|
return update_template; |
||||||
|
} |
||||||
|
|
||||||
|
bool update_count(std::vector<std::string> words, int count) { |
||||||
|
std::string update_template = UPDATE_STATEMENT; |
||||||
|
|
||||||
|
|
||||||
|
SQLite::Statement increment(data, update_template); |
||||||
|
|
||||||
|
increment.bind(1, count); |
||||||
|
|
||||||
|
for (int i = 0; i < order + 1; i++) { |
||||||
|
std::cout << i << ": " << words[i]; |
||||||
|
increment.bind(i + 2, words[i]); |
||||||
|
} |
||||||
|
|
||||||
|
return increment.executeStep(); |
||||||
|
} |
||||||
|
|
||||||
|
bool insert_record(std::vector<std::string> words, int count) { |
||||||
|
|
||||||
|
std::string update_template = INSERT_STATEMENT; |
||||||
|
|
||||||
|
|
||||||
|
SQLite::Statement increment(data, update_template); |
||||||
|
for (int i = 0; i < order + 1; i++) { |
||||||
|
std::cout << i << ": " << words[i]; |
||||||
|
increment.bind(i + 1, words[i]); |
||||||
|
} |
||||||
|
|
||||||
|
increment.bind(order + 2, count); |
||||||
|
|
||||||
|
return increment.executeStep(); |
||||||
|
} |
||||||
|
|
||||||
|
|
||||||
|
bool add_to_db(std::vector<std::string> words) { |
||||||
|
|
||||||
|
// UPDATE COMPANY SET ADDRESS = 'Texas' WHERE ID = 6;
|
||||||
|
|
||||||
|
std::string query_template = "SELECT count FROM markov WHERE "; |
||||||
|
for (int i = 0; i < order + 1; i++) { |
||||||
|
query_template += "word_" + std::to_string(i); |
||||||
|
query_template += " = ? "; |
||||||
|
if (i != order) { |
||||||
|
query_template += "AND "; |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
|
||||||
|
SQLite::Statement update(data, query_template); |
||||||
|
|
||||||
|
for (int i = 1; i < order + 2; i++) { |
||||||
|
update.bind(i, words[i-1]); |
||||||
|
} |
||||||
|
|
||||||
|
int count = 0; |
||||||
|
if (update.executeStep()) { |
||||||
|
count = update.getColumn(0);
|
||||||
|
count += 1; |
||||||
|
update_count(words, count); |
||||||
|
} else { |
||||||
|
count = 1; |
||||||
|
insert_record(words, count); |
||||||
|
} |
||||||
|
|
||||||
|
return false; |
||||||
|
} |
||||||
|
|
||||||
|
std::string get_next_word(std::vector<std::string> words) { |
||||||
|
if (words.size() != order) { |
||||||
|
throw "Invalid prompt vector size"; |
||||||
|
return ""; |
||||||
|
} |
||||||
|
|
||||||
|
SQLite::Statement query(data, QUERY_STATEMENT); |
||||||
|
|
||||||
|
for (int i = 0; i < order; i++) { |
||||||
|
query.bind(i+1, words[i]); |
||||||
|
} |
||||||
|
|
||||||
|
int total = 0; |
||||||
|
while (query.executeStep()) { |
||||||
|
int count = query.getColumn(order + 1); |
||||||
|
total += count; |
||||||
|
} |
||||||
|
|
||||||
|
total += 1; |
||||||
|
int threshold = rand() % total; |
||||||
|
|
||||||
|
// count up to threshold
|
||||||
|
// pretty sure the order doesn't matter statistically?
|
||||||
|
// Obviously it will have an impact but I think the distribution should
|
||||||
|
// be the same whether the list is ordered by frequency or not
|
||||||
|
|
||||||
|
query.reset(); |
||||||
|
total = 0; |
||||||
|
|
||||||
|
std::vector<std::string> message; |
||||||
|
|
||||||
|
while (query.executeStep()) { |
||||||
|
int count = query.getColumn(order + 1); |
||||||
|
total += count; |
||||||
|
if (total >= threshold) { |
||||||
|
std::string next_word = query.getColumn(order); |
||||||
|
return next_word; |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
std::cout << "No matches found" << std::endl; |
||||||
|
throw "No next word found"; |
||||||
|
return ""; |
||||||
|
} |
||||||
|
|
||||||
|
|
||||||
|
public: |
||||||
|
|
||||||
|
SQLiteMarkov(int ord, std::string dbname): order(ord), |
||||||
|
db_filename(get_filepath(dbname)),
|
||||||
|
data(Initialise_Data()), |
||||||
|
INSERT_STATEMENT(get_insert_statement_string()), |
||||||
|
QUERY_STATEMENT(get_select_statement_string()), |
||||||
|
UPDATE_STATEMENT(get_update_statement_string()) |
||||||
|
{;}; |
||||||
|
|
||||||
|
SQLiteMarkov(int ord): order(ord),
|
||||||
|
db_filename(get_filepath("markov.sqlite")),
|
||||||
|
data(Initialise_Data()), |
||||||
|
INSERT_STATEMENT(get_insert_statement_string()), |
||||||
|
QUERY_STATEMENT(get_select_statement_string()), |
||||||
|
UPDATE_STATEMENT(get_update_statement_string()) |
||||||
|
{;}; |
||||||
|
|
||||||
|
SQLiteMarkov(): order(1),
|
||||||
|
db_filename(get_filepath("markov.sqlite")),
|
||||||
|
data(Initialise_Data()), |
||||||
|
INSERT_STATEMENT(get_insert_statement_string()), |
||||||
|
QUERY_STATEMENT(get_select_statement_string()), |
||||||
|
UPDATE_STATEMENT(get_update_statement_string()) |
||||||
|
{;}; |
||||||
|
|
||||||
|
SQLiteMarkov(std::string dbname): order(1),
|
||||||
|
db_filename(get_filepath(dbname)),
|
||||||
|
data(Initialise_Data()), |
||||||
|
INSERT_STATEMENT(get_insert_statement_string()), |
||||||
|
QUERY_STATEMENT(get_select_statement_string()), |
||||||
|
UPDATE_STATEMENT(get_update_statement_string()) |
||||||
|
{;}; |
||||||
|
|
||||||
|
bool add_ngrams(std::vector<std::string> words) { |
||||||
|
// assertions?
|
||||||
|
if (words.size() > order + 1) { |
||||||
|
return true; |
||||||
|
} |
||||||
|
|
||||||
|
for (int i = 0; i < order + 1; i++) { |
||||||
|
if (words[i].length() > SQLiteMarkov::WORD_SIZE) { |
||||||
|
return true; |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
return add_to_db(words); |
||||||
|
|
||||||
|
} |
||||||
|
|
||||||
|
std::string get_continuation(std::vector<std::string> prompt) { |
||||||
|
std::string new_words = ""; |
||||||
|
std::vector<std::string> words(prompt); |
||||||
|
int count = 0; |
||||||
|
|
||||||
|
do { |
||||||
|
std::string next_word = "invalid"; |
||||||
|
|
||||||
|
try { |
||||||
|
next_word = get_next_word(words); |
||||||
|
} catch (std::exception &e) { |
||||||
|
std::cout << "exception: " << e.what() << std::endl;
|
||||||
|
std::cout << "Failed Get Next Word" << std::endl; |
||||||
|
break; |
||||||
|
} |
||||||
|
|
||||||
|
if (next_word == "MESSAGE_END") { |
||||||
|
break; |
||||||
|
} |
||||||
|
|
||||||
|
new_words += " " + next_word; |
||||||
|
std::regex end_punc("[\\.\\?\\!]"); |
||||||
|
|
||||||
|
if (std::regex_search(std::to_string(next_word.back()), end_punc)) { |
||||||
|
break; |
||||||
|
} |
||||||
|
|
||||||
|
count++; |
||||||
|
words.push_back(next_word); |
||||||
|
words.erase(words.begin()); |
||||||
|
|
||||||
|
} while (count < MAX_REPLY_LENGTH); |
||||||
|
|
||||||
|
return new_words; |
||||||
|
} |
||||||
|
}; |
||||||
|
|
||||||
|
|
||||||
|
class MarkovHandler { |
||||||
|
|
||||||
|
private: |
||||||
|
SQLiteMarkov markov; |
||||||
|
|
||||||
|
const bool MAKE_LOWER = true; |
||||||
|
|
||||||
|
public: |
||||||
|
|
||||||
|
bool test() { |
||||||
|
std::vector<std::string> words {"so", "long"}; |
||||||
|
add_ngrams("so long and thanks\nfor\tall the-fishes bro man"); |
||||||
|
//markov.add_ngrams(words);
|
||||||
|
return false; |
||||||
|
} |
||||||
|
|
||||||
|
bool process_string(std::string& message) { |
||||||
|
if (MAKE_LOWER) { |
||||||
|
for (int i = 0; i < message.length(); i++) { |
||||||
|
if (std::isupper(message[i])) { |
||||||
|
message[i] = std::tolower(message[i], std::locale()); |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
message += " MESSAGE_END"; |
||||||
|
|
||||||
|
return false; |
||||||
|
} |
||||||
|
|
||||||
|
std::vector<std::string> split_string(std::string message) { |
||||||
|
|
||||||
|
std::istringstream ss(message); |
||||||
|
|
||||||
|
std::vector<std::string> words{ |
||||||
|
std::istream_iterator<std::string>{ss},
|
||||||
|
std::istream_iterator<std::string>{} |
||||||
|
}; |
||||||
|
|
||||||
|
return words; |
||||||
|
} |
||||||
|
|
||||||
|
bool add_ngrams(std::string message) { |
||||||
|
process_string(message); |
||||||
|
std::vector<std::string> words (split_string(message)); |
||||||
|
|
||||||
|
for (int i = 0; i < words.size(); i++) { |
||||||
|
if (i >= markov.order) { |
||||||
|
std::vector<std::string> add_words; |
||||||
|
|
||||||
|
int count = 0; |
||||||
|
for (int j = i - markov.order; j <= i; j++, count++) { |
||||||
|
add_words.push_back(words[j]); |
||||||
|
std::cout << j << ": " << add_words[count] << std::endl; |
||||||
|
} |
||||||
|
|
||||||
|
markov.add_ngrams(add_words); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
return false; |
||||||
|
} |
||||||
|
|
||||||
|
std::string continue_message(std::string message) { |
||||||
|
std::vector<std::string> words {split_string(message)}; |
||||||
|
words.erase(words.begin(), words.begin() + words.size() - markov.order); |
||||||
|
|
||||||
|
return markov.get_continuation(words); |
||||||
|
}
|
||||||
|
}; |
||||||
|
|
||||||
|
class Bot { |
||||||
|
private: |
||||||
|
std::string apikey; |
||||||
|
telegram::sender sender; |
||||||
|
MarkovHandler markov; |
||||||
|
|
||||||
|
public: |
||||||
|
telegram::listener::poll listener; |
||||||
|
|
||||||
|
Bot() : apikey(get_token()), sender(apikey), listener(sender) { |
||||||
|
echo(); |
||||||
|
}; |
||||||
|
|
||||||
|
Bot(std::string key) : apikey(key), sender(apikey), listener(sender) { |
||||||
|
echo(); |
||||||
|
}; |
||||||
|
|
||||||
|
void echo() { |
||||||
|
listener.set_callback_message([&](telegram::types::message const &message){
|
||||||
|
|
||||||
|
std::string message_text = message.text.value(); |
||||||
|
markov.add_ngrams(message_text); |
||||||
|
|
||||||
|
std::string reply = markov.continue_message(message_text); |
||||||
|
sender.send_message(message.chat.id, reply); |
||||||
|
}); |
||||||
|
} |
||||||
|
|
||||||
|
void run() { |
||||||
|
listener.run(); |
||||||
|
} |
||||||
|
}; |
||||||
|
|
||||||
|
auto main()->int { |
||||||
|
Bot bot {}; |
||||||
|
bot.run(); |
||||||
|
return EXIT_SUCCESS; |
||||||
|
}; |
Loading…
Reference in new issue