diff --git a/blastmud_listener/src/main.rs b/blastmud_listener/src/main.rs index 4d50f17..e6c3d82 100644 --- a/blastmud_listener/src/main.rs +++ b/blastmud_listener/src/main.rs @@ -9,7 +9,7 @@ use tokio::time::{self, Duration}; use tokio::net::{TcpStream, TcpListener}; use tokio::signal::unix::{signal, SignalKind}; use tokio::sync::{mpsc, Mutex}; -use tokio::io::{BufReader}; +use tokio::io::{BufReader, AsyncWriteExt}; use log::{warn, info}; use simple_logger::SimpleLogger; use std::sync::Arc; @@ -37,13 +37,17 @@ enum ServerTaskCommand { Send { message: MessageFromListener } } -fn run_server_task () + Send + 'static>( +fn run_server_task( unfinished_business: Option, mut receiver: mpsc::Receiver, sender: mpsc::Sender, server: String, message_handler: FHandler -) { +) +where + FHandler: Fn(MessageToListener) -> HandlerFut + Send + 'static, + HandlerFut: Future +{ task::spawn(async move { let conn = loop { match TcpStream::connect(&server).await { @@ -186,25 +190,60 @@ fn run_server_task () + Send + 'static>( } enum SessionCommand { - Disconnect + Disconnect, + SendString { message : String } } struct SessionRecord { - channel: mpsc::Sender + channel: mpsc::Sender, + disconnect_channel: mpsc::UnboundedSender<()> } type SessionMap = Arc>>; -fn handle_server_message(session_map: SessionMap, message: MessageToListener) { +async fn handle_server_message(session_map: SessionMap, message: MessageToListener) { + match message { + MessageToListener::AcknowledgeMessage => { + warn!("Unexpected AcknowledgeMessage from gameserver. This suggests a bug in the gameserver"); + } + MessageToListener::DisconnectSession { session } => { + match session_map.lock().await.get(&session) { + // Just silently ignore it if they are disconnected. + None => {} + Some(SessionRecord { channel, disconnect_channel, .. }) => { + match channel.try_send(SessionCommand::Disconnect) { + Err(mpsc::error::TrySendError::Full(_)) => { + disconnect_channel.send(()).unwrap_or(()); + } + _ => {} + } + } + } + } + MessageToListener::SendToSession { session, msg } => { + match session_map.lock().await.get(&session) { + // Just silently ignore it if they are disconnected. + None => {} + Some(SessionRecord { channel, .. }) => { + channel.try_send(SessionCommand::SendString { message: msg }) + .unwrap_or(()); + } + } + } + } } fn start_server_task(server: String, session_map: SessionMap) -> mpsc::Sender { let (sender, receiver) = mpsc::channel(20); run_server_task(None, receiver, sender.clone(), server, - move |msg| { handle_server_message(session_map.clone(), msg); }); + move |msg| handle_server_message(session_map.clone(), + msg) ); sender } +const MAX_CAPACITY: usize = 20; +const STOP_READING_CAPACITY: usize = 10; + async fn handle_client_socket( server: mpsc::Sender, active_sessions: SessionMap, @@ -219,34 +258,58 @@ async fn handle_client_socket( let session = Uuid::new_v4(); info!("Accepted session {} from {}", session, addr); - let (sender, receiver) = mpsc::channel(20); - active_sessions.lock().await.insert(session, SessionRecord { channel: sender }); + + let (lsender, mut lreceiver) = mpsc::channel(MAX_CAPACITY); + let (discon_sender, mut discon_receiver) = mpsc::unbounded_channel(); + + active_sessions.lock().await.insert( + session, SessionRecord { + channel: lsender.clone(), + disconnect_channel: discon_sender.clone() + }); server.send(ServerTaskCommand::Send { message: MessageFromListener::SessionConnected { session, source: addr.to_string() }}).await.unwrap(); - loop { - match rbuf.try_next().await { - Err(e) => { - info!("Client connection {} got error {}", session, e); - break; + 'client_loop: loop { + tokio::select!( + Some(()) = discon_receiver.recv() => { + info!("Client connection {} instructed for immediate disconnect", session); + break 'client_loop; } - Ok(None) => { - info!("Client connection {} closed", session); - break; - } - Ok(Some(msg)) => { - server.send(ServerTaskCommand::Send { - message: MessageFromListener::SessionSentLine { session, msg } - }).await.unwrap(); - /* match wstream.write_all((msg + "\r\n").as_bytes()).await { + Some(message) = lreceiver.recv() => { + match message { + SessionCommand::Disconnect => { + info!("Client connection {} instructed for disconnect", session); + break 'client_loop; + } + SessionCommand::SendString { message } => + match wstream.write_all((message + "\r\n").as_bytes()).await { + Err(e) => { + info!("Client connection {} got error {}", session, e); + } + Ok(()) => {} + } + } + }, + line_read = rbuf.try_next(), if lsender.capacity() > STOP_READING_CAPACITY => { + match line_read { Err(e) => { info!("Client connection {} got error {}", session, e); + break 'client_loop; } - Ok(()) => {} - } */ + Ok(None) => { + info!("Client connection {} closed", session); + break 'client_loop; + } + Ok(Some(msg)) => { + server.send(ServerTaskCommand::Send { + message: MessageFromListener::SessionSentLine { session, msg } + }).await.unwrap(); + } + } } - } + ); } server.send(ServerTaskCommand::Send { message: MessageFromListener::SessionDisconnected { @@ -312,3 +375,13 @@ async fn main() -> Result<(), Box> { } } } + + +#[cfg(test)] +mod tests { + #[test] + fn doesnt_stop_reading_at_max_capacity() { + use crate::*; + assert!(MAX_CAPACITY > STOP_READING_CAPACITY); + } +}