Refactor dynamic Result type + add session cleanup

This commit is contained in:
Condorra 2022-12-18 23:44:04 +11:00
parent 7dd8b05855
commit b090d701aa
9 changed files with 91 additions and 34 deletions

1
Cargo.lock generated
View File

@ -54,6 +54,7 @@ version = "0.1.0"
dependencies = [ dependencies = [
"base64 0.20.0", "base64 0.20.0",
"blastmud_interfaces", "blastmud_interfaces",
"deadpool",
"deadpool-postgres", "deadpool-postgres",
"futures", "futures",
"log", "log",

View File

@ -63,4 +63,4 @@ Create a user with a secret password, and username `blast`. Create a production
To get to the latest schema: To get to the latest schema:
* Run `psql <schema/schema.sql` to create the temporary `blast_schemaonly` database. * Run `psql <schema/schema.sql` to create the temporary `blast_schemaonly` database.
* Run `migra "postgres:///blast" "postgres:///blast_schemaonly" > /tmp/update.sql` * Run `migra "postgres:///blast" "postgres:///blast_schemaonly" > /tmp/update.sql`
* Check `/tmp/update.sql` and if it looks good, apply it with `psql -d blast </tmp/update.sql` * Check `/tmp/update.sql` and if it looks good, apply it with `psql -u blast -d blast </tmp/update.sql`

View File

@ -8,6 +8,7 @@ edition = "2021"
[dependencies] [dependencies]
base64 = "0.20.0" base64 = "0.20.0"
blastmud_interfaces = { path = "../blastmud_interfaces" } blastmud_interfaces = { path = "../blastmud_interfaces" }
deadpool = "0.9.5"
deadpool-postgres = { version = "0.10.3", features = ["serde"] } deadpool-postgres = { version = "0.10.3", features = ["serde"] }
futures = "0.3.25" futures = "0.3.25"
log = "0.4.17" log = "0.4.17"

View File

@ -4,6 +4,7 @@ use serde::Deserialize;
use ring::signature; use ring::signature;
use base64; use base64;
use log::info; use log::info;
use crate::DResult;
#[derive(Deserialize)] #[derive(Deserialize)]
struct AV { struct AV {
@ -22,12 +23,12 @@ static KEY_BYTES: [u8;65] = [
0xa8, 0xb3, 0x02, 0x35, 0x7e 0xa8, 0xb3, 0x02, 0x35, 0x7e
]; ];
pub fn check() -> Result<(), Box<dyn Error>> { pub fn check() -> DResult<()> {
let av: AV = serde_yaml::from_str(&fs::read_to_string("age-verification.yml")?). let av: AV = serde_yaml::from_str(&fs::read_to_string("age-verification.yml")?).
map_err(|error| Box::new(error) as Box<dyn Error>)?; map_err(|error| Box::new(error) as Box<dyn Error + Send + Sync>)?;
if av.copyright != "This file is protected by copyright and may not be used or reproduced except as authorised by the copyright holder. All rights reserved." || if av.copyright != "This file is protected by copyright and may not be used or reproduced except as authorised by the copyright holder. All rights reserved." ||
av.assertion != "age>=18" { av.assertion != "age>=18" {
Err(Box::<dyn Error>::from("Invalid age-verification.yml"))?; Err(Box::<dyn Error + Send + Sync>::from("Invalid age-verification.yml"))?;
} }
let sign_text = format!("cn={};{};serial={}", av.cn, av.assertion, av.serial); let sign_text = format!("cn={};{};serial={}", av.cn, av.assertion, av.serial);
@ -35,5 +36,5 @@ pub fn check() -> Result<(), Box<dyn Error>> {
signature::UnparsedPublicKey::new(&signature::ECDSA_P256_SHA256_ASN1, &KEY_BYTES); signature::UnparsedPublicKey::new(&signature::ECDSA_P256_SHA256_ASN1, &KEY_BYTES);
info!("Checking sign_text: {}", sign_text); info!("Checking sign_text: {}", sign_text);
key.verify(&sign_text.as_bytes(), &base64::decode(av.sig)?) key.verify(&sign_text.as_bytes(), &base64::decode(av.sig)?)
.map_err(|e| Box::<dyn Error>::from(format!("Invalid age-verification.yml signature: {}", e))) .map_err(|_| Box::<dyn Error + Send + Sync>::from("Invalid age-verification.yml signature"))
} }

View File

@ -1,22 +1,67 @@
use tokio_postgres::config::Config as PgConfig; use tokio_postgres::config::Config as PgConfig;
use deadpool_postgres::{Manager, ManagerConfig, Pool, RecyclingMethod}; use deadpool_postgres::{Manager, Object, ManagerConfig, Pool,
RecyclingMethod};
use std::error::Error; use std::error::Error;
use std::str::FromStr; use std::str::FromStr;
use uuid::Uuid; use uuid::Uuid;
use tokio_postgres::NoTls; use tokio_postgres::NoTls;
use crate::DResult;
pub async fn record_listener_ping(_listener: Uuid, _pool: Pool) { #[derive(Clone, Debug)]
// pool.get().await?.query(""); pub struct DBPool {
pool: Pool
} }
pub fn start_pool(connstr: &str) -> Result<Pool, Box<dyn Error>> { pub async fn record_listener_ping(listener: Uuid, pool: DBPool) -> DResult<()> {
get_conn(pool).await?.execute(
"INSERT INTO listeners (listener, last_seen) \
VALUES ($1, NOW()) \
ON CONFLICT (listener) \
DO UPDATE SET last_seen = EXCLUDED.last_seen", &[&listener]).await?;
Ok(())
}
pub async fn get_dead_listeners(pool: DBPool) -> DResult<Vec<Uuid>> {
Ok(get_conn(pool).await?
.query("SELECT listener FROM listeners WHERE last_seen < NOW() - \
INTERVAL 2 minutes", &[])
.await?.into_iter().map(|r| r.get(0)).collect())
}
pub async fn cleanup_listener(pool: DBPool, listener: Uuid) -> DResult<()> {
let mut conn = get_conn(pool).await?;
let tx = conn.transaction().await?;
tx.execute("UPDATE users SET current_session = NULL, \
current_listener = NULL WHERE current_listener = $1",
&[&listener]).await?;
tx.execute("DELETE FROM sendqueue WHERE listener = $1",
&[&listener]).await?;
tx.execute("DELETE FROM sessions WHERE listener = $1",
&[&listener]).await?;
tx.execute("DELETE FROM listeners WHERE listener = $1",
&[&listener]).await?;
tx.commit().await?;
Ok(())
}
pub async fn get_conn(DBPool { pool }: DBPool) ->
DResult<Object> {
let conn = pool.get().await?;
conn.execute("SET synchronous_commit=off", &[]).await?;
Ok(conn)
}
pub fn start_pool(connstr: &str) -> DResult<DBPool> {
let mgr_config = ManagerConfig { let mgr_config = ManagerConfig {
recycling_method: RecyclingMethod::Fast recycling_method: RecyclingMethod::Fast
}; };
let mgr = Manager::from_config( let mgr = Manager::from_config(
PgConfig::from_str(connstr)?, PgConfig::from_str(connstr)
.map_err(|e| Box::new(e) as Box<dyn Error + Send + Sync>)?,
NoTls, mgr_config NoTls, mgr_config
); );
Pool::builder(mgr).max_size(4).build().map_err(|e| Box::new(e) as Box<dyn Error>) Pool::builder(mgr).max_size(4).build()
.map_err(|e| Box::new(e) as Box<dyn Error + Send + Sync>)
.map(|pool| DBPool { pool })
} }

View File

@ -1,5 +1,4 @@
use std::error::Error; use tokio::{task, time};
use tokio::task;
use tokio::net::{TcpSocket, TcpStream, lookup_host}; use tokio::net::{TcpSocket, TcpStream, lookup_host};
use log::{info, warn}; use log::{info, warn};
use tokio_util::codec; use tokio_util::codec;
@ -12,6 +11,7 @@ use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use uuid::Uuid; use uuid::Uuid;
use std::collections::BTreeMap; use std::collections::BTreeMap;
use crate::DResult;
#[derive(Debug)] #[derive(Debug)]
pub struct ListenerSend { pub struct ListenerSend {
@ -26,7 +26,7 @@ async fn handle_from_listener<FHandler, HandlerFut>(
listener_map: ListenerMap) listener_map: ListenerMap)
where where
FHandler: Fn(Uuid, MessageFromListener) -> HandlerFut + Send + 'static, FHandler: Fn(Uuid, MessageFromListener) -> HandlerFut + Send + 'static,
HandlerFut: Future<Output = Result<(), Box<dyn Error>>> + Send + 'static { HandlerFut: Future<Output = DResult<()>> + Send + 'static {
let mut conn_framed = tokio_serde::Framed::new( let mut conn_framed = tokio_serde::Framed::new(
codec::Framed::new(conn, LengthDelimitedCodec::new()), codec::Framed::new(conn, LengthDelimitedCodec::new()),
Cbor::<MessageFromListener, MessageToListener>::default() Cbor::<MessageFromListener, MessageToListener>::default()
@ -163,6 +163,8 @@ where
); );
} }
// We delay to avoid wasting resources if we do end up in a loop.
time::sleep(time::Duration::from_secs(1)).await;
listener_map.lock().await.remove(&listener_id); listener_map.lock().await.remove(&listener_id);
} }
@ -174,10 +176,10 @@ pub async fn start_listener<FHandler, HandlerFut>(
bind_to: String, bind_to: String,
listener_map: ListenerMap, listener_map: ListenerMap,
handle_message: FHandler handle_message: FHandler
) -> Result<(), Box<dyn Error>> ) -> DResult<()>
where where
FHandler: Fn(Uuid, MessageFromListener) -> HandlerFut + Send + Clone + 'static, FHandler: Fn(Uuid, MessageFromListener) -> HandlerFut + Send + Clone + 'static,
HandlerFut: Future<Output = Result<(), Box<dyn Error>>> + Send + 'static HandlerFut: Future<Output = DResult<()>> + Send + 'static
{ {
info!("Starting listener on {}", bind_to); info!("Starting listener on {}", bind_to);
let addr = lookup_host(bind_to).await?.next().expect("listener address didn't resolve"); let addr = lookup_host(bind_to).await?.next().expect("listener address didn't resolve");

View File

@ -10,6 +10,9 @@ mod listener;
mod message_handler; mod message_handler;
mod version_cutover; mod version_cutover;
mod av; mod av;
mod regular_tasks;
pub type DResult<T> = Result<T, Box<dyn Error + Send + Sync>>;
#[derive(Deserialize, Debug)] #[derive(Deserialize, Debug)]
struct Config { struct Config {
@ -18,16 +21,16 @@ struct Config {
database_conn_string: String database_conn_string: String
} }
fn read_latest_config() -> Result<Config, Box<dyn Error>> { fn read_latest_config() -> DResult<Config> {
serde_yaml::from_str(&fs::read_to_string("gameserver.conf")?). serde_yaml::from_str(&fs::read_to_string("gameserver.conf")?).
map_err(|error| Box::new(error) as Box<dyn Error>) map_err(|error| Box::new(error) as Box<dyn Error + Send + Sync>)
} }
#[tokio::main] #[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> { async fn main() -> DResult<()> {
SimpleLogger::new().with_level(LevelFilter::Info).init().unwrap(); SimpleLogger::new().with_level(LevelFilter::Info).init().unwrap();
av::check().or_else(|e| -> Result<(), Box<dyn Error>> { av::check().or_else(|e| -> Result<(), Box<dyn Error + Send + Sync>> {
error!("Couldn't verify age-verification.yml - this is not a complete game. Check README.md: {}", e); error!("Couldn't verify age-verification.yml - this is not a complete game. Check README.md: {}", e);
Err(e) Err(e)
})?; })?;
@ -35,18 +38,21 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let pool = db::start_pool(&config.database_conn_string)?; let pool = db::start_pool(&config.database_conn_string)?;
// Test the database connection string works so we quit early if not... // Test the database connection string works so we quit early if not...
let _ = pool.get().await?.query("SELECT 1", &[]).await?; let _ = db::get_conn(pool.clone()).await?.query("SELECT 1", &[]).await?;
info!("Database pool initialised: {:?}", pool.status()); info!("Database pool initialised");
let listener_map = listener::make_listener_map(); let listener_map = listener::make_listener_map();
let mh_pool = pool.clone();
listener::start_listener(config.listener, listener_map.clone(), listener::start_listener(config.listener, listener_map.clone(),
move |listener_id, msg| { move |listener_id, msg| {
message_handler::handle(listener_id, msg, pool.clone(), listener_map.clone()) message_handler::handle(listener_id, msg, mh_pool.clone(), listener_map.clone())
} }
).await?; ).await?;
version_cutover::replace_old_gameserver(&config.pidfile)?; version_cutover::replace_old_gameserver(&config.pidfile)?;
regular_tasks::start_regular_tasks(pool.clone())?;
let mut sigusr1 = signal(SignalKind::user_defined1())?; let mut sigusr1 = signal(SignalKind::user_defined1())?;
sigusr1.recv().await; sigusr1.recv().await;

View File

@ -1,16 +1,16 @@
use blastmud_interfaces::*; use blastmud_interfaces::*;
use deadpool_postgres::Pool;
use crate::listener::ListenerMap; use crate::listener::ListenerMap;
use crate::db;
use MessageFromListener::*; use MessageFromListener::*;
use uuid::Uuid; use uuid::Uuid;
use tokio::{sync::oneshot, task}; use tokio::{sync::oneshot, task};
use crate::listener::ListenerSend; use crate::listener::ListenerSend;
use std::error::Error; use crate::DResult;
pub async fn handle(listener: Uuid, msg: MessageFromListener, _pool: Pool, listener_map: ListenerMap) pub async fn handle(listener: Uuid, msg: MessageFromListener, pool: db::DBPool, listener_map: ListenerMap)
-> Result<(), Box<dyn Error>> { -> DResult<()> {
match msg { match msg {
ListenerPing { uuid: _ } => {} ListenerPing { .. } => { db::record_listener_ping(listener, pool).await?; }
SessionConnected { session: _, source: _ } => {} SessionConnected { session: _, source: _ } => {}
SessionDisconnected { session: _ } => {} SessionDisconnected { session: _ } => {}
SessionSentLine { session, msg } => { SessionSentLine { session, msg } => {

View File

@ -3,8 +3,9 @@ use std::path::Path;
use std::error::Error; use std::error::Error;
use log::info; use log::info;
use nix::{sys::signal::{kill, Signal}, unistd::Pid}; use nix::{sys::signal::{kill, Signal}, unistd::Pid};
use crate::DResult;
pub fn replace_old_gameserver(pidfile: &str) -> Result<(), Box<dyn Error>> { pub fn replace_old_gameserver(pidfile: &str) -> DResult<()> {
match read_to_string(pidfile) { match read_to_string(pidfile) {
Err(e) => Err(e) =>
if e.kind() == std::io::ErrorKind::NotFound { if e.kind() == std::io::ErrorKind::NotFound {
@ -12,16 +13,16 @@ pub fn replace_old_gameserver(pidfile: &str) -> Result<(), Box<dyn Error>> {
Ok(()) Ok(())
} else { } else {
info!("Error reading pidfile (other than NotFound): {}", e); info!("Error reading pidfile (other than NotFound): {}", e);
Err(Box::new(e) as Box::<dyn Error>) Err(Box::new(e) as Box::<dyn Error + Send + Sync>)
} }
Ok(f) => { Ok(f) => {
let pid: Pid = Pid::from_raw(f.parse().map_err(|e| Box::new(e) as Box::<dyn Error>)?); let pid: Pid = Pid::from_raw(f.parse().map_err(|e| Box::new(e) as Box::<dyn Error + Send + Sync>)?);
match read_to_string(format!("/proc/{}/cmdline", pid)) { match read_to_string(format!("/proc/{}/cmdline", pid)) {
Ok(content) => Ok(content) =>
if content.contains("blastmud_game") { if content.contains("blastmud_game") {
info!("pid in pidfile references blastmud_game; starting cutover"); info!("pid in pidfile references blastmud_game; starting cutover");
kill(pid, Signal::SIGUSR1) kill(pid, Signal::SIGUSR1)
.map_err(|e| Box::new(e) as Box<dyn Error>) .map_err(|e| Box::new(e) as Box<dyn Error + Send + Sync>)
} else { } else {
info!("Pid in pidfile is for process not including blastmud_game - ignoring pidfile"); info!("Pid in pidfile is for process not including blastmud_game - ignoring pidfile");
Ok(()) Ok(())
@ -35,5 +36,5 @@ pub fn replace_old_gameserver(pidfile: &str) -> Result<(), Box<dyn Error>> {
}?; }?;
info!("Writing new pidfile"); info!("Writing new pidfile");
write(Path::new(pidfile), format!("{}", std::process::id())) write(Path::new(pidfile), format!("{}", std::process::id()))
.map_err(|e| Box::new(e) as Box::<dyn Error>) .map_err(|e| Box::new(e) as Box::<dyn Error + Send + Sync>)
} }