feat: added ACL to all server-server routes

This commit is contained in:
NinekoTheCat 2023-12-25 00:50:49 +01:00
parent e1d0ec4c4d
commit d325537308
No known key found for this signature in database
GPG key ID: 700DB3F678A4AB66
2 changed files with 85 additions and 38 deletions

View file

@ -13,8 +13,9 @@ use std::{
};
use axum::{
extract::{DefaultBodyLimit, FromRequestParts, MatchedPath},
response::IntoResponse,
extract::{DefaultBodyLimit, FromRequestParts, Host, MatchedPath},
middleware::Next,
response::{IntoResponse, Response},
routing::{get, on, MethodFilter},
Router,
};
@ -30,12 +31,15 @@ use http::{
};
use hyper::Server;
use hyperlocal::SocketIncoming;
use ruma::api::{
use ruma::{
api::{
client::{
error::{Error as RumaError, ErrorBody, ErrorKind},
uiaa::UiaaResponse,
},
IncomingRequest,
},
ServerName,
};
use tokio::{net::UnixListener, signal, sync::oneshot};
use tower::ServiceBuilder;
@ -280,8 +284,8 @@ async fn run_server() -> io::Result<()> {
async fn spawn_task<B: Send + 'static>(
req: axum::http::Request<B>,
next: axum::middleware::Next<B>,
) -> std::result::Result<axum::response::Response, StatusCode> {
next: Next<B>,
) -> std::result::Result<Response, StatusCode> {
if services().globals.shutdown.load(atomic::Ordering::Relaxed) {
return Err(StatusCode::SERVICE_UNAVAILABLE);
}
@ -292,8 +296,8 @@ async fn spawn_task<B: Send + 'static>(
async fn unrecognized_method<B: Send + 'static>(
req: axum::http::Request<B>,
next: axum::middleware::Next<B>,
) -> std::result::Result<axum::response::Response, StatusCode> {
next: Next<B>,
) -> std::result::Result<Response, StatusCode> {
let method = req.method().clone();
let uri = req.uri().clone();
let inner = next.run(req).await;
@ -452,33 +456,7 @@ fn routes() -> Router {
.ruma_route(client_server::get_relating_events_with_rel_type_route)
.ruma_route(client_server::get_relating_events_route)
.ruma_route(client_server::get_hierarchy_route)
.ruma_route(server_server::get_server_version_route)
.route(
"/_matrix/key/v2/server",
get(server_server::get_server_keys_route),
)
.route(
"/_matrix/key/v2/server/:key_id",
get(server_server::get_server_keys_deprecated_route),
)
.ruma_route(server_server::get_public_rooms_route)
.ruma_route(server_server::get_public_rooms_filtered_route)
.ruma_route(server_server::send_transaction_message_route)
.ruma_route(server_server::get_event_route)
.ruma_route(server_server::get_backfill_route)
.ruma_route(server_server::get_missing_events_route)
.ruma_route(server_server::get_event_authorization_route)
.ruma_route(server_server::get_room_state_route)
.ruma_route(server_server::get_room_state_ids_route)
.ruma_route(server_server::create_join_event_template_route)
.ruma_route(server_server::create_join_event_v1_route)
.ruma_route(server_server::create_join_event_v2_route)
.ruma_route(server_server::create_invite_route)
.ruma_route(server_server::get_devices_route)
.ruma_route(server_server::get_room_information_route)
.ruma_route(server_server::get_profile_information_route)
.ruma_route(server_server::get_keys_route)
.ruma_route(server_server::claim_keys_route)
.nest("*", server_routes())
.route(
"/_matrix/client/r0/rooms/:room_id/initialSync",
get(initial_sync),
@ -636,7 +614,74 @@ fn method_to_filter(method: Method) -> MethodFilter {
m => panic!("Unsupported HTTP method: {m:?}"),
}
}
fn server_routes() -> Router {
Router::default()
.ruma_route(server_server::get_server_version_route)
.route(
"/_matrix/key/v2/server",
get(server_server::get_server_keys_route),
)
.route(
"/_matrix/key/v2/server/:key_id",
get(server_server::get_server_keys_deprecated_route),
)
.ruma_route(server_server::get_public_rooms_route)
.ruma_route(server_server::get_public_rooms_filtered_route)
.ruma_route(server_server::send_transaction_message_route)
.ruma_route(server_server::get_event_route)
.ruma_route(server_server::get_backfill_route)
.ruma_route(server_server::get_missing_events_route)
.ruma_route(server_server::get_event_authorization_route)
.ruma_route(server_server::get_room_state_route)
.ruma_route(server_server::get_room_state_ids_route)
.ruma_route(server_server::create_join_event_template_route)
.ruma_route(server_server::create_join_event_v1_route)
.ruma_route(server_server::create_join_event_v2_route)
.ruma_route(server_server::create_invite_route)
.ruma_route(server_server::get_devices_route)
.ruma_route(server_server::get_room_information_route)
.ruma_route(server_server::get_profile_information_route)
.ruma_route(server_server::get_keys_route)
.ruma_route(server_server::claim_keys_route)
.route_layer(axum::middleware::from_fn(deny_if_not_allowed_by_acl))
}
pub async fn deny_if_not_allowed_by_acl<T>(
host: Option<Host>,
request: http::Request<T>,
next: Next<T>,
) -> Response {
let Some(host) = host else {
return Error::BadRequest(
ruma::api::client::error::ErrorKind::MissingParam,
"host header not given",
)
.into_response();
};
let parsed_host = match url::Host::parse(&host.0) {
Ok(x) => x,
Err(error) => {
warn!("got error {} when parsing {:?}", error, host);
return Error::BadRequest(
ruma::api::client::error::ErrorKind::InvalidParam,
"host header malformed",
)
.into_response();
}
};
if services()
.acl
.is_federation_with_allowed(parsed_host.clone())
{
next.run(request).await
} else {
Error::ACLBlock(
ServerName::parse(parsed_host.to_string()).expect("host is valid after parsing"),
)
.into_response()
}
}
#[cfg(unix)]
#[tracing::instrument(err)]
fn maximize_fd_limit() -> Result<(), nix::errno::Errno> {

View file

@ -7,6 +7,8 @@ use ring::digest;
use ruma::{
canonical_json::try_from_json_map, CanonicalJsonError, CanonicalJsonObject, OwnedUserId,
};
use tracing::warn;
use std::{
cmp::Ordering,
fmt,