Moved to SQLx <<<3333

main
BurnyLlama 2023-04-13 21:23:24 +02:00
parent 2dcfb991b4
commit 9be403d4cc
5 changed files with 185 additions and 171 deletions

View File

@ -1,5 +1,6 @@
{
"rust-analyzer.linkedProjects": [
"./Cargo.toml",
"./Cargo.toml",
"./Cargo.toml"
]

View File

@ -7,9 +7,9 @@ edition = "2021"
[dependencies]
chrono = { version = "0.4.24", features = ["serde"] }
dotenv = "0.15.0"
dotenvy = "0.15.7"
rocket = { version = "0.5.0-rc.2", features = ["json"] }
tokio = { version = "1.27.0", features = ["full"] }
tokio-postgres = { version = "0.7.8", features = ["with-chrono-0_4"] }
tokio-test = "0.4.2"
serde = "1.0.154"
serde = "1.0.154"
sqlx = { version = "0.6", features = [ "runtime-tokio-rustls", "postgres", "chrono" ] }
#tokio = { version = "1.27.0", features = ["full"] }
tokio-test = "0.4.2"

View File

@ -1,148 +1,34 @@
use dotenvy::dotenv;
use sqlx::{postgres::PgPoolOptions, PgPool};
use std::{env, error::Error};
use dotenv::dotenv;
use tokio_postgres::{Client, NoTls};
use self::models::map::Map;
mod models;
pub mod models;
pub struct DatabaseHandler {
pub client: Client,
pub pool: PgPool,
}
impl DatabaseHandler {
pub async fn create() -> Result<DatabaseHandler, Box<dyn Error>> {
pub async fn create() -> Result<Self, Box<dyn Error>> {
// Load the env file
dotenv()?;
dotenv().ok();
// Load in the environment variables
let connection_host = env::var("DB_HOST")?;
let connection_user = env::var("DB_USER")?;
let connection_db = env::var("DB_NAME")?;
let db_url = &env::var("DATABASE_URL")?;
let connection_string = format!(
"host={} user={} dbname={}",
connection_host, connection_user, connection_db
);
let pool = PgPoolOptions::new()
.max_connections(10)
.connect(db_url)
.await?;
// Connect to the database
let (client, connection) = tokio_postgres::connect(&connection_string, NoTls).await?;
// Sanity check
let row: (i64,) = sqlx::query_as("SELECT $1")
.bind(150_i64)
.fetch_one(&pool)
.await?;
// NOTE: This comes directly from the official documentation (https://docs.rs/tokio-postgres/latest/tokio_postgres/#example)
// I hace no idea what in the flying flop it means...
tokio::spawn(async move {
if let Err(e) = connection.await {
eprintln!("Error while connecting to the database: {}", e);
}
});
assert_eq!(row.0, 150);
Ok(DatabaseHandler { client })
}
pub async fn get_all_maps(&self) -> Result<Vec<Map>, Box<dyn Error>> {
// FIXME: Why does this have to be mutable? Should fix.
let mut results = self
.client
.query(
"
SELECT record_maps.map, type, points, stars, mapper, release, width, height,
CONCAT_WS(',',
CASE WHEN DEATH = '1' THEN 'DEATH' END,
CASE WHEN THROUGH = '1' THEN 'THROUGH' END,
CASE WHEN JUMP = '1' THEN 'JUMP' END,
CASE WHEN DFREEZE = '1' THEN 'DFREEZE' END,
CASE WHEN EHOOK_START = '1' THEN 'EHOOK_START' END,
CASE WHEN HIT_END = '1' THEN 'HIT_END' END,
CASE WHEN SOLO_START = '1' THEN 'SOLO_START' END,
CASE WHEN TELE_GUN = '1' THEN 'TELE_GUN' END,
CASE WHEN TELE_GRENADE = '1' THEN 'TELE_GRENADE' END,
CASE WHEN TELE_LASER = '1' THEN 'TELE_LASER' END,
CASE WHEN NPC_START = '1' THEN 'NPC_START' END,
CASE WHEN SUPER_START = '1' THEN 'SUPER_START' END,
CASE WHEN JETPACK_START = '1' THEN 'JETPACK_START' END,
CASE WHEN WALLJUMP = '1' THEN 'WALLJUMP' END,
CASE WHEN NPH_START = '1' THEN 'NPH_START' END,
CASE WHEN WEAPON_SHOTGUN = '1' THEN 'WEAPON_SHOTGUN' END,
CASE WHEN WEAPON_GRENADE = '1' THEN 'WEAPON_GRENADE' END,
CASE WHEN POWERUP_NINJA = '1' THEN 'POWERUP_NINJA' END,
CASE WHEN WEAPON_RIFLE = '1' THEN 'WEAPON_RIFLE' END,
CASE WHEN LASER_STOP = '1' THEN 'LASER_STOP' END,
CASE WHEN CRAZY_SHOTGUN = '1' THEN 'CRAZY_SHOTGUN' END,
CASE WHEN DRAGGER = '1' THEN 'DRAGGER' END,
CASE WHEN DOOR = '1' THEN 'DOOR' END,
CASE WHEN SWITCH_TIMED = '1' THEN 'SWITCH_TIMED' END,
CASE WHEN SWITCH = '1' THEN 'SWITCH' END,
CASE WHEN STOP = '1' THEN 'STOP' END,
CASE WHEN THROUGH_ALL = '1' THEN 'THROUGH_ALL' END,
CASE WHEN TUNE = '1' THEN 'TUNE' END,
CASE WHEN OLDLASER = '1' THEN 'OLDLASER' END,
CASE WHEN TELEINEVIL = '1' THEN 'TELEINEVIL' END,
CASE WHEN TELEIN = '1' THEN 'TELEIN' END,
CASE WHEN TELECHECK = '1' THEN 'TELECHECK' END,
CASE WHEN TELEINWEAPON = '1' THEN 'TELEINWEAPON' END,
CASE WHEN TELEINHOOK = '1' THEN 'TELEINHOOK' END,
CASE WHEN CHECKPOINT_FIRST = '1' THEN 'CHECKPOINT_FIRST' END,
CASE WHEN BONUS = '1' THEN 'BONUS' END,
CASE WHEN BOOST = '1' THEN 'BOOST' END,
CASE WHEN PLASMAF = '1' THEN 'PLASMAF' END,
CASE WHEN PLASMAE = '1' THEN 'PLASMAE' END,
CASE WHEN PLASMAU = '1' THEN 'PLASMAU' END) AS tiles
FROM record_maps JOIN record_mapinfo ON record_maps.map = record_mapinfo.map
",
&[],
)
.await?
.into_iter()
.map(|row| Map::from_db_row(&row));
// If the result has errors, return it. Otherwise, return all the rows.
match results.find(|row| row.is_err()) {
Some(row) => row.map(|row| vec![row]),
None => Ok(results.map(|row| row.unwrap()).collect()),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_database_connection() {
async fn test() {
let db = match DatabaseHandler::create().await {
Ok(db) => db,
Err(err) => panic!("Could not get a client!\n{}", err),
};
let msg = "Hello World!";
let rows = match db.client.query("SELECT $1::TEXT", &[&msg]).await {
Ok(rows) => rows,
Err(err) => panic!("Could not create query!\n{}", err),
};
let value: &str = rows[0].get(0);
assert_eq!(value, msg)
}
tokio_test::block_on(test())
}
#[test]
fn test_get_all_maps() {
async fn test() {
let db = match DatabaseHandler::create().await {
Ok(db) => db,
Err(err) => panic!("Could not get a client!\n{}", err),
};
match db.get_all_maps().await {
Ok(maps) => println!("Found maps: {:?}", maps.len()),
Err(err) => panic!("Could not get all maps!\n{}", err),
};
}
tokio_test::block_on(test())
Ok(DatabaseHandler { pool })
}
}

View File

@ -1,43 +1,171 @@
use std::error::Error;
use chrono::NaiveDateTime;
use rocket::futures::StreamExt;
use serde::{Deserialize, Serialize};
use tokio_postgres::Row;
#[derive(Debug, Clone, Serialize, Deserialize)]
use crate::database::DatabaseHandler;
#[derive(Debug, Clone, sqlx::FromRow, Serialize, Deserialize)]
pub struct Map {
map: String,
mapper: String,
category: String,
points: u8,
stars: u8,
release: Option<NaiveDateTime>,
width: u16,
height: u16,
tiles: Vec<String>,
pub map: String,
pub mapper: String,
pub category: String,
pub points: i16,
pub stars: i16,
pub release: Option<NaiveDateTime>,
pub width: Option<i16>,
pub height: Option<i16>,
pub tiles: Option<Vec<String>>,
}
impl Map {
pub fn from_db_row(db_row: &Row) -> Result<Self, Box<dyn Error>> {
let map: String = db_row.try_get(0)?;
let category: String = db_row.try_get(1)?;
let points_i16: i16 = db_row.try_get(2)?;
let points: u8 = points_i16 as u8;
let stars_i16: i16 = db_row.try_get(3)?;
let stars: u8 = stars_i16 as u8;
let mapper: String = db_row.try_get(4)?;
let release: Option<NaiveDateTime> = db_row.try_get(5)?;
pub async fn get_all_maps(db: &DatabaseHandler) -> Result<Vec<Map>, sqlx::Error> {
sqlx::query_as!(
Map,
"
SELECT record_maps.map, category, points, stars, mapper, release, width, height,
string_to_array(CONCAT_WS(',',
CASE WHEN DEATH = '1' THEN 'DEATH' END,
CASE WHEN THROUGH = '1' THEN 'THROUGH' END,
CASE WHEN JUMP = '1' THEN 'JUMP' END,
CASE WHEN DFREEZE = '1' THEN 'DFREEZE' END,
CASE WHEN EHOOK_START = '1' THEN 'EHOOK_START' END,
CASE WHEN HIT_END = '1' THEN 'HIT_END' END,
CASE WHEN SOLO_START = '1' THEN 'SOLO_START' END,
CASE WHEN TELE_GUN = '1' THEN 'TELE_GUN' END,
CASE WHEN TELE_GRENADE = '1' THEN 'TELE_GRENADE' END,
CASE WHEN TELE_LASER = '1' THEN 'TELE_LASER' END,
CASE WHEN NPC_START = '1' THEN 'NPC_START' END,
CASE WHEN SUPER_START = '1' THEN 'SUPER_START' END,
CASE WHEN JETPACK_START = '1' THEN 'JETPACK_START' END,
CASE WHEN WALLJUMP = '1' THEN 'WALLJUMP' END,
CASE WHEN NPH_START = '1' THEN 'NPH_START' END,
CASE WHEN WEAPON_SHOTGUN = '1' THEN 'WEAPON_SHOTGUN' END,
CASE WHEN WEAPON_GRENADE = '1' THEN 'WEAPON_GRENADE' END,
CASE WHEN POWERUP_NINJA = '1' THEN 'POWERUP_NINJA' END,
CASE WHEN WEAPON_RIFLE = '1' THEN 'WEAPON_RIFLE' END,
CASE WHEN LASER_STOP = '1' THEN 'LASER_STOP' END,
CASE WHEN CRAZY_SHOTGUN = '1' THEN 'CRAZY_SHOTGUN' END,
CASE WHEN DRAGGER = '1' THEN 'DRAGGER' END,
CASE WHEN DOOR = '1' THEN 'DOOR' END,
CASE WHEN SWITCH_TIMED = '1' THEN 'SWITCH_TIMED' END,
CASE WHEN SWITCH = '1' THEN 'SWITCH' END,
CASE WHEN STOP = '1' THEN 'STOP' END,
CASE WHEN THROUGH_ALL = '1' THEN 'THROUGH_ALL' END,
CASE WHEN TUNE = '1' THEN 'TUNE' END,
CASE WHEN OLDLASER = '1' THEN 'OLDLASER' END,
CASE WHEN TELEINEVIL = '1' THEN 'TELEINEVIL' END,
CASE WHEN TELEIN = '1' THEN 'TELEIN' END,
CASE WHEN TELECHECK = '1' THEN 'TELECHECK' END,
CASE WHEN TELEINWEAPON = '1' THEN 'TELEINWEAPON' END,
CASE WHEN TELEINHOOK = '1' THEN 'TELEINHOOK' END,
CASE WHEN CHECKPOINT_FIRST = '1' THEN 'CHECKPOINT_FIRST' END,
CASE WHEN BONUS = '1' THEN 'BONUS' END,
CASE WHEN BOOST = '1' THEN 'BOOST' END,
CASE WHEN PLASMAF = '1' THEN 'PLASMAF' END,
CASE WHEN PLASMAE = '1' THEN 'PLASMAE' END,
CASE WHEN PLASMAU = '1' THEN 'PLASMAU' END), ',') AS tiles
FROM record_maps JOIN record_mapinfo ON record_maps.map = record_mapinfo.map;
"
)
.fetch_all(&db.pool)
.await
}
Ok(Map {
map,
mapper,
category,
points,
stars,
release,
width: 1,
height: 1,
tiles: vec![],
})
pub async fn get_map_by_name(db: &DatabaseHandler, map: &str) -> Result<Map, sqlx::Error> {
sqlx::query_as!(
Map,
"
SELECT record_maps.map, category, points, stars, mapper, release, width, height,
string_to_array(CONCAT_WS(',',
CASE WHEN DEATH = '1' THEN 'DEATH' END,
CASE WHEN THROUGH = '1' THEN 'THROUGH' END,
CASE WHEN JUMP = '1' THEN 'JUMP' END,
CASE WHEN DFREEZE = '1' THEN 'DFREEZE' END,
CASE WHEN EHOOK_START = '1' THEN 'EHOOK_START' END,
CASE WHEN HIT_END = '1' THEN 'HIT_END' END,
CASE WHEN SOLO_START = '1' THEN 'SOLO_START' END,
CASE WHEN TELE_GUN = '1' THEN 'TELE_GUN' END,
CASE WHEN TELE_GRENADE = '1' THEN 'TELE_GRENADE' END,
CASE WHEN TELE_LASER = '1' THEN 'TELE_LASER' END,
CASE WHEN NPC_START = '1' THEN 'NPC_START' END,
CASE WHEN SUPER_START = '1' THEN 'SUPER_START' END,
CASE WHEN JETPACK_START = '1' THEN 'JETPACK_START' END,
CASE WHEN WALLJUMP = '1' THEN 'WALLJUMP' END,
CASE WHEN NPH_START = '1' THEN 'NPH_START' END,
CASE WHEN WEAPON_SHOTGUN = '1' THEN 'WEAPON_SHOTGUN' END,
CASE WHEN WEAPON_GRENADE = '1' THEN 'WEAPON_GRENADE' END,
CASE WHEN POWERUP_NINJA = '1' THEN 'POWERUP_NINJA' END,
CASE WHEN WEAPON_RIFLE = '1' THEN 'WEAPON_RIFLE' END,
CASE WHEN LASER_STOP = '1' THEN 'LASER_STOP' END,
CASE WHEN CRAZY_SHOTGUN = '1' THEN 'CRAZY_SHOTGUN' END,
CASE WHEN DRAGGER = '1' THEN 'DRAGGER' END,
CASE WHEN DOOR = '1' THEN 'DOOR' END,
CASE WHEN SWITCH_TIMED = '1' THEN 'SWITCH_TIMED' END,
CASE WHEN SWITCH = '1' THEN 'SWITCH' END,
CASE WHEN STOP = '1' THEN 'STOP' END,
CASE WHEN THROUGH_ALL = '1' THEN 'THROUGH_ALL' END,
CASE WHEN TUNE = '1' THEN 'TUNE' END,
CASE WHEN OLDLASER = '1' THEN 'OLDLASER' END,
CASE WHEN TELEINEVIL = '1' THEN 'TELEINEVIL' END,
CASE WHEN TELEIN = '1' THEN 'TELEIN' END,
CASE WHEN TELECHECK = '1' THEN 'TELECHECK' END,
CASE WHEN TELEINWEAPON = '1' THEN 'TELEINWEAPON' END,
CASE WHEN TELEINHOOK = '1' THEN 'TELEINHOOK' END,
CASE WHEN CHECKPOINT_FIRST = '1' THEN 'CHECKPOINT_FIRST' END,
CASE WHEN BONUS = '1' THEN 'BONUS' END,
CASE WHEN BOOST = '1' THEN 'BOOST' END,
CASE WHEN PLASMAF = '1' THEN 'PLASMAF' END,
CASE WHEN PLASMAE = '1' THEN 'PLASMAE' END,
CASE WHEN PLASMAU = '1' THEN 'PLASMAU' END), ',') AS tiles
FROM record_maps JOIN record_mapinfo ON record_maps.map = record_mapinfo.map
WHERE record_maps.map = $1;
",
map
)
.fetch_one(&db.pool)
.await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_get_all_maps() {
async fn test() {
let db = match DatabaseHandler::create().await {
Ok(db) => db,
Err(err) => panic!("Error while connecting to database! {:?}", err),
};
let maps = match Map::get_all_maps(&db).await {
Ok(maps) => maps,
Err(err) => panic!("Error while getting all maps! {:?}", err),
};
for map in maps {
println!("{:?}", map);
}
}
tokio_test::block_on(test());
}
#[test]
fn test_get_map_by_name() {
async fn test() {
let db = match DatabaseHandler::create().await {
Ok(db) => db,
Err(err) => panic!("Error while connecting to database! {:?}", err),
};
let map = match Map::get_map_by_name(&db, "Kobra").await {
Ok(map) => map,
Err(err) => panic!("Error while getting all maps! {:?}", err),
};
println!("{:?}", map);
}
tokio_test::block_on(test());
}
}

View File

@ -2,7 +2,6 @@ use database::DatabaseHandler;
#[macro_use]
extern crate rocket;
extern crate dotenv;
mod database;