forked from blasthavers/blastmud
Implement features of listener to limit abuse of resources / DoS
This commit is contained in:
parent
e5e6c45e81
commit
55d3087d21
@ -23,6 +23,7 @@ use tokio_stream::wrappers::ReceiverStream;
|
|||||||
use warp;
|
use warp;
|
||||||
use warp::filters::ws;
|
use warp::filters::ws;
|
||||||
use warp::Filter;
|
use warp::Filter;
|
||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
#[derive(Deserialize, Debug)]
|
#[derive(Deserialize, Debug)]
|
||||||
struct Config {
|
struct Config {
|
||||||
@ -227,13 +228,18 @@ struct SessionRecord {
|
|||||||
disconnect_channel: mpsc::UnboundedSender<()>
|
disconnect_channel: mpsc::UnboundedSender<()>
|
||||||
}
|
}
|
||||||
|
|
||||||
type SessionMap = Arc<Mutex<BTreeMap<Uuid, SessionRecord>>>;
|
struct SessionIndexes {
|
||||||
|
by_uuid: BTreeMap<Uuid, SessionRecord>,
|
||||||
|
count_by_source: BTreeMap<String, u64>
|
||||||
|
}
|
||||||
|
|
||||||
|
type SessionMap = Arc<Mutex<SessionIndexes>>;
|
||||||
|
|
||||||
async fn handle_server_message(session_map: SessionMap, message: MessageToListener) {
|
async fn handle_server_message(session_map: SessionMap, message: MessageToListener) {
|
||||||
match message {
|
match message {
|
||||||
MessageToListener::AcknowledgeMessage => {}
|
MessageToListener::AcknowledgeMessage => {}
|
||||||
MessageToListener::DisconnectSession { session } => {
|
MessageToListener::DisconnectSession { session } => {
|
||||||
match session_map.lock().await.get(&session) {
|
match session_map.lock().await.by_uuid.get(&session) {
|
||||||
// Just silently ignore it if they are disconnected.
|
// Just silently ignore it if they are disconnected.
|
||||||
None => {}
|
None => {}
|
||||||
Some(SessionRecord { channel, disconnect_channel, .. }) => {
|
Some(SessionRecord { channel, disconnect_channel, .. }) => {
|
||||||
@ -247,7 +253,7 @@ async fn handle_server_message(session_map: SessionMap, message: MessageToListen
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
MessageToListener::SendToSession { session, msg } => {
|
MessageToListener::SendToSession { session, msg } => {
|
||||||
match session_map.lock().await.get(&session) {
|
match session_map.lock().await.by_uuid.get(&session) {
|
||||||
// Just silently ignore it if they are disconnected.
|
// Just silently ignore it if they are disconnected.
|
||||||
None => {}
|
None => {}
|
||||||
Some(SessionRecord { channel, .. }) => {
|
Some(SessionRecord { channel, .. }) => {
|
||||||
@ -273,6 +279,56 @@ fn start_server_task(listener_id: Uuid,
|
|||||||
const MAX_CAPACITY: usize = 20;
|
const MAX_CAPACITY: usize = 20;
|
||||||
const STOP_READING_CAPACITY: usize = 10;
|
const STOP_READING_CAPACITY: usize = 10;
|
||||||
|
|
||||||
|
struct TokenBucket {
|
||||||
|
level: f64,
|
||||||
|
last_topup: Instant,
|
||||||
|
max_level: f64,
|
||||||
|
alloc_per_ms: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TokenBucket {
|
||||||
|
pub fn new(initial_level: f64, max_level: f64, alloc_per_ms: f64) -> TokenBucket {
|
||||||
|
TokenBucket {
|
||||||
|
level: initial_level,
|
||||||
|
last_topup: Instant::now(),
|
||||||
|
max_level,
|
||||||
|
alloc_per_ms
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn update(self: &mut Self) {
|
||||||
|
self.level =
|
||||||
|
(self.level + self.alloc_per_ms * (self.last_topup.elapsed().as_millis() as f64))
|
||||||
|
.min(self.max_level);
|
||||||
|
self.last_topup = Instant::now();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn consume(self: &mut Self) {
|
||||||
|
self.level = self.level - 1.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn consume_minor(self: &mut Self) {
|
||||||
|
self.level = self.level - 0.1;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn nearly_empty(self: &Self) -> bool {
|
||||||
|
self.level < 1.0
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn has_capacity(self: &Self) -> bool {
|
||||||
|
self.level > 0.0
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn time_to_capacity(self: &Self) -> Duration {
|
||||||
|
Duration::from_millis((-self.level / self.alloc_per_ms) as u64)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const CLIENT_INITIAL_TOKENS: f64 = 10.0;
|
||||||
|
const CLIENT_MAX_LEVEL: f64 = 60.0;
|
||||||
|
const CLIENT_ALLOC_PER_MS: f64 = 0.005;
|
||||||
|
const MAX_CONNS_PER_IP: u64 = 5;
|
||||||
|
|
||||||
async fn handle_client_socket(
|
async fn handle_client_socket(
|
||||||
server: mpsc::Sender<ServerTaskCommand>,
|
server: mpsc::Sender<ServerTaskCommand>,
|
||||||
active_sessions: SessionMap,
|
active_sessions: SessionMap,
|
||||||
@ -285,22 +341,41 @@ async fn handle_client_socket(
|
|||||||
codec::LinesCodec::new_with_max_length(512)
|
codec::LinesCodec::new_with_max_length(512)
|
||||||
);
|
);
|
||||||
let session = Uuid::new_v4();
|
let session = Uuid::new_v4();
|
||||||
|
let mut tok_bucket =
|
||||||
|
TokenBucket::new(CLIENT_INITIAL_TOKENS, CLIENT_MAX_LEVEL, CLIENT_ALLOC_PER_MS);
|
||||||
info!("Accepted session {} from {}", session, addr);
|
info!("Accepted session {} from {}", session, addr);
|
||||||
|
|
||||||
|
|
||||||
let (lsender, mut lreceiver) = mpsc::channel(MAX_CAPACITY);
|
let (lsender, mut lreceiver) = mpsc::channel(MAX_CAPACITY);
|
||||||
let (discon_sender, mut discon_receiver) = mpsc::unbounded_channel();
|
let (discon_sender, mut discon_receiver) = mpsc::unbounded_channel();
|
||||||
|
|
||||||
active_sessions.lock().await.insert(
|
let mut sess_idx_lock = active_sessions.lock().await;
|
||||||
|
let addr_str = addr.ip().to_string();
|
||||||
|
if *sess_idx_lock.count_by_source.get(&addr_str).unwrap_or(&0) >= MAX_CONNS_PER_IP {
|
||||||
|
drop(sess_idx_lock);
|
||||||
|
info!("Rejecting session {} because of too many concurrent connections", session);
|
||||||
|
match wstream.write_all("Too many connections from same IP\r\n".as_bytes()).await {
|
||||||
|
Err(e) => {
|
||||||
|
info!("Client connection {} got error {}", session, e);
|
||||||
|
}
|
||||||
|
Ok(()) => {}
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
sess_idx_lock.count_by_source.entry(addr_str.clone()).and_modify(|c| { *c += 1; }).or_insert(1);
|
||||||
|
sess_idx_lock.by_uuid.insert(
|
||||||
session, SessionRecord {
|
session, SessionRecord {
|
||||||
channel: lsender.clone(),
|
channel: lsender.clone(),
|
||||||
disconnect_channel: discon_sender.clone()
|
disconnect_channel: discon_sender.clone()
|
||||||
});
|
});
|
||||||
|
drop(sess_idx_lock);
|
||||||
|
|
||||||
server.send(ServerTaskCommand::Send { message: MessageFromListener::SessionConnected {
|
server.send(ServerTaskCommand::Send { message: MessageFromListener::SessionConnected {
|
||||||
session, source: addr.to_string()
|
session, source: addr.to_string()
|
||||||
}}).await.unwrap();
|
}}).await.unwrap();
|
||||||
|
|
||||||
'client_loop: loop {
|
'client_loop: loop {
|
||||||
|
tok_bucket.update();
|
||||||
tokio::select!(
|
tokio::select!(
|
||||||
Some(()) = discon_receiver.recv() => {
|
Some(()) = discon_receiver.recv() => {
|
||||||
info!("Client connection {} instructed for immediate disconnect", session);
|
info!("Client connection {} instructed for immediate disconnect", session);
|
||||||
@ -320,8 +395,10 @@ async fn handle_client_socket(
|
|||||||
Ok(()) => {}
|
Ok(()) => {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
line_read = rbuf.try_next(), if lsender.capacity() > STOP_READING_CAPACITY => {
|
_ = time::sleep(tok_bucket.time_to_capacity()), if !tok_bucket.has_capacity() => {}
|
||||||
|
line_read = rbuf.try_next(), if lsender.capacity() > STOP_READING_CAPACITY &&
|
||||||
|
tok_bucket.has_capacity() => {
|
||||||
match line_read {
|
match line_read {
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
info!("Client connection {} got error {}", session, e);
|
info!("Client connection {} got error {}", session, e);
|
||||||
@ -332,9 +409,21 @@ async fn handle_client_socket(
|
|||||||
break 'client_loop;
|
break 'client_loop;
|
||||||
}
|
}
|
||||||
Ok(Some(msg)) => {
|
Ok(Some(msg)) => {
|
||||||
server.send(ServerTaskCommand::Send {
|
if tok_bucket.nearly_empty() {
|
||||||
message: MessageFromListener::SessionSentLine {session, msg }
|
match wstream.write_all("You're sending too fast; dropped message.\r\n"
|
||||||
}).await.unwrap();
|
.as_bytes()).await {
|
||||||
|
Err(e) => {
|
||||||
|
info!("Client connection {} got error {}", session, e);
|
||||||
|
}
|
||||||
|
Ok(()) => {}
|
||||||
|
}
|
||||||
|
tok_bucket.consume_minor();
|
||||||
|
} else {
|
||||||
|
tok_bucket.consume();
|
||||||
|
server.send(ServerTaskCommand::Send {
|
||||||
|
message: MessageFromListener::SessionSentLine {session, msg }
|
||||||
|
}).await.unwrap();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -344,7 +433,13 @@ async fn handle_client_socket(
|
|||||||
server.send(ServerTaskCommand::Send { message: MessageFromListener::SessionDisconnected {
|
server.send(ServerTaskCommand::Send { message: MessageFromListener::SessionDisconnected {
|
||||||
session
|
session
|
||||||
}}).await.unwrap();
|
}}).await.unwrap();
|
||||||
active_sessions.lock().await.remove(&session);
|
|
||||||
|
sess_idx_lock = active_sessions.lock().await;
|
||||||
|
sess_idx_lock.by_uuid.remove(&session);
|
||||||
|
sess_idx_lock.count_by_source.entry(addr_str.clone()).and_modify(|v| { *v -= 1; });
|
||||||
|
if *sess_idx_lock.count_by_source.get(&addr_str).unwrap_or(&1) <= 0 {
|
||||||
|
sess_idx_lock.count_by_source.remove(&addr_str);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn start_pinger(listener: Uuid, server: mpsc::Sender<ServerTaskCommand>) {
|
fn start_pinger(listener: Uuid, server: mpsc::Sender<ServerTaskCommand>) {
|
||||||
@ -369,17 +464,36 @@ async fn handle_websocket(
|
|||||||
|
|
||||||
let (lsender, mut lreceiver) = mpsc::channel(MAX_CAPACITY);
|
let (lsender, mut lreceiver) = mpsc::channel(MAX_CAPACITY);
|
||||||
let (discon_sender, mut discon_receiver) = mpsc::unbounded_channel();
|
let (discon_sender, mut discon_receiver) = mpsc::unbounded_channel();
|
||||||
|
|
||||||
active_sessions.lock().await.insert(
|
let mut sess_idx_lock = active_sessions.lock().await;
|
||||||
|
let addr_str: String = src.split(" ").last().unwrap_or("").to_string();
|
||||||
|
if *sess_idx_lock.count_by_source.get(&addr_str).unwrap_or(&0) >= MAX_CONNS_PER_IP {
|
||||||
|
drop(sess_idx_lock);
|
||||||
|
info!("Rejecting session {} because of too many concurrent connections", session);
|
||||||
|
match ws.send(ws::Message::text("Too many connections from same IP\r\n")).await {
|
||||||
|
Err(e) => {
|
||||||
|
info!("Client connection {} got error {}", session, e);
|
||||||
|
}
|
||||||
|
Ok(()) => {}
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
sess_idx_lock.count_by_source.entry(addr_str.clone()).and_modify(|c| { *c += 1; }).or_insert(1);
|
||||||
|
sess_idx_lock.by_uuid.insert(
|
||||||
session, SessionRecord {
|
session, SessionRecord {
|
||||||
channel: lsender.clone(),
|
channel: lsender.clone(),
|
||||||
disconnect_channel: discon_sender.clone()
|
disconnect_channel: discon_sender.clone()
|
||||||
});
|
});
|
||||||
|
drop(sess_idx_lock);
|
||||||
|
|
||||||
server.send(ServerTaskCommand::Send { message: MessageFromListener::SessionConnected {
|
server.send(ServerTaskCommand::Send { message: MessageFromListener::SessionConnected {
|
||||||
session, source: src
|
session, source: src
|
||||||
}}).await.unwrap();
|
}}).await.unwrap();
|
||||||
|
let mut tok_bucket =
|
||||||
|
TokenBucket::new(CLIENT_INITIAL_TOKENS, CLIENT_MAX_LEVEL, CLIENT_ALLOC_PER_MS);
|
||||||
|
|
||||||
'client_loop: loop {
|
'client_loop: loop {
|
||||||
|
tok_bucket.update();
|
||||||
tokio::select!(
|
tokio::select!(
|
||||||
Some(()) = discon_receiver.recv() => {
|
Some(()) = discon_receiver.recv() => {
|
||||||
info!("Client connection {} instructed for immediate disconnect", session);
|
info!("Client connection {} instructed for immediate disconnect", session);
|
||||||
@ -399,8 +513,10 @@ async fn handle_websocket(
|
|||||||
Ok(()) => {}
|
Ok(()) => {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
msg_read = ws.try_next(), if lsender.capacity() > STOP_READING_CAPACITY => {
|
_ = time::sleep(tok_bucket.time_to_capacity()), if !tok_bucket.has_capacity() => {}
|
||||||
|
msg_read = ws.try_next(), if lsender.capacity() > STOP_READING_CAPACITY &&
|
||||||
|
tok_bucket.has_capacity() => {
|
||||||
match msg_read {
|
match msg_read {
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
info!("Client connection {} got error {}", session, e);
|
info!("Client connection {} got error {}", session, e);
|
||||||
@ -418,12 +534,24 @@ async fn handle_websocket(
|
|||||||
match wsmsg.to_str() {
|
match wsmsg.to_str() {
|
||||||
Err(_) => {}
|
Err(_) => {}
|
||||||
Ok(msg) => {
|
Ok(msg) => {
|
||||||
server.send(ServerTaskCommand::Send {
|
if tok_bucket.nearly_empty() {
|
||||||
message: MessageFromListener::SessionSentLine {
|
match ws.send(ws::Message::text("You're sending too fast; dropped message.\r\n")).await {
|
||||||
session,
|
Err(e) => {
|
||||||
msg: msg.to_owned()
|
info!("Client connection {} got error {}", session, e);
|
||||||
|
}
|
||||||
|
Ok(()) => {}
|
||||||
}
|
}
|
||||||
}).await.unwrap_or(());
|
tok_bucket.consume_minor();
|
||||||
|
} else {
|
||||||
|
tok_bucket.consume();
|
||||||
|
|
||||||
|
server.send(ServerTaskCommand::Send {
|
||||||
|
message: MessageFromListener::SessionSentLine {
|
||||||
|
session,
|
||||||
|
msg: msg.to_owned()
|
||||||
|
}
|
||||||
|
}).await.unwrap_or(());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -444,7 +572,13 @@ async fn handle_websocket(
|
|||||||
server.send(ServerTaskCommand::Send { message: MessageFromListener::SessionDisconnected {
|
server.send(ServerTaskCommand::Send { message: MessageFromListener::SessionDisconnected {
|
||||||
session
|
session
|
||||||
}}).await.unwrap();
|
}}).await.unwrap();
|
||||||
active_sessions.lock().await.remove(&session);
|
|
||||||
|
sess_idx_lock = active_sessions.lock().await;
|
||||||
|
sess_idx_lock.by_uuid.remove(&session);
|
||||||
|
sess_idx_lock.count_by_source.entry(addr_str.clone()).and_modify(|v| { *v -= 1; });
|
||||||
|
if *sess_idx_lock.count_by_source.get(&addr_str).unwrap_or(&1) <= 0 {
|
||||||
|
sess_idx_lock.count_by_source.remove(&addr_str);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn upgrade_websocket(src: String, wsreq: ws::Ws,
|
async fn upgrade_websocket(src: String, wsreq: ws::Ws,
|
||||||
@ -482,7 +616,7 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
|
|||||||
let listener_id = Uuid::new_v4();
|
let listener_id = Uuid::new_v4();
|
||||||
let mut config = read_latest_config()?;
|
let mut config = read_latest_config()?;
|
||||||
let active_sessions: SessionMap =
|
let active_sessions: SessionMap =
|
||||||
Arc::new(Mutex::new(BTreeMap::new()));
|
Arc::new(Mutex::new(SessionIndexes { by_uuid: BTreeMap::new(), count_by_source: BTreeMap::new() }));
|
||||||
let server_sender = start_server_task(listener_id, config.gameserver, active_sessions.clone());
|
let server_sender = start_server_task(listener_id, config.gameserver, active_sessions.clone());
|
||||||
|
|
||||||
start_pinger(listener_id, server_sender.clone());
|
start_pinger(listener_id, server_sender.clone());
|
||||||
|
Loading…
Reference in New Issue
Block a user