mirror of
https://forgejo.ellis.link/continuwuation/continuwuity.git
synced 2025-07-09 13:26:41 +02:00
feat: added ACL to all server-server routes
This commit is contained in:
parent
e1d0ec4c4d
commit
d325537308
2 changed files with 85 additions and 38 deletions
113
src/main.rs
113
src/main.rs
|
@ -13,8 +13,9 @@ use std::{
|
||||||
};
|
};
|
||||||
|
|
||||||
use axum::{
|
use axum::{
|
||||||
extract::{DefaultBodyLimit, FromRequestParts, MatchedPath},
|
extract::{DefaultBodyLimit, FromRequestParts, Host, MatchedPath},
|
||||||
response::IntoResponse,
|
middleware::Next,
|
||||||
|
response::{IntoResponse, Response},
|
||||||
routing::{get, on, MethodFilter},
|
routing::{get, on, MethodFilter},
|
||||||
Router,
|
Router,
|
||||||
};
|
};
|
||||||
|
@ -30,12 +31,15 @@ use http::{
|
||||||
};
|
};
|
||||||
use hyper::Server;
|
use hyper::Server;
|
||||||
use hyperlocal::SocketIncoming;
|
use hyperlocal::SocketIncoming;
|
||||||
use ruma::api::{
|
use ruma::{
|
||||||
|
api::{
|
||||||
client::{
|
client::{
|
||||||
error::{Error as RumaError, ErrorBody, ErrorKind},
|
error::{Error as RumaError, ErrorBody, ErrorKind},
|
||||||
uiaa::UiaaResponse,
|
uiaa::UiaaResponse,
|
||||||
},
|
},
|
||||||
IncomingRequest,
|
IncomingRequest,
|
||||||
|
},
|
||||||
|
ServerName,
|
||||||
};
|
};
|
||||||
use tokio::{net::UnixListener, signal, sync::oneshot};
|
use tokio::{net::UnixListener, signal, sync::oneshot};
|
||||||
use tower::ServiceBuilder;
|
use tower::ServiceBuilder;
|
||||||
|
@ -280,8 +284,8 @@ async fn run_server() -> io::Result<()> {
|
||||||
|
|
||||||
async fn spawn_task<B: Send + 'static>(
|
async fn spawn_task<B: Send + 'static>(
|
||||||
req: axum::http::Request<B>,
|
req: axum::http::Request<B>,
|
||||||
next: axum::middleware::Next<B>,
|
next: Next<B>,
|
||||||
) -> std::result::Result<axum::response::Response, StatusCode> {
|
) -> std::result::Result<Response, StatusCode> {
|
||||||
if services().globals.shutdown.load(atomic::Ordering::Relaxed) {
|
if services().globals.shutdown.load(atomic::Ordering::Relaxed) {
|
||||||
return Err(StatusCode::SERVICE_UNAVAILABLE);
|
return Err(StatusCode::SERVICE_UNAVAILABLE);
|
||||||
}
|
}
|
||||||
|
@ -292,8 +296,8 @@ async fn spawn_task<B: Send + 'static>(
|
||||||
|
|
||||||
async fn unrecognized_method<B: Send + 'static>(
|
async fn unrecognized_method<B: Send + 'static>(
|
||||||
req: axum::http::Request<B>,
|
req: axum::http::Request<B>,
|
||||||
next: axum::middleware::Next<B>,
|
next: Next<B>,
|
||||||
) -> std::result::Result<axum::response::Response, StatusCode> {
|
) -> std::result::Result<Response, StatusCode> {
|
||||||
let method = req.method().clone();
|
let method = req.method().clone();
|
||||||
let uri = req.uri().clone();
|
let uri = req.uri().clone();
|
||||||
let inner = next.run(req).await;
|
let inner = next.run(req).await;
|
||||||
|
@ -452,33 +456,7 @@ fn routes() -> Router {
|
||||||
.ruma_route(client_server::get_relating_events_with_rel_type_route)
|
.ruma_route(client_server::get_relating_events_with_rel_type_route)
|
||||||
.ruma_route(client_server::get_relating_events_route)
|
.ruma_route(client_server::get_relating_events_route)
|
||||||
.ruma_route(client_server::get_hierarchy_route)
|
.ruma_route(client_server::get_hierarchy_route)
|
||||||
.ruma_route(server_server::get_server_version_route)
|
.nest("*", server_routes())
|
||||||
.route(
|
|
||||||
"/_matrix/key/v2/server",
|
|
||||||
get(server_server::get_server_keys_route),
|
|
||||||
)
|
|
||||||
.route(
|
|
||||||
"/_matrix/key/v2/server/:key_id",
|
|
||||||
get(server_server::get_server_keys_deprecated_route),
|
|
||||||
)
|
|
||||||
.ruma_route(server_server::get_public_rooms_route)
|
|
||||||
.ruma_route(server_server::get_public_rooms_filtered_route)
|
|
||||||
.ruma_route(server_server::send_transaction_message_route)
|
|
||||||
.ruma_route(server_server::get_event_route)
|
|
||||||
.ruma_route(server_server::get_backfill_route)
|
|
||||||
.ruma_route(server_server::get_missing_events_route)
|
|
||||||
.ruma_route(server_server::get_event_authorization_route)
|
|
||||||
.ruma_route(server_server::get_room_state_route)
|
|
||||||
.ruma_route(server_server::get_room_state_ids_route)
|
|
||||||
.ruma_route(server_server::create_join_event_template_route)
|
|
||||||
.ruma_route(server_server::create_join_event_v1_route)
|
|
||||||
.ruma_route(server_server::create_join_event_v2_route)
|
|
||||||
.ruma_route(server_server::create_invite_route)
|
|
||||||
.ruma_route(server_server::get_devices_route)
|
|
||||||
.ruma_route(server_server::get_room_information_route)
|
|
||||||
.ruma_route(server_server::get_profile_information_route)
|
|
||||||
.ruma_route(server_server::get_keys_route)
|
|
||||||
.ruma_route(server_server::claim_keys_route)
|
|
||||||
.route(
|
.route(
|
||||||
"/_matrix/client/r0/rooms/:room_id/initialSync",
|
"/_matrix/client/r0/rooms/:room_id/initialSync",
|
||||||
get(initial_sync),
|
get(initial_sync),
|
||||||
|
@ -636,7 +614,74 @@ fn method_to_filter(method: Method) -> MethodFilter {
|
||||||
m => panic!("Unsupported HTTP method: {m:?}"),
|
m => panic!("Unsupported HTTP method: {m:?}"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
fn server_routes() -> Router {
|
||||||
|
Router::default()
|
||||||
|
.ruma_route(server_server::get_server_version_route)
|
||||||
|
.route(
|
||||||
|
"/_matrix/key/v2/server",
|
||||||
|
get(server_server::get_server_keys_route),
|
||||||
|
)
|
||||||
|
.route(
|
||||||
|
"/_matrix/key/v2/server/:key_id",
|
||||||
|
get(server_server::get_server_keys_deprecated_route),
|
||||||
|
)
|
||||||
|
.ruma_route(server_server::get_public_rooms_route)
|
||||||
|
.ruma_route(server_server::get_public_rooms_filtered_route)
|
||||||
|
.ruma_route(server_server::send_transaction_message_route)
|
||||||
|
.ruma_route(server_server::get_event_route)
|
||||||
|
.ruma_route(server_server::get_backfill_route)
|
||||||
|
.ruma_route(server_server::get_missing_events_route)
|
||||||
|
.ruma_route(server_server::get_event_authorization_route)
|
||||||
|
.ruma_route(server_server::get_room_state_route)
|
||||||
|
.ruma_route(server_server::get_room_state_ids_route)
|
||||||
|
.ruma_route(server_server::create_join_event_template_route)
|
||||||
|
.ruma_route(server_server::create_join_event_v1_route)
|
||||||
|
.ruma_route(server_server::create_join_event_v2_route)
|
||||||
|
.ruma_route(server_server::create_invite_route)
|
||||||
|
.ruma_route(server_server::get_devices_route)
|
||||||
|
.ruma_route(server_server::get_room_information_route)
|
||||||
|
.ruma_route(server_server::get_profile_information_route)
|
||||||
|
.ruma_route(server_server::get_keys_route)
|
||||||
|
.ruma_route(server_server::claim_keys_route)
|
||||||
|
.route_layer(axum::middleware::from_fn(deny_if_not_allowed_by_acl))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn deny_if_not_allowed_by_acl<T>(
|
||||||
|
host: Option<Host>,
|
||||||
|
request: http::Request<T>,
|
||||||
|
next: Next<T>,
|
||||||
|
) -> Response {
|
||||||
|
let Some(host) = host else {
|
||||||
|
return Error::BadRequest(
|
||||||
|
ruma::api::client::error::ErrorKind::MissingParam,
|
||||||
|
"host header not given",
|
||||||
|
)
|
||||||
|
.into_response();
|
||||||
|
};
|
||||||
|
let parsed_host = match url::Host::parse(&host.0) {
|
||||||
|
Ok(x) => x,
|
||||||
|
Err(error) => {
|
||||||
|
warn!("got error {} when parsing {:?}", error, host);
|
||||||
|
return Error::BadRequest(
|
||||||
|
ruma::api::client::error::ErrorKind::InvalidParam,
|
||||||
|
"host header malformed",
|
||||||
|
)
|
||||||
|
.into_response();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if services()
|
||||||
|
.acl
|
||||||
|
.is_federation_with_allowed(parsed_host.clone())
|
||||||
|
{
|
||||||
|
next.run(request).await
|
||||||
|
} else {
|
||||||
|
Error::ACLBlock(
|
||||||
|
ServerName::parse(parsed_host.to_string()).expect("host is valid after parsing"),
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
}
|
||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
#[tracing::instrument(err)]
|
#[tracing::instrument(err)]
|
||||||
fn maximize_fd_limit() -> Result<(), nix::errno::Errno> {
|
fn maximize_fd_limit() -> Result<(), nix::errno::Errno> {
|
||||||
|
|
|
@ -7,6 +7,8 @@ use ring::digest;
|
||||||
use ruma::{
|
use ruma::{
|
||||||
canonical_json::try_from_json_map, CanonicalJsonError, CanonicalJsonObject, OwnedUserId,
|
canonical_json::try_from_json_map, CanonicalJsonError, CanonicalJsonObject, OwnedUserId,
|
||||||
};
|
};
|
||||||
|
use tracing::warn;
|
||||||
|
|
||||||
use std::{
|
use std::{
|
||||||
cmp::Ordering,
|
cmp::Ordering,
|
||||||
fmt,
|
fmt,
|
||||||
|
|
Loading…
Add table
Reference in a new issue