From b090d701aa168f4ebb7e1f1767564b9ef349cdda Mon Sep 17 00:00:00 2001 From: Shagnor Date: Sun, 18 Dec 2022 23:44:04 +1100 Subject: [PATCH] Refactor dynamic Result type + add session cleanup --- Cargo.lock | 1 + README.md | 2 +- blastmud_game/Cargo.toml | 1 + blastmud_game/src/av.rs | 9 +++-- blastmud_game/src/db.rs | 57 +++++++++++++++++++++++++--- blastmud_game/src/listener.rs | 12 +++--- blastmud_game/src/main.rs | 22 +++++++---- blastmud_game/src/message_handler.rs | 10 ++--- blastmud_game/src/version_cutover.rs | 11 +++--- 9 files changed, 91 insertions(+), 34 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 728ec58..58a5b6c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -54,6 +54,7 @@ version = "0.1.0" dependencies = [ "base64 0.20.0", "blastmud_interfaces", + "deadpool", "deadpool-postgres", "futures", "log", diff --git a/README.md b/README.md index ead77a0..4f1d90d 100644 --- a/README.md +++ b/README.md @@ -63,4 +63,4 @@ Create a user with a secret password, and username `blast`. Create a production To get to the latest schema: * Run `psql /tmp/update.sql` -* Check `/tmp/update.sql` and if it looks good, apply it with `psql -d blast Result<(), Box> { +pub fn check() -> DResult<()> { let av: AV = serde_yaml::from_str(&fs::read_to_string("age-verification.yml")?). - map_err(|error| Box::new(error) as Box)?; + map_err(|error| Box::new(error) as Box)?; if av.copyright != "This file is protected by copyright and may not be used or reproduced except as authorised by the copyright holder. All rights reserved." || av.assertion != "age>=18" { - Err(Box::::from("Invalid age-verification.yml"))?; + Err(Box::::from("Invalid age-verification.yml"))?; } let sign_text = format!("cn={};{};serial={}", av.cn, av.assertion, av.serial); @@ -35,5 +36,5 @@ pub fn check() -> Result<(), Box> { signature::UnparsedPublicKey::new(&signature::ECDSA_P256_SHA256_ASN1, &KEY_BYTES); info!("Checking sign_text: {}", sign_text); key.verify(&sign_text.as_bytes(), &base64::decode(av.sig)?) - .map_err(|e| Box::::from(format!("Invalid age-verification.yml signature: {}", e))) + .map_err(|_| Box::::from("Invalid age-verification.yml signature")) } diff --git a/blastmud_game/src/db.rs b/blastmud_game/src/db.rs index f77f9d4..58b7c38 100644 --- a/blastmud_game/src/db.rs +++ b/blastmud_game/src/db.rs @@ -1,22 +1,67 @@ use tokio_postgres::config::Config as PgConfig; -use deadpool_postgres::{Manager, ManagerConfig, Pool, RecyclingMethod}; +use deadpool_postgres::{Manager, Object, ManagerConfig, Pool, + RecyclingMethod}; use std::error::Error; use std::str::FromStr; use uuid::Uuid; use tokio_postgres::NoTls; +use crate::DResult; -pub async fn record_listener_ping(_listener: Uuid, _pool: Pool) { - // pool.get().await?.query(""); +#[derive(Clone, Debug)] +pub struct DBPool { + pool: Pool } -pub fn start_pool(connstr: &str) -> Result> { +pub async fn record_listener_ping(listener: Uuid, pool: DBPool) -> DResult<()> { + get_conn(pool).await?.execute( + "INSERT INTO listeners (listener, last_seen) \ + VALUES ($1, NOW()) \ + ON CONFLICT (listener) \ + DO UPDATE SET last_seen = EXCLUDED.last_seen", &[&listener]).await?; + Ok(()) +} + +pub async fn get_dead_listeners(pool: DBPool) -> DResult> { + Ok(get_conn(pool).await? + .query("SELECT listener FROM listeners WHERE last_seen < NOW() - \ + INTERVAL 2 minutes", &[]) + .await?.into_iter().map(|r| r.get(0)).collect()) +} + +pub async fn cleanup_listener(pool: DBPool, listener: Uuid) -> DResult<()> { + let mut conn = get_conn(pool).await?; + let tx = conn.transaction().await?; + tx.execute("UPDATE users SET current_session = NULL, \ + current_listener = NULL WHERE current_listener = $1", + &[&listener]).await?; + tx.execute("DELETE FROM sendqueue WHERE listener = $1", + &[&listener]).await?; + tx.execute("DELETE FROM sessions WHERE listener = $1", + &[&listener]).await?; + tx.execute("DELETE FROM listeners WHERE listener = $1", + &[&listener]).await?; + tx.commit().await?; + Ok(()) +} + +pub async fn get_conn(DBPool { pool }: DBPool) -> + DResult { + let conn = pool.get().await?; + conn.execute("SET synchronous_commit=off", &[]).await?; + Ok(conn) +} + +pub fn start_pool(connstr: &str) -> DResult { let mgr_config = ManagerConfig { recycling_method: RecyclingMethod::Fast }; let mgr = Manager::from_config( - PgConfig::from_str(connstr)?, + PgConfig::from_str(connstr) + .map_err(|e| Box::new(e) as Box)?, NoTls, mgr_config ); - Pool::builder(mgr).max_size(4).build().map_err(|e| Box::new(e) as Box) + Pool::builder(mgr).max_size(4).build() + .map_err(|e| Box::new(e) as Box) + .map(|pool| DBPool { pool }) } diff --git a/blastmud_game/src/listener.rs b/blastmud_game/src/listener.rs index 55f6f20..0c9b3fb 100644 --- a/blastmud_game/src/listener.rs +++ b/blastmud_game/src/listener.rs @@ -1,5 +1,4 @@ -use std::error::Error; -use tokio::task; +use tokio::{task, time}; use tokio::net::{TcpSocket, TcpStream, lookup_host}; use log::{info, warn}; use tokio_util::codec; @@ -12,6 +11,7 @@ use std::net::SocketAddr; use std::sync::Arc; use uuid::Uuid; use std::collections::BTreeMap; +use crate::DResult; #[derive(Debug)] pub struct ListenerSend { @@ -26,7 +26,7 @@ async fn handle_from_listener( listener_map: ListenerMap) where FHandler: Fn(Uuid, MessageFromListener) -> HandlerFut + Send + 'static, - HandlerFut: Future>> + Send + 'static { + HandlerFut: Future> + Send + 'static { let mut conn_framed = tokio_serde::Framed::new( codec::Framed::new(conn, LengthDelimitedCodec::new()), Cbor::::default() @@ -163,6 +163,8 @@ where ); } + // We delay to avoid wasting resources if we do end up in a loop. + time::sleep(time::Duration::from_secs(1)).await; listener_map.lock().await.remove(&listener_id); } @@ -174,10 +176,10 @@ pub async fn start_listener( bind_to: String, listener_map: ListenerMap, handle_message: FHandler -) -> Result<(), Box> +) -> DResult<()> where FHandler: Fn(Uuid, MessageFromListener) -> HandlerFut + Send + Clone + 'static, - HandlerFut: Future>> + Send + 'static + HandlerFut: Future> + Send + 'static { info!("Starting listener on {}", bind_to); let addr = lookup_host(bind_to).await?.next().expect("listener address didn't resolve"); diff --git a/blastmud_game/src/main.rs b/blastmud_game/src/main.rs index 7719075..8266573 100644 --- a/blastmud_game/src/main.rs +++ b/blastmud_game/src/main.rs @@ -10,6 +10,9 @@ mod listener; mod message_handler; mod version_cutover; mod av; +mod regular_tasks; + +pub type DResult = Result>; #[derive(Deserialize, Debug)] struct Config { @@ -18,16 +21,16 @@ struct Config { database_conn_string: String } -fn read_latest_config() -> Result> { +fn read_latest_config() -> DResult { serde_yaml::from_str(&fs::read_to_string("gameserver.conf")?). - map_err(|error| Box::new(error) as Box) + map_err(|error| Box::new(error) as Box) } #[tokio::main] -async fn main() -> Result<(), Box> { +async fn main() -> DResult<()> { SimpleLogger::new().with_level(LevelFilter::Info).init().unwrap(); - av::check().or_else(|e| -> Result<(), Box> { + av::check().or_else(|e| -> Result<(), Box> { error!("Couldn't verify age-verification.yml - this is not a complete game. Check README.md: {}", e); Err(e) })?; @@ -35,19 +38,22 @@ async fn main() -> Result<(), Box> { let pool = db::start_pool(&config.database_conn_string)?; // Test the database connection string works so we quit early if not... - let _ = pool.get().await?.query("SELECT 1", &[]).await?; + let _ = db::get_conn(pool.clone()).await?.query("SELECT 1", &[]).await?; - info!("Database pool initialised: {:?}", pool.status()); + info!("Database pool initialised"); let listener_map = listener::make_listener_map(); + + let mh_pool = pool.clone(); listener::start_listener(config.listener, listener_map.clone(), move |listener_id, msg| { - message_handler::handle(listener_id, msg, pool.clone(), listener_map.clone()) + message_handler::handle(listener_id, msg, mh_pool.clone(), listener_map.clone()) } ).await?; version_cutover::replace_old_gameserver(&config.pidfile)?; - + regular_tasks::start_regular_tasks(pool.clone())?; + let mut sigusr1 = signal(SignalKind::user_defined1())?; sigusr1.recv().await; diff --git a/blastmud_game/src/message_handler.rs b/blastmud_game/src/message_handler.rs index dd4ffcc..8ad9a27 100644 --- a/blastmud_game/src/message_handler.rs +++ b/blastmud_game/src/message_handler.rs @@ -1,16 +1,16 @@ use blastmud_interfaces::*; -use deadpool_postgres::Pool; use crate::listener::ListenerMap; +use crate::db; use MessageFromListener::*; use uuid::Uuid; use tokio::{sync::oneshot, task}; use crate::listener::ListenerSend; -use std::error::Error; +use crate::DResult; -pub async fn handle(listener: Uuid, msg: MessageFromListener, _pool: Pool, listener_map: ListenerMap) - -> Result<(), Box> { +pub async fn handle(listener: Uuid, msg: MessageFromListener, pool: db::DBPool, listener_map: ListenerMap) + -> DResult<()> { match msg { - ListenerPing { uuid: _ } => {} + ListenerPing { .. } => { db::record_listener_ping(listener, pool).await?; } SessionConnected { session: _, source: _ } => {} SessionDisconnected { session: _ } => {} SessionSentLine { session, msg } => { diff --git a/blastmud_game/src/version_cutover.rs b/blastmud_game/src/version_cutover.rs index 759693c..bddd2ba 100644 --- a/blastmud_game/src/version_cutover.rs +++ b/blastmud_game/src/version_cutover.rs @@ -3,8 +3,9 @@ use std::path::Path; use std::error::Error; use log::info; use nix::{sys::signal::{kill, Signal}, unistd::Pid}; +use crate::DResult; -pub fn replace_old_gameserver(pidfile: &str) -> Result<(), Box> { +pub fn replace_old_gameserver(pidfile: &str) -> DResult<()> { match read_to_string(pidfile) { Err(e) => if e.kind() == std::io::ErrorKind::NotFound { @@ -12,16 +13,16 @@ pub fn replace_old_gameserver(pidfile: &str) -> Result<(), Box> { Ok(()) } else { info!("Error reading pidfile (other than NotFound): {}", e); - Err(Box::new(e) as Box::) + Err(Box::new(e) as Box::) } Ok(f) => { - let pid: Pid = Pid::from_raw(f.parse().map_err(|e| Box::new(e) as Box::)?); + let pid: Pid = Pid::from_raw(f.parse().map_err(|e| Box::new(e) as Box::)?); match read_to_string(format!("/proc/{}/cmdline", pid)) { Ok(content) => if content.contains("blastmud_game") { info!("pid in pidfile references blastmud_game; starting cutover"); kill(pid, Signal::SIGUSR1) - .map_err(|e| Box::new(e) as Box) + .map_err(|e| Box::new(e) as Box) } else { info!("Pid in pidfile is for process not including blastmud_game - ignoring pidfile"); Ok(()) @@ -35,5 +36,5 @@ pub fn replace_old_gameserver(pidfile: &str) -> Result<(), Box> { }?; info!("Writing new pidfile"); write(Path::new(pidfile), format!("{}", std::process::id())) - .map_err(|e| Box::new(e) as Box::) + .map_err(|e| Box::new(e) as Box::) }