You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
602 lines
17 KiB
602 lines
17 KiB
#include "songdb.h" |
|
#include <cstdio> |
|
#include <cstdlib> |
|
#include <iostream> |
|
#include <istream> |
|
#include <sstream> |
|
#include <string> |
|
#include <optional> |
|
#include <set> |
|
#include <map> |
|
#include <vector> |
|
#include <memory> |
|
#include <cmath> |
|
#include <spdlog/spdlog.h> |
|
#include "sqlite3.h" |
|
|
|
int songdb::callback(void *valmap, int argc, char **argv, char **azColName) { |
|
std::map<std::string, std::string> *values = (std::map<std::string, std::string> *)valmap; |
|
|
|
for (int i = 0; i < argc; i++) { |
|
values->insert({azColName[i], argv[i]}); |
|
} |
|
return 0; |
|
} |
|
|
|
bool songdb::setup_tables() { |
|
char *errmsg; |
|
int err; |
|
|
|
/* |
|
* songlists: [list id (autoincrement int), (int64) telegram group id, (string) list name ] |
|
* |
|
* songs : [songid int autoincrement, song name NOT NULL, song artist NULL, Song spotify ID (string) ] |
|
* |
|
* votes: [song id foreign key, list id foreign key, telegram user id, int rating ] |
|
* |
|
*/ |
|
|
|
std::string create_songlist_table = |
|
"CREATE TABLE IF NOT EXISTS songlists (" |
|
"id INTEGER PRIMARY KEY AUTOINCREMENT, " |
|
"groupid INT NOT NULL, " |
|
"name VARCHAR(200) NOT NULL " |
|
");"; |
|
|
|
std::string create_votes_table = |
|
"CREATE TABLE IF NOT EXISTS votes (" |
|
"song INT NOT NULL, " |
|
"list INT NOT NULL, " |
|
"user INT NOT NULL, " |
|
"rating NOT NULL, " |
|
" FOREIGN KEY (song) REFERENCES tracks (id)," |
|
" FOREIGN KEY (list) REFERENCES songlists (id)," |
|
" PRIMARY KEY (song, list, user)" |
|
");"; |
|
|
|
std::string create_tracks_table = |
|
"CREATE TABLE IF NOT EXISTS tracks (" |
|
"id INTEGER PRIMARY KEY AUTOINCREMENT, " |
|
"name VARCHAR(200) NOT NULL, " |
|
"artist VARCHAR(200), " |
|
"spotifyid VARCHAR(200) " |
|
");"; |
|
|
|
err = sqlite3_exec(db, create_songlist_table.c_str(), NULL, 0, &errmsg); |
|
|
|
if (err != SQLITE_OK) { |
|
std::cout << "SQLite Error " << create_songlist_table << "\n" << errmsg << std::endl; |
|
exit (1); |
|
} |
|
|
|
err = sqlite3_exec(db, create_tracks_table.c_str(), NULL, 0, &errmsg); |
|
|
|
if (err != SQLITE_OK) { |
|
std::cout << "SQLite Error: " << create_tracks_table << "\n" << errmsg << std::endl; |
|
exit (1); |
|
} |
|
|
|
|
|
err = sqlite3_exec(db, create_votes_table.c_str(), NULL, 0, &errmsg); |
|
|
|
if (err != SQLITE_OK) { |
|
std::cout << "SQLite Error: " << create_votes_table << "\n" << errmsg << std::endl; |
|
exit (1); |
|
} |
|
|
|
if (errmsg) |
|
sqlite3_free(errmsg); |
|
return false; |
|
} |
|
|
|
int songdb::check_error(int rc) { |
|
if (rc != SQLITE_OK) { |
|
spdlog::error("SQLite: {}", sqlite3_errmsg(db)); |
|
spdlog::dump_backtrace(); |
|
exit(1); |
|
} |
|
return 0; |
|
} |
|
|
|
void songdb::create_new_list(int64_t group_id) { |
|
spdlog::debug("create_new_list {}", group_id); |
|
|
|
sqlite3_stmt *statement; |
|
int err; |
|
std::string insert_query = "INSERT INTO songlists (groupid, name) VALUES (? , \"default\");"; |
|
err = sqlite3_prepare_v2(db, insert_query.c_str(), insert_query.length(), &statement, NULL); |
|
check_error(err); |
|
err = sqlite3_bind_int64(statement, 1, group_id); |
|
check_error(err); |
|
err = sqlite3_step(statement); |
|
|
|
if (err != SQLITE_DONE) { |
|
check_error(err); |
|
} |
|
|
|
sqlite3_finalize(statement); |
|
|
|
} |
|
|
|
/* Retrieve or create the song list id for a group. |
|
* |
|
* As is, song lists are unique to telegram groups |
|
*/ |
|
int64_t songdb::get_song_list_id(int64_t group_id) { |
|
spdlog::debug("{} {}", __PRETTY_FUNCTION__, __LINE__); |
|
|
|
std::string list_query = "SELECT * FROM songlists WHERE groupid = ?;"; |
|
/* A constraint of this implementation is that every group can only have |
|
* one list |
|
*/ |
|
int err; |
|
|
|
|
|
sqlite3_stmt *statement; |
|
sqlite3_prepare_v2(db, list_query.c_str(), list_query.length(), &statement, NULL); |
|
err = sqlite3_bind_int64(statement, 1, group_id); |
|
check_error(err); |
|
err = sqlite3_step(statement); |
|
int list_id; |
|
|
|
if (err == SQLITE_ROW && sqlite3_column_count(statement)) { |
|
|
|
list_id = sqlite3_column_int(statement, 0); |
|
|
|
err = sqlite3_step(statement); |
|
if (err != SQLITE_DONE) { |
|
// should only ever be 1 entry |
|
spdlog::error("Should only be one list per group?"); |
|
spdlog::dump_backtrace(); |
|
exit(1); |
|
} |
|
|
|
|
|
} else if (err == SQLITE_DONE) { |
|
// create the list |
|
sqlite3_finalize(statement); |
|
create_new_list(group_id); |
|
|
|
err = sqlite3_prepare_v2(db, list_query.c_str(), list_query.length(), &statement, NULL); |
|
check_error(err); |
|
err = sqlite3_bind_int64(statement, 1, group_id); |
|
check_error(err); |
|
|
|
err = sqlite3_step(statement); |
|
|
|
if (err == SQLITE_ROW){ |
|
list_id = sqlite3_column_int64(statement, 0); |
|
} |
|
else if (err != SQLITE_DONE) { |
|
|
|
check_error(err); |
|
} |
|
} else { |
|
check_error(err); |
|
} |
|
|
|
|
|
return list_id; |
|
|
|
} |
|
|
|
std::optional<songdb::track_entry> songdb::get_song(int64_t id) { |
|
std::string check_exist = "SELECT * FROM tracks WHERE id = ?;"; |
|
|
|
int err; |
|
sqlite3_stmt *statement; |
|
err = sqlite3_prepare_v2(db, check_exist.c_str(), check_exist.length(), &statement, NULL); |
|
check_error(err); |
|
err = sqlite3_bind_int64(statement, 1, id); |
|
check_error(err); |
|
|
|
err = sqlite3_step(statement); |
|
if (err == SQLITE_ROW) { |
|
|
|
int64_t id = sqlite3_column_int64(statement, 0); |
|
std::string name {(char *)sqlite3_column_text(statement, 1)}; |
|
std::string artist {(char *)sqlite3_column_text(statement, 2)}; |
|
char * spotid = (char *)sqlite3_column_text(statement, 3); |
|
|
|
if (spotid) { |
|
track_entry e {id, name, artist, std::string(spotid)}; |
|
sqlite3_finalize(statement); |
|
return e; |
|
} |
|
|
|
track_entry e {id, name, artist}; |
|
sqlite3_finalize(statement); |
|
return e; |
|
} |
|
sqlite3_finalize(statement); |
|
return {}; |
|
} |
|
|
|
std::optional<songdb::track_entry> songdb::get_song(std::string name, std::string artist) { |
|
return get_song(name, artist, {}); |
|
} |
|
|
|
std::optional<songdb::track_entry> songdb::get_song(std::string name, std::string artist, std::optional<std::string> spotify_id) { |
|
|
|
spdlog::debug("get_song"); |
|
|
|
std::string check_exist_spot = "SELECT * FROM tracks WHERE spotifyid = ?;"; |
|
std::string check_exist_nospot = "SELECT * FROM tracks WHERE name = ? AND artist = ?;"; |
|
|
|
int err; |
|
sqlite3_stmt *statement; |
|
|
|
/* check whether the song exists */ |
|
if (spotify_id) { |
|
err = sqlite3_prepare_v2(db, check_exist_spot.c_str(), check_exist_spot.length(), &statement, NULL); |
|
check_error(err); |
|
|
|
err = sqlite3_bind_text(statement, 1, spotify_id->c_str(), spotify_id->length(), NULL); |
|
check_error(err); |
|
} else { |
|
err = sqlite3_prepare_v2(db, check_exist_nospot.c_str(), check_exist_nospot.length(), &statement, NULL); |
|
check_error(err); |
|
|
|
err = sqlite3_bind_text(statement, 1, name.c_str(), name.length(), NULL); |
|
check_error(err); |
|
|
|
err = sqlite3_bind_text(statement, 2, artist.c_str(), artist.length(), NULL); |
|
check_error(err); |
|
} |
|
|
|
err = sqlite3_step(statement); |
|
if (err == SQLITE_ROW) { |
|
|
|
int64_t id = sqlite3_column_int64(statement, 0); |
|
std::string name {(char *)sqlite3_column_text(statement, 1)}; |
|
std::string artist {(char *)sqlite3_column_text(statement, 2)}; |
|
char * spotid = (char *)sqlite3_column_text(statement, 3); |
|
|
|
if (spotid) { |
|
track_entry e {id, name, artist, std::string(spotid)}; |
|
sqlite3_finalize(statement); |
|
return e; |
|
} |
|
track_entry e {id, name, artist}; |
|
sqlite3_finalize(statement); |
|
return e; |
|
spdlog::info("Entry exists."); |
|
} else if (err != SQLITE_DONE) { |
|
check_error(err); |
|
sqlite3_finalize(statement); |
|
return {}; |
|
} |
|
|
|
sqlite3_finalize(statement); |
|
return {}; |
|
} |
|
|
|
std::optional<songdb::track_entry> songdb::insert_song(std::string name, std::string artist, std::optional<std::string> spotify_id) { |
|
|
|
spdlog::debug("insert_song"); |
|
|
|
std::string ins_query_spot = "INSERT INTO tracks (name, artist, spotifyid) VALUES(?, ?, ?);"; |
|
std::string ins_query_nospot = "INSERT INTO tracks (name, artist) VALUES(?, ?);"; |
|
|
|
// int list_id = get_song_list_id(group_id); |
|
sqlite3_stmt *statement; |
|
int err; |
|
|
|
auto e = get_song(name, artist, spotify_id); |
|
if (e) { |
|
return e; |
|
} |
|
|
|
if (spotify_id) { |
|
err = sqlite3_prepare_v2(db, ins_query_spot.c_str(), ins_query_spot.length(), &statement, NULL); |
|
} else { |
|
err = sqlite3_prepare_v2(db, ins_query_nospot.c_str(), ins_query_nospot.length(), &statement, NULL); |
|
} |
|
check_error(err); |
|
|
|
err = sqlite3_bind_text(statement, 1, name.c_str(), name.length(), NULL); |
|
check_error(err); |
|
err = sqlite3_bind_text(statement, 2, artist.c_str(), artist.length(), NULL); |
|
check_error(err); |
|
|
|
if (spotify_id) { |
|
err = sqlite3_bind_text(statement, 3, spotify_id->c_str(), spotify_id->length(), NULL); |
|
check_error(err); |
|
} |
|
|
|
err = sqlite3_step(statement); |
|
if (err != SQLITE_DONE) { |
|
sqlite3_finalize(statement); |
|
spdlog::warn("Failed insertion Sqlite: {}", sqlite3_errstr(err)); |
|
spdlog::dump_backtrace(); |
|
sqlite3_finalize(statement); |
|
return {}; |
|
} |
|
|
|
sqlite3_finalize(statement); |
|
return get_song(name, artist, spotify_id); |
|
} |
|
|
|
bool songdb::insert_vote(int64_t user, int64_t group, int value, int64_t songid) { |
|
spdlog::debug("insert_vote"); |
|
auto song = get_song(songid); |
|
if (!song) { |
|
spdlog::error("Failed to add vote, couldnt find song id: {}", songid); |
|
return true; |
|
} |
|
|
|
int64_t list = get_song_list_id(group); |
|
|
|
|
|
/* find existing vote */ |
|
std::string ins_query = "INSERT OR REPLACE INTO votes (song, list, user, rating) " |
|
"VALUES (?, ?, ?, ?);"; |
|
|
|
sqlite3_stmt *statement; |
|
int err; |
|
err = sqlite3_prepare_v2(db, ins_query.c_str(), ins_query.length(), &statement, NULL); |
|
check_error(err); |
|
err = sqlite3_bind_int64(statement, 1, song->id); |
|
check_error(err); |
|
err = sqlite3_bind_int64(statement, 2, list); |
|
check_error(err); |
|
err = sqlite3_bind_int64(statement, 3, user); |
|
check_error(err); |
|
err= sqlite3_bind_int64(statement, 4, value); |
|
check_error(err); |
|
|
|
err = sqlite3_step(statement); |
|
if (err != SQLITE_DONE) { |
|
check_error(err); |
|
sqlite3_finalize(statement); |
|
return true; |
|
} |
|
|
|
sqlite3_finalize(statement); |
|
return false; |
|
} |
|
|
|
std::vector<songdb::vote> songdb::get_votes_list(int64_t song_list) { |
|
std::string query = "SELECT * FROM votes WHERE list = ?"; |
|
|
|
sqlite3_stmt *statement; |
|
int err; |
|
err = sqlite3_prepare_v2(db, query.c_str(), query.length(), &statement, NULL); |
|
check_error(err); |
|
err = sqlite3_bind_int64(statement, 1, song_list); |
|
check_error(err); |
|
|
|
|
|
std::vector<vote> votes; |
|
|
|
err = sqlite3_step(statement); |
|
while (err == SQLITE_ROW) { |
|
int song = sqlite3_column_int(statement, 0); |
|
int list = sqlite3_column_int(statement, 1); |
|
int user = sqlite3_column_int(statement, 2); |
|
int value= sqlite3_column_int(statement, 3); |
|
|
|
votes.push_back({song, list,user,value}); |
|
err = sqlite3_step(statement); |
|
} |
|
|
|
if (err != SQLITE_DONE) { |
|
check_error(err); |
|
} |
|
|
|
sqlite3_finalize(statement); |
|
return votes; |
|
}; |
|
|
|
songdb::base_weight_vector songdb::get_base_weights (int64_t song_list) { |
|
spdlog::debug("{} {}", __PRETTY_FUNCTION__, __LINE__); |
|
std::vector<vote> list = get_votes_list(song_list); |
|
std::set<int64_t> chat_members; |
|
|
|
// {song, {user, vote}} |
|
std::map<int, std::map<int64_t, double>> vote_info {}; |
|
|
|
for (auto v : list) { |
|
chat_members.insert(v.user); |
|
|
|
vote_info[v.song] = {{v.user, v.value}}; |
|
} |
|
|
|
for (auto v : vote_info) { |
|
// Insert zero values for members who have not voted |
|
for (auto m : chat_members) { |
|
if (!v.second.count(m)) { |
|
v.second.insert({m, 0.0}); |
|
} |
|
} |
|
|
|
// normalise weightings |
|
double total = 0; |
|
for (auto m : v.second) { |
|
total += m.second; |
|
} |
|
|
|
for (auto m : v.second) { |
|
v.second[m.first] = m.second / total; |
|
} |
|
} |
|
|
|
if (vote_info.size() == 0) { |
|
return {}; |
|
} |
|
|
|
/* turn it into a nice easy to use vector |
|
* Relying on the fact maps are sorted and things will always be in the |
|
* same order. |
|
*/ |
|
base_weight_vector v; |
|
|
|
for (auto song : vote_info) { |
|
v.song_order.push_back(song.first); |
|
} |
|
|
|
auto a = vote_info.begin(); |
|
for (auto user : a->second) { |
|
v.person_order.push_back(user.first); |
|
} |
|
|
|
for (auto song: vote_info) { |
|
std::vector<double> user_votes; |
|
for (auto user : song.second) { |
|
user_votes.push_back(user.second); |
|
} |
|
|
|
assert(user_votes.size() == v.person_order.size()); |
|
|
|
v.weights.push_back(user_votes); |
|
} |
|
|
|
return v; |
|
} |
|
|
|
double songdb::dot_product(const std::vector<double> &a, const std::vector<double> &b) { |
|
assert(a.size() == b.size()); |
|
|
|
double dot = 0; |
|
|
|
for (int i = 0; i < a.size(); i++) { |
|
dot += a.at(i) * b.at(i); |
|
} |
|
|
|
double dot2 = sqrt(dot); |
|
return dot2; |
|
} |
|
|
|
double songdb::weight_badness_inner_product(const std::vector<double> ¤t_badness, const std::vector<double> &song_goodness) { |
|
return dot_product(current_badness,song_goodness); |
|
} |
|
|
|
std::vector<double> songdb::update_badness(std::vector<double> old_badness, std::vector<double> song_goodness) { |
|
auto new_badness = old_badness; |
|
for (int i = 0; i < old_badness.size(); i++) { |
|
new_badness[i] = new_badness[i] - song_goodness[i]; |
|
} |
|
|
|
return new_badness; |
|
} |
|
|
|
/** |
|
* The returned base weight vector has the songs in the sorted order, |
|
* the weights field is the badness vector used for the next song in the |
|
* list. |
|
* |
|
* /param num: the number songs to return |
|
*/ |
|
songdb::base_weight_vector songdb::get_top_songs(songdb::base_weight_vector input, std::vector<double> starting_badness, int num) { |
|
spdlog::debug("{} {}", __PRETTY_FUNCTION__, __LINE__); |
|
|
|
if (num > input.song_order.size()) { |
|
num = input.song_order.size(); |
|
} |
|
|
|
base_weight_vector result {}; |
|
auto current_badness = starting_badness; |
|
result.person_order = input.person_order; |
|
|
|
struct score { |
|
int64_t song; |
|
double score; |
|
std::vector<double> base_weight; |
|
}; |
|
|
|
|
|
// create scores vector |
|
std::vector<score> scores {}; |
|
for (int i = 0; i < input.song_order.size(); i++) { |
|
|
|
scores.push_back({input.song_order.at(i), |
|
0, input.weights.at(i)}); |
|
} |
|
|
|
for (int i = 0; i < num; i++) { |
|
// Compute scores based on badness |
|
for (int j = 0; j < scores.size(); j++) { |
|
scores[j].score = weight_badness_inner_product(current_badness, scores[j].base_weight); |
|
} |
|
|
|
// sort scores |
|
std::sort(scores.rbegin(), scores.rend(), |
|
[](const score &a, const score &b) |
|
{ |
|
return a.score > b.score; |
|
} |
|
); |
|
|
|
// chose the song with the best score |
|
auto chosen = scores.at(scores.size() - 1); |
|
result.song_order.push_back(chosen.song); |
|
|
|
// update badness vector |
|
current_badness = update_badness(current_badness, chosen.base_weight); |
|
result.weights.push_back(current_badness); |
|
|
|
// run algorithm again on the subset not containing the chosen |
|
// score, with the updated badness vector |
|
scores.pop_back(); |
|
} |
|
|
|
return result; |
|
} |
|
|
|
std::string songdb::get_top_5_songs(int64_t telegram_group) { |
|
spdlog::debug("{} {}", __PRETTY_FUNCTION__, __LINE__); |
|
int64_t song_list = get_song_list_id(telegram_group); |
|
auto songs = get_base_weights(song_list); |
|
|
|
if (songs.weights.size() == 0) { |
|
return {}; |
|
} |
|
|
|
std::vector<double> starting_badness {}; |
|
for (int i = 0; i < songs.person_order.size(); i++) { |
|
starting_badness.push_back(1); |
|
} |
|
|
|
auto chosen = get_top_songs(songs, starting_badness, 5); |
|
|
|
std::string slist = "Top 5 Songs:\n\n"; |
|
for (int i = 0; i < chosen.song_order.size(); i++) { |
|
int64_t songid = chosen.song_order.at(i); |
|
auto song = get_song(songid); |
|
if (!song) { |
|
throw std::runtime_error("Invalid state achieved."); |
|
} |
|
|
|
slist += song->name; |
|
slist += " "; |
|
slist += song->artist; |
|
slist += "\n"; |
|
|
|
} |
|
return slist; |
|
} |
|
|
|
std::vector<songdb::track_entry> songdb::generate_track_list(int64_t song_list) { |
|
auto base_weights = get_base_weights(song_list); |
|
|
|
for (int i = 0; i < base_weights.song_order.size(); i++) { |
|
spdlog::info("song {} nppl {}", base_weights.song_order[i], base_weights.weights.size()); |
|
} |
|
|
|
std::vector<track_entry> retlist {}; |
|
return retlist; |
|
} |
|
|
|
songdb::songdb(std::string filepath): filepath(filepath) { |
|
int err = sqlite3_open(filepath.c_str(), &db); |
|
if (err) { |
|
std::cout << "Failed to open database: " << sqlite3_errmsg(db); |
|
exit(1); |
|
} |
|
|
|
setup_tables(); |
|
} |
|
|
|
songdb::~songdb () { |
|
sqlite3_close(db); |
|
}
|
|
|