From 55d3087d2197bb5a7061d949e4478cc86c2501a7 Mon Sep 17 00:00:00 2001 From: Shagnor Date: Sun, 25 Dec 2022 12:20:06 +1100 Subject: [PATCH] Implement features of listener to limit abuse of resources / DoS --- blastmud_listener/src/main.rs | 176 ++++++++++++++++++++++++++++++---- 1 file changed, 155 insertions(+), 21 deletions(-) diff --git a/blastmud_listener/src/main.rs b/blastmud_listener/src/main.rs index 524d8fd6..6e8fee99 100644 --- a/blastmud_listener/src/main.rs +++ b/blastmud_listener/src/main.rs @@ -23,6 +23,7 @@ use tokio_stream::wrappers::ReceiverStream; use warp; use warp::filters::ws; use warp::Filter; +use std::time::Instant; #[derive(Deserialize, Debug)] struct Config { @@ -227,13 +228,18 @@ struct SessionRecord { disconnect_channel: mpsc::UnboundedSender<()> } -type SessionMap = Arc>>; +struct SessionIndexes { + by_uuid: BTreeMap, + count_by_source: BTreeMap +} + +type SessionMap = Arc>; async fn handle_server_message(session_map: SessionMap, message: MessageToListener) { match message { MessageToListener::AcknowledgeMessage => {} MessageToListener::DisconnectSession { session } => { - match session_map.lock().await.get(&session) { + match session_map.lock().await.by_uuid.get(&session) { // Just silently ignore it if they are disconnected. None => {} Some(SessionRecord { channel, disconnect_channel, .. }) => { @@ -247,7 +253,7 @@ async fn handle_server_message(session_map: SessionMap, message: MessageToListen } } MessageToListener::SendToSession { session, msg } => { - match session_map.lock().await.get(&session) { + match session_map.lock().await.by_uuid.get(&session) { // Just silently ignore it if they are disconnected. None => {} Some(SessionRecord { channel, .. }) => { @@ -273,6 +279,56 @@ fn start_server_task(listener_id: Uuid, const MAX_CAPACITY: usize = 20; const STOP_READING_CAPACITY: usize = 10; +struct TokenBucket { + level: f64, + last_topup: Instant, + max_level: f64, + alloc_per_ms: f64, +} + +impl TokenBucket { + pub fn new(initial_level: f64, max_level: f64, alloc_per_ms: f64) -> TokenBucket { + TokenBucket { + level: initial_level, + last_topup: Instant::now(), + max_level, + alloc_per_ms + } + } + + pub fn update(self: &mut Self) { + self.level = + (self.level + self.alloc_per_ms * (self.last_topup.elapsed().as_millis() as f64)) + .min(self.max_level); + self.last_topup = Instant::now(); + } + + pub fn consume(self: &mut Self) { + self.level = self.level - 1.0; + } + + pub fn consume_minor(self: &mut Self) { + self.level = self.level - 0.1; + } + + pub fn nearly_empty(self: &Self) -> bool { + self.level < 1.0 + } + + pub fn has_capacity(self: &Self) -> bool { + self.level > 0.0 + } + + pub fn time_to_capacity(self: &Self) -> Duration { + Duration::from_millis((-self.level / self.alloc_per_ms) as u64) + } +} + +const CLIENT_INITIAL_TOKENS: f64 = 10.0; +const CLIENT_MAX_LEVEL: f64 = 60.0; +const CLIENT_ALLOC_PER_MS: f64 = 0.005; +const MAX_CONNS_PER_IP: u64 = 5; + async fn handle_client_socket( server: mpsc::Sender, active_sessions: SessionMap, @@ -285,22 +341,41 @@ async fn handle_client_socket( codec::LinesCodec::new_with_max_length(512) ); let session = Uuid::new_v4(); + let mut tok_bucket = + TokenBucket::new(CLIENT_INITIAL_TOKENS, CLIENT_MAX_LEVEL, CLIENT_ALLOC_PER_MS); info!("Accepted session {} from {}", session, addr); let (lsender, mut lreceiver) = mpsc::channel(MAX_CAPACITY); let (discon_sender, mut discon_receiver) = mpsc::unbounded_channel(); - active_sessions.lock().await.insert( + let mut sess_idx_lock = active_sessions.lock().await; + let addr_str = addr.ip().to_string(); + if *sess_idx_lock.count_by_source.get(&addr_str).unwrap_or(&0) >= MAX_CONNS_PER_IP { + drop(sess_idx_lock); + info!("Rejecting session {} because of too many concurrent connections", session); + match wstream.write_all("Too many connections from same IP\r\n".as_bytes()).await { + Err(e) => { + info!("Client connection {} got error {}", session, e); + } + Ok(()) => {} + } + return; + } + sess_idx_lock.count_by_source.entry(addr_str.clone()).and_modify(|c| { *c += 1; }).or_insert(1); + sess_idx_lock.by_uuid.insert( session, SessionRecord { channel: lsender.clone(), disconnect_channel: discon_sender.clone() }); + drop(sess_idx_lock); + server.send(ServerTaskCommand::Send { message: MessageFromListener::SessionConnected { session, source: addr.to_string() }}).await.unwrap(); 'client_loop: loop { + tok_bucket.update(); tokio::select!( Some(()) = discon_receiver.recv() => { info!("Client connection {} instructed for immediate disconnect", session); @@ -320,8 +395,10 @@ async fn handle_client_socket( Ok(()) => {} } } - }, - line_read = rbuf.try_next(), if lsender.capacity() > STOP_READING_CAPACITY => { + } + _ = time::sleep(tok_bucket.time_to_capacity()), if !tok_bucket.has_capacity() => {} + line_read = rbuf.try_next(), if lsender.capacity() > STOP_READING_CAPACITY && + tok_bucket.has_capacity() => { match line_read { Err(e) => { info!("Client connection {} got error {}", session, e); @@ -332,9 +409,21 @@ async fn handle_client_socket( break 'client_loop; } Ok(Some(msg)) => { - server.send(ServerTaskCommand::Send { - message: MessageFromListener::SessionSentLine {session, msg } - }).await.unwrap(); + if tok_bucket.nearly_empty() { + match wstream.write_all("You're sending too fast; dropped message.\r\n" + .as_bytes()).await { + Err(e) => { + info!("Client connection {} got error {}", session, e); + } + Ok(()) => {} + } + tok_bucket.consume_minor(); + } else { + tok_bucket.consume(); + server.send(ServerTaskCommand::Send { + message: MessageFromListener::SessionSentLine {session, msg } + }).await.unwrap(); + } } } } @@ -344,7 +433,13 @@ async fn handle_client_socket( server.send(ServerTaskCommand::Send { message: MessageFromListener::SessionDisconnected { session }}).await.unwrap(); - active_sessions.lock().await.remove(&session); + + sess_idx_lock = active_sessions.lock().await; + sess_idx_lock.by_uuid.remove(&session); + sess_idx_lock.count_by_source.entry(addr_str.clone()).and_modify(|v| { *v -= 1; }); + if *sess_idx_lock.count_by_source.get(&addr_str).unwrap_or(&1) <= 0 { + sess_idx_lock.count_by_source.remove(&addr_str); + } } fn start_pinger(listener: Uuid, server: mpsc::Sender) { @@ -369,17 +464,36 @@ async fn handle_websocket( let (lsender, mut lreceiver) = mpsc::channel(MAX_CAPACITY); let (discon_sender, mut discon_receiver) = mpsc::unbounded_channel(); - - active_sessions.lock().await.insert( + + let mut sess_idx_lock = active_sessions.lock().await; + let addr_str: String = src.split(" ").last().unwrap_or("").to_string(); + if *sess_idx_lock.count_by_source.get(&addr_str).unwrap_or(&0) >= MAX_CONNS_PER_IP { + drop(sess_idx_lock); + info!("Rejecting session {} because of too many concurrent connections", session); + match ws.send(ws::Message::text("Too many connections from same IP\r\n")).await { + Err(e) => { + info!("Client connection {} got error {}", session, e); + } + Ok(()) => {} + } + return; + } + sess_idx_lock.count_by_source.entry(addr_str.clone()).and_modify(|c| { *c += 1; }).or_insert(1); + sess_idx_lock.by_uuid.insert( session, SessionRecord { channel: lsender.clone(), disconnect_channel: discon_sender.clone() }); + drop(sess_idx_lock); + server.send(ServerTaskCommand::Send { message: MessageFromListener::SessionConnected { session, source: src }}).await.unwrap(); + let mut tok_bucket = + TokenBucket::new(CLIENT_INITIAL_TOKENS, CLIENT_MAX_LEVEL, CLIENT_ALLOC_PER_MS); 'client_loop: loop { + tok_bucket.update(); tokio::select!( Some(()) = discon_receiver.recv() => { info!("Client connection {} instructed for immediate disconnect", session); @@ -399,8 +513,10 @@ async fn handle_websocket( Ok(()) => {} } } - }, - msg_read = ws.try_next(), if lsender.capacity() > STOP_READING_CAPACITY => { + } + _ = time::sleep(tok_bucket.time_to_capacity()), if !tok_bucket.has_capacity() => {} + msg_read = ws.try_next(), if lsender.capacity() > STOP_READING_CAPACITY && + tok_bucket.has_capacity() => { match msg_read { Err(e) => { info!("Client connection {} got error {}", session, e); @@ -418,12 +534,24 @@ async fn handle_websocket( match wsmsg.to_str() { Err(_) => {} Ok(msg) => { - server.send(ServerTaskCommand::Send { - message: MessageFromListener::SessionSentLine { - session, - msg: msg.to_owned() + if tok_bucket.nearly_empty() { + match ws.send(ws::Message::text("You're sending too fast; dropped message.\r\n")).await { + Err(e) => { + info!("Client connection {} got error {}", session, e); + } + Ok(()) => {} } - }).await.unwrap_or(()); + tok_bucket.consume_minor(); + } else { + tok_bucket.consume(); + + server.send(ServerTaskCommand::Send { + message: MessageFromListener::SessionSentLine { + session, + msg: msg.to_owned() + } + }).await.unwrap_or(()); + } } } } @@ -444,7 +572,13 @@ async fn handle_websocket( server.send(ServerTaskCommand::Send { message: MessageFromListener::SessionDisconnected { session }}).await.unwrap(); - active_sessions.lock().await.remove(&session); + + sess_idx_lock = active_sessions.lock().await; + sess_idx_lock.by_uuid.remove(&session); + sess_idx_lock.count_by_source.entry(addr_str.clone()).and_modify(|v| { *v -= 1; }); + if *sess_idx_lock.count_by_source.get(&addr_str).unwrap_or(&1) <= 0 { + sess_idx_lock.count_by_source.remove(&addr_str); + } } async fn upgrade_websocket(src: String, wsreq: ws::Ws, @@ -482,7 +616,7 @@ async fn main() -> Result<(), Box> { let listener_id = Uuid::new_v4(); let mut config = read_latest_config()?; let active_sessions: SessionMap = - Arc::new(Mutex::new(BTreeMap::new())); + Arc::new(Mutex::new(SessionIndexes { by_uuid: BTreeMap::new(), count_by_source: BTreeMap::new() })); let server_sender = start_server_task(listener_id, config.gameserver, active_sessions.clone()); start_pinger(listener_id, server_sender.clone());