forked from blasthavers/blastmud
Implement features of listener to limit abuse of resources / DoS
This commit is contained in:
parent
e5e6c45e81
commit
55d3087d21
@ -23,6 +23,7 @@ use tokio_stream::wrappers::ReceiverStream;
|
||||
use warp;
|
||||
use warp::filters::ws;
|
||||
use warp::Filter;
|
||||
use std::time::Instant;
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct Config {
|
||||
@ -227,13 +228,18 @@ struct SessionRecord {
|
||||
disconnect_channel: mpsc::UnboundedSender<()>
|
||||
}
|
||||
|
||||
type SessionMap = Arc<Mutex<BTreeMap<Uuid, SessionRecord>>>;
|
||||
struct SessionIndexes {
|
||||
by_uuid: BTreeMap<Uuid, SessionRecord>,
|
||||
count_by_source: BTreeMap<String, u64>
|
||||
}
|
||||
|
||||
type SessionMap = Arc<Mutex<SessionIndexes>>;
|
||||
|
||||
async fn handle_server_message(session_map: SessionMap, message: MessageToListener) {
|
||||
match message {
|
||||
MessageToListener::AcknowledgeMessage => {}
|
||||
MessageToListener::DisconnectSession { session } => {
|
||||
match session_map.lock().await.get(&session) {
|
||||
match session_map.lock().await.by_uuid.get(&session) {
|
||||
// Just silently ignore it if they are disconnected.
|
||||
None => {}
|
||||
Some(SessionRecord { channel, disconnect_channel, .. }) => {
|
||||
@ -247,7 +253,7 @@ async fn handle_server_message(session_map: SessionMap, message: MessageToListen
|
||||
}
|
||||
}
|
||||
MessageToListener::SendToSession { session, msg } => {
|
||||
match session_map.lock().await.get(&session) {
|
||||
match session_map.lock().await.by_uuid.get(&session) {
|
||||
// Just silently ignore it if they are disconnected.
|
||||
None => {}
|
||||
Some(SessionRecord { channel, .. }) => {
|
||||
@ -273,6 +279,56 @@ fn start_server_task(listener_id: Uuid,
|
||||
const MAX_CAPACITY: usize = 20;
|
||||
const STOP_READING_CAPACITY: usize = 10;
|
||||
|
||||
struct TokenBucket {
|
||||
level: f64,
|
||||
last_topup: Instant,
|
||||
max_level: f64,
|
||||
alloc_per_ms: f64,
|
||||
}
|
||||
|
||||
impl TokenBucket {
|
||||
pub fn new(initial_level: f64, max_level: f64, alloc_per_ms: f64) -> TokenBucket {
|
||||
TokenBucket {
|
||||
level: initial_level,
|
||||
last_topup: Instant::now(),
|
||||
max_level,
|
||||
alloc_per_ms
|
||||
}
|
||||
}
|
||||
|
||||
pub fn update(self: &mut Self) {
|
||||
self.level =
|
||||
(self.level + self.alloc_per_ms * (self.last_topup.elapsed().as_millis() as f64))
|
||||
.min(self.max_level);
|
||||
self.last_topup = Instant::now();
|
||||
}
|
||||
|
||||
pub fn consume(self: &mut Self) {
|
||||
self.level = self.level - 1.0;
|
||||
}
|
||||
|
||||
pub fn consume_minor(self: &mut Self) {
|
||||
self.level = self.level - 0.1;
|
||||
}
|
||||
|
||||
pub fn nearly_empty(self: &Self) -> bool {
|
||||
self.level < 1.0
|
||||
}
|
||||
|
||||
pub fn has_capacity(self: &Self) -> bool {
|
||||
self.level > 0.0
|
||||
}
|
||||
|
||||
pub fn time_to_capacity(self: &Self) -> Duration {
|
||||
Duration::from_millis((-self.level / self.alloc_per_ms) as u64)
|
||||
}
|
||||
}
|
||||
|
||||
const CLIENT_INITIAL_TOKENS: f64 = 10.0;
|
||||
const CLIENT_MAX_LEVEL: f64 = 60.0;
|
||||
const CLIENT_ALLOC_PER_MS: f64 = 0.005;
|
||||
const MAX_CONNS_PER_IP: u64 = 5;
|
||||
|
||||
async fn handle_client_socket(
|
||||
server: mpsc::Sender<ServerTaskCommand>,
|
||||
active_sessions: SessionMap,
|
||||
@ -285,22 +341,41 @@ async fn handle_client_socket(
|
||||
codec::LinesCodec::new_with_max_length(512)
|
||||
);
|
||||
let session = Uuid::new_v4();
|
||||
let mut tok_bucket =
|
||||
TokenBucket::new(CLIENT_INITIAL_TOKENS, CLIENT_MAX_LEVEL, CLIENT_ALLOC_PER_MS);
|
||||
info!("Accepted session {} from {}", session, addr);
|
||||
|
||||
|
||||
let (lsender, mut lreceiver) = mpsc::channel(MAX_CAPACITY);
|
||||
let (discon_sender, mut discon_receiver) = mpsc::unbounded_channel();
|
||||
|
||||
active_sessions.lock().await.insert(
|
||||
let mut sess_idx_lock = active_sessions.lock().await;
|
||||
let addr_str = addr.ip().to_string();
|
||||
if *sess_idx_lock.count_by_source.get(&addr_str).unwrap_or(&0) >= MAX_CONNS_PER_IP {
|
||||
drop(sess_idx_lock);
|
||||
info!("Rejecting session {} because of too many concurrent connections", session);
|
||||
match wstream.write_all("Too many connections from same IP\r\n".as_bytes()).await {
|
||||
Err(e) => {
|
||||
info!("Client connection {} got error {}", session, e);
|
||||
}
|
||||
Ok(()) => {}
|
||||
}
|
||||
return;
|
||||
}
|
||||
sess_idx_lock.count_by_source.entry(addr_str.clone()).and_modify(|c| { *c += 1; }).or_insert(1);
|
||||
sess_idx_lock.by_uuid.insert(
|
||||
session, SessionRecord {
|
||||
channel: lsender.clone(),
|
||||
disconnect_channel: discon_sender.clone()
|
||||
});
|
||||
drop(sess_idx_lock);
|
||||
|
||||
server.send(ServerTaskCommand::Send { message: MessageFromListener::SessionConnected {
|
||||
session, source: addr.to_string()
|
||||
}}).await.unwrap();
|
||||
|
||||
'client_loop: loop {
|
||||
tok_bucket.update();
|
||||
tokio::select!(
|
||||
Some(()) = discon_receiver.recv() => {
|
||||
info!("Client connection {} instructed for immediate disconnect", session);
|
||||
@ -320,8 +395,10 @@ async fn handle_client_socket(
|
||||
Ok(()) => {}
|
||||
}
|
||||
}
|
||||
},
|
||||
line_read = rbuf.try_next(), if lsender.capacity() > STOP_READING_CAPACITY => {
|
||||
}
|
||||
_ = time::sleep(tok_bucket.time_to_capacity()), if !tok_bucket.has_capacity() => {}
|
||||
line_read = rbuf.try_next(), if lsender.capacity() > STOP_READING_CAPACITY &&
|
||||
tok_bucket.has_capacity() => {
|
||||
match line_read {
|
||||
Err(e) => {
|
||||
info!("Client connection {} got error {}", session, e);
|
||||
@ -332,9 +409,21 @@ async fn handle_client_socket(
|
||||
break 'client_loop;
|
||||
}
|
||||
Ok(Some(msg)) => {
|
||||
server.send(ServerTaskCommand::Send {
|
||||
message: MessageFromListener::SessionSentLine {session, msg }
|
||||
}).await.unwrap();
|
||||
if tok_bucket.nearly_empty() {
|
||||
match wstream.write_all("You're sending too fast; dropped message.\r\n"
|
||||
.as_bytes()).await {
|
||||
Err(e) => {
|
||||
info!("Client connection {} got error {}", session, e);
|
||||
}
|
||||
Ok(()) => {}
|
||||
}
|
||||
tok_bucket.consume_minor();
|
||||
} else {
|
||||
tok_bucket.consume();
|
||||
server.send(ServerTaskCommand::Send {
|
||||
message: MessageFromListener::SessionSentLine {session, msg }
|
||||
}).await.unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -344,7 +433,13 @@ async fn handle_client_socket(
|
||||
server.send(ServerTaskCommand::Send { message: MessageFromListener::SessionDisconnected {
|
||||
session
|
||||
}}).await.unwrap();
|
||||
active_sessions.lock().await.remove(&session);
|
||||
|
||||
sess_idx_lock = active_sessions.lock().await;
|
||||
sess_idx_lock.by_uuid.remove(&session);
|
||||
sess_idx_lock.count_by_source.entry(addr_str.clone()).and_modify(|v| { *v -= 1; });
|
||||
if *sess_idx_lock.count_by_source.get(&addr_str).unwrap_or(&1) <= 0 {
|
||||
sess_idx_lock.count_by_source.remove(&addr_str);
|
||||
}
|
||||
}
|
||||
|
||||
fn start_pinger(listener: Uuid, server: mpsc::Sender<ServerTaskCommand>) {
|
||||
@ -369,17 +464,36 @@ async fn handle_websocket(
|
||||
|
||||
let (lsender, mut lreceiver) = mpsc::channel(MAX_CAPACITY);
|
||||
let (discon_sender, mut discon_receiver) = mpsc::unbounded_channel();
|
||||
|
||||
active_sessions.lock().await.insert(
|
||||
|
||||
let mut sess_idx_lock = active_sessions.lock().await;
|
||||
let addr_str: String = src.split(" ").last().unwrap_or("").to_string();
|
||||
if *sess_idx_lock.count_by_source.get(&addr_str).unwrap_or(&0) >= MAX_CONNS_PER_IP {
|
||||
drop(sess_idx_lock);
|
||||
info!("Rejecting session {} because of too many concurrent connections", session);
|
||||
match ws.send(ws::Message::text("Too many connections from same IP\r\n")).await {
|
||||
Err(e) => {
|
||||
info!("Client connection {} got error {}", session, e);
|
||||
}
|
||||
Ok(()) => {}
|
||||
}
|
||||
return;
|
||||
}
|
||||
sess_idx_lock.count_by_source.entry(addr_str.clone()).and_modify(|c| { *c += 1; }).or_insert(1);
|
||||
sess_idx_lock.by_uuid.insert(
|
||||
session, SessionRecord {
|
||||
channel: lsender.clone(),
|
||||
disconnect_channel: discon_sender.clone()
|
||||
});
|
||||
drop(sess_idx_lock);
|
||||
|
||||
server.send(ServerTaskCommand::Send { message: MessageFromListener::SessionConnected {
|
||||
session, source: src
|
||||
}}).await.unwrap();
|
||||
let mut tok_bucket =
|
||||
TokenBucket::new(CLIENT_INITIAL_TOKENS, CLIENT_MAX_LEVEL, CLIENT_ALLOC_PER_MS);
|
||||
|
||||
'client_loop: loop {
|
||||
tok_bucket.update();
|
||||
tokio::select!(
|
||||
Some(()) = discon_receiver.recv() => {
|
||||
info!("Client connection {} instructed for immediate disconnect", session);
|
||||
@ -399,8 +513,10 @@ async fn handle_websocket(
|
||||
Ok(()) => {}
|
||||
}
|
||||
}
|
||||
},
|
||||
msg_read = ws.try_next(), if lsender.capacity() > STOP_READING_CAPACITY => {
|
||||
}
|
||||
_ = time::sleep(tok_bucket.time_to_capacity()), if !tok_bucket.has_capacity() => {}
|
||||
msg_read = ws.try_next(), if lsender.capacity() > STOP_READING_CAPACITY &&
|
||||
tok_bucket.has_capacity() => {
|
||||
match msg_read {
|
||||
Err(e) => {
|
||||
info!("Client connection {} got error {}", session, e);
|
||||
@ -418,12 +534,24 @@ async fn handle_websocket(
|
||||
match wsmsg.to_str() {
|
||||
Err(_) => {}
|
||||
Ok(msg) => {
|
||||
server.send(ServerTaskCommand::Send {
|
||||
message: MessageFromListener::SessionSentLine {
|
||||
session,
|
||||
msg: msg.to_owned()
|
||||
if tok_bucket.nearly_empty() {
|
||||
match ws.send(ws::Message::text("You're sending too fast; dropped message.\r\n")).await {
|
||||
Err(e) => {
|
||||
info!("Client connection {} got error {}", session, e);
|
||||
}
|
||||
Ok(()) => {}
|
||||
}
|
||||
}).await.unwrap_or(());
|
||||
tok_bucket.consume_minor();
|
||||
} else {
|
||||
tok_bucket.consume();
|
||||
|
||||
server.send(ServerTaskCommand::Send {
|
||||
message: MessageFromListener::SessionSentLine {
|
||||
session,
|
||||
msg: msg.to_owned()
|
||||
}
|
||||
}).await.unwrap_or(());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -444,7 +572,13 @@ async fn handle_websocket(
|
||||
server.send(ServerTaskCommand::Send { message: MessageFromListener::SessionDisconnected {
|
||||
session
|
||||
}}).await.unwrap();
|
||||
active_sessions.lock().await.remove(&session);
|
||||
|
||||
sess_idx_lock = active_sessions.lock().await;
|
||||
sess_idx_lock.by_uuid.remove(&session);
|
||||
sess_idx_lock.count_by_source.entry(addr_str.clone()).and_modify(|v| { *v -= 1; });
|
||||
if *sess_idx_lock.count_by_source.get(&addr_str).unwrap_or(&1) <= 0 {
|
||||
sess_idx_lock.count_by_source.remove(&addr_str);
|
||||
}
|
||||
}
|
||||
|
||||
async fn upgrade_websocket(src: String, wsreq: ws::Ws,
|
||||
@ -482,7 +616,7 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
|
||||
let listener_id = Uuid::new_v4();
|
||||
let mut config = read_latest_config()?;
|
||||
let active_sessions: SessionMap =
|
||||
Arc::new(Mutex::new(BTreeMap::new()));
|
||||
Arc::new(Mutex::new(SessionIndexes { by_uuid: BTreeMap::new(), count_by_source: BTreeMap::new() }));
|
||||
let server_sender = start_server_task(listener_id, config.gameserver, active_sessions.clone());
|
||||
|
||||
start_pinger(listener_id, server_sender.clone());
|
||||
|
Loading…
Reference in New Issue
Block a user