Implement features of listener to limit abuse of resources / DoS

This commit is contained in:
Condorra 2022-12-25 12:20:06 +11:00
parent e5e6c45e81
commit 55d3087d21

View File

@ -23,6 +23,7 @@ use tokio_stream::wrappers::ReceiverStream;
use warp; use warp;
use warp::filters::ws; use warp::filters::ws;
use warp::Filter; use warp::Filter;
use std::time::Instant;
#[derive(Deserialize, Debug)] #[derive(Deserialize, Debug)]
struct Config { struct Config {
@ -227,13 +228,18 @@ struct SessionRecord {
disconnect_channel: mpsc::UnboundedSender<()> disconnect_channel: mpsc::UnboundedSender<()>
} }
type SessionMap = Arc<Mutex<BTreeMap<Uuid, SessionRecord>>>; struct SessionIndexes {
by_uuid: BTreeMap<Uuid, SessionRecord>,
count_by_source: BTreeMap<String, u64>
}
type SessionMap = Arc<Mutex<SessionIndexes>>;
async fn handle_server_message(session_map: SessionMap, message: MessageToListener) { async fn handle_server_message(session_map: SessionMap, message: MessageToListener) {
match message { match message {
MessageToListener::AcknowledgeMessage => {} MessageToListener::AcknowledgeMessage => {}
MessageToListener::DisconnectSession { session } => { MessageToListener::DisconnectSession { session } => {
match session_map.lock().await.get(&session) { match session_map.lock().await.by_uuid.get(&session) {
// Just silently ignore it if they are disconnected. // Just silently ignore it if they are disconnected.
None => {} None => {}
Some(SessionRecord { channel, disconnect_channel, .. }) => { Some(SessionRecord { channel, disconnect_channel, .. }) => {
@ -247,7 +253,7 @@ async fn handle_server_message(session_map: SessionMap, message: MessageToListen
} }
} }
MessageToListener::SendToSession { session, msg } => { MessageToListener::SendToSession { session, msg } => {
match session_map.lock().await.get(&session) { match session_map.lock().await.by_uuid.get(&session) {
// Just silently ignore it if they are disconnected. // Just silently ignore it if they are disconnected.
None => {} None => {}
Some(SessionRecord { channel, .. }) => { Some(SessionRecord { channel, .. }) => {
@ -273,6 +279,56 @@ fn start_server_task(listener_id: Uuid,
const MAX_CAPACITY: usize = 20; const MAX_CAPACITY: usize = 20;
const STOP_READING_CAPACITY: usize = 10; const STOP_READING_CAPACITY: usize = 10;
struct TokenBucket {
level: f64,
last_topup: Instant,
max_level: f64,
alloc_per_ms: f64,
}
impl TokenBucket {
pub fn new(initial_level: f64, max_level: f64, alloc_per_ms: f64) -> TokenBucket {
TokenBucket {
level: initial_level,
last_topup: Instant::now(),
max_level,
alloc_per_ms
}
}
pub fn update(self: &mut Self) {
self.level =
(self.level + self.alloc_per_ms * (self.last_topup.elapsed().as_millis() as f64))
.min(self.max_level);
self.last_topup = Instant::now();
}
pub fn consume(self: &mut Self) {
self.level = self.level - 1.0;
}
pub fn consume_minor(self: &mut Self) {
self.level = self.level - 0.1;
}
pub fn nearly_empty(self: &Self) -> bool {
self.level < 1.0
}
pub fn has_capacity(self: &Self) -> bool {
self.level > 0.0
}
pub fn time_to_capacity(self: &Self) -> Duration {
Duration::from_millis((-self.level / self.alloc_per_ms) as u64)
}
}
const CLIENT_INITIAL_TOKENS: f64 = 10.0;
const CLIENT_MAX_LEVEL: f64 = 60.0;
const CLIENT_ALLOC_PER_MS: f64 = 0.005;
const MAX_CONNS_PER_IP: u64 = 5;
async fn handle_client_socket( async fn handle_client_socket(
server: mpsc::Sender<ServerTaskCommand>, server: mpsc::Sender<ServerTaskCommand>,
active_sessions: SessionMap, active_sessions: SessionMap,
@ -285,22 +341,41 @@ async fn handle_client_socket(
codec::LinesCodec::new_with_max_length(512) codec::LinesCodec::new_with_max_length(512)
); );
let session = Uuid::new_v4(); let session = Uuid::new_v4();
let mut tok_bucket =
TokenBucket::new(CLIENT_INITIAL_TOKENS, CLIENT_MAX_LEVEL, CLIENT_ALLOC_PER_MS);
info!("Accepted session {} from {}", session, addr); info!("Accepted session {} from {}", session, addr);
let (lsender, mut lreceiver) = mpsc::channel(MAX_CAPACITY); let (lsender, mut lreceiver) = mpsc::channel(MAX_CAPACITY);
let (discon_sender, mut discon_receiver) = mpsc::unbounded_channel(); let (discon_sender, mut discon_receiver) = mpsc::unbounded_channel();
active_sessions.lock().await.insert( let mut sess_idx_lock = active_sessions.lock().await;
let addr_str = addr.ip().to_string();
if *sess_idx_lock.count_by_source.get(&addr_str).unwrap_or(&0) >= MAX_CONNS_PER_IP {
drop(sess_idx_lock);
info!("Rejecting session {} because of too many concurrent connections", session);
match wstream.write_all("Too many connections from same IP\r\n".as_bytes()).await {
Err(e) => {
info!("Client connection {} got error {}", session, e);
}
Ok(()) => {}
}
return;
}
sess_idx_lock.count_by_source.entry(addr_str.clone()).and_modify(|c| { *c += 1; }).or_insert(1);
sess_idx_lock.by_uuid.insert(
session, SessionRecord { session, SessionRecord {
channel: lsender.clone(), channel: lsender.clone(),
disconnect_channel: discon_sender.clone() disconnect_channel: discon_sender.clone()
}); });
drop(sess_idx_lock);
server.send(ServerTaskCommand::Send { message: MessageFromListener::SessionConnected { server.send(ServerTaskCommand::Send { message: MessageFromListener::SessionConnected {
session, source: addr.to_string() session, source: addr.to_string()
}}).await.unwrap(); }}).await.unwrap();
'client_loop: loop { 'client_loop: loop {
tok_bucket.update();
tokio::select!( tokio::select!(
Some(()) = discon_receiver.recv() => { Some(()) = discon_receiver.recv() => {
info!("Client connection {} instructed for immediate disconnect", session); info!("Client connection {} instructed for immediate disconnect", session);
@ -320,8 +395,10 @@ async fn handle_client_socket(
Ok(()) => {} Ok(()) => {}
} }
} }
}, }
line_read = rbuf.try_next(), if lsender.capacity() > STOP_READING_CAPACITY => { _ = time::sleep(tok_bucket.time_to_capacity()), if !tok_bucket.has_capacity() => {}
line_read = rbuf.try_next(), if lsender.capacity() > STOP_READING_CAPACITY &&
tok_bucket.has_capacity() => {
match line_read { match line_read {
Err(e) => { Err(e) => {
info!("Client connection {} got error {}", session, e); info!("Client connection {} got error {}", session, e);
@ -332,9 +409,21 @@ async fn handle_client_socket(
break 'client_loop; break 'client_loop;
} }
Ok(Some(msg)) => { Ok(Some(msg)) => {
server.send(ServerTaskCommand::Send { if tok_bucket.nearly_empty() {
message: MessageFromListener::SessionSentLine {session, msg } match wstream.write_all("You're sending too fast; dropped message.\r\n"
}).await.unwrap(); .as_bytes()).await {
Err(e) => {
info!("Client connection {} got error {}", session, e);
}
Ok(()) => {}
}
tok_bucket.consume_minor();
} else {
tok_bucket.consume();
server.send(ServerTaskCommand::Send {
message: MessageFromListener::SessionSentLine {session, msg }
}).await.unwrap();
}
} }
} }
} }
@ -344,7 +433,13 @@ async fn handle_client_socket(
server.send(ServerTaskCommand::Send { message: MessageFromListener::SessionDisconnected { server.send(ServerTaskCommand::Send { message: MessageFromListener::SessionDisconnected {
session session
}}).await.unwrap(); }}).await.unwrap();
active_sessions.lock().await.remove(&session);
sess_idx_lock = active_sessions.lock().await;
sess_idx_lock.by_uuid.remove(&session);
sess_idx_lock.count_by_source.entry(addr_str.clone()).and_modify(|v| { *v -= 1; });
if *sess_idx_lock.count_by_source.get(&addr_str).unwrap_or(&1) <= 0 {
sess_idx_lock.count_by_source.remove(&addr_str);
}
} }
fn start_pinger(listener: Uuid, server: mpsc::Sender<ServerTaskCommand>) { fn start_pinger(listener: Uuid, server: mpsc::Sender<ServerTaskCommand>) {
@ -369,17 +464,36 @@ async fn handle_websocket(
let (lsender, mut lreceiver) = mpsc::channel(MAX_CAPACITY); let (lsender, mut lreceiver) = mpsc::channel(MAX_CAPACITY);
let (discon_sender, mut discon_receiver) = mpsc::unbounded_channel(); let (discon_sender, mut discon_receiver) = mpsc::unbounded_channel();
active_sessions.lock().await.insert( let mut sess_idx_lock = active_sessions.lock().await;
let addr_str: String = src.split(" ").last().unwrap_or("").to_string();
if *sess_idx_lock.count_by_source.get(&addr_str).unwrap_or(&0) >= MAX_CONNS_PER_IP {
drop(sess_idx_lock);
info!("Rejecting session {} because of too many concurrent connections", session);
match ws.send(ws::Message::text("Too many connections from same IP\r\n")).await {
Err(e) => {
info!("Client connection {} got error {}", session, e);
}
Ok(()) => {}
}
return;
}
sess_idx_lock.count_by_source.entry(addr_str.clone()).and_modify(|c| { *c += 1; }).or_insert(1);
sess_idx_lock.by_uuid.insert(
session, SessionRecord { session, SessionRecord {
channel: lsender.clone(), channel: lsender.clone(),
disconnect_channel: discon_sender.clone() disconnect_channel: discon_sender.clone()
}); });
drop(sess_idx_lock);
server.send(ServerTaskCommand::Send { message: MessageFromListener::SessionConnected { server.send(ServerTaskCommand::Send { message: MessageFromListener::SessionConnected {
session, source: src session, source: src
}}).await.unwrap(); }}).await.unwrap();
let mut tok_bucket =
TokenBucket::new(CLIENT_INITIAL_TOKENS, CLIENT_MAX_LEVEL, CLIENT_ALLOC_PER_MS);
'client_loop: loop { 'client_loop: loop {
tok_bucket.update();
tokio::select!( tokio::select!(
Some(()) = discon_receiver.recv() => { Some(()) = discon_receiver.recv() => {
info!("Client connection {} instructed for immediate disconnect", session); info!("Client connection {} instructed for immediate disconnect", session);
@ -399,8 +513,10 @@ async fn handle_websocket(
Ok(()) => {} Ok(()) => {}
} }
} }
}, }
msg_read = ws.try_next(), if lsender.capacity() > STOP_READING_CAPACITY => { _ = time::sleep(tok_bucket.time_to_capacity()), if !tok_bucket.has_capacity() => {}
msg_read = ws.try_next(), if lsender.capacity() > STOP_READING_CAPACITY &&
tok_bucket.has_capacity() => {
match msg_read { match msg_read {
Err(e) => { Err(e) => {
info!("Client connection {} got error {}", session, e); info!("Client connection {} got error {}", session, e);
@ -418,12 +534,24 @@ async fn handle_websocket(
match wsmsg.to_str() { match wsmsg.to_str() {
Err(_) => {} Err(_) => {}
Ok(msg) => { Ok(msg) => {
server.send(ServerTaskCommand::Send { if tok_bucket.nearly_empty() {
message: MessageFromListener::SessionSentLine { match ws.send(ws::Message::text("You're sending too fast; dropped message.\r\n")).await {
session, Err(e) => {
msg: msg.to_owned() info!("Client connection {} got error {}", session, e);
}
Ok(()) => {}
} }
}).await.unwrap_or(()); tok_bucket.consume_minor();
} else {
tok_bucket.consume();
server.send(ServerTaskCommand::Send {
message: MessageFromListener::SessionSentLine {
session,
msg: msg.to_owned()
}
}).await.unwrap_or(());
}
} }
} }
} }
@ -444,7 +572,13 @@ async fn handle_websocket(
server.send(ServerTaskCommand::Send { message: MessageFromListener::SessionDisconnected { server.send(ServerTaskCommand::Send { message: MessageFromListener::SessionDisconnected {
session session
}}).await.unwrap(); }}).await.unwrap();
active_sessions.lock().await.remove(&session);
sess_idx_lock = active_sessions.lock().await;
sess_idx_lock.by_uuid.remove(&session);
sess_idx_lock.count_by_source.entry(addr_str.clone()).and_modify(|v| { *v -= 1; });
if *sess_idx_lock.count_by_source.get(&addr_str).unwrap_or(&1) <= 0 {
sess_idx_lock.count_by_source.remove(&addr_str);
}
} }
async fn upgrade_websocket(src: String, wsreq: ws::Ws, async fn upgrade_websocket(src: String, wsreq: ws::Ws,
@ -482,7 +616,7 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
let listener_id = Uuid::new_v4(); let listener_id = Uuid::new_v4();
let mut config = read_latest_config()?; let mut config = read_latest_config()?;
let active_sessions: SessionMap = let active_sessions: SessionMap =
Arc::new(Mutex::new(BTreeMap::new())); Arc::new(Mutex::new(SessionIndexes { by_uuid: BTreeMap::new(), count_by_source: BTreeMap::new() }));
let server_sender = start_server_task(listener_id, config.gameserver, active_sessions.clone()); let server_sender = start_server_task(listener_id, config.gameserver, active_sessions.clone());
start_pinger(listener_id, server_sender.clone()); start_pinger(listener_id, server_sender.clone());