Check Origin, not CORS (no CORS for WS)

This commit is contained in:
Condorra 2024-09-07 16:42:29 +10:00
parent 89379e40d9
commit 958607eabb
3 changed files with 15 additions and 28 deletions

16
Cargo.lock generated
View File

@ -19,21 +19,6 @@ dependencies = [
"tracing", "tracing",
] ]
[[package]]
name = "actix-cors"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f9e772b3bcafe335042b5db010ab7c09013dad6eac4915c91d8d50902769f331"
dependencies = [
"actix-utils",
"actix-web",
"derive_more",
"futures-util",
"log",
"once_cell",
"smallvec",
]
[[package]] [[package]]
name = "actix-http" name = "actix-http"
version = "3.9.0" version = "3.9.0"
@ -1736,7 +1721,6 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
name = "worldwideportal-server" name = "worldwideportal-server"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"actix-cors",
"actix-web", "actix-web",
"actix-ws", "actix-ws",
"anyhow", "anyhow",

View File

@ -6,7 +6,6 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies] [dependencies]
actix-cors = "0.7.0"
actix-web = { version = "4.9.0", features = ["rustls-0_23"] } actix-web = { version = "4.9.0", features = ["rustls-0_23"] }
actix-ws = "0.3.0" actix-ws = "0.3.0"
anyhow = "1.0.86" anyhow = "1.0.86"

View File

@ -1,9 +1,8 @@
use actix_cors::Cors;
use actix_web::{ use actix_web::{
self, self,
error::InternalError, error::{ErrorForbidden, InternalError},
get, get,
http::StatusCode, http::{header::ORIGIN, StatusCode},
middleware::Logger, middleware::Logger,
rt::{self, net::TcpStream}, rt::{self, net::TcpStream},
web::{self, Data}, web::{self, Data},
@ -44,6 +43,19 @@ async fn ws(
req: HttpRequest, req: HttpRequest,
body: web::Payload, body: web::Payload,
) -> impl Responder { ) -> impl Responder {
match req.headers().get(&ORIGIN) {
None => Err(ErrorForbidden("Missing origin"))?,
Some(origin) => {
if !config_data
.allowed_origins
.iter()
.any(|o| o.matches(origin.to_str().unwrap_or("invalid")))
{
Err(ErrorForbidden("Disallowed origin"))?;
}
}
}
let (response, mut session, stream) = actix_ws::handle(&req, body)?; let (response, mut session, stream) = actix_ws::handle(&req, body)?;
let mut stream = stream.aggregate_continuations().max_continuation_size(1024); let mut stream = stream.aggregate_continuations().max_continuation_size(1024);
@ -182,16 +194,8 @@ async fn main() -> anyhow::Result<()> {
let server_data = data.clone(); let server_data = data.clone();
let server = HttpServer::new(move || { let server = HttpServer::new(move || {
let logger = Logger::default(); let logger = Logger::default();
let cors_server_data = server_data.clone();
let cors = Cors::default().allowed_origin_fn(move |origin, _| {
cors_server_data
.allowed_origins
.iter()
.any(|o| o.matches(origin.to_str().unwrap_or("invalid")))
});
App::new() App::new()
.wrap(logger) .wrap(logger)
.wrap(cors)
.app_data(server_data.clone()) .app_data(server_data.clone())
.service(ws) .service(ws)
}); });