diff --git a/.vscode/settings.json b/.vscode/settings.json index e47d9bb..868d3ca 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,5 +1,6 @@ { "rust-analyzer.linkedProjects": [ + "./Cargo.toml", "./Cargo.toml", "./Cargo.toml" ] diff --git a/Cargo.toml b/Cargo.toml index 6203242..8b8e01f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,9 +7,9 @@ edition = "2021" [dependencies] chrono = { version = "0.4.24", features = ["serde"] } -dotenv = "0.15.0" +dotenvy = "0.15.7" rocket = { version = "0.5.0-rc.2", features = ["json"] } -tokio = { version = "1.27.0", features = ["full"] } -tokio-postgres = { version = "0.7.8", features = ["with-chrono-0_4"] } -tokio-test = "0.4.2" -serde = "1.0.154" \ No newline at end of file +serde = "1.0.154" +sqlx = { version = "0.6", features = [ "runtime-tokio-rustls", "postgres", "chrono" ] } +#tokio = { version = "1.27.0", features = ["full"] } +tokio-test = "0.4.2" \ No newline at end of file diff --git a/src/database/mod.rs b/src/database/mod.rs index ed45c38..5d25e54 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -1,148 +1,34 @@ +use dotenvy::dotenv; +use sqlx::{postgres::PgPoolOptions, PgPool}; use std::{env, error::Error}; -use dotenv::dotenv; -use tokio_postgres::{Client, NoTls}; - -use self::models::map::Map; - -mod models; +pub mod models; pub struct DatabaseHandler { - pub client: Client, + pub pool: PgPool, } impl DatabaseHandler { - pub async fn create() -> Result> { + pub async fn create() -> Result> { // Load the env file - dotenv()?; + dotenv().ok(); // Load in the environment variables - let connection_host = env::var("DB_HOST")?; - let connection_user = env::var("DB_USER")?; - let connection_db = env::var("DB_NAME")?; + let db_url = &env::var("DATABASE_URL")?; - let connection_string = format!( - "host={} user={} dbname={}", - connection_host, connection_user, connection_db - ); + let pool = PgPoolOptions::new() + .max_connections(10) + .connect(db_url) + .await?; - // Connect to the database - let (client, connection) = tokio_postgres::connect(&connection_string, NoTls).await?; + // Sanity check + let row: (i64,) = sqlx::query_as("SELECT $1") + .bind(150_i64) + .fetch_one(&pool) + .await?; - // NOTE: This comes directly from the official documentation (https://docs.rs/tokio-postgres/latest/tokio_postgres/#example) - // I hace no idea what in the flying flop it means... - tokio::spawn(async move { - if let Err(e) = connection.await { - eprintln!("Error while connecting to the database: {}", e); - } - }); + assert_eq!(row.0, 150); - Ok(DatabaseHandler { client }) - } - - pub async fn get_all_maps(&self) -> Result, Box> { - // FIXME: Why does this have to be mutable? Should fix. - let mut results = self - .client - .query( - " - SELECT record_maps.map, type, points, stars, mapper, release, width, height, - CONCAT_WS(',', - CASE WHEN DEATH = '1' THEN 'DEATH' END, - CASE WHEN THROUGH = '1' THEN 'THROUGH' END, - CASE WHEN JUMP = '1' THEN 'JUMP' END, - CASE WHEN DFREEZE = '1' THEN 'DFREEZE' END, - CASE WHEN EHOOK_START = '1' THEN 'EHOOK_START' END, - CASE WHEN HIT_END = '1' THEN 'HIT_END' END, - CASE WHEN SOLO_START = '1' THEN 'SOLO_START' END, - CASE WHEN TELE_GUN = '1' THEN 'TELE_GUN' END, - CASE WHEN TELE_GRENADE = '1' THEN 'TELE_GRENADE' END, - CASE WHEN TELE_LASER = '1' THEN 'TELE_LASER' END, - CASE WHEN NPC_START = '1' THEN 'NPC_START' END, - CASE WHEN SUPER_START = '1' THEN 'SUPER_START' END, - CASE WHEN JETPACK_START = '1' THEN 'JETPACK_START' END, - CASE WHEN WALLJUMP = '1' THEN 'WALLJUMP' END, - CASE WHEN NPH_START = '1' THEN 'NPH_START' END, - CASE WHEN WEAPON_SHOTGUN = '1' THEN 'WEAPON_SHOTGUN' END, - CASE WHEN WEAPON_GRENADE = '1' THEN 'WEAPON_GRENADE' END, - CASE WHEN POWERUP_NINJA = '1' THEN 'POWERUP_NINJA' END, - CASE WHEN WEAPON_RIFLE = '1' THEN 'WEAPON_RIFLE' END, - CASE WHEN LASER_STOP = '1' THEN 'LASER_STOP' END, - CASE WHEN CRAZY_SHOTGUN = '1' THEN 'CRAZY_SHOTGUN' END, - CASE WHEN DRAGGER = '1' THEN 'DRAGGER' END, - CASE WHEN DOOR = '1' THEN 'DOOR' END, - CASE WHEN SWITCH_TIMED = '1' THEN 'SWITCH_TIMED' END, - CASE WHEN SWITCH = '1' THEN 'SWITCH' END, - CASE WHEN STOP = '1' THEN 'STOP' END, - CASE WHEN THROUGH_ALL = '1' THEN 'THROUGH_ALL' END, - CASE WHEN TUNE = '1' THEN 'TUNE' END, - CASE WHEN OLDLASER = '1' THEN 'OLDLASER' END, - CASE WHEN TELEINEVIL = '1' THEN 'TELEINEVIL' END, - CASE WHEN TELEIN = '1' THEN 'TELEIN' END, - CASE WHEN TELECHECK = '1' THEN 'TELECHECK' END, - CASE WHEN TELEINWEAPON = '1' THEN 'TELEINWEAPON' END, - CASE WHEN TELEINHOOK = '1' THEN 'TELEINHOOK' END, - CASE WHEN CHECKPOINT_FIRST = '1' THEN 'CHECKPOINT_FIRST' END, - CASE WHEN BONUS = '1' THEN 'BONUS' END, - CASE WHEN BOOST = '1' THEN 'BOOST' END, - CASE WHEN PLASMAF = '1' THEN 'PLASMAF' END, - CASE WHEN PLASMAE = '1' THEN 'PLASMAE' END, - CASE WHEN PLASMAU = '1' THEN 'PLASMAU' END) AS tiles - FROM record_maps JOIN record_mapinfo ON record_maps.map = record_mapinfo.map - ", - &[], - ) - .await? - .into_iter() - .map(|row| Map::from_db_row(&row)); - - // If the result has errors, return it. Otherwise, return all the rows. - match results.find(|row| row.is_err()) { - Some(row) => row.map(|row| vec![row]), - None => Ok(results.map(|row| row.unwrap()).collect()), - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_database_connection() { - async fn test() { - let db = match DatabaseHandler::create().await { - Ok(db) => db, - Err(err) => panic!("Could not get a client!\n{}", err), - }; - - let msg = "Hello World!"; - let rows = match db.client.query("SELECT $1::TEXT", &[&msg]).await { - Ok(rows) => rows, - Err(err) => panic!("Could not create query!\n{}", err), - }; - - let value: &str = rows[0].get(0); - assert_eq!(value, msg) - } - - tokio_test::block_on(test()) - } - - #[test] - fn test_get_all_maps() { - async fn test() { - let db = match DatabaseHandler::create().await { - Ok(db) => db, - Err(err) => panic!("Could not get a client!\n{}", err), - }; - - match db.get_all_maps().await { - Ok(maps) => println!("Found maps: {:?}", maps.len()), - Err(err) => panic!("Could not get all maps!\n{}", err), - }; - } - - tokio_test::block_on(test()) + Ok(DatabaseHandler { pool }) } } diff --git a/src/database/models/map.rs b/src/database/models/map.rs index 74fa569..3e53116 100644 --- a/src/database/models/map.rs +++ b/src/database/models/map.rs @@ -1,43 +1,171 @@ -use std::error::Error; - use chrono::NaiveDateTime; +use rocket::futures::StreamExt; use serde::{Deserialize, Serialize}; -use tokio_postgres::Row; -#[derive(Debug, Clone, Serialize, Deserialize)] +use crate::database::DatabaseHandler; + +#[derive(Debug, Clone, sqlx::FromRow, Serialize, Deserialize)] pub struct Map { - map: String, - mapper: String, - category: String, - points: u8, - stars: u8, - release: Option, - width: u16, - height: u16, - tiles: Vec, + pub map: String, + pub mapper: String, + pub category: String, + pub points: i16, + pub stars: i16, + pub release: Option, + pub width: Option, + pub height: Option, + pub tiles: Option>, } impl Map { - pub fn from_db_row(db_row: &Row) -> Result> { - let map: String = db_row.try_get(0)?; - let category: String = db_row.try_get(1)?; - let points_i16: i16 = db_row.try_get(2)?; - let points: u8 = points_i16 as u8; - let stars_i16: i16 = db_row.try_get(3)?; - let stars: u8 = stars_i16 as u8; - let mapper: String = db_row.try_get(4)?; - let release: Option = db_row.try_get(5)?; + pub async fn get_all_maps(db: &DatabaseHandler) -> Result, sqlx::Error> { + sqlx::query_as!( + Map, + " + SELECT record_maps.map, category, points, stars, mapper, release, width, height, + string_to_array(CONCAT_WS(',', + CASE WHEN DEATH = '1' THEN 'DEATH' END, + CASE WHEN THROUGH = '1' THEN 'THROUGH' END, + CASE WHEN JUMP = '1' THEN 'JUMP' END, + CASE WHEN DFREEZE = '1' THEN 'DFREEZE' END, + CASE WHEN EHOOK_START = '1' THEN 'EHOOK_START' END, + CASE WHEN HIT_END = '1' THEN 'HIT_END' END, + CASE WHEN SOLO_START = '1' THEN 'SOLO_START' END, + CASE WHEN TELE_GUN = '1' THEN 'TELE_GUN' END, + CASE WHEN TELE_GRENADE = '1' THEN 'TELE_GRENADE' END, + CASE WHEN TELE_LASER = '1' THEN 'TELE_LASER' END, + CASE WHEN NPC_START = '1' THEN 'NPC_START' END, + CASE WHEN SUPER_START = '1' THEN 'SUPER_START' END, + CASE WHEN JETPACK_START = '1' THEN 'JETPACK_START' END, + CASE WHEN WALLJUMP = '1' THEN 'WALLJUMP' END, + CASE WHEN NPH_START = '1' THEN 'NPH_START' END, + CASE WHEN WEAPON_SHOTGUN = '1' THEN 'WEAPON_SHOTGUN' END, + CASE WHEN WEAPON_GRENADE = '1' THEN 'WEAPON_GRENADE' END, + CASE WHEN POWERUP_NINJA = '1' THEN 'POWERUP_NINJA' END, + CASE WHEN WEAPON_RIFLE = '1' THEN 'WEAPON_RIFLE' END, + CASE WHEN LASER_STOP = '1' THEN 'LASER_STOP' END, + CASE WHEN CRAZY_SHOTGUN = '1' THEN 'CRAZY_SHOTGUN' END, + CASE WHEN DRAGGER = '1' THEN 'DRAGGER' END, + CASE WHEN DOOR = '1' THEN 'DOOR' END, + CASE WHEN SWITCH_TIMED = '1' THEN 'SWITCH_TIMED' END, + CASE WHEN SWITCH = '1' THEN 'SWITCH' END, + CASE WHEN STOP = '1' THEN 'STOP' END, + CASE WHEN THROUGH_ALL = '1' THEN 'THROUGH_ALL' END, + CASE WHEN TUNE = '1' THEN 'TUNE' END, + CASE WHEN OLDLASER = '1' THEN 'OLDLASER' END, + CASE WHEN TELEINEVIL = '1' THEN 'TELEINEVIL' END, + CASE WHEN TELEIN = '1' THEN 'TELEIN' END, + CASE WHEN TELECHECK = '1' THEN 'TELECHECK' END, + CASE WHEN TELEINWEAPON = '1' THEN 'TELEINWEAPON' END, + CASE WHEN TELEINHOOK = '1' THEN 'TELEINHOOK' END, + CASE WHEN CHECKPOINT_FIRST = '1' THEN 'CHECKPOINT_FIRST' END, + CASE WHEN BONUS = '1' THEN 'BONUS' END, + CASE WHEN BOOST = '1' THEN 'BOOST' END, + CASE WHEN PLASMAF = '1' THEN 'PLASMAF' END, + CASE WHEN PLASMAE = '1' THEN 'PLASMAE' END, + CASE WHEN PLASMAU = '1' THEN 'PLASMAU' END), ',') AS tiles + FROM record_maps JOIN record_mapinfo ON record_maps.map = record_mapinfo.map; + " + ) + .fetch_all(&db.pool) + .await + } - Ok(Map { - map, - mapper, - category, - points, - stars, - release, - width: 1, - height: 1, - tiles: vec![], - }) + pub async fn get_map_by_name(db: &DatabaseHandler, map: &str) -> Result { + sqlx::query_as!( + Map, + " + SELECT record_maps.map, category, points, stars, mapper, release, width, height, + string_to_array(CONCAT_WS(',', + CASE WHEN DEATH = '1' THEN 'DEATH' END, + CASE WHEN THROUGH = '1' THEN 'THROUGH' END, + CASE WHEN JUMP = '1' THEN 'JUMP' END, + CASE WHEN DFREEZE = '1' THEN 'DFREEZE' END, + CASE WHEN EHOOK_START = '1' THEN 'EHOOK_START' END, + CASE WHEN HIT_END = '1' THEN 'HIT_END' END, + CASE WHEN SOLO_START = '1' THEN 'SOLO_START' END, + CASE WHEN TELE_GUN = '1' THEN 'TELE_GUN' END, + CASE WHEN TELE_GRENADE = '1' THEN 'TELE_GRENADE' END, + CASE WHEN TELE_LASER = '1' THEN 'TELE_LASER' END, + CASE WHEN NPC_START = '1' THEN 'NPC_START' END, + CASE WHEN SUPER_START = '1' THEN 'SUPER_START' END, + CASE WHEN JETPACK_START = '1' THEN 'JETPACK_START' END, + CASE WHEN WALLJUMP = '1' THEN 'WALLJUMP' END, + CASE WHEN NPH_START = '1' THEN 'NPH_START' END, + CASE WHEN WEAPON_SHOTGUN = '1' THEN 'WEAPON_SHOTGUN' END, + CASE WHEN WEAPON_GRENADE = '1' THEN 'WEAPON_GRENADE' END, + CASE WHEN POWERUP_NINJA = '1' THEN 'POWERUP_NINJA' END, + CASE WHEN WEAPON_RIFLE = '1' THEN 'WEAPON_RIFLE' END, + CASE WHEN LASER_STOP = '1' THEN 'LASER_STOP' END, + CASE WHEN CRAZY_SHOTGUN = '1' THEN 'CRAZY_SHOTGUN' END, + CASE WHEN DRAGGER = '1' THEN 'DRAGGER' END, + CASE WHEN DOOR = '1' THEN 'DOOR' END, + CASE WHEN SWITCH_TIMED = '1' THEN 'SWITCH_TIMED' END, + CASE WHEN SWITCH = '1' THEN 'SWITCH' END, + CASE WHEN STOP = '1' THEN 'STOP' END, + CASE WHEN THROUGH_ALL = '1' THEN 'THROUGH_ALL' END, + CASE WHEN TUNE = '1' THEN 'TUNE' END, + CASE WHEN OLDLASER = '1' THEN 'OLDLASER' END, + CASE WHEN TELEINEVIL = '1' THEN 'TELEINEVIL' END, + CASE WHEN TELEIN = '1' THEN 'TELEIN' END, + CASE WHEN TELECHECK = '1' THEN 'TELECHECK' END, + CASE WHEN TELEINWEAPON = '1' THEN 'TELEINWEAPON' END, + CASE WHEN TELEINHOOK = '1' THEN 'TELEINHOOK' END, + CASE WHEN CHECKPOINT_FIRST = '1' THEN 'CHECKPOINT_FIRST' END, + CASE WHEN BONUS = '1' THEN 'BONUS' END, + CASE WHEN BOOST = '1' THEN 'BOOST' END, + CASE WHEN PLASMAF = '1' THEN 'PLASMAF' END, + CASE WHEN PLASMAE = '1' THEN 'PLASMAE' END, + CASE WHEN PLASMAU = '1' THEN 'PLASMAU' END), ',') AS tiles + FROM record_maps JOIN record_mapinfo ON record_maps.map = record_mapinfo.map + WHERE record_maps.map = $1; + ", + map + ) + .fetch_one(&db.pool) + .await + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_get_all_maps() { + async fn test() { + let db = match DatabaseHandler::create().await { + Ok(db) => db, + Err(err) => panic!("Error while connecting to database! {:?}", err), + }; + let maps = match Map::get_all_maps(&db).await { + Ok(maps) => maps, + Err(err) => panic!("Error while getting all maps! {:?}", err), + }; + + for map in maps { + println!("{:?}", map); + } + } + + tokio_test::block_on(test()); + } + + #[test] + fn test_get_map_by_name() { + async fn test() { + let db = match DatabaseHandler::create().await { + Ok(db) => db, + Err(err) => panic!("Error while connecting to database! {:?}", err), + }; + let map = match Map::get_map_by_name(&db, "Kobra").await { + Ok(map) => map, + Err(err) => panic!("Error while getting all maps! {:?}", err), + }; + + println!("{:?}", map); + } + + tokio_test::block_on(test()); } } diff --git a/src/main.rs b/src/main.rs index 0449c22..685bd84 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,7 +2,6 @@ use database::DatabaseHandler; #[macro_use] extern crate rocket; -extern crate dotenv; mod database;