From d325537308d257f00b293a586fffe6ab58f0748a Mon Sep 17 00:00:00 2001 From: NinekoTheCat Date: Mon, 25 Dec 2023 00:50:49 +0100 Subject: [PATCH] feat: added ACL to all server-server routes --- src/main.rs | 121 ++++++++++++++++++++++++++++++++--------------- src/utils/mod.rs | 2 + 2 files changed, 85 insertions(+), 38 deletions(-) diff --git a/src/main.rs b/src/main.rs index 576e17c3..69c383d9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -13,8 +13,9 @@ use std::{ }; use axum::{ - extract::{DefaultBodyLimit, FromRequestParts, MatchedPath}, - response::IntoResponse, + extract::{DefaultBodyLimit, FromRequestParts, Host, MatchedPath}, + middleware::Next, + response::{IntoResponse, Response}, routing::{get, on, MethodFilter}, Router, }; @@ -30,12 +31,15 @@ use http::{ }; use hyper::Server; use hyperlocal::SocketIncoming; -use ruma::api::{ - client::{ - error::{Error as RumaError, ErrorBody, ErrorKind}, - uiaa::UiaaResponse, +use ruma::{ + api::{ + client::{ + error::{Error as RumaError, ErrorBody, ErrorKind}, + uiaa::UiaaResponse, + }, + IncomingRequest, }, - IncomingRequest, + ServerName, }; use tokio::{net::UnixListener, signal, sync::oneshot}; use tower::ServiceBuilder; @@ -280,8 +284,8 @@ async fn run_server() -> io::Result<()> { async fn spawn_task( req: axum::http::Request, - next: axum::middleware::Next, -) -> std::result::Result { + next: Next, +) -> std::result::Result { if services().globals.shutdown.load(atomic::Ordering::Relaxed) { return Err(StatusCode::SERVICE_UNAVAILABLE); } @@ -292,8 +296,8 @@ async fn spawn_task( async fn unrecognized_method( req: axum::http::Request, - next: axum::middleware::Next, -) -> std::result::Result { + next: Next, +) -> std::result::Result { let method = req.method().clone(); let uri = req.uri().clone(); let inner = next.run(req).await; @@ -452,33 +456,7 @@ fn routes() -> Router { .ruma_route(client_server::get_relating_events_with_rel_type_route) .ruma_route(client_server::get_relating_events_route) .ruma_route(client_server::get_hierarchy_route) - .ruma_route(server_server::get_server_version_route) - .route( - "/_matrix/key/v2/server", - get(server_server::get_server_keys_route), - ) - .route( - "/_matrix/key/v2/server/:key_id", - get(server_server::get_server_keys_deprecated_route), - ) - .ruma_route(server_server::get_public_rooms_route) - .ruma_route(server_server::get_public_rooms_filtered_route) - .ruma_route(server_server::send_transaction_message_route) - .ruma_route(server_server::get_event_route) - .ruma_route(server_server::get_backfill_route) - .ruma_route(server_server::get_missing_events_route) - .ruma_route(server_server::get_event_authorization_route) - .ruma_route(server_server::get_room_state_route) - .ruma_route(server_server::get_room_state_ids_route) - .ruma_route(server_server::create_join_event_template_route) - .ruma_route(server_server::create_join_event_v1_route) - .ruma_route(server_server::create_join_event_v2_route) - .ruma_route(server_server::create_invite_route) - .ruma_route(server_server::get_devices_route) - .ruma_route(server_server::get_room_information_route) - .ruma_route(server_server::get_profile_information_route) - .ruma_route(server_server::get_keys_route) - .ruma_route(server_server::claim_keys_route) + .nest("*", server_routes()) .route( "/_matrix/client/r0/rooms/:room_id/initialSync", get(initial_sync), @@ -636,7 +614,74 @@ fn method_to_filter(method: Method) -> MethodFilter { m => panic!("Unsupported HTTP method: {m:?}"), } } +fn server_routes() -> Router { + Router::default() + .ruma_route(server_server::get_server_version_route) + .route( + "/_matrix/key/v2/server", + get(server_server::get_server_keys_route), + ) + .route( + "/_matrix/key/v2/server/:key_id", + get(server_server::get_server_keys_deprecated_route), + ) + .ruma_route(server_server::get_public_rooms_route) + .ruma_route(server_server::get_public_rooms_filtered_route) + .ruma_route(server_server::send_transaction_message_route) + .ruma_route(server_server::get_event_route) + .ruma_route(server_server::get_backfill_route) + .ruma_route(server_server::get_missing_events_route) + .ruma_route(server_server::get_event_authorization_route) + .ruma_route(server_server::get_room_state_route) + .ruma_route(server_server::get_room_state_ids_route) + .ruma_route(server_server::create_join_event_template_route) + .ruma_route(server_server::create_join_event_v1_route) + .ruma_route(server_server::create_join_event_v2_route) + .ruma_route(server_server::create_invite_route) + .ruma_route(server_server::get_devices_route) + .ruma_route(server_server::get_room_information_route) + .ruma_route(server_server::get_profile_information_route) + .ruma_route(server_server::get_keys_route) + .ruma_route(server_server::claim_keys_route) + .route_layer(axum::middleware::from_fn(deny_if_not_allowed_by_acl)) +} +pub async fn deny_if_not_allowed_by_acl( + host: Option, + request: http::Request, + next: Next, +) -> Response { + let Some(host) = host else { + return Error::BadRequest( + ruma::api::client::error::ErrorKind::MissingParam, + "host header not given", + ) + .into_response(); + }; + let parsed_host = match url::Host::parse(&host.0) { + Ok(x) => x, + Err(error) => { + warn!("got error {} when parsing {:?}", error, host); + return Error::BadRequest( + ruma::api::client::error::ErrorKind::InvalidParam, + "host header malformed", + ) + .into_response(); + } + }; + + if services() + .acl + .is_federation_with_allowed(parsed_host.clone()) + { + next.run(request).await + } else { + Error::ACLBlock( + ServerName::parse(parsed_host.to_string()).expect("host is valid after parsing"), + ) + .into_response() + } +} #[cfg(unix)] #[tracing::instrument(err)] fn maximize_fd_limit() -> Result<(), nix::errno::Errno> { diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 1bd6fde2..c6e2f214 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -7,6 +7,8 @@ use ring::digest; use ruma::{ canonical_json::try_from_json_map, CanonicalJsonError, CanonicalJsonObject, OwnedUserId, }; +use tracing::warn; + use std::{ cmp::Ordering, fmt,