feat: added ACL to all server-server routes

This commit is contained in:
NinekoTheCat 2023-12-25 00:50:49 +01:00
parent e1d0ec4c4d
commit d325537308
No known key found for this signature in database
GPG key ID: 700DB3F678A4AB66
2 changed files with 85 additions and 38 deletions

View file

@ -13,8 +13,9 @@ use std::{
}; };
use axum::{ use axum::{
extract::{DefaultBodyLimit, FromRequestParts, MatchedPath}, extract::{DefaultBodyLimit, FromRequestParts, Host, MatchedPath},
response::IntoResponse, middleware::Next,
response::{IntoResponse, Response},
routing::{get, on, MethodFilter}, routing::{get, on, MethodFilter},
Router, Router,
}; };
@ -30,12 +31,15 @@ use http::{
}; };
use hyper::Server; use hyper::Server;
use hyperlocal::SocketIncoming; use hyperlocal::SocketIncoming;
use ruma::api::{ use ruma::{
client::{ api::{
error::{Error as RumaError, ErrorBody, ErrorKind}, client::{
uiaa::UiaaResponse, error::{Error as RumaError, ErrorBody, ErrorKind},
uiaa::UiaaResponse,
},
IncomingRequest,
}, },
IncomingRequest, ServerName,
}; };
use tokio::{net::UnixListener, signal, sync::oneshot}; use tokio::{net::UnixListener, signal, sync::oneshot};
use tower::ServiceBuilder; use tower::ServiceBuilder;
@ -280,8 +284,8 @@ async fn run_server() -> io::Result<()> {
async fn spawn_task<B: Send + 'static>( async fn spawn_task<B: Send + 'static>(
req: axum::http::Request<B>, req: axum::http::Request<B>,
next: axum::middleware::Next<B>, next: Next<B>,
) -> std::result::Result<axum::response::Response, StatusCode> { ) -> std::result::Result<Response, StatusCode> {
if services().globals.shutdown.load(atomic::Ordering::Relaxed) { if services().globals.shutdown.load(atomic::Ordering::Relaxed) {
return Err(StatusCode::SERVICE_UNAVAILABLE); return Err(StatusCode::SERVICE_UNAVAILABLE);
} }
@ -292,8 +296,8 @@ async fn spawn_task<B: Send + 'static>(
async fn unrecognized_method<B: Send + 'static>( async fn unrecognized_method<B: Send + 'static>(
req: axum::http::Request<B>, req: axum::http::Request<B>,
next: axum::middleware::Next<B>, next: Next<B>,
) -> std::result::Result<axum::response::Response, StatusCode> { ) -> std::result::Result<Response, StatusCode> {
let method = req.method().clone(); let method = req.method().clone();
let uri = req.uri().clone(); let uri = req.uri().clone();
let inner = next.run(req).await; let inner = next.run(req).await;
@ -452,33 +456,7 @@ fn routes() -> Router {
.ruma_route(client_server::get_relating_events_with_rel_type_route) .ruma_route(client_server::get_relating_events_with_rel_type_route)
.ruma_route(client_server::get_relating_events_route) .ruma_route(client_server::get_relating_events_route)
.ruma_route(client_server::get_hierarchy_route) .ruma_route(client_server::get_hierarchy_route)
.ruma_route(server_server::get_server_version_route) .nest("*", server_routes())
.route(
"/_matrix/key/v2/server",
get(server_server::get_server_keys_route),
)
.route(
"/_matrix/key/v2/server/:key_id",
get(server_server::get_server_keys_deprecated_route),
)
.ruma_route(server_server::get_public_rooms_route)
.ruma_route(server_server::get_public_rooms_filtered_route)
.ruma_route(server_server::send_transaction_message_route)
.ruma_route(server_server::get_event_route)
.ruma_route(server_server::get_backfill_route)
.ruma_route(server_server::get_missing_events_route)
.ruma_route(server_server::get_event_authorization_route)
.ruma_route(server_server::get_room_state_route)
.ruma_route(server_server::get_room_state_ids_route)
.ruma_route(server_server::create_join_event_template_route)
.ruma_route(server_server::create_join_event_v1_route)
.ruma_route(server_server::create_join_event_v2_route)
.ruma_route(server_server::create_invite_route)
.ruma_route(server_server::get_devices_route)
.ruma_route(server_server::get_room_information_route)
.ruma_route(server_server::get_profile_information_route)
.ruma_route(server_server::get_keys_route)
.ruma_route(server_server::claim_keys_route)
.route( .route(
"/_matrix/client/r0/rooms/:room_id/initialSync", "/_matrix/client/r0/rooms/:room_id/initialSync",
get(initial_sync), get(initial_sync),
@ -636,7 +614,74 @@ fn method_to_filter(method: Method) -> MethodFilter {
m => panic!("Unsupported HTTP method: {m:?}"), m => panic!("Unsupported HTTP method: {m:?}"),
} }
} }
fn server_routes() -> Router {
Router::default()
.ruma_route(server_server::get_server_version_route)
.route(
"/_matrix/key/v2/server",
get(server_server::get_server_keys_route),
)
.route(
"/_matrix/key/v2/server/:key_id",
get(server_server::get_server_keys_deprecated_route),
)
.ruma_route(server_server::get_public_rooms_route)
.ruma_route(server_server::get_public_rooms_filtered_route)
.ruma_route(server_server::send_transaction_message_route)
.ruma_route(server_server::get_event_route)
.ruma_route(server_server::get_backfill_route)
.ruma_route(server_server::get_missing_events_route)
.ruma_route(server_server::get_event_authorization_route)
.ruma_route(server_server::get_room_state_route)
.ruma_route(server_server::get_room_state_ids_route)
.ruma_route(server_server::create_join_event_template_route)
.ruma_route(server_server::create_join_event_v1_route)
.ruma_route(server_server::create_join_event_v2_route)
.ruma_route(server_server::create_invite_route)
.ruma_route(server_server::get_devices_route)
.ruma_route(server_server::get_room_information_route)
.ruma_route(server_server::get_profile_information_route)
.ruma_route(server_server::get_keys_route)
.ruma_route(server_server::claim_keys_route)
.route_layer(axum::middleware::from_fn(deny_if_not_allowed_by_acl))
}
pub async fn deny_if_not_allowed_by_acl<T>(
host: Option<Host>,
request: http::Request<T>,
next: Next<T>,
) -> Response {
let Some(host) = host else {
return Error::BadRequest(
ruma::api::client::error::ErrorKind::MissingParam,
"host header not given",
)
.into_response();
};
let parsed_host = match url::Host::parse(&host.0) {
Ok(x) => x,
Err(error) => {
warn!("got error {} when parsing {:?}", error, host);
return Error::BadRequest(
ruma::api::client::error::ErrorKind::InvalidParam,
"host header malformed",
)
.into_response();
}
};
if services()
.acl
.is_federation_with_allowed(parsed_host.clone())
{
next.run(request).await
} else {
Error::ACLBlock(
ServerName::parse(parsed_host.to_string()).expect("host is valid after parsing"),
)
.into_response()
}
}
#[cfg(unix)] #[cfg(unix)]
#[tracing::instrument(err)] #[tracing::instrument(err)]
fn maximize_fd_limit() -> Result<(), nix::errno::Errno> { fn maximize_fd_limit() -> Result<(), nix::errno::Errno> {

View file

@ -7,6 +7,8 @@ use ring::digest;
use ruma::{ use ruma::{
canonical_json::try_from_json_map, CanonicalJsonError, CanonicalJsonObject, OwnedUserId, canonical_json::try_from_json_map, CanonicalJsonError, CanonicalJsonObject, OwnedUserId,
}; };
use tracing::warn;
use std::{ use std::{
cmp::Ordering, cmp::Ordering,
fmt, fmt,