diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..d5e60da --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +build/ +.env diff --git a/.gitmodules b/.gitmodules index dc7de8c..66b256c 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,6 +1,6 @@ -[submodule "tgbot-cpp"] - path = tgbot-cpp - url = https://github.com/reo7sp/tgbot-cpp/ -[submodule "spdlog"] - path = spdlog +[submodule "lib/tgbot-cpp"] + path = lib/tgbot-cpp + url = https://github.com/reo7sp/tgbot-cpp.git +[submodule "lib/spdlog"] + path = lib/spdlog url = https://github.com/gabime/spdlog.git diff --git a/CMakeLists.txt b/CMakeLists.txt index ea863fb..14a009c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -37,22 +37,23 @@ find_package(Threads REQUIRED) find_package(OpenSSL REQUIRED) find_package(Boost COMPONENTS system REQUIRED) include_directories(/usr/local/include ${OPENSSL_INCLUDE_DIR} - ${Boost_INCLUDE_DIR} ${CMAKE_SOURCE_DIR}/spdlog/include - ${CMAKE_SOURCE_DIR}/tgbot-cpp/include) + ${Boost_INCLUDE_DIR} ${CMAKE_SOURCE_DIR}/lib ${CMAKE_SOURCE_DIR}/lib/spdlog/include + ${CMAKE_SOURCE_DIR}/lib/tgbot-cpp/include) if (CURL_FOUND) include_directories(${CURL_INCLUDE_DIRS}) add_definitions(-DHAVE_CURL) endif() -add_executable(telegram_bog telegram_bot.cpp sqlite3.c) +file(GLOB_RECURSE src_files CONFIGURE_DEPENDS "src/*.h" "src/*.cpp") +add_executable(telegram_bog ${src_files} lib/sqlite3.c) #set(OPENSSL_USE_STATIC_LIBS TRUE) target_link_libraries(telegram_bog PRIVATE - ${CMAKE_SOURCE_DIR}/tgbot-cpp/build/libTgBot.a ${CMAKE_THREAD_LIBS_INIT} + ${CMAKE_SOURCE_DIR}/lib/tgbot-cpp/build/libTgBot.a ${CMAKE_THREAD_LIBS_INIT} ${OPENSSL_LIBRARIES} ${Boost_LIBRARIES} ${CURL_LIBRARIES}) target_link_libraries(telegram_bog PRIVATE nlohmann_json::nlohmann_json) target_link_libraries(telegram_bog PRIVATE cpr::cpr) -target_link_libraries(telegram_bog PRIVATE ${CMAKE_SOURCE_DIR}/spdlog/build/libspdlog.a) +target_link_libraries(telegram_bog PRIVATE ${CMAKE_SOURCE_DIR}/lib/spdlog/build/libspdlog.a) diff --git a/Base64.hpp b/lib/Base64.hpp similarity index 99% rename from Base64.hpp rename to lib/Base64.hpp index d1c717b..d501e21 100644 --- a/Base64.hpp +++ b/lib/Base64.hpp @@ -13,6 +13,7 @@ #include #include "CpuFeatures.hpp" +#define _mm256_set_m128i(v0, v1) _mm256_insertf128_si256(_mm256_castsi128_si256(v1), (v0), 1) namespace base64 { enum class Codepath { @@ -662,4 +663,4 @@ namespace base64 { return buf; } -} \ No newline at end of file +} diff --git a/CpuFeatures.hpp b/lib/CpuFeatures.hpp similarity index 100% rename from CpuFeatures.hpp rename to lib/CpuFeatures.hpp diff --git a/spdlog b/lib/spdlog similarity index 100% rename from spdlog rename to lib/spdlog diff --git a/sqlite3.c b/lib/sqlite3.c similarity index 100% rename from sqlite3.c rename to lib/sqlite3.c diff --git a/sqlite3.h b/lib/sqlite3.h similarity index 100% rename from sqlite3.h rename to lib/sqlite3.h diff --git a/tgbot-cpp b/lib/tgbot-cpp similarity index 100% rename from tgbot-cpp rename to lib/tgbot-cpp diff --git a/src/songdb.cpp b/src/songdb.cpp new file mode 100644 index 0000000..a1284a5 --- /dev/null +++ b/src/songdb.cpp @@ -0,0 +1,602 @@ +#include "songdb.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "sqlite3.h" + +int songdb::callback(void *valmap, int argc, char **argv, char **azColName) { + std::map *values = (std::map *)valmap; + + for (int i = 0; i < argc; i++) { + values->insert({azColName[i], argv[i]}); + } + return 0; +} + +bool songdb::setup_tables() { + char *errmsg; + int err; + + /* + * songlists: [list id (autoincrement int), (int64) telegram group id, (string) list name ] + * + * songs : [songid int autoincrement, song name NOT NULL, song artist NULL, Song spotify ID (string) ] + * + * votes: [song id foreign key, list id foreign key, telegram user id, int rating ] + * + */ + + std::string create_songlist_table = + "CREATE TABLE IF NOT EXISTS songlists (" + "id INTEGER PRIMARY KEY AUTOINCREMENT, " + "groupid INT NOT NULL, " + "name VARCHAR(200) NOT NULL " + ");"; + + std::string create_votes_table = + "CREATE TABLE IF NOT EXISTS votes (" + "song INT NOT NULL, " + "list INT NOT NULL, " + "user INT NOT NULL, " + "rating NOT NULL, " + " FOREIGN KEY (song) REFERENCES tracks (id)," + " FOREIGN KEY (list) REFERENCES songlists (id)," + " PRIMARY KEY (song, list, user)" + ");"; + + std::string create_tracks_table = + "CREATE TABLE IF NOT EXISTS tracks (" + "id INTEGER PRIMARY KEY AUTOINCREMENT, " + "name VARCHAR(200) NOT NULL, " + "artist VARCHAR(200), " + "spotifyid VARCHAR(200) " + ");"; + + err = sqlite3_exec(db, create_songlist_table.c_str(), NULL, 0, &errmsg); + + if (err != SQLITE_OK) { + std::cout << "SQLite Error " << create_songlist_table << "\n" << errmsg << std::endl; + exit (1); + } + + err = sqlite3_exec(db, create_tracks_table.c_str(), NULL, 0, &errmsg); + + if (err != SQLITE_OK) { + std::cout << "SQLite Error: " << create_tracks_table << "\n" << errmsg << std::endl; + exit (1); + } + + + err = sqlite3_exec(db, create_votes_table.c_str(), NULL, 0, &errmsg); + + if (err != SQLITE_OK) { + std::cout << "SQLite Error: " << create_votes_table << "\n" << errmsg << std::endl; + exit (1); + } + + if (errmsg) + sqlite3_free(errmsg); + return false; +} + +int songdb::check_error(int rc) { + if (rc != SQLITE_OK) { + spdlog::error("SQLite: {}", sqlite3_errmsg(db)); + spdlog::dump_backtrace(); + exit(1); + } + return 0; +} + +void songdb::create_new_list(int64_t group_id) { + spdlog::debug("create_new_list {}", group_id); + + sqlite3_stmt *statement; + int err; + std::string insert_query = "INSERT INTO songlists (groupid, name) VALUES (? , \"default\");"; + err = sqlite3_prepare_v2(db, insert_query.c_str(), insert_query.length(), &statement, NULL); + check_error(err); + err = sqlite3_bind_int64(statement, 1, group_id); + check_error(err); + err = sqlite3_step(statement); + + if (err != SQLITE_DONE) { + check_error(err); + } + + sqlite3_finalize(statement); + +} + +/* Retrieve or create the song list id for a group. + * + * As is, song lists are unique to telegram groups + */ +int64_t songdb::get_song_list_id(int64_t group_id) { + spdlog::debug("{} {}", __PRETTY_FUNCTION__, __LINE__); + + std::string list_query = "SELECT * FROM songlists WHERE groupid = ?;"; + /* A constraint of this implementation is that every group can only have + * one list + */ + int err; + + + sqlite3_stmt *statement; + sqlite3_prepare_v2(db, list_query.c_str(), list_query.length(), &statement, NULL); + err = sqlite3_bind_int64(statement, 1, group_id); + check_error(err); + err = sqlite3_step(statement); + int list_id; + + if (err == SQLITE_ROW && sqlite3_column_count(statement)) { + + list_id = sqlite3_column_int(statement, 0); + + err = sqlite3_step(statement); + if (err != SQLITE_DONE) { + // should only ever be 1 entry + spdlog::error("Should only be one list per group?"); + spdlog::dump_backtrace(); + exit(1); + } + + + } else if (err == SQLITE_DONE) { + // create the list + sqlite3_finalize(statement); + create_new_list(group_id); + + err = sqlite3_prepare_v2(db, list_query.c_str(), list_query.length(), &statement, NULL); + check_error(err); + err = sqlite3_bind_int64(statement, 1, group_id); + check_error(err); + + err = sqlite3_step(statement); + + if (err == SQLITE_ROW){ + list_id = sqlite3_column_int64(statement, 0); + } + else if (err != SQLITE_DONE) { + + check_error(err); + } + } else { + check_error(err); + } + + + return list_id; + +} + +std::optional songdb::get_song(int64_t id) { + std::string check_exist = "SELECT * FROM tracks WHERE id = ?;"; + + int err; + sqlite3_stmt *statement; + err = sqlite3_prepare_v2(db, check_exist.c_str(), check_exist.length(), &statement, NULL); + check_error(err); + err = sqlite3_bind_int64(statement, 1, id); + check_error(err); + + err = sqlite3_step(statement); + if (err == SQLITE_ROW) { + + int64_t id = sqlite3_column_int64(statement, 0); + std::string name {(char *)sqlite3_column_text(statement, 1)}; + std::string artist {(char *)sqlite3_column_text(statement, 2)}; + char * spotid = (char *)sqlite3_column_text(statement, 3); + + if (spotid) { + track_entry e {id, name, artist, std::string(spotid)}; + sqlite3_finalize(statement); + return e; + } + + track_entry e {id, name, artist}; + sqlite3_finalize(statement); + return e; + } + sqlite3_finalize(statement); + return {}; +} + +std::optional songdb::get_song(std::string name, std::string artist) { + return get_song(name, artist, {}); +} + +std::optional songdb::get_song(std::string name, std::string artist, std::optional spotify_id) { + + spdlog::debug("get_song"); + + std::string check_exist_spot = "SELECT * FROM tracks WHERE spotifyid = ?;"; + std::string check_exist_nospot = "SELECT * FROM tracks WHERE name = ? AND artist = ?;"; + + int err; + sqlite3_stmt *statement; + + /* check whether the song exists */ + if (spotify_id) { + err = sqlite3_prepare_v2(db, check_exist_spot.c_str(), check_exist_spot.length(), &statement, NULL); + check_error(err); + + err = sqlite3_bind_text(statement, 1, spotify_id->c_str(), spotify_id->length(), NULL); + check_error(err); + } else { + err = sqlite3_prepare_v2(db, check_exist_nospot.c_str(), check_exist_nospot.length(), &statement, NULL); + check_error(err); + + err = sqlite3_bind_text(statement, 1, name.c_str(), name.length(), NULL); + check_error(err); + + err = sqlite3_bind_text(statement, 2, artist.c_str(), artist.length(), NULL); + check_error(err); + } + + err = sqlite3_step(statement); + if (err == SQLITE_ROW) { + + int64_t id = sqlite3_column_int64(statement, 0); + std::string name {(char *)sqlite3_column_text(statement, 1)}; + std::string artist {(char *)sqlite3_column_text(statement, 2)}; + char * spotid = (char *)sqlite3_column_text(statement, 3); + + if (spotid) { + track_entry e {id, name, artist, std::string(spotid)}; + sqlite3_finalize(statement); + return e; + } + track_entry e {id, name, artist}; + sqlite3_finalize(statement); + return e; + spdlog::info("Entry exists."); + } else if (err != SQLITE_DONE) { + check_error(err); + sqlite3_finalize(statement); + return {}; + } + + sqlite3_finalize(statement); + return {}; +} + +std::optional songdb::insert_song(std::string name, std::string artist, std::optional spotify_id) { + + spdlog::debug("insert_song"); + + std::string ins_query_spot = "INSERT INTO tracks (name, artist, spotifyid) VALUES(?, ?, ?);"; + std::string ins_query_nospot = "INSERT INTO tracks (name, artist) VALUES(?, ?);"; + +// int list_id = get_song_list_id(group_id); + sqlite3_stmt *statement; + int err; + + auto e = get_song(name, artist, spotify_id); + if (e) { + return e; + } + + if (spotify_id) { + err = sqlite3_prepare_v2(db, ins_query_spot.c_str(), ins_query_spot.length(), &statement, NULL); + } else { + err = sqlite3_prepare_v2(db, ins_query_nospot.c_str(), ins_query_nospot.length(), &statement, NULL); + } + check_error(err); + + err = sqlite3_bind_text(statement, 1, name.c_str(), name.length(), NULL); + check_error(err); + err = sqlite3_bind_text(statement, 2, artist.c_str(), artist.length(), NULL); + check_error(err); + + if (spotify_id) { + err = sqlite3_bind_text(statement, 3, spotify_id->c_str(), spotify_id->length(), NULL); + check_error(err); + } + + err = sqlite3_step(statement); + if (err != SQLITE_DONE) { + sqlite3_finalize(statement); + spdlog::warn("Failed insertion Sqlite: {}", sqlite3_errstr(err)); + spdlog::dump_backtrace(); + sqlite3_finalize(statement); + return {}; + } + + sqlite3_finalize(statement); + return get_song(name, artist, spotify_id); +} + +bool songdb::insert_vote(int64_t user, int64_t group, int value, int64_t songid) { + spdlog::debug("insert_vote"); + auto song = get_song(songid); + if (!song) { + spdlog::error("Failed to add vote, couldnt find song id: {}", songid); + return true; + } + + int64_t list = get_song_list_id(group); + + + /* find existing vote */ + std::string ins_query = "INSERT OR REPLACE INTO votes (song, list, user, rating) " + "VALUES (?, ?, ?, ?);"; + + sqlite3_stmt *statement; + int err; + err = sqlite3_prepare_v2(db, ins_query.c_str(), ins_query.length(), &statement, NULL); + check_error(err); + err = sqlite3_bind_int64(statement, 1, song->id); + check_error(err); + err = sqlite3_bind_int64(statement, 2, list); + check_error(err); + err = sqlite3_bind_int64(statement, 3, user); + check_error(err); + err= sqlite3_bind_int64(statement, 4, value); + check_error(err); + + err = sqlite3_step(statement); + if (err != SQLITE_DONE) { + check_error(err); + sqlite3_finalize(statement); + return true; + } + + sqlite3_finalize(statement); + return false; +} + +std::vector songdb::get_votes_list(int64_t song_list) { + std::string query = "SELECT * FROM votes WHERE list = ?"; + + sqlite3_stmt *statement; + int err; + err = sqlite3_prepare_v2(db, query.c_str(), query.length(), &statement, NULL); + check_error(err); + err = sqlite3_bind_int64(statement, 1, song_list); + check_error(err); + + + std::vector votes; + + err = sqlite3_step(statement); + while (err == SQLITE_ROW) { + int song = sqlite3_column_int(statement, 0); + int list = sqlite3_column_int(statement, 1); + int user = sqlite3_column_int(statement, 2); + int value= sqlite3_column_int(statement, 3); + + votes.push_back({song, list,user,value}); + err = sqlite3_step(statement); + } + + if (err != SQLITE_DONE) { + check_error(err); + } + + sqlite3_finalize(statement); + return votes; +}; + +songdb::base_weight_vector songdb::get_base_weights (int64_t song_list) { + spdlog::debug("{} {}", __PRETTY_FUNCTION__, __LINE__); + std::vector list = get_votes_list(song_list); + std::set chat_members; + + // {song, {user, vote}} + std::map> vote_info {}; + + for (auto v : list) { + chat_members.insert(v.user); + + vote_info[v.song] = {{v.user, v.value}}; + } + + for (auto v : vote_info) { + // Insert zero values for members who have not voted + for (auto m : chat_members) { + if (!v.second.count(m)) { + v.second.insert({m, 0.0}); + } + } + + // normalise weightings + double total = 0; + for (auto m : v.second) { + total += m.second; + } + + for (auto m : v.second) { + v.second[m.first] = m.second / total; + } + } + + if (vote_info.size() == 0) { + return {}; + } + + /* turn it into a nice easy to use vector + * Relying on the fact maps are sorted and things will always be in the + * same order. + */ + base_weight_vector v; + + for (auto song : vote_info) { + v.song_order.push_back(song.first); + } + + auto a = vote_info.begin(); + for (auto user : a->second) { + v.person_order.push_back(user.first); + } + + for (auto song: vote_info) { + std::vector user_votes; + for (auto user : song.second) { + user_votes.push_back(user.second); + } + + assert(user_votes.size() == v.person_order.size()); + + v.weights.push_back(user_votes); + } + + return v; +} + +double songdb::dot_product(const std::vector &a, const std::vector &b) { + assert(a.size() == b.size()); + + double dot = 0; + + for (int i = 0; i < a.size(); i++) { + dot += a.at(i) * b.at(i); + } + + double dot2 = sqrt(dot); + return dot2; +} + +double songdb::weight_badness_inner_product(const std::vector ¤t_badness, const std::vector &song_goodness) { + return dot_product(current_badness,song_goodness); +} + +std::vector songdb::update_badness(std::vector old_badness, std::vector song_goodness) { + auto new_badness = old_badness; + for (int i = 0; i < old_badness.size(); i++) { + new_badness[i] = new_badness[i] - song_goodness[i]; + } + + return new_badness; +} + +/** + * The returned base weight vector has the songs in the sorted order, + * the weights field is the badness vector used for the next song in the + * list. + * + * /param num: the number songs to return + */ +songdb::base_weight_vector songdb::get_top_songs(songdb::base_weight_vector input, std::vector starting_badness, int num) { + spdlog::debug("{} {}", __PRETTY_FUNCTION__, __LINE__); + + if (num > input.song_order.size()) { + num = input.song_order.size(); + } + + base_weight_vector result {}; + auto current_badness = starting_badness; + result.person_order = input.person_order; + + struct score { + int64_t song; + double score; + std::vector base_weight; + }; + + + // create scores vector + std::vector scores {}; + for (int i = 0; i < input.song_order.size(); i++) { + + scores.push_back({input.song_order.at(i), + 0, input.weights.at(i)}); + } + + for (int i = 0; i < num; i++) { + // Compute scores based on badness + for (int j = 0; j < scores.size(); j++) { + scores[j].score = weight_badness_inner_product(current_badness, scores[j].base_weight); + } + + // sort scores + std::sort(scores.rbegin(), scores.rend(), + [](const score &a, const score &b) + { + return a.score > b.score; + } + ); + + // chose the song with the best score + auto chosen = scores.at(scores.size() - 1); + result.song_order.push_back(chosen.song); + + // update badness vector + current_badness = update_badness(current_badness, chosen.base_weight); + result.weights.push_back(current_badness); + + // run algorithm again on the subset not containing the chosen + // score, with the updated badness vector + scores.pop_back(); + } + + return result; +} + +std::string songdb::get_top_5_songs(int64_t telegram_group) { + spdlog::debug("{} {}", __PRETTY_FUNCTION__, __LINE__); + int64_t song_list = get_song_list_id(telegram_group); + auto songs = get_base_weights(song_list); + + if (songs.weights.size() == 0) { + return {}; + } + + std::vector starting_badness {}; + for (int i = 0; i < songs.person_order.size(); i++) { + starting_badness.push_back(1); + } + + auto chosen = get_top_songs(songs, starting_badness, 5); + + std::string slist = "Top 5 Songs:\n\n"; + for (int i = 0; i < chosen.song_order.size(); i++) { + int64_t songid = chosen.song_order.at(i); + auto song = get_song(songid); + if (!song) { + throw std::runtime_error("Invalid state achieved."); + } + + slist += song->name; + slist += " "; + slist += song->artist; + slist += "\n"; + + } + return slist; +} + +std::vector songdb::generate_track_list(int64_t song_list) { + auto base_weights = get_base_weights(song_list); + + for (int i = 0; i < base_weights.song_order.size(); i++) { + spdlog::info("song {} nppl {}", base_weights.song_order[i], base_weights.weights.size()); + } + + std::vector retlist {}; + return retlist; + } + +songdb::songdb(std::string filepath): filepath(filepath) { + int err = sqlite3_open(filepath.c_str(), &db); + if (err) { + std::cout << "Failed to open database: " << sqlite3_errmsg(db); + exit(1); + } + + setup_tables(); +} + +songdb::~songdb () { + sqlite3_close(db); +} diff --git a/src/songdb.h b/src/songdb.h new file mode 100644 index 0000000..b96be78 --- /dev/null +++ b/src/songdb.h @@ -0,0 +1,70 @@ +#pragma once + +#include "sqlite3.h" +#include +#include +#include +#include +#include + +class songdb { + std::string filepath; + sqlite3 *db; + + public: + + struct runtime_vals { }; + runtime_vals runtime_data; + + protected: + + enum error_codes { + NOT_FOUND, + ALREADY_ADDED + }; + + static int callback(void *valmap, int argc, char **argv, char **azColName); + bool setup_tables(); + int check_error(int rc); + + public: + + struct track_entry { + int64_t id; + std::string name; + std::string artist; + std::optional spotify_id; + + }; + + struct vote { + int song; + int list; + int64_t user; + double value; + }; + + struct base_weight_vector { + std::vector person_order; + std::vector song_order; + std::vector> weights; + }; + + void create_new_list(int64_t group_id); + int64_t get_song_list_id(int64_t group_id); + std::optional get_song(int64_t id); + std::optional get_song(std::string name, std::string artist); + std::optional get_song(std::string name, std::string artist, std::optional spotify_id); + std::optional insert_song(std::string name, std::string artist, std::optional spotify_id); + bool insert_vote(int64_t user, int64_t group, int value, int64_t songid); + std::vector get_votes_list(int64_t song_list); + base_weight_vector get_base_weights (int64_t song_list); + double dot_product(const std::vector &a, const std::vector &b); + double weight_badness_inner_product(const std::vector ¤t_badness, const std::vector &song_goodness); + std::vector update_badness(std::vector old_badness, std::vector song_goodness); + base_weight_vector get_top_songs(base_weight_vector input, std::vector starting_badness, int num); + std::string get_top_5_songs(int64_t telegram_group); + std::vector generate_track_list(int64_t song_list); + songdb(std::string filepath); + ~songdb(); +}; diff --git a/src/spotify.cpp b/src/spotify.cpp new file mode 100644 index 0000000..591d0cd --- /dev/null +++ b/src/spotify.cpp @@ -0,0 +1,114 @@ +#include "spotify.h" +#include "cpr/cprtypes.h" +#include "cpr/parameters.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include "Base64.hpp" +#include + +using json = nlohmann::json; + +void spotify::verify_logged_in() { + /* + * + * Would be a good idea to just requrst a page to chek. but the /me + * endpoint is for logged in users i guess? and this iis jut an app. + * + cpr::Response r = cpr::Get(cpr::Url{API_NAME_BASE + "v1/me"},cpr::Header{{"Authorization", "Bearer " + access_token}}); + std::istringstream isj {r.text}; + json auth_response; + try { + isj >> auth_response; + } catch (json::exception & e) { + std::cout << e.id << r.text; + } + + if (auth_response.count("display_name")) { + std::cout << "Logged in as " << auth_response["display_name"] << std::endl << auth_response["href"] << std::endl; + } else { + std::cout << "Failed to log in" << std::endl << auth_response.dump(4) << std::endl; + } + */ +} + +spotify::spotify(std::string access_token) : access_token(access_token) { + auth_header = cpr::Header{{"Authorization", "Bearer " + access_token}}; + verify_logged_in(); +} + +spotify::spotify(std::string client_id, std::string client_secret) { + auto ascii_token = client_id + ":" + client_secret; + + size_t buf_length = base64::get_encoded_length(ascii_token.length()); + auto buf = std::make_unique(buf_length); + std::string auth_token; + + base64::encode((uint8_t *)(ascii_token.c_str()), ascii_token.length(), buf.get(), buf_length); + auth_token = std::string {(char *)buf.get(), buf_length}; + + cpr::Response r = cpr::Post(cpr::Url{"https://accounts.spotify.com/api/token"},cpr::Header{{"Authorization", "Basic " + auth_token}}, cpr::Parameters{{"grant_type", "client_credentials"}}); + + std::istringstream isj {r.text}; + + json auth_response; + if (r.status_code == 200) { + isj >> auth_response; + } else { + spdlog::error("Login error {} {}",r.status_code, r.status_line); + } + + + if (auth_response.count("access_token")) { + access_token = auth_response["access_token"]; + std::cout << "Successfully logged into spotify" << std::endl << "Access token: " << access_token << std::endl; + } else { + std::cout << "Unable to log into spotify:" << std::endl << auth_response.dump(4) << std::endl; + exit (1); + } + + auth_header = cpr::Header{{"Authorization", "Bearer " + access_token}}; + verify_logged_in(); +} + +std::optional spotify::get_track(std::string track_id) { + + auto r = cpr::Get(cpr::Url{API_NAME_BASE + "v1/tracks/" + track_id}, auth_header); + if (r.status_code == 200) { + std::istringstream isj {r.text}; + json info ; + try { + isj >> info; + } catch (json::exception & e) { + std::cout << "Json error" << std::endl; + } + return info; + } else { + std::cout << r.text << std::endl; + return {}; + } +} + +/* + * Parse track link from a spotify url like: + * + * https://open.spotify.com/track/4UO1pfxi5fDbxshrwwznJ2?si=BtN9Yn_JQXSHGUa4CEZKvQ&utm_source=copy-link + * + * + */ +std::optional spotify::track_id_from_link(std::string link) { + const std::string start = "spotify.com/track/"; + auto f = link.find(start); + if (f == std::string::npos) { + return {}; + } + + auto end = link.find("?", f); + auto begin = f + start.length(); + return link.substr(begin, end - begin); +} diff --git a/src/spotify.h b/src/spotify.h new file mode 100644 index 0000000..dc8a9a4 --- /dev/null +++ b/src/spotify.h @@ -0,0 +1,31 @@ +#pragma once + +#include "spotify.h" +#include "cpr/cprtypes.h" +#include "cpr/parameters.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include "Base64.hpp" +#include + +using json = nlohmann::json; + +class spotify { + + const std::string API_NAME_BASE = "https://api.spotify.com/"; + std::string access_token; + cpr::Header auth_header; + + public: + void verify_logged_in(); + spotify(std::string access_token); + spotify(std::string client_id, std::string client_secret); + std::optional get_track(std::string track_id) ; + std::optional track_id_from_link(std::string link); +}; diff --git a/src/telegram_bot.cpp b/src/telegram_bot.cpp new file mode 100644 index 0000000..6084a82 --- /dev/null +++ b/src/telegram_bot.cpp @@ -0,0 +1,184 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "util.h" +#include "spotify.h" +#include "songdb.h" + +using namespace TgBot; +using json = nlohmann::json; + +int main() { + spdlog::set_level(spdlog::level::info); // Set global log level to debug + spdlog::enable_backtrace(32); + + char *teletok = getenv("TELEGRAM_TOKEN"); + + // needed if file not exist + char *spotid = getenv("SPOTIFY_ID"); + char *spotsecret = getenv("SPOTIFY_SECRET"); + char *spotaccess_token = getenv("SPOTIFY_TOKEN"); + + if (!teletok) { + std::cout << "Need to set environment variable TELEGRAM_TOKEN" << std::endl; + exit(1); + } + + if (!spotaccess_token) { + if (!spotid) { + std::cout << "Need to set environment variable SPOTIFY_ID or SPOTIFY_TOKEN" << std::endl; + exit(1); + } + if (!spotsecret) { + std::cout << "Need to set environment variable SPOTIFY_SECRET of SPOTIFY_TOKEN" << std::endl; + exit(1); + } + } + + spotify *s; + if (spotaccess_token) + s = new spotify(spotaccess_token); + else + s = new spotify(spotid, spotsecret); + + signal(SIGINT, [](int s) { + spdlog::info("Shutting down..."); + exit(0); + }); + + songdb data {"test.db"}; + std::string teletoken {teletok}; + Bot bot(teletoken); + + InlineKeyboardMarkup::Ptr keyboard(new InlineKeyboardMarkup); + std::vector row0; + + for(int i = 0; i <= 4; i++) { + InlineKeyboardButton::Ptr btn(new InlineKeyboardButton); + btn->text = std::to_string(i); + btn->callbackData= std::to_string(i); + row0.push_back(btn); + } + + keyboard->inlineKeyboard.push_back(row0); + + bot.getEvents().onCallbackQuery([&bot, &keyboard, &data](CallbackQuery::Ptr query) { + if ((query->data == "1") || (query->data == "2") || (query->data == "3") || (query->data == "4") || (query->data == "0")) { + std::istringstream is {query->data}; + int value; + is >> value; + + std::string songidflag = "songid:"; + auto a = query->message->text.find(songidflag); + auto b = query->message->text.find("\n", a); + + if (a == std::string::npos || b == std::string::npos) { + spdlog::error("Parse songid"); + spdlog::dump_backtrace(); + return; + } + a += songidflag.length(); + + std::istringstream is2 {query->message->text.substr(a, b - a)}; + int64_t songid; + is2 >> songid; + auto song = data.get_song(songid); + if (!song) { + spdlog::error ("bad song id"); + } + + data.insert_vote(query->from->id, query->message->chat->id, value, songid); + } + }); + + bot.getEvents().onCommand("add", [&bot, &keyboard, &data, s](Message::Ptr message) { + std::string title; + std::string artist; + int songid; + + if (message->text.find("spotify.com") != std::string::npos) { + std::string link = util::trim_whitespace(message->text.substr(message->text.find("add") + 3)); + auto resp = s->track_id_from_link(link); + if (!resp) { + bot.getApi().sendMessage(message->chat->id, "Sorry, I don't understand that link."); + return; + } + + auto spot_resp = s->get_track(*resp); + + if (!spot_resp) { + bot.getApi().sendMessage(message->chat->id, "Sorry, I cannot find that track in spotify."); + return; + } + + json track_data = *spot_resp; + + title = track_data["name"]; + artist = track_data["artists"][0]["name"]; + auto song = data.insert_song(title, artist, *resp); + songid = song->id; + } else { + title = util::trim_whitespace(message->text.substr(message->text.find("add") + 3)); + artist = ""; + auto song = data.insert_song(title, artist, {}); + songid = song->id; + } + + std::string response = "Added song: " + title; + if (artist != "") + response += ", by " + artist; + + response += "\n\n"; + std::ostringstream os; + os << songid; + + response += "songid:" + os.str() + "\n\r\n\r"; + response += "Everyone, please rate how well you know this song /5"; + + bot.getApi().sendMessage(message->chat->id, response, false, 0, keyboard, "Markdown"); + }); + + bot.getEvents().onCommand("vote", [&bot](Message::Ptr message) { + bot.getApi().sendMessage(message->chat->id, "Hi!"); + }); + + + bot.getEvents().onCommand("start", [&bot, &data](Message::Ptr message) { + bot.getApi().sendMessage(message->chat->id, "Hi!"); + }); + + bot.getEvents().onCommand("list", [&bot, &data](Message::Ptr message) { + try { + std::string response = data.get_top_5_songs(message->chat->id); + bot.getApi().sendMessage(message->chat->id, response); + } catch (std::exception const &e) { + spdlog::error("exp: {}", e.what()); + spdlog::dump_backtrace(); + } + }); + + try { + printf("Bot username: %s\n", bot.getApi().getMe()->username.c_str()); + bot.getApi().deleteWebhook(); + TgLongPoll longPoll(bot); + while (true) { + printf("Long poll started\n"); + longPoll.start(); + } + } catch (std::exception& e) { + printf("error: %s\n", e.what()); + } + + return 0; +} + diff --git a/src/util.cpp b/src/util.cpp new file mode 100644 index 0000000..bc60f6b --- /dev/null +++ b/src/util.cpp @@ -0,0 +1,42 @@ +#include "util.h" +#include +#include +#include +#include +#include +#include +#include + +namespace util { + +std::string +trim_whitespace(std::string s) +{ + int ff = s.find_first_not_of(" \n\t"); + int ll = s.find_last_not_of(" \n\t"); + return s.substr(ff, ll - ff + 1); +} + + + +std::string +read_file(std::string const &fpath) +{ + std::ostringstream sstr; + std::ifstream in (fpath); + sstr << in.rdbuf(); + return sstr.str(); +} + +void write_file(std::string const &fpath, std::string const &content) { + std::fstream s; + s.open(fpath, std::ios_base::out); + if (!s.is_open()) { + std::cerr << "Error: failed to open file "<< fpath; + return; + } + s << content; + s.close(); +} + +}; diff --git a/src/util.h b/src/util.h new file mode 100644 index 0000000..9909afd --- /dev/null +++ b/src/util.h @@ -0,0 +1,9 @@ +#pragma once + +#include + +namespace util { + std::string trim_whitespace(std::string s); + std::string read_file(std::string const &fpath); + void write_file(std::string const &fpath, std::string const &content); +} diff --git a/telegram_bot.cpp b/telegram_bot.cpp deleted file mode 100644 index 452338b..0000000 --- a/telegram_bot.cpp +++ /dev/null @@ -1,1070 +0,0 @@ -#include "cpr/cprtypes.h" -#include "cpr/parameters.h" -#include "tgbot/net/BoostHttpOnlySslClient.h" -#include "tgbot/net/HttpParser.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "Base64.hpp" -#include -#include -#include -#include -#include -#include "sqlite3.h" -#include -#include -#include -#include - -using namespace TgBot; -using json = nlohmann::json; - - -namespace util { - -std::string -trim_whitespace(std::string s) -{ - int ff = s.find_first_not_of(" \n\t"); - int ll = s.find_last_not_of(" \n\t"); - return s.substr(ff, ll - ff + 1); -} - - - -std::string -read_file(std::string const &fpath) -{ - std::ostringstream sstr; - std::ifstream in (fpath); - sstr << in.rdbuf(); - return sstr.str(); -} - -void write_file(std::string const &fpath, std::string const &content) { - std::fstream s; - s.open(fpath, std::ios_base::out); - if (!s.is_open()) { - std::cerr << "Error: failed to open file "<< fpath; - return; - } - s << content; - s.close(); -} - -}; - -namespace kare { - -class spotify { - - const std::string API_NAME_BASE = "https://api.spotify.com/"; - std::string access_token; - cpr::Header auth_header; - - - public: - - void - verify_logged_in() - { - /* - * - * Would be a good idea to just requrst a page to chek. but the /me - * endpoint is for logged in users i guess? and this iis jut an app. - * - cpr::Response r = cpr::Get(cpr::Url{API_NAME_BASE + "v1/me"},cpr::Header{{"Authorization", "Bearer " + access_token}}); - std::istringstream isj {r.text}; - json auth_response; - try { - isj >> auth_response; - } catch (json::exception & e) { - std::cout << e.id << r.text; - } - - if (auth_response.count("display_name")) { - std::cout << "Logged in as " << auth_response["display_name"] << std::endl << auth_response["href"] << std::endl; - } else { - std::cout << "Failed to log in" << std::endl << auth_response.dump(4) << std::endl; - } - */ - } - - spotify(std::string access_token) : access_token(access_token) - { - auth_header = cpr::Header{{"Authorization", "Bearer " + access_token}}; - verify_logged_in(); - } - - spotify (std::string client_id, std::string client_secret) - { - auto ascii_token = client_id + ":" + client_secret; - - size_t buf_length = base64::get_encoded_length(ascii_token.length()); - auto buf = std::make_unique(buf_length); - std::string auth_token; - - base64::encode((uint8_t *)(ascii_token.c_str()), ascii_token.length(), buf.get(), buf_length); - auth_token = std::string {(char *)buf.get(), buf_length}; - - cpr::Response r = cpr::Post(cpr::Url{"https://accounts.spotify.com/api/token"},cpr::Header{{"Authorization", "Basic " + auth_token}}, cpr::Parameters{{"grant_type", "client_credentials"}}); - - std::istringstream isj {r.text}; - - json auth_response; - if (r.status_code == 200) { - isj >> auth_response; - } else { - spdlog::error("Login error {} {}",r.status_code, r.status_line); - } - - - if (auth_response.count("access_token")) { - access_token = auth_response["access_token"]; - std::cout << "Successfully logged into spotify" << std::endl << "Access token: " << access_token << std::endl; - } else { - std::cout << "Unable to log into spotify:" << std::endl << auth_response.dump(4) << std::endl; - exit (1); - } - - auth_header = cpr::Header{{"Authorization", "Bearer " + access_token}}; - verify_logged_in(); - } - - std::optional - get_track(std::string track_id) - { - - auto r = cpr::Get(cpr::Url{API_NAME_BASE + "v1/tracks/" + track_id}, auth_header); - if (r.status_code == 200) { - std::istringstream isj {r.text}; - json info ; - try { - isj >> info; - } catch (json::exception & e) { - std::cout << "Json error" << std::endl; - } - return info; - } else { - std::cout << r.text << std::endl; - return {}; - } - - } - - /* - * Parse track link from a spotify url like: - * - * https://open.spotify.com/track/4UO1pfxi5fDbxshrwwznJ2?si=BtN9Yn_JQXSHGUa4CEZKvQ&utm_source=copy-link - * - * - */ - std::optional - track_id_from_link(std::string link) - { - const std::string start = "spotify.com/track/"; - auto f = link.find(start); - if (f == std::string::npos) { - return {}; - } - - auto end = link.find("?", f); - auto begin = f + start.length(); - return link.substr(begin, end - begin); - } - -}; - - -class songdb { - /* need one table for every chat: keep it all in memory? */ - - std::string filepath; - sqlite3 *db; - - public: - - struct runtime_vals { - - }; - - runtime_vals runtime_data; - - protected: - - enum error_codes { - NOT_FOUND, - ALREADY_ADDED - }; - - static int - callback(void *valmap, int argc, char **argv, char **azColName) - { - std::map *values = (std::map *)valmap; - - for (int i = 0; i < argc; i++) { - values->insert({azColName[i], argv[i]}); - } - return 0; - } - - bool - setup_tables() - { - char *errmsg; - int err; - - /* - * songlists: [list id (autoincrement int), (int64) telegram group id, (string) list name ] - * - * songs : [songid int autoincrement, song name NOT NULL, song artist NULL, Song spotify ID (string) ] - * - * votes: [song id foreign key, list id foreign key, telegram user id, int rating ] - * - */ - - std::string create_songlist_table = - "CREATE TABLE IF NOT EXISTS songlists (" - "id INTEGER PRIMARY KEY AUTOINCREMENT, " - "groupid INT NOT NULL, " - "name VARCHAR(200) NOT NULL " - ");"; - - std::string create_votes_table = - "CREATE TABLE IF NOT EXISTS votes (" - "song INT NOT NULL, " - "list INT NOT NULL, " - "user INT NOT NULL, " - "rating NOT NULL, " - " FOREIGN KEY (song) REFERENCES tracks (id)," - " FOREIGN KEY (list) REFERENCES songlists (id)," - " PRIMARY KEY (song, list, user)" - ");"; - - std::string create_tracks_table = - "CREATE TABLE IF NOT EXISTS tracks (" - "id INTEGER PRIMARY KEY AUTOINCREMENT, " - "name VARCHAR(200) NOT NULL, " - "artist VARCHAR(200), " - "spotifyid VARCHAR(200) " - ");"; - - err = sqlite3_exec(db, create_songlist_table.c_str(), NULL, 0, &errmsg); - - if (err != SQLITE_OK) { - std::cout << "SQLite Error " << create_songlist_table << "\n" << errmsg << std::endl; - exit (1); - } - - err = sqlite3_exec(db, create_tracks_table.c_str(), NULL, 0, &errmsg); - - if (err != SQLITE_OK) { - std::cout << "SQLite Error: " << create_tracks_table << "\n" << errmsg << std::endl; - exit (1); - } - - - err = sqlite3_exec(db, create_votes_table.c_str(), NULL, 0, &errmsg); - - if (err != SQLITE_OK) { - std::cout << "SQLite Error: " << create_votes_table << "\n" << errmsg << std::endl; - exit (1); - } - - if (errmsg) - sqlite3_free(errmsg); - return false; - } - - int check_error(int rc) { - if (rc != SQLITE_OK) { - spdlog::error("SQLite: {}", sqlite3_errmsg(db)); - spdlog::dump_backtrace(); - exit(1); - } - return 0; - } - - - - public: - - void create_new_list(int64_t group_id) { - spdlog::debug("create_new_list {}", group_id); - - sqlite3_stmt *statement; - int err; - std::string insert_query = "INSERT INTO songlists (groupid, name) VALUES (? , \"default\");"; - err = sqlite3_prepare_v2(db, insert_query.c_str(), insert_query.length(), &statement, NULL); - check_error(err); - err = sqlite3_bind_int64(statement, 1, group_id); - check_error(err); - err = sqlite3_step(statement); - - if (err != SQLITE_DONE) { - check_error(err); - } - - sqlite3_finalize(statement); - - } - - /* Retrieve or create the song list id for a group. - * - * As is, song lists are unique to telegram groups - */ - int64_t get_song_list_id(int64_t group_id) { - spdlog::debug("{} {}", __PRETTY_FUNCTION__, __LINE__); - - std::string list_query = "SELECT * FROM songlists WHERE groupid = ?;"; - /* A constraint of this implementation is that every group can only have - * one list - */ - int err; - - - sqlite3_stmt *statement; - sqlite3_prepare_v2(db, list_query.c_str(), list_query.length(), &statement, NULL); - err = sqlite3_bind_int64(statement, 1, group_id); - check_error(err); - err = sqlite3_step(statement); - int list_id; - - if (err == SQLITE_ROW && sqlite3_column_count(statement)) { - - list_id = sqlite3_column_int(statement, 0); - - err = sqlite3_step(statement); - if (err != SQLITE_DONE) { - // should only ever be 1 entry - spdlog::error("Should only be one list per group?"); - spdlog::dump_backtrace(); - exit(1); - } - - - } else if (err == SQLITE_DONE) { - // create the list - sqlite3_finalize(statement); - create_new_list(group_id); - - err = sqlite3_prepare_v2(db, list_query.c_str(), list_query.length(), &statement, NULL); - check_error(err); - err = sqlite3_bind_int64(statement, 1, group_id); - check_error(err); - - err = sqlite3_step(statement); - - if (err == SQLITE_ROW){ - list_id = sqlite3_column_int64(statement, 0); - } - else if (err != SQLITE_DONE) { - - check_error(err); - } - } else { - check_error(err); - } - - - return list_id; - - } - - struct track_entry { - int64_t id; - std::string name; - std::string artist; - std::optional spotify_id; - - }; - - std::optional get_song(int64_t id) { - std::string check_exist = "SELECT * FROM tracks WHERE id = ?;"; - - int err; - sqlite3_stmt *statement; - err = sqlite3_prepare_v2(db, check_exist.c_str(), check_exist.length(), &statement, NULL); - check_error(err); - err = sqlite3_bind_int64(statement, 1, id); - check_error(err); - - err = sqlite3_step(statement); - if (err == SQLITE_ROW) { - - int64_t id = sqlite3_column_int64(statement, 0); - std::string name {(char *)sqlite3_column_text(statement, 1)}; - std::string artist {(char *)sqlite3_column_text(statement, 2)}; - char * spotid = (char *)sqlite3_column_text(statement, 3); - - if (spotid) { - track_entry e {id, name, artist, std::string(spotid)}; - sqlite3_finalize(statement); - return e; - } - - track_entry e {id, name, artist}; - sqlite3_finalize(statement); - return e; - } - sqlite3_finalize(statement); - return {}; - } - - std::optional get_song(std::string name, std::string artist) { - return get_song(name, artist, {}); - } - - std::optional get_song(std::string name, std::string artist, std::optional spotify_id) { - - spdlog::debug("get_song"); - - std::string check_exist_spot = "SELECT * FROM tracks WHERE spotifyid = ?;"; - std::string check_exist_nospot = "SELECT * FROM tracks WHERE name = ? AND artist = ?;"; - - int err; - sqlite3_stmt *statement; - - /* check whether the song exists */ - if (spotify_id) { - err = sqlite3_prepare_v2(db, check_exist_spot.c_str(), check_exist_spot.length(), &statement, NULL); - check_error(err); - - err = sqlite3_bind_text(statement, 1, spotify_id->c_str(), spotify_id->length(), NULL); - check_error(err); - } else { - err = sqlite3_prepare_v2(db, check_exist_nospot.c_str(), check_exist_nospot.length(), &statement, NULL); - check_error(err); - - err = sqlite3_bind_text(statement, 1, name.c_str(), name.length(), NULL); - check_error(err); - - err = sqlite3_bind_text(statement, 2, artist.c_str(), artist.length(), NULL); - check_error(err); - } - - err = sqlite3_step(statement); - if (err == SQLITE_ROW) { - - int64_t id = sqlite3_column_int64(statement, 0); - std::string name {(char *)sqlite3_column_text(statement, 1)}; - std::string artist {(char *)sqlite3_column_text(statement, 2)}; - char * spotid = (char *)sqlite3_column_text(statement, 3); - - if (spotid) { - track_entry e {id, name, artist, std::string(spotid)}; - sqlite3_finalize(statement); - return e; - } - track_entry e {id, name, artist}; - sqlite3_finalize(statement); - return e; - spdlog::info("Entry exists."); - } else if (err != SQLITE_DONE) { - check_error(err); - sqlite3_finalize(statement); - return {}; - } - - sqlite3_finalize(statement); - return {}; - - } - - std::optional insert_song(std::string name, std::string artist, std::optional spotify_id) { - - spdlog::debug("insert_song"); - - std::string ins_query_spot = "INSERT INTO tracks (name, artist, spotifyid) VALUES(?, ?, ?);"; - std::string ins_query_nospot = "INSERT INTO tracks (name, artist) VALUES(?, ?);"; - -// int list_id = get_song_list_id(group_id); - sqlite3_stmt *statement; - int err; - - auto e = get_song(name, artist, spotify_id); - if (e) { - return e; - } - - if (spotify_id) { - err = sqlite3_prepare_v2(db, ins_query_spot.c_str(), ins_query_spot.length(), &statement, NULL); - } else { - err = sqlite3_prepare_v2(db, ins_query_nospot.c_str(), ins_query_nospot.length(), &statement, NULL); - } - check_error(err); - - err = sqlite3_bind_text(statement, 1, name.c_str(), name.length(), NULL); - check_error(err); - err = sqlite3_bind_text(statement, 2, artist.c_str(), artist.length(), NULL); - check_error(err); - - if (spotify_id) { - err = sqlite3_bind_text(statement, 3, spotify_id->c_str(), spotify_id->length(), NULL); - check_error(err); - } - - err = sqlite3_step(statement); - if (err != SQLITE_DONE) { - sqlite3_finalize(statement); - spdlog::warn("Failed insertion Sqlite: {}", sqlite3_errstr(err)); - spdlog::dump_backtrace(); - sqlite3_finalize(statement); - return {}; - } - - sqlite3_finalize(statement); - return get_song(name, artist, spotify_id); - } - - bool - insert_vote(int64_t user, int64_t group, int value, int64_t songid) - { - spdlog::debug("insert_vote"); - auto song = get_song(songid); - if (!song) { - spdlog::error("Failed to add vote, couldnt find song id: {}", songid); - return true; - } - - int64_t list = get_song_list_id(group); - - - /* find existing vote */ - - - std::string ins_query = "INSERT OR REPLACE INTO votes (song, list, user, rating) " - "VALUES (?, ?, ?, ?);"; - - sqlite3_stmt *statement; - int err; - err = sqlite3_prepare_v2(db, ins_query.c_str(), ins_query.length(), &statement, NULL); - check_error(err); - err = sqlite3_bind_int64(statement, 1, song->id); - check_error(err); - err = sqlite3_bind_int64(statement, 2, list); - check_error(err); - err = sqlite3_bind_int64(statement, 3, user); - check_error(err); - err= sqlite3_bind_int64(statement, 4, value); - check_error(err); - - err = sqlite3_step(statement); - if (err != SQLITE_DONE) { - check_error(err); - sqlite3_finalize(statement); - return true; - } - - sqlite3_finalize(statement); - return false; - } - - - struct vote { - int song; - int list; - int64_t user; - double value; - }; - - std::vector - get_votes_list(int64_t song_list) - { - std::string query = "SELECT * FROM votes WHERE list = ?"; - - sqlite3_stmt *statement; - int err; - err = sqlite3_prepare_v2(db, query.c_str(), query.length(), &statement, NULL); - check_error(err); - err = sqlite3_bind_int64(statement, 1, song_list); - check_error(err); - - - std::vector votes; - - err = sqlite3_step(statement); - while (err == SQLITE_ROW) { - int song = sqlite3_column_int(statement, 0); - int list = sqlite3_column_int(statement, 1); - int user = sqlite3_column_int(statement, 2); - int value= sqlite3_column_int(statement, 3); - - votes.push_back({song, list,user,value}); - err = sqlite3_step(statement); - } - - if (err != SQLITE_DONE) { - check_error(err); - } - - sqlite3_finalize(statement); - return votes; - }; - - struct base_weight_vector { - std::vector person_order; - std::vector song_order; - std::vector> weights; - }; - - //std::map> - base_weight_vector - get_base_weights (int64_t song_list) - { - spdlog::debug("{} {}", __PRETTY_FUNCTION__, __LINE__); - std::vector list = get_votes_list(song_list); - std::set chat_members; - - // {song, {user, vote}} - std::map> vote_info {}; - - for (auto v : list) { - chat_members.insert(v.user); - - vote_info[v.song] = {{v.user, v.value}}; - } - - for (auto v : vote_info) { - // Insert zero values for members who have not voted - for (auto m : chat_members) { - if (!v.second.count(m)) { - v.second.insert({m, 0.0}); - } - } - - // normalise weightings - double total = 0; - for (auto m : v.second) { - total += m.second; - } - - for (auto m : v.second) { - v.second[m.first] = m.second / total; - } - } - - if (vote_info.size() == 0) { - return {}; - } - - /* turn it into a nice easy to use vector - * Relying on the fact maps are sorted and things will always be in the - * same order. - */ - base_weight_vector v; - - for (auto song : vote_info) { - v.song_order.push_back(song.first); - } - - auto a = vote_info.begin(); - for (auto user : a->second) { - v.person_order.push_back(user.first); - } - - for (auto song: vote_info) { - std::vector user_votes; - for (auto user : song.second) { - user_votes.push_back(user.second); - } - - assert(user_votes.size() == v.person_order.size()); - - v.weights.push_back(user_votes); - } - - return v; - - } - - - double - dot_product(const std::vector &a, const std::vector &b) - { - assert(a.size() == b.size()); - - double dot = 0; - - for (int i = 0; i < a.size(); i++) { - dot += a.at(i) * b.at(i); - } - - double dot2 = sqrt(dot); - return dot2; - } - - - double - weight_badness_inner_product(const std::vector ¤t_badness, const std::vector &song_goodness) - { - return dot_product(current_badness,song_goodness); - } - - std::vector - update_badness(std::vector old_badness, std::vector song_goodness) - { - auto new_badness = old_badness; - for (int i = 0; i < old_badness.size(); i++) { - new_badness[i] = new_badness[i] - song_goodness[i]; - } - - return new_badness; - } - - /** - * The returned base weight vector has the songs in the sorted order, - * the weights field is the badness vector used for the next song in the - * list. - * - * /param num: the number songs to return - */ - base_weight_vector - get_top_songs(base_weight_vector input, std::vector starting_badness, int num) - { - spdlog::debug("{} {}", __PRETTY_FUNCTION__, __LINE__); - - if (num > input.song_order.size()) { - num = input.song_order.size(); - } - - base_weight_vector result {}; - auto current_badness = starting_badness; - result.person_order = input.person_order; - - struct score { - int64_t song; - double score; - std::vector base_weight; - }; - - - // create scores vector - std::vector scores {}; - for (int i = 0; i < input.song_order.size(); i++) { - - scores.push_back({input.song_order.at(i), - 0, input.weights.at(i)}); - } - - for (int i = 0; i < num; i++) { - // Compute scores based on badness - for (int j = 0; j < scores.size(); j++) { - scores[j].score = weight_badness_inner_product(current_badness, scores[j].base_weight); - } - - // sort scores - std::sort(scores.rbegin(), scores.rend(), - [](const score &a, const score &b) - { - return a.score > b.score; - } - ); - - // chose the song with the best score - auto chosen = scores.at(scores.size() - 1); - result.song_order.push_back(chosen.song); - - // update badness vector - current_badness = update_badness(current_badness, chosen.base_weight); - result.weights.push_back(current_badness); - - // run algorithm again on the subset not containing the chosen - // score, with the updated badness vector - scores.pop_back(); - } - - return result; - } - - std::string get_top_5_songs(int64_t telegram_group) { - spdlog::debug("{} {}", __PRETTY_FUNCTION__, __LINE__); - int64_t song_list = get_song_list_id(telegram_group); - auto songs = get_base_weights(song_list); - - if (songs.weights.size() == 0) { - return {}; - } - - std::vector starting_badness {}; - for (int i = 0; i < songs.person_order.size(); i++) { - starting_badness.push_back(1); - } - - auto chosen = get_top_songs(songs, starting_badness, 5); - - std::string slist = "Top 5 Songs:\n\n"; - for (int i = 0; i < chosen.song_order.size(); i++) { - int64_t songid = chosen.song_order.at(i); - auto song = get_song(songid); - if (!song) { - throw std::runtime_error("Invalid state achieved."); - } - - slist += song->name; - slist += " "; - slist += song->artist; - slist += "\n"; - - } - return slist; - } - - - std::vector - generate_track_list(int64_t song_list) - { - auto base_weights = get_base_weights(song_list); - - for (int i = 0; i < base_weights.song_order.size(); i++) { - spdlog::info("song {} nppl {}", base_weights.song_order[i], base_weights.weights.size()); - } - - std::vector retlist {}; - return retlist; - } - - songdb (std::string filepath): filepath(filepath) { - int err = sqlite3_open(filepath.c_str(), &db); - if (err) { - std::cout << "Failed to open database: " << sqlite3_errmsg(db); - exit(1); - } - - setup_tables(); - } - - ~songdb () { - sqlite3_close(db); - } - - -}; -}; - -int main() { - - spdlog::set_level(spdlog::level::info); // Set global log level to debug - spdlog::enable_backtrace(32); - - char *teletok = getenv("TELEGRAM_TOKEN"); - - // needed if file not exist - char *spotid = getenv("SPOTIFY_ID"); - char *spotsecret = getenv("SPOTIFY_SECRET"); - char *spotaccess_token = getenv("SPOTIFY_TOKEN"); - - - if (!teletok) { - std::cout << "Need to set environment variable TELEGRAM_TOKEN" << std::endl; - exit(1); - } - - if (!spotaccess_token) { - if (!spotid) { - std::cout << "Need to set environment variable SPOTIFY_ID or SPOTIFY_TOKEN" << std::endl; - exit(1); - } - if (!spotsecret) { - std::cout << "Need to set environment variable SPOTIFY_SECRET of SPOTIFY_TOKEN" << std::endl; - exit(1); - } - } - - kare::songdb data {"test.db"}; - - std::string teletoken {teletok}; - - kare::spotify *s; - if (spotaccess_token) - s = new kare::spotify(spotaccess_token); - else - s = new kare::spotify(spotid, spotsecret); - - - signal(SIGINT, [](int s) { - spdlog::info("Shutting down..."); - exit(0); - }); - - - - - Bot bot(teletoken); - - - InlineKeyboardMarkup::Ptr keyboard(new InlineKeyboardMarkup); - std::vector row0; - - - InlineKeyboardButton::Ptr button5(new InlineKeyboardButton); - button5->text = "0"; - button5->callbackData= "0"; - row0.push_back(button5); - - InlineKeyboardButton::Ptr button1(new InlineKeyboardButton); - button1->text = "1"; - button1->callbackData = "1"; - row0.push_back(button1); - - InlineKeyboardButton::Ptr button2(new InlineKeyboardButton); - button2->text = "2"; - button2->callbackData = "2"; - row0.push_back(button2); - - InlineKeyboardButton::Ptr button3(new InlineKeyboardButton); - button3->text = "3"; - button3->callbackData = "3"; - row0.push_back(button3); - - InlineKeyboardButton::Ptr button4(new InlineKeyboardButton); - button4->text = "4"; - button4->callbackData = "4"; - row0.push_back(button4); - - - keyboard->inlineKeyboard.push_back(row0); - - bot.getEvents().onCallbackQuery([&bot, &keyboard, &data](CallbackQuery::Ptr query) { - - if ((query->data == "1") || (query->data == "2") || (query->data == "3") || (query->data == "4") || (query->data == "0")) { - - std::istringstream is {query->data}; - int value; - is >> value; - - std::string songidflag = "songid:"; - auto a = query->message->text.find(songidflag); - auto b = query->message->text.find("\n", a); - - if (a==std::string::npos || b==std::string::npos) { - spdlog::error("Parse songid"); - spdlog::dump_backtrace(); - return; - } - a += songidflag.length(); - - std::istringstream is2 {query->message->text.substr(a, b - a)}; - int64_t songid; - is2 >> songid; - auto song = data.get_song(songid); - if (!song) { - spdlog::error ("bad song id"); - } - - - data.insert_vote(query->from->id, query->message->chat->id, value, songid); - - - } - }); - - bot.getEvents().onCommand("add", [&bot, &keyboard, &data, s](Message::Ptr message) { - - std::string title; - std::string artist; - int songid; - - if (message->text.find("spotify.com") != std::string::npos) { - std::string link = util::trim_whitespace(message->text.substr(message->text.find("add") + 3)); - auto resp = s->track_id_from_link(link); - if (!resp) { - bot.getApi().sendMessage(message->chat->id, "Sorry, I don't understand that link."); - return; - } - - auto spot_resp = s->get_track(*resp); - - if (!spot_resp) { - bot.getApi().sendMessage(message->chat->id, "Sorry, I cannot find that track in spotify."); - return; - } - - json track_data = *spot_resp; - - title = track_data["name"]; - artist = track_data["artists"][0]["name"]; - auto song = data.insert_song(title, artist, *resp); - songid = song->id; - - } else { - title = util::trim_whitespace(message->text.substr(message->text.find("add") + 3)); - artist = ""; - auto song = data.insert_song(title, artist, {}); - songid = song->id; - } - - std::string response = "Added song: " + title; - - if (artist != "") - response += ", by " + artist; - - response += "\n\n"; - std::ostringstream os; - os << songid; - - response += "songid:" + os.str() + "\n\r\n\r"; - response += "Everyone, please rate how well you know this song /5"; - - bot.getApi().sendMessage(message->chat->id, response, false, 0, keyboard, "Markdown"); - }); - - bot.getEvents().onCommand("vote", [&bot](Message::Ptr message) { - bot.getApi().sendMessage(message->chat->id, "Hi!"); - }); - - - bot.getEvents().onCommand("start", [&bot, &data](Message::Ptr message) { - - - bot.getApi().sendMessage(message->chat->id, "Hi!"); - - }); - - bot.getEvents().onCommand("list", [&bot, &data](Message::Ptr message) { - - try { - - std::string response = data.get_top_5_songs(message->chat->id); - bot.getApi().sendMessage(message->chat->id, response); - } catch (std::exception const &e) { - spdlog::error("exp: {}", e.what()); - spdlog::dump_backtrace(); - } - - - }); - - - std::string * a = new std::string("hello world"); - - - try { - printf("Bot username: %s\n", bot.getApi().getMe()->username.c_str()); - bot.getApi().deleteWebhook(); - - TgLongPoll longPoll(bot); - while (true) { - printf("Long poll started\n"); - longPoll.start(); - } - } catch (std::exception& e) { - printf("error: %s\n", e.what()); - } - - return 0; -} -