From b9ce99d036572aee3c1b3be725e0ae9449a865a5 Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Sat, 19 Jul 2025 20:22:29 +0100 Subject: [PATCH 01/31] feat(policy-server): Policy server following --- src/core/matrix/state_res/event_auth.rs | 7 +- .../rooms/event_handler/call_policyserv.rs | 71 +++++++++++++++++++ src/service/rooms/event_handler/mod.rs | 1 + .../event_handler/upgrade_outlier_pdu.rs | 34 ++++++--- 4 files changed, 102 insertions(+), 11 deletions(-) create mode 100644 src/service/rooms/event_handler/call_policyserv.rs diff --git a/src/core/matrix/state_res/event_auth.rs b/src/core/matrix/state_res/event_auth.rs index 5c36ce03..819d05e2 100644 --- a/src/core/matrix/state_res/event_auth.rs +++ b/src/core/matrix/state_res/event_auth.rs @@ -5,7 +5,7 @@ use futures::{ future::{OptionFuture, join3}, }; use ruma::{ - Int, OwnedUserId, RoomVersionId, UserId, + EventId, Int, OwnedUserId, RoomVersionId, UserId, events::room::{ create::RoomCreateEventContent, join_rules::{JoinRule, RoomJoinRulesEventContent}, @@ -217,8 +217,9 @@ where } /* - // TODO: In the past this code caused problems federating with synapse, maybe this has been - // resolved already. Needs testing. + // TODO: In the past this code was commented as it caused problems with Synapse. This is no + // longer the case. This needs to be implemented. + // See also: https://github.com/ruma/ruma/pull/2064 // // 2. Reject if auth_events // a. auth_events cannot have duplicate keys since it's a BTree diff --git a/src/service/rooms/event_handler/call_policyserv.rs b/src/service/rooms/event_handler/call_policyserv.rs new file mode 100644 index 00000000..4a52227d --- /dev/null +++ b/src/service/rooms/event_handler/call_policyserv.rs @@ -0,0 +1,71 @@ +use conduwuit::{ + Err, Event, PduEvent, Result, debug, implement, utils::to_canonical_object, warn, +}; +use ruma::{ + RoomId, ServerName, + api::federation::room::policy::v1::Request as PolicyRequest, + canonical_json::to_canonical_value, + events::{StateEventType, room::policy::RoomPolicyEventContent}, +}; + +/// Returns Ok if the policy server allows the event +#[implement(super::Service)] +#[tracing::instrument(skip_all, level = "debug")] +pub async fn policyserv_check(&self, pdu: &PduEvent, room_id: &RoomId) -> Result { + let Ok(policyserver) = self + .services + .state_accessor + .room_state_get_content(room_id, &StateEventType::RoomPolicy, "") + .await + .map(|c: RoomPolicyEventContent| c) + else { + return Ok(()); + }; + + let via = match policyserver.via { + | Some(ref via) => ServerName::parse(via)?, + | None => { + debug!("No policy server configured for room {room_id}"); + return Ok(()); + }, + }; + // TODO: dont do *this* + let pdu_json = self.services.timeline.get_pdu_json(pdu.event_id()).await?; + let outgoing = self + .services + .sending + .convert_to_outgoing_federation_event(pdu_json) + .await; + // let s = match serde_json::to_string(outgoing.as_ref()) { + // | Ok(s) => s, + // | Err(e) => { + // warn!("Failed to convert pdu {} to outgoing federation event: {e}", + // pdu.event_id()); return Err!(Request(InvalidParam("Failed to convert PDU + // to outgoing event."))); }, + // }; + debug!("Checking pdu {outgoing:?} for spam with policy server {via} for room {room_id}"); + let response = self + .services + .sending + .send_federation_request(via, PolicyRequest { + event_id: pdu.event_id().to_owned(), + pdu: Some(outgoing), + }) + .await; + let response = match response { + | Ok(response) => response, + | Err(e) => { + warn!("Failed to contact policy server {via} for room {room_id}: {e}"); + return Ok(()); + }, + }; + if response.recommendation == "spam" { + warn!( + "Event {} in room {room_id} was marked as spam by policy server {via}", + pdu.event_id().to_owned() + ); + return Err!(Request(Forbidden("Event was marked as spam by policy server"))); + }; + + Ok(()) +} diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index aed38e1e..5ed25c6e 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -1,4 +1,5 @@ mod acl_check; +mod call_policyserv; mod fetch_and_handle_outliers; mod fetch_prev; mod fetch_state; diff --git a/src/service/rooms/event_handler/upgrade_outlier_pdu.rs b/src/service/rooms/event_handler/upgrade_outlier_pdu.rs index 4093cb05..abb5c116 100644 --- a/src/service/rooms/event_handler/upgrade_outlier_pdu.rs +++ b/src/service/rooms/event_handler/upgrade_outlier_pdu.rs @@ -1,7 +1,7 @@ use std::{borrow::Borrow, collections::BTreeMap, iter::once, sync::Arc, time::Instant}; use conduwuit::{ - Err, Result, debug, debug_info, err, implement, is_equal_to, + Err, Result, debug, debug_info, err, implement, info, is_equal_to, matrix::{Event, EventTypeExt, PduEvent, StateKey, state_res}, trace, utils::stream::{BroadbandExt, ReadyExt}, @@ -47,7 +47,7 @@ where return Err!(Request(InvalidParam("Event has been soft failed"))); } - debug!("Upgrading to timeline pdu"); + debug!("Upgrading pdu {} from outlier to timeline pdu", incoming_pdu.event_id); let timer = Instant::now(); let room_version_id = get_room_version_id(create_event)?; @@ -55,7 +55,7 @@ where // backwards extremities doing all the checks in this list starting at 1. // These are not timeline events. - debug!("Resolving state at event"); + debug!("Resolving state at event {}", incoming_pdu.event_id); let mut state_at_incoming_event = if incoming_pdu.prev_events().count() == 1 { self.state_at_incoming_degree_one(&incoming_pdu).await? } else { @@ -74,7 +74,7 @@ where let room_version = to_room_version(&room_version_id); - debug!("Performing auth check"); + debug!("Performing auth check to upgrade {}", incoming_pdu.event_id); // 11. Check the auth of the event passes based on the state of the event let state_fetch_state = &state_at_incoming_event; let state_fetch = |k: StateEventType, s: StateKey| async move { @@ -84,6 +84,7 @@ where self.services.timeline.get_pdu(event_id).await.ok() }; + debug!("running auth check on {}", incoming_pdu.event_id); let auth_check = state_res::event_auth::auth_check( &room_version, &incoming_pdu, @@ -97,7 +98,7 @@ where return Err!(Request(Forbidden("Event has failed auth check with state at the event."))); } - debug!("Gathering auth events"); + debug!("Gathering auth events for {}", incoming_pdu.event_id); let auth_events = self .services .state @@ -115,6 +116,7 @@ where ready(auth_events.get(&key).map(ToOwned::to_owned)) }; + debug!("running auth check on {} with claimed state auth", incoming_pdu.event_id); let auth_check = state_res::event_auth::auth_check( &room_version, &incoming_pdu, @@ -125,8 +127,8 @@ where .map_err(|e| err!(Request(Forbidden("Auth check failed: {e:?}"))))?; // Soft fail check before doing state res - debug!("Performing soft-fail check"); - let soft_fail = match (auth_check, incoming_pdu.redacts_id(&room_version_id)) { + debug!("Performing soft-fail check on {}", incoming_pdu.event_id); + let mut soft_fail = match (auth_check, incoming_pdu.redacts_id(&room_version_id)) { | (false, _) => true, | (true, None) => false, | (true, Some(redact_id)) => @@ -219,10 +221,26 @@ where .await?; } + // 14-pre. If the event is not a state event, ask the policy server about it + if incoming_pdu.state_key.is_none() + && incoming_pdu.sender().server_name() != self.services.globals.server_name() + { + debug!("Checking policy server for event {}", incoming_pdu.event_id); + let policy = self.policyserv_check(&incoming_pdu, room_id); + if let Err(e) = policy.await { + warn!("Policy server check failed for event {}: {e}", incoming_pdu.event_id); + if !soft_fail { + soft_fail = true; + } + } + debug!("Policy server check passed for event {}", incoming_pdu.event_id); + } + // 14. Check if the event passes auth based on the "current state" of the room, // if not soft fail it if soft_fail { - debug!("Soft failing event"); + info!("Soft failing event {}", incoming_pdu.event_id); + // assert!(extremities.is_empty(), "soft_fail extremities empty"); let extremities = extremities.iter().map(Borrow::borrow); self.services From 1dc9abc00e4df4475c75e5cf84b8ab96cf32674a Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Sat, 19 Jul 2025 20:34:34 +0100 Subject: [PATCH 02/31] chore: Update ruwuma & fix lints --- Cargo.lock | 22 +++++++++---------- Cargo.toml | 2 +- src/core/matrix/state_res/event_auth.rs | 2 +- .../rooms/event_handler/call_policyserv.rs | 7 ++---- 4 files changed, 15 insertions(+), 18 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6f711007..700c04f5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3900,7 +3900,7 @@ checksum = "88f8660c1ff60292143c98d08fc6e2f654d722db50410e3f3797d40baaf9d8f3" [[package]] name = "ruma" version = "0.10.1" -source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=a4b948b40417a65ab0282ae47cc50035dd455e02#a4b948b40417a65ab0282ae47cc50035dd455e02" +source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=b753738047d1f443aca870896ef27ecaacf027da#b753738047d1f443aca870896ef27ecaacf027da" dependencies = [ "assign", "js_int", @@ -3920,7 +3920,7 @@ dependencies = [ [[package]] name = "ruma-appservice-api" version = "0.10.0" -source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=a4b948b40417a65ab0282ae47cc50035dd455e02#a4b948b40417a65ab0282ae47cc50035dd455e02" +source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=b753738047d1f443aca870896ef27ecaacf027da#b753738047d1f443aca870896ef27ecaacf027da" dependencies = [ "js_int", "ruma-common", @@ -3932,7 +3932,7 @@ dependencies = [ [[package]] name = "ruma-client-api" version = "0.18.0" -source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=a4b948b40417a65ab0282ae47cc50035dd455e02#a4b948b40417a65ab0282ae47cc50035dd455e02" +source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=b753738047d1f443aca870896ef27ecaacf027da#b753738047d1f443aca870896ef27ecaacf027da" dependencies = [ "as_variant", "assign", @@ -3955,7 +3955,7 @@ dependencies = [ [[package]] name = "ruma-common" version = "0.13.0" -source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=a4b948b40417a65ab0282ae47cc50035dd455e02#a4b948b40417a65ab0282ae47cc50035dd455e02" +source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=b753738047d1f443aca870896ef27ecaacf027da#b753738047d1f443aca870896ef27ecaacf027da" dependencies = [ "as_variant", "base64 0.22.1", @@ -3987,7 +3987,7 @@ dependencies = [ [[package]] name = "ruma-events" version = "0.28.1" -source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=a4b948b40417a65ab0282ae47cc50035dd455e02#a4b948b40417a65ab0282ae47cc50035dd455e02" +source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=b753738047d1f443aca870896ef27ecaacf027da#b753738047d1f443aca870896ef27ecaacf027da" dependencies = [ "as_variant", "indexmap 2.9.0", @@ -4012,7 +4012,7 @@ dependencies = [ [[package]] name = "ruma-federation-api" version = "0.9.0" -source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=a4b948b40417a65ab0282ae47cc50035dd455e02#a4b948b40417a65ab0282ae47cc50035dd455e02" +source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=b753738047d1f443aca870896ef27ecaacf027da#b753738047d1f443aca870896ef27ecaacf027da" dependencies = [ "bytes", "headers", @@ -4034,7 +4034,7 @@ dependencies = [ [[package]] name = "ruma-identifiers-validation" version = "0.9.5" -source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=a4b948b40417a65ab0282ae47cc50035dd455e02#a4b948b40417a65ab0282ae47cc50035dd455e02" +source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=b753738047d1f443aca870896ef27ecaacf027da#b753738047d1f443aca870896ef27ecaacf027da" dependencies = [ "js_int", "thiserror 2.0.12", @@ -4043,7 +4043,7 @@ dependencies = [ [[package]] name = "ruma-identity-service-api" version = "0.9.0" -source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=a4b948b40417a65ab0282ae47cc50035dd455e02#a4b948b40417a65ab0282ae47cc50035dd455e02" +source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=b753738047d1f443aca870896ef27ecaacf027da#b753738047d1f443aca870896ef27ecaacf027da" dependencies = [ "js_int", "ruma-common", @@ -4053,7 +4053,7 @@ dependencies = [ [[package]] name = "ruma-macros" version = "0.13.0" -source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=a4b948b40417a65ab0282ae47cc50035dd455e02#a4b948b40417a65ab0282ae47cc50035dd455e02" +source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=b753738047d1f443aca870896ef27ecaacf027da#b753738047d1f443aca870896ef27ecaacf027da" dependencies = [ "cfg-if", "proc-macro-crate", @@ -4068,7 +4068,7 @@ dependencies = [ [[package]] name = "ruma-push-gateway-api" version = "0.9.0" -source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=a4b948b40417a65ab0282ae47cc50035dd455e02#a4b948b40417a65ab0282ae47cc50035dd455e02" +source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=b753738047d1f443aca870896ef27ecaacf027da#b753738047d1f443aca870896ef27ecaacf027da" dependencies = [ "js_int", "ruma-common", @@ -4080,7 +4080,7 @@ dependencies = [ [[package]] name = "ruma-signatures" version = "0.15.0" -source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=a4b948b40417a65ab0282ae47cc50035dd455e02#a4b948b40417a65ab0282ae47cc50035dd455e02" +source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=b753738047d1f443aca870896ef27ecaacf027da#b753738047d1f443aca870896ef27ecaacf027da" dependencies = [ "base64 0.22.1", "ed25519-dalek", diff --git a/Cargo.toml b/Cargo.toml index ef917332..fb00d6d3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -352,7 +352,7 @@ version = "0.1.2" [workspace.dependencies.ruma] git = "https://forgejo.ellis.link/continuwuation/ruwuma" #branch = "conduwuit-changes" -rev = "a4b948b40417a65ab0282ae47cc50035dd455e02" +rev = "b753738047d1f443aca870896ef27ecaacf027da" features = [ "compat", "rand", diff --git a/src/core/matrix/state_res/event_auth.rs b/src/core/matrix/state_res/event_auth.rs index 819d05e2..81c83431 100644 --- a/src/core/matrix/state_res/event_auth.rs +++ b/src/core/matrix/state_res/event_auth.rs @@ -5,7 +5,7 @@ use futures::{ future::{OptionFuture, join3}, }; use ruma::{ - EventId, Int, OwnedUserId, RoomVersionId, UserId, + Int, OwnedUserId, RoomVersionId, UserId, events::room::{ create::RoomCreateEventContent, join_rules::{JoinRule, RoomJoinRulesEventContent}, diff --git a/src/service/rooms/event_handler/call_policyserv.rs b/src/service/rooms/event_handler/call_policyserv.rs index 4a52227d..804c77eb 100644 --- a/src/service/rooms/event_handler/call_policyserv.rs +++ b/src/service/rooms/event_handler/call_policyserv.rs @@ -1,10 +1,7 @@ -use conduwuit::{ - Err, Event, PduEvent, Result, debug, implement, utils::to_canonical_object, warn, -}; +use conduwuit::{Err, Event, PduEvent, Result, debug, implement, warn}; use ruma::{ RoomId, ServerName, api::federation::room::policy::v1::Request as PolicyRequest, - canonical_json::to_canonical_value, events::{StateEventType, room::policy::RoomPolicyEventContent}, }; @@ -65,7 +62,7 @@ pub async fn policyserv_check(&self, pdu: &PduEvent, room_id: &RoomId) -> Result pdu.event_id().to_owned() ); return Err!(Request(Forbidden("Event was marked as spam by policy server"))); - }; + } Ok(()) } From be61ff1465f8b57dfe754acaaef6d2234c326771 Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Sat, 19 Jul 2025 20:47:02 +0100 Subject: [PATCH 03/31] fix(policy-server): Avoid unnecessary database lookup --- src/service/rooms/event_handler/call_policyserv.rs | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/src/service/rooms/event_handler/call_policyserv.rs b/src/service/rooms/event_handler/call_policyserv.rs index 804c77eb..e7ae1d0f 100644 --- a/src/service/rooms/event_handler/call_policyserv.rs +++ b/src/service/rooms/event_handler/call_policyserv.rs @@ -26,20 +26,11 @@ pub async fn policyserv_check(&self, pdu: &PduEvent, room_id: &RoomId) -> Result return Ok(()); }, }; - // TODO: dont do *this* - let pdu_json = self.services.timeline.get_pdu_json(pdu.event_id()).await?; let outgoing = self .services .sending - .convert_to_outgoing_federation_event(pdu_json) + .convert_to_outgoing_federation_event(pdu.to_canonical_object()) .await; - // let s = match serde_json::to_string(outgoing.as_ref()) { - // | Ok(s) => s, - // | Err(e) => { - // warn!("Failed to convert pdu {} to outgoing federation event: {e}", - // pdu.event_id()); return Err!(Request(InvalidParam("Failed to convert PDU - // to outgoing event."))); }, - // }; debug!("Checking pdu {outgoing:?} for spam with policy server {via} for room {room_id}"); let response = self .services From 964b23a4282b84d1ca10d9eaf6f056646e21a8d9 Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Sat, 19 Jul 2025 20:50:47 +0100 Subject: [PATCH 04/31] style(policy-server): Restructure logging --- src/service/rooms/event_handler/call_policyserv.rs | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/service/rooms/event_handler/call_policyserv.rs b/src/service/rooms/event_handler/call_policyserv.rs index e7ae1d0f..894e28af 100644 --- a/src/service/rooms/event_handler/call_policyserv.rs +++ b/src/service/rooms/event_handler/call_policyserv.rs @@ -43,14 +43,21 @@ pub async fn policyserv_check(&self, pdu: &PduEvent, room_id: &RoomId) -> Result let response = match response { | Ok(response) => response, | Err(e) => { - warn!("Failed to contact policy server {via} for room {room_id}: {e}"); + warn!( + via = %via, + event_id = %pdu.event_id(), + room_id = %room_id, + "Failed to contact policy server: {e}" + ); return Ok(()); }, }; if response.recommendation == "spam" { warn!( - "Event {} in room {room_id} was marked as spam by policy server {via}", - pdu.event_id().to_owned() + via = %via, + event_id = %pdu.event_id(), + room_id = %room_id, + "Event was marked as spam by policy server", ); return Err!(Request(Forbidden("Event was marked as spam by policy server"))); } From 40d789dd72d7981fc4f819ebb5c39b33a8db9400 Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Sat, 19 Jul 2025 20:54:06 +0100 Subject: [PATCH 05/31] feat(policy-server): Soft-fail redactions for failed events --- .../event_handler/upgrade_outlier_pdu.rs | 22 ++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/src/service/rooms/event_handler/upgrade_outlier_pdu.rs b/src/service/rooms/event_handler/upgrade_outlier_pdu.rs index abb5c116..e8e22fe9 100644 --- a/src/service/rooms/event_handler/upgrade_outlier_pdu.rs +++ b/src/service/rooms/event_handler/upgrade_outlier_pdu.rs @@ -222,9 +222,7 @@ where } // 14-pre. If the event is not a state event, ask the policy server about it - if incoming_pdu.state_key.is_none() - && incoming_pdu.sender().server_name() != self.services.globals.server_name() - { + if incoming_pdu.state_key.is_none() { debug!("Checking policy server for event {}", incoming_pdu.event_id); let policy = self.policyserv_check(&incoming_pdu, room_id); if let Err(e) = policy.await { @@ -236,6 +234,24 @@ where debug!("Policy server check passed for event {}", incoming_pdu.event_id); } + // Additionally, if this is a redaction for a soft-failed event, we soft-fail it + // also + if let Some(redact_id) = incoming_pdu.redacts_id(&room_version_id) { + debug!("Checking if redaction {} is for a soft-failed event", redact_id); + if self + .services + .pdu_metadata + .is_event_soft_failed(&redact_id) + .await + { + warn!( + "Redaction {} is for a soft-failed event, soft failing the redaction", + redact_id + ); + soft_fail = true; + } + } + // 14. Check if the event passes auth based on the "current state" of the room, // if not soft fail it if soft_fail { From efce67264e605ca6350e68ca5db47c4957121f07 Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Sat, 19 Jul 2025 21:09:23 +0100 Subject: [PATCH 06/31] feat(policy-server): Prevent local events that fail the policy check --- src/service/rooms/timeline/create.rs | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/service/rooms/timeline/create.rs b/src/service/rooms/timeline/create.rs index 20ccaf56..6301d785 100644 --- a/src/service/rooms/timeline/create.rs +++ b/src/service/rooms/timeline/create.rs @@ -165,6 +165,17 @@ pub async fn create_hash_and_sign_event( return Err!(Request(Forbidden("Event is not authorized."))); } + // Check with the policy server + if self + .services + .event_handler + .policyserv_check(&pdu, room_id) + .await + .is_err() + { + return Err!(Request(Forbidden(debug_warn!("Policy server marked this event as spam")))); + } + // Hash and sign let mut pdu_json = utils::to_canonical_object(&pdu).map_err(|e| { err!(Request(BadJson(warn!("Failed to convert PDU to canonical JSON: {e}")))) From fe1610ab1cd8618a0981e4279c61f761dfb3ca22 Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Sat, 19 Jul 2025 22:07:18 +0100 Subject: [PATCH 07/31] feat(policy-server): Limit policy server request timeout to 10 seconds --- .../rooms/event_handler/call_policyserv.rs | 33 +++++++++++++------ 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/src/service/rooms/event_handler/call_policyserv.rs b/src/service/rooms/event_handler/call_policyserv.rs index 894e28af..0592186a 100644 --- a/src/service/rooms/event_handler/call_policyserv.rs +++ b/src/service/rooms/event_handler/call_policyserv.rs @@ -1,3 +1,5 @@ +use std::time::Duration; + use conduwuit::{Err, Event, PduEvent, Result, debug, implement, warn}; use ruma::{ RoomId, ServerName, @@ -32,17 +34,19 @@ pub async fn policyserv_check(&self, pdu: &PduEvent, room_id: &RoomId) -> Result .convert_to_outgoing_federation_event(pdu.to_canonical_object()) .await; debug!("Checking pdu {outgoing:?} for spam with policy server {via} for room {room_id}"); - let response = self - .services - .sending - .send_federation_request(via, PolicyRequest { - event_id: pdu.event_id().to_owned(), - pdu: Some(outgoing), - }) - .await; + let response = tokio::time::timeout( + Duration::from_secs(10), + self.services + .sending + .send_federation_request(via, PolicyRequest { + event_id: pdu.event_id().to_owned(), + pdu: Some(outgoing), + }), + ) + .await; let response = match response { - | Ok(response) => response, - | Err(e) => { + | Ok(Ok(response)) => response, + | Ok(Err(e)) => { warn!( via = %via, event_id = %pdu.event_id(), @@ -51,6 +55,15 @@ pub async fn policyserv_check(&self, pdu: &PduEvent, room_id: &RoomId) -> Result ); return Ok(()); }, + | Err(_) => { + warn!( + via = %via, + event_id = %pdu.event_id(), + room_id = %room_id, + "Policy server request timed out after 10 seconds" + ); + return Ok(()); + }, }; if response.recommendation == "spam" { warn!( From 977fddf4c5f7561c48b40ffc855ddaa6ef90219a Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Sat, 19 Jul 2025 23:50:32 +0100 Subject: [PATCH 08/31] feat(policy-server): Optimise policy server lookups --- src/service/rooms/event_handler/call_policyserv.rs | 12 ++++++++++++ src/service/rooms/event_handler/mod.rs | 2 ++ 2 files changed, 14 insertions(+) diff --git a/src/service/rooms/event_handler/call_policyserv.rs b/src/service/rooms/event_handler/call_policyserv.rs index 0592186a..331d4c8f 100644 --- a/src/service/rooms/event_handler/call_policyserv.rs +++ b/src/service/rooms/event_handler/call_policyserv.rs @@ -11,6 +11,10 @@ use ruma::{ #[implement(super::Service)] #[tracing::instrument(skip_all, level = "debug")] pub async fn policyserv_check(&self, pdu: &PduEvent, room_id: &RoomId) -> Result { + if pdu.event_type().to_owned() == StateEventType::RoomPolicy.into() { + debug!("Skipping spam check for policy server meta-event in room {room_id}"); + return Ok(()); + } let Ok(policyserver) = self .services .state_accessor @@ -28,6 +32,14 @@ pub async fn policyserv_check(&self, pdu: &PduEvent, room_id: &RoomId) -> Result return Ok(()); }, }; + if via.is_empty() { + debug!("Policy server is empty for room {room_id}, skipping spam check"); + return Ok(()); + } + if !self.services.state_cache.server_in_room(via, room_id).await { + debug!("Policy server {via} is not in the room {room_id}, skipping spam check"); + return Ok(()); + } let outgoing = self .services .sending diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index 5ed25c6e..4e948e95 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -43,6 +43,7 @@ struct Services { server_keys: Dep, short: Dep, state: Dep, + state_cache: Dep, state_accessor: Dep, state_compressor: Dep, timeline: Dep, @@ -68,6 +69,7 @@ impl crate::Service for Service { pdu_metadata: args.depend::("rooms::pdu_metadata"), short: args.depend::("rooms::short"), state: args.depend::("rooms::state"), + state_cache: args.depend::("rooms::state_cache"), state_accessor: args .depend::("rooms::state_accessor"), state_compressor: args From 5454c22b5bed70479610f19ab80ff981ab952003 Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Sat, 19 Jul 2025 23:54:07 +0100 Subject: [PATCH 09/31] style(policy-server): Run clippy --- src/service/rooms/event_handler/call_policyserv.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/service/rooms/event_handler/call_policyserv.rs b/src/service/rooms/event_handler/call_policyserv.rs index 331d4c8f..96e3f7cc 100644 --- a/src/service/rooms/event_handler/call_policyserv.rs +++ b/src/service/rooms/event_handler/call_policyserv.rs @@ -11,7 +11,7 @@ use ruma::{ #[implement(super::Service)] #[tracing::instrument(skip_all, level = "debug")] pub async fn policyserv_check(&self, pdu: &PduEvent, room_id: &RoomId) -> Result { - if pdu.event_type().to_owned() == StateEventType::RoomPolicy.into() { + if *pdu.event_type() == StateEventType::RoomPolicy.into() { debug!("Skipping spam check for policy server meta-event in room {room_id}"); return Ok(()); } From 4f90379ac15503be293f68c2f77bfa99130d5904 Mon Sep 17 00:00:00 2001 From: Jade Ellis Date: Sun, 20 Jul 2025 01:03:18 +0100 Subject: [PATCH 10/31] style: Improve logging and comments --- src/core/matrix/state_res/event_auth.rs | 4 +- .../rooms/event_handler/call_policyserv.rs | 26 ++++++- .../event_handler/upgrade_outlier_pdu.rs | 73 +++++++++++++++---- 3 files changed, 82 insertions(+), 21 deletions(-) diff --git a/src/core/matrix/state_res/event_auth.rs b/src/core/matrix/state_res/event_auth.rs index 81c83431..77a4a95c 100644 --- a/src/core/matrix/state_res/event_auth.rs +++ b/src/core/matrix/state_res/event_auth.rs @@ -149,8 +149,8 @@ where for<'a> &'a E: Event + Send, { debug!( - event_id = format!("{}", incoming_event.event_id()), - event_type = format!("{}", incoming_event.event_type()), + event_id = %incoming_event.event_id(), + event_type = ?incoming_event.event_type(), "auth_check beginning" ); diff --git a/src/service/rooms/event_handler/call_policyserv.rs b/src/service/rooms/event_handler/call_policyserv.rs index 96e3f7cc..aef99dba 100644 --- a/src/service/rooms/event_handler/call_policyserv.rs +++ b/src/service/rooms/event_handler/call_policyserv.rs @@ -1,3 +1,8 @@ +//! Policy server integration for event spam checking in Matrix rooms. +//! +//! This module implements a check against a room-specific policy server, as +//! described in the relevant Matrix spec proposal (see: https://github.com/matrix-org/matrix-spec-proposals/pull/4284). + use std::time::Duration; use conduwuit::{Err, Event, PduEvent, Result, debug, implement, warn}; @@ -12,7 +17,11 @@ use ruma::{ #[tracing::instrument(skip_all, level = "debug")] pub async fn policyserv_check(&self, pdu: &PduEvent, room_id: &RoomId) -> Result { if *pdu.event_type() == StateEventType::RoomPolicy.into() { - debug!("Skipping spam check for policy server meta-event in room {room_id}"); + debug!( + room_id = %room_id, + event_type = ?pdu.event_type(), + "Skipping spam check for policy server meta-event" + ); return Ok(()); } let Ok(policyserver) = self @@ -37,7 +46,11 @@ pub async fn policyserv_check(&self, pdu: &PduEvent, room_id: &RoomId) -> Result return Ok(()); } if !self.services.state_cache.server_in_room(via, room_id).await { - debug!("Policy server {via} is not in the room {room_id}, skipping spam check"); + debug!( + room_id = %room_id, + via = %via, + "Policy server is not in the room, skipping spam check" + ); return Ok(()); } let outgoing = self @@ -45,7 +58,12 @@ pub async fn policyserv_check(&self, pdu: &PduEvent, room_id: &RoomId) -> Result .sending .convert_to_outgoing_federation_event(pdu.to_canonical_object()) .await; - debug!("Checking pdu {outgoing:?} for spam with policy server {via} for room {room_id}"); + debug!( + room_id = %room_id, + via = %via, + outgoing = ?outgoing, + "Checking event for spam with policy server" + ); let response = tokio::time::timeout( Duration::from_secs(10), self.services @@ -65,6 +83,8 @@ pub async fn policyserv_check(&self, pdu: &PduEvent, room_id: &RoomId) -> Result room_id = %room_id, "Failed to contact policy server: {e}" ); + // Network or policy server errors are treated as non-fatal: event is allowed by + // default. return Ok(()); }, | Err(_) => { diff --git a/src/service/rooms/event_handler/upgrade_outlier_pdu.rs b/src/service/rooms/event_handler/upgrade_outlier_pdu.rs index e8e22fe9..d3dc32fb 100644 --- a/src/service/rooms/event_handler/upgrade_outlier_pdu.rs +++ b/src/service/rooms/event_handler/upgrade_outlier_pdu.rs @@ -47,7 +47,10 @@ where return Err!(Request(InvalidParam("Event has been soft failed"))); } - debug!("Upgrading pdu {} from outlier to timeline pdu", incoming_pdu.event_id); + debug!( + event_id = %incoming_pdu.event_id, + "Upgrading PDU from outlier to timeline" + ); let timer = Instant::now(); let room_version_id = get_room_version_id(create_event)?; @@ -55,7 +58,10 @@ where // backwards extremities doing all the checks in this list starting at 1. // These are not timeline events. - debug!("Resolving state at event {}", incoming_pdu.event_id); + debug!( + event_id = %incoming_pdu.event_id, + "Resolving state at event" + ); let mut state_at_incoming_event = if incoming_pdu.prev_events().count() == 1 { self.state_at_incoming_degree_one(&incoming_pdu).await? } else { @@ -74,7 +80,10 @@ where let room_version = to_room_version(&room_version_id); - debug!("Performing auth check to upgrade {}", incoming_pdu.event_id); + debug!( + event_id = %incoming_pdu.event_id, + "Performing auth check to upgrade" + ); // 11. Check the auth of the event passes based on the state of the event let state_fetch_state = &state_at_incoming_event; let state_fetch = |k: StateEventType, s: StateKey| async move { @@ -84,7 +93,10 @@ where self.services.timeline.get_pdu(event_id).await.ok() }; - debug!("running auth check on {}", incoming_pdu.event_id); + debug!( + event_id = %incoming_pdu.event_id, + "Running initial auth check" + ); let auth_check = state_res::event_auth::auth_check( &room_version, &incoming_pdu, @@ -98,7 +110,10 @@ where return Err!(Request(Forbidden("Event has failed auth check with state at the event."))); } - debug!("Gathering auth events for {}", incoming_pdu.event_id); + debug!( + event_id = %incoming_pdu.event_id, + "Gathering auth events" + ); let auth_events = self .services .state @@ -116,7 +131,10 @@ where ready(auth_events.get(&key).map(ToOwned::to_owned)) }; - debug!("running auth check on {} with claimed state auth", incoming_pdu.event_id); + debug!( + event_id = %incoming_pdu.event_id, + "Running auth check with claimed state auth" + ); let auth_check = state_res::event_auth::auth_check( &room_version, &incoming_pdu, @@ -127,7 +145,10 @@ where .map_err(|e| err!(Request(Forbidden("Auth check failed: {e:?}"))))?; // Soft fail check before doing state res - debug!("Performing soft-fail check on {}", incoming_pdu.event_id); + debug!( + event_id = %incoming_pdu.event_id, + "Performing soft-fail check" + ); let mut soft_fail = match (auth_check, incoming_pdu.redacts_id(&room_version_id)) { | (false, _) => true, | (true, None) => false, @@ -142,7 +163,10 @@ where // 13. Use state resolution to find new room state // We start looking at current room state now, so lets lock the room - trace!("Locking the room"); + trace!( + room_id = %room_id, + "Locking the room" + ); let state_lock = self.services.state.mutex.lock(room_id).await; // Now we calculate the set of extremities this room has after the incoming @@ -223,21 +247,32 @@ where // 14-pre. If the event is not a state event, ask the policy server about it if incoming_pdu.state_key.is_none() { - debug!("Checking policy server for event {}", incoming_pdu.event_id); + debug!( + event_id = %incoming_pdu.event_id,"Checking policy server for event"); let policy = self.policyserv_check(&incoming_pdu, room_id); if let Err(e) = policy.await { - warn!("Policy server check failed for event {}: {e}", incoming_pdu.event_id); + warn!( + event_id = %incoming_pdu.event_id, + error = ?e, + "Policy server check failed for event" + ); if !soft_fail { soft_fail = true; } } - debug!("Policy server check passed for event {}", incoming_pdu.event_id); + debug!( + event_id = %incoming_pdu.event_id, + "Policy server check passed for event" + ); } // Additionally, if this is a redaction for a soft-failed event, we soft-fail it // also if let Some(redact_id) = incoming_pdu.redacts_id(&room_version_id) { - debug!("Checking if redaction {} is for a soft-failed event", redact_id); + debug!( + redact_id = %redact_id, + "Checking if redaction is for a soft-failed event" + ); if self .services .pdu_metadata @@ -245,8 +280,8 @@ where .await { warn!( - "Redaction {} is for a soft-failed event, soft failing the redaction", - redact_id + redact_id = %redact_id, + "Redaction is for a soft-failed event, soft failing the redaction" ); soft_fail = true; } @@ -255,7 +290,10 @@ where // 14. Check if the event passes auth based on the "current state" of the room, // if not soft fail it if soft_fail { - info!("Soft failing event {}", incoming_pdu.event_id); + info!( + event_id = %incoming_pdu.event_id, + "Soft failing event" + ); // assert!(extremities.is_empty(), "soft_fail extremities empty"); let extremities = extremities.iter().map(Borrow::borrow); @@ -276,7 +314,10 @@ where .pdu_metadata .mark_event_soft_failed(incoming_pdu.event_id()); - warn!("Event was soft failed: {:?}", incoming_pdu.event_id()); + warn!( + event_id = %incoming_pdu.event_id, + "Event was soft failed" + ); return Err!(Request(InvalidParam("Event has been soft failed"))); } From 9051ce63f7c419f2eb14a1cfcfb849bf831723db Mon Sep 17 00:00:00 2001 From: rooot Date: Sun, 20 Jul 2025 03:14:35 +0200 Subject: [PATCH 11/31] feat(config): introduce federation connection timeout setting fixes #906 Signed-off-by: rooot --- conduwuit-example.toml | 8 ++++++++ src/core/config/mod.rs | 12 ++++++++++++ src/service/client/mod.rs | 3 +++ 3 files changed, 23 insertions(+) diff --git a/conduwuit-example.toml b/conduwuit-example.toml index bdc2f570..3b7bbbb8 100644 --- a/conduwuit-example.toml +++ b/conduwuit-example.toml @@ -325,6 +325,14 @@ # #well_known_timeout = 10 +# Federation client connection timeout (seconds). You should not set this +# to high values, as dead homeservers can significantly slow down +# federation, specifically key retrieval, which will take roughly the +# amount of time you configure here given that a homeserver doesn't +# respond. +# +#federation_conn_timeout = 10 + # Federation client request timeout (seconds). You most definitely want # this to be high to account for extremely large room joins, slow # homeservers, your own resources etc. diff --git a/src/core/config/mod.rs b/src/core/config/mod.rs index d93acd9b..515409be 100644 --- a/src/core/config/mod.rs +++ b/src/core/config/mod.rs @@ -412,6 +412,16 @@ pub struct Config { #[serde(default = "default_well_known_timeout")] pub well_known_timeout: u64, + /// Federation client connection timeout (seconds). You should not set this + /// to high values, as dead homeservers can significantly slow down + /// federation, specifically key retrieval, which will take roughly the + /// amount of time you configure here given that a homeserver doesn't + /// respond. + /// + /// default: 10 + #[serde(default = "default_federation_conn_timeout")] + pub federation_conn_timeout: u64, + /// Federation client request timeout (seconds). You most definitely want /// this to be high to account for extremely large room joins, slow /// homeservers, your own resources etc. @@ -2193,6 +2203,8 @@ fn default_well_known_conn_timeout() -> u64 { 6 } fn default_well_known_timeout() -> u64 { 10 } +fn default_federation_conn_timeout() -> u64 { 10 } + fn default_federation_timeout() -> u64 { 25 } fn default_federation_idle_timeout() -> u64 { 25 } diff --git a/src/service/client/mod.rs b/src/service/client/mod.rs index 1aeeb492..239340ba 100644 --- a/src/service/client/mod.rs +++ b/src/service/client/mod.rs @@ -66,6 +66,7 @@ impl crate::Service for Service { federation: base(config)? .dns_resolver(resolver.resolver.hooked.clone()) + .connect_timeout(Duration::from_secs(config.federation_conn_timeout)) .read_timeout(Duration::from_secs(config.federation_timeout)) .pool_max_idle_per_host(config.federation_idle_per_host.into()) .pool_idle_timeout(Duration::from_secs(config.federation_idle_timeout)) @@ -74,6 +75,7 @@ impl crate::Service for Service { synapse: base(config)? .dns_resolver(resolver.resolver.hooked.clone()) + .connect_timeout(Duration::from_secs(config.federation_conn_timeout)) .read_timeout(Duration::from_secs(305)) .pool_max_idle_per_host(0) .redirect(redirect::Policy::limited(3)) @@ -81,6 +83,7 @@ impl crate::Service for Service { sender: base(config)? .dns_resolver(resolver.resolver.hooked.clone()) + .connect_timeout(Duration::from_secs(config.federation_conn_timeout)) .read_timeout(Duration::from_secs(config.sender_timeout)) .timeout(Duration::from_secs(config.sender_timeout)) .pool_max_idle_per_host(1) From 0631094350bd07b35fbbc7aa9b70a0eb74cd3b28 Mon Sep 17 00:00:00 2001 From: rooot Date: Sun, 20 Jul 2025 16:46:18 +0200 Subject: [PATCH 12/31] docs(config): warn about federation key query timeout caveat Signed-off-by: rooot --- conduwuit-example.toml | 3 ++- src/core/config/mod.rs | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/conduwuit-example.toml b/conduwuit-example.toml index 3b7bbbb8..2fab9cdf 100644 --- a/conduwuit-example.toml +++ b/conduwuit-example.toml @@ -329,7 +329,8 @@ # to high values, as dead homeservers can significantly slow down # federation, specifically key retrieval, which will take roughly the # amount of time you configure here given that a homeserver doesn't -# respond. +# respond. This will cause most clients to time out /keys/query, causing +# E2EE and device verification to fail. # #federation_conn_timeout = 10 diff --git a/src/core/config/mod.rs b/src/core/config/mod.rs index 515409be..909462db 100644 --- a/src/core/config/mod.rs +++ b/src/core/config/mod.rs @@ -416,7 +416,8 @@ pub struct Config { /// to high values, as dead homeservers can significantly slow down /// federation, specifically key retrieval, which will take roughly the /// amount of time you configure here given that a homeserver doesn't - /// respond. + /// respond. This will cause most clients to time out /keys/query, causing + /// E2EE and device verification to fail. /// /// default: 10 #[serde(default = "default_federation_conn_timeout")] From 30a8c06fd9caa4276a4261a107fc84414e36ce6c Mon Sep 17 00:00:00 2001 From: Jade Ellis Date: Sat, 19 Jul 2025 20:36:27 +0100 Subject: [PATCH 13/31] refactor: Replace std Mutex with parking_lot --- Cargo.lock | 1 + Cargo.toml | 7 +++++++ src/admin/processor.rs | 16 +++++----------- src/core/Cargo.toml | 1 + src/core/info/rustc.rs | 10 +++------- src/core/log/capture/layer.rs | 2 +- src/core/log/capture/mod.rs | 8 +++++--- src/core/log/capture/util.rs | 12 ++++++------ src/core/log/reload.rs | 16 ++++------------ src/core/mod.rs | 1 + src/core/utils/mutex_map.rs | 25 ++++++++----------------- src/macros/rustc.rs | 4 ++-- 12 files changed, 44 insertions(+), 59 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6f711007..b084f72a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -967,6 +967,7 @@ dependencies = [ "maplit", "nix", "num-traits", + "parking_lot", "rand 0.8.5", "regex", "reqwest", diff --git a/Cargo.toml b/Cargo.toml index ef917332..3e52c4b2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -515,6 +515,13 @@ version = "1.0" [workspace.dependencies.proc-macro2] version = "1.0" +[workspace.dependencies.parking_lot] +version = "0.12.4" + +# Use this when extending with_lock::WithLock to parking_lot +# [workspace.dependencies.lock_api] +# version = "0.4.13" + [workspace.dependencies.bytesize] version = "2.0" diff --git a/src/admin/processor.rs b/src/admin/processor.rs index e80000c1..2c91efe1 100644 --- a/src/admin/processor.rs +++ b/src/admin/processor.rs @@ -1,14 +1,8 @@ -use std::{ - fmt::Write, - mem::take, - panic::AssertUnwindSafe, - sync::{Arc, Mutex}, - time::SystemTime, -}; +use std::{fmt::Write, mem::take, panic::AssertUnwindSafe, sync::Arc, time::SystemTime}; use clap::{CommandFactory, Parser}; use conduwuit::{ - Error, Result, debug, error, + Error, Result, SyncMutex, debug, error, log::{ capture, capture::Capture, @@ -123,7 +117,7 @@ async fn process( let mut output = String::new(); // Prepend the logs only if any were captured - let logs = logs.lock().expect("locked"); + let logs = logs.lock(); if logs.lines().count() > 2 { writeln!(&mut output, "{logs}").expect("failed to format logs to command output"); } @@ -132,7 +126,7 @@ async fn process( (result, output) } -fn capture_create(context: &Context<'_>) -> (Arc, Arc>) { +fn capture_create(context: &Context<'_>) -> (Arc, Arc>) { let env_config = &context.services.server.config.admin_log_capture; let env_filter = EnvFilter::try_new(env_config).unwrap_or_else(|e| { warn!("admin_log_capture filter invalid: {e:?}"); @@ -152,7 +146,7 @@ fn capture_create(context: &Context<'_>) -> (Arc, Arc>) { data.level() <= log_level && data.our_modules() && data.scope.contains(&"admin") }; - let logs = Arc::new(Mutex::new( + let logs = Arc::new(SyncMutex::new( collect_stream(|s| markdown_table_head(s)).expect("markdown table header"), )); diff --git a/src/core/Cargo.toml b/src/core/Cargo.toml index 0c33c590..7a3721d6 100644 --- a/src/core/Cargo.toml +++ b/src/core/Cargo.toml @@ -110,6 +110,7 @@ tracing-core.workspace = true tracing-subscriber.workspace = true tracing.workspace = true url.workspace = true +parking_lot.workspace = true [target.'cfg(unix)'.dependencies] nix.workspace = true diff --git a/src/core/info/rustc.rs b/src/core/info/rustc.rs index 048c0cd5..60156301 100644 --- a/src/core/info/rustc.rs +++ b/src/core/info/rustc.rs @@ -3,18 +3,15 @@ //! several crates, lower-level information is supplied from each crate during //! static initialization. -use std::{ - collections::BTreeMap, - sync::{Mutex, OnceLock}, -}; +use std::{collections::BTreeMap, sync::OnceLock}; -use crate::utils::exchange; +use crate::{SyncMutex, utils::exchange}; /// Raw capture of rustc flags used to build each crate in the project. Informed /// by rustc_flags_capture macro (one in each crate's mod.rs). This is /// done during static initialization which is why it's mutex-protected and pub. /// Should not be written to by anything other than our macro. -pub static FLAGS: Mutex> = Mutex::new(BTreeMap::new()); +pub static FLAGS: SyncMutex> = SyncMutex::new(BTreeMap::new()); /// Processed list of enabled features across all project crates. This is /// generated from the data in FLAGS. @@ -27,7 +24,6 @@ fn init_features() -> Vec<&'static str> { let mut features = Vec::new(); FLAGS .lock() - .expect("locked") .iter() .for_each(|(_, flags)| append_features(&mut features, flags)); diff --git a/src/core/log/capture/layer.rs b/src/core/log/capture/layer.rs index 381a652f..b3235d91 100644 --- a/src/core/log/capture/layer.rs +++ b/src/core/log/capture/layer.rs @@ -55,7 +55,7 @@ where let mut visitor = Visitor { values: Values::new() }; event.record(&mut visitor); - let mut closure = capture.closure.lock().expect("exclusive lock"); + let mut closure = capture.closure.lock(); closure(Data { layer, event, diff --git a/src/core/log/capture/mod.rs b/src/core/log/capture/mod.rs index 20f70091..b7e5d2b5 100644 --- a/src/core/log/capture/mod.rs +++ b/src/core/log/capture/mod.rs @@ -4,7 +4,7 @@ pub mod layer; pub mod state; pub mod util; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; pub use data::Data; use guard::Guard; @@ -12,6 +12,8 @@ pub use layer::{Layer, Value}; pub use state::State; pub use util::*; +use crate::SyncMutex; + pub type Filter = dyn Fn(Data<'_>) -> bool + Send + Sync + 'static; pub type Closure = dyn FnMut(Data<'_>) + Send + Sync + 'static; @@ -19,7 +21,7 @@ pub type Closure = dyn FnMut(Data<'_>) + Send + Sync + 'static; pub struct Capture { state: Arc, filter: Option>, - closure: Mutex>, + closure: SyncMutex>, } impl Capture { @@ -34,7 +36,7 @@ impl Capture { Arc::new(Self { state: state.clone(), filter: filter.map(|p| -> Box { Box::new(p) }), - closure: Mutex::new(Box::new(closure)), + closure: SyncMutex::new(Box::new(closure)), }) } diff --git a/src/core/log/capture/util.rs b/src/core/log/capture/util.rs index 65524be5..21a416a9 100644 --- a/src/core/log/capture/util.rs +++ b/src/core/log/capture/util.rs @@ -1,31 +1,31 @@ -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use super::{ super::{Level, fmt}, Closure, Data, }; -use crate::Result; +use crate::{Result, SyncMutex}; -pub fn fmt_html(out: Arc>) -> Box +pub fn fmt_html(out: Arc>) -> Box where S: std::fmt::Write + Send + 'static, { fmt(fmt::html, out) } -pub fn fmt_markdown(out: Arc>) -> Box +pub fn fmt_markdown(out: Arc>) -> Box where S: std::fmt::Write + Send + 'static, { fmt(fmt::markdown, out) } -pub fn fmt(fun: F, out: Arc>) -> Box +pub fn fmt(fun: F, out: Arc>) -> Box where F: Fn(&mut S, &Level, &str, &str) -> Result<()> + Send + Sync + Copy + 'static, S: std::fmt::Write + Send + 'static, { - Box::new(move |data| call(fun, &mut *out.lock().expect("locked"), &data)) + Box::new(move |data| call(fun, &mut *out.lock(), &data)) } fn call(fun: F, out: &mut S, data: &Data<'_>) diff --git a/src/core/log/reload.rs b/src/core/log/reload.rs index f72fde47..356ee9f2 100644 --- a/src/core/log/reload.rs +++ b/src/core/log/reload.rs @@ -1,11 +1,8 @@ -use std::{ - collections::HashMap, - sync::{Arc, Mutex}, -}; +use std::{collections::HashMap, sync::Arc}; use tracing_subscriber::{EnvFilter, reload}; -use crate::{Result, error}; +use crate::{Result, SyncMutex, error}; /// We need to store a reload::Handle value, but can't name it's type explicitly /// because the S type parameter depends on the subscriber's previous layers. In @@ -35,7 +32,7 @@ impl ReloadHandle for reload::Handle { #[derive(Clone)] pub struct LogLevelReloadHandles { - handles: Arc>, + handles: Arc>, } type HandleMap = HashMap; @@ -43,16 +40,12 @@ type Handle = Box + Send + Sync>; impl LogLevelReloadHandles { pub fn add(&self, name: &str, handle: Handle) { - self.handles - .lock() - .expect("locked") - .insert(name.into(), handle); + self.handles.lock().insert(name.into(), handle); } pub fn reload(&self, new_value: &EnvFilter, names: Option<&[&str]>) -> Result<()> { self.handles .lock() - .expect("locked") .iter() .filter(|(name, _)| names.is_some_and(|names| names.contains(&name.as_str()))) .for_each(|(_, handle)| { @@ -66,7 +59,6 @@ impl LogLevelReloadHandles { pub fn current(&self, name: &str) -> Option { self.handles .lock() - .expect("locked") .get(name) .map(|handle| handle.current())? } diff --git a/src/core/mod.rs b/src/core/mod.rs index d99139be..363fece8 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -28,6 +28,7 @@ pub use info::{ pub use matrix::{ Event, EventTypeExt, Pdu, PduCount, PduEvent, PduId, RoomVersion, pdu, state_res, }; +pub use parking_lot::{Mutex as SyncMutex, RwLock as SyncRwLock}; pub use server::Server; pub use utils::{ctor, dtor, implement, result, result::Result}; diff --git a/src/core/utils/mutex_map.rs b/src/core/utils/mutex_map.rs index 01504ce6..ddb361a4 100644 --- a/src/core/utils/mutex_map.rs +++ b/src/core/utils/mutex_map.rs @@ -1,12 +1,8 @@ -use std::{ - fmt::Debug, - hash::Hash, - sync::{Arc, TryLockError::WouldBlock}, -}; +use std::{fmt::Debug, hash::Hash, sync::Arc}; use tokio::sync::OwnedMutexGuard as Omg; -use crate::{Result, err}; +use crate::{Result, SyncMutex, err}; /// Map of Mutexes pub struct MutexMap { @@ -19,7 +15,7 @@ pub struct Guard { } type Map = Arc>; -type MapMutex = std::sync::Mutex>; +type MapMutex = SyncMutex>; type HashMap = std::collections::HashMap>; type Value = Arc>; @@ -45,7 +41,6 @@ where let val = self .map .lock() - .expect("locked") .entry(k.try_into().expect("failed to construct key")) .or_default() .clone(); @@ -66,7 +61,6 @@ where let val = self .map .lock() - .expect("locked") .entry(k.try_into().expect("failed to construct key")) .or_default() .clone(); @@ -87,10 +81,7 @@ where let val = self .map .try_lock() - .map_err(|e| match e { - | WouldBlock => err!("would block"), - | _ => panic!("{e:?}"), - })? + .ok_or_else(|| err!("would block"))? .entry(k.try_into().expect("failed to construct key")) .or_default() .clone(); @@ -102,13 +93,13 @@ where } #[must_use] - pub fn contains(&self, k: &Key) -> bool { self.map.lock().expect("locked").contains_key(k) } + pub fn contains(&self, k: &Key) -> bool { self.map.lock().contains_key(k) } #[must_use] - pub fn is_empty(&self) -> bool { self.map.lock().expect("locked").is_empty() } + pub fn is_empty(&self) -> bool { self.map.lock().is_empty() } #[must_use] - pub fn len(&self) -> usize { self.map.lock().expect("locked").len() } + pub fn len(&self) -> usize { self.map.lock().len() } } impl Default for MutexMap @@ -123,7 +114,7 @@ impl Drop for Guard { #[tracing::instrument(name = "unlock", level = "trace", skip_all)] fn drop(&mut self) { if Arc::strong_count(Omg::mutex(&self.val)) <= 2 { - self.map.lock().expect("locked").retain(|_, val| { + self.map.lock().retain(|_, val| { !Arc::ptr_eq(val, Omg::mutex(&self.val)) || Arc::strong_count(val) > 2 }); } diff --git a/src/macros/rustc.rs b/src/macros/rustc.rs index 1220c8d4..cf935fe5 100644 --- a/src/macros/rustc.rs +++ b/src/macros/rustc.rs @@ -15,13 +15,13 @@ pub(super) fn flags_capture(args: TokenStream) -> TokenStream { #[conduwuit_core::ctor] fn _set_rustc_flags() { - conduwuit_core::info::rustc::FLAGS.lock().expect("locked").insert(#crate_name, &RUSTC_FLAGS); + conduwuit_core::info::rustc::FLAGS.lock().insert(#crate_name, &RUSTC_FLAGS); } // static strings have to be yanked on module unload #[conduwuit_core::dtor] fn _unset_rustc_flags() { - conduwuit_core::info::rustc::FLAGS.lock().expect("locked").remove(#crate_name); + conduwuit_core::info::rustc::FLAGS.lock().remove(#crate_name); } }; From a1d616e3e3dfc6ad7bdd165d9435bfc733955b73 Mon Sep 17 00:00:00 2001 From: Jade Ellis Date: Sat, 19 Jul 2025 21:03:17 +0100 Subject: [PATCH 14/31] refactor: Replace std RwLock with parking_lot --- src/core/log/capture/layer.rs | 1 - src/core/log/capture/state.rs | 14 +++++----- src/database/watchers.rs | 10 +++---- src/service/globals/data.rs | 22 +++++++-------- src/service/manager.rs | 1 - src/service/service.rs | 50 ++++++++++++++++------------------- src/service/services.rs | 18 +++++-------- src/service/uiaa/mod.rs | 10 +++---- 8 files changed, 54 insertions(+), 72 deletions(-) diff --git a/src/core/log/capture/layer.rs b/src/core/log/capture/layer.rs index b3235d91..e3fe66df 100644 --- a/src/core/log/capture/layer.rs +++ b/src/core/log/capture/layer.rs @@ -40,7 +40,6 @@ where self.state .active .read() - .expect("shared lock") .iter() .filter(|capture| filter(self, capture, event, &ctx)) .for_each(|capture| handle(self, capture, event, &ctx)); diff --git a/src/core/log/capture/state.rs b/src/core/log/capture/state.rs index dad6c8d8..92a1608f 100644 --- a/src/core/log/capture/state.rs +++ b/src/core/log/capture/state.rs @@ -1,10 +1,11 @@ -use std::sync::{Arc, RwLock}; +use std::sync::Arc; use super::Capture; +use crate::SyncRwLock; /// Capture layer state. pub struct State { - pub(super) active: RwLock>>, + pub(super) active: SyncRwLock>>, } impl Default for State { @@ -13,17 +14,14 @@ impl Default for State { impl State { #[must_use] - pub fn new() -> Self { Self { active: RwLock::new(Vec::new()) } } + pub fn new() -> Self { Self { active: SyncRwLock::new(Vec::new()) } } pub(super) fn add(&self, capture: &Arc) { - self.active - .write() - .expect("locked for writing") - .push(capture.clone()); + self.active.write().push(capture.clone()); } pub(super) fn del(&self, capture: &Arc) { - let mut vec = self.active.write().expect("locked for writing"); + let mut vec = self.active.write(); if let Some(pos) = vec.iter().position(|v| Arc::ptr_eq(v, capture)) { vec.swap_remove(pos); } diff --git a/src/database/watchers.rs b/src/database/watchers.rs index efb939d7..0e911c82 100644 --- a/src/database/watchers.rs +++ b/src/database/watchers.rs @@ -2,12 +2,12 @@ use std::{ collections::{HashMap, hash_map}, future::Future, pin::Pin, - sync::RwLock, }; +use conduwuit::SyncRwLock; use tokio::sync::watch; -type Watcher = RwLock, (watch::Sender<()>, watch::Receiver<()>)>>; +type Watcher = SyncRwLock, (watch::Sender<()>, watch::Receiver<()>)>>; #[derive(Default)] pub(crate) struct Watchers { @@ -19,7 +19,7 @@ impl Watchers { &'a self, prefix: &[u8], ) -> Pin + Send + 'a>> { - let mut rx = match self.watchers.write().unwrap().entry(prefix.to_vec()) { + let mut rx = match self.watchers.write().entry(prefix.to_vec()) { | hash_map::Entry::Occupied(o) => o.get().1.clone(), | hash_map::Entry::Vacant(v) => { let (tx, rx) = watch::channel(()); @@ -35,7 +35,7 @@ impl Watchers { } pub(crate) fn wake(&self, key: &[u8]) { - let watchers = self.watchers.read().unwrap(); + let watchers = self.watchers.read(); let mut triggered = Vec::new(); for length in 0..=key.len() { if watchers.contains_key(&key[..length]) { @@ -46,7 +46,7 @@ impl Watchers { drop(watchers); if !triggered.is_empty() { - let mut watchers = self.watchers.write().unwrap(); + let mut watchers = self.watchers.write(); for prefix in triggered { if let Some(tx) = watchers.remove(prefix) { tx.0.send(()).expect("channel should still be open"); diff --git a/src/service/globals/data.rs b/src/service/globals/data.rs index 21c09252..07f1de5c 100644 --- a/src/service/globals/data.rs +++ b/src/service/globals/data.rs @@ -1,11 +1,11 @@ -use std::sync::{Arc, RwLock}; +use std::sync::Arc; -use conduwuit::{Result, utils}; +use conduwuit::{Result, SyncRwLock, utils}; use database::{Database, Deserialized, Map}; pub struct Data { global: Arc, - counter: RwLock, + counter: SyncRwLock, pub(super) db: Arc, } @@ -16,25 +16,21 @@ impl Data { let db = &args.db; Self { global: db["global"].clone(), - counter: RwLock::new( - Self::stored_count(&db["global"]).expect("initialized global counter"), - ), + counter: SyncRwLock::new(Self::stored_count(&db["global"]).unwrap_or_default()), db: args.db.clone(), } } pub fn next_count(&self) -> Result { let _cork = self.db.cork(); - let mut lock = self.counter.write().expect("locked"); + let mut lock = self.counter.write(); let counter: &mut u64 = &mut lock; debug_assert!( - *counter == Self::stored_count(&self.global).expect("database failure"), + *counter == Self::stored_count(&self.global).unwrap_or_default(), "counter mismatch" ); - *counter = counter - .checked_add(1) - .expect("counter must not overflow u64"); + *counter = counter.checked_add(1).unwrap_or(*counter); self.global.insert(COUNTER, counter.to_be_bytes()); @@ -43,10 +39,10 @@ impl Data { #[inline] pub fn current_count(&self) -> u64 { - let lock = self.counter.read().expect("locked"); + let lock = self.counter.read(); let counter: &u64 = &lock; debug_assert!( - *counter == Self::stored_count(&self.global).expect("database failure"), + *counter == Self::stored_count(&self.global).unwrap_or_default(), "counter mismatch" ); diff --git a/src/service/manager.rs b/src/service/manager.rs index 3cdf5945..7a2e50d5 100644 --- a/src/service/manager.rs +++ b/src/service/manager.rs @@ -58,7 +58,6 @@ impl Manager { let services: Vec> = self .service .read() - .expect("locked for reading") .values() .map(|val| val.0.upgrade()) .map(|arc| arc.expect("services available for manager startup")) diff --git a/src/service/service.rs b/src/service/service.rs index 574efd8f..3bc61aeb 100644 --- a/src/service/service.rs +++ b/src/service/service.rs @@ -3,11 +3,13 @@ use std::{ collections::BTreeMap, fmt::Write, ops::Deref, - sync::{Arc, OnceLock, RwLock, Weak}, + sync::{Arc, OnceLock, Weak}, }; use async_trait::async_trait; -use conduwuit::{Err, Result, Server, err, error::inspect_log, utils::string::SplitInfallible}; +use conduwuit::{ + Err, Result, Server, SyncRwLock, err, error::inspect_log, utils::string::SplitInfallible, +}; use database::Database; /// Abstract interface for a Service @@ -62,7 +64,7 @@ pub(crate) struct Dep { name: &'static str, } -pub(crate) type Map = RwLock; +pub(crate) type Map = SyncRwLock; pub(crate) type MapType = BTreeMap; pub(crate) type MapVal = (Weak, Weak); pub(crate) type MapKey = String; @@ -143,15 +145,12 @@ pub(crate) fn get(map: &Map, name: &str) -> Option> where T: Any + Send + Sync + Sized, { - map.read() - .expect("locked for reading") - .get(name) - .map(|(_, s)| { - s.upgrade().map(|s| { - s.downcast::() - .expect("Service must be correctly downcast.") - }) - })? + map.read().get(name).map(|(_, s)| { + s.upgrade().map(|s| { + s.downcast::() + .expect("Service must be correctly downcast.") + }) + })? } /// Reference a Service by name. Returns Err if the Service does not exist or @@ -160,21 +159,18 @@ pub(crate) fn try_get(map: &Map, name: &str) -> Result> where T: Any + Send + Sync + Sized, { - map.read() - .expect("locked for reading") - .get(name) - .map_or_else( - || Err!("Service {name:?} does not exist or has not been built yet."), - |(_, s)| { - s.upgrade().map_or_else( - || Err!("Service {name:?} no longer exists."), - |s| { - s.downcast::() - .map_err(|_| err!("Service {name:?} must be correctly downcast.")) - }, - ) - }, - ) + map.read().get(name).map_or_else( + || Err!("Service {name:?} does not exist or has not been built yet."), + |(_, s)| { + s.upgrade().map_or_else( + || Err!("Service {name:?} no longer exists."), + |s| { + s.downcast::() + .map_err(|_| err!("Service {name:?} must be correctly downcast.")) + }, + ) + }, + ) } /// Utility for service implementations; see Service::name() in the trait. diff --git a/src/service/services.rs b/src/service/services.rs index daece245..642f61c7 100644 --- a/src/service/services.rs +++ b/src/service/services.rs @@ -1,10 +1,8 @@ -use std::{ - any::Any, - collections::BTreeMap, - sync::{Arc, RwLock}, -}; +use std::{any::Any, collections::BTreeMap, sync::Arc}; -use conduwuit::{Result, Server, debug, debug_info, info, trace, utils::stream::IterStream}; +use conduwuit::{ + Result, Server, SyncRwLock, debug, debug_info, info, trace, utils::stream::IterStream, +}; use database::Database; use futures::{Stream, StreamExt, TryStreamExt}; use tokio::sync::Mutex; @@ -52,7 +50,7 @@ impl Services { #[allow(clippy::cognitive_complexity)] pub async fn build(server: Arc) -> Result> { let db = Database::open(&server).await?; - let service: Arc = Arc::new(RwLock::new(BTreeMap::new())); + let service: Arc = Arc::new(SyncRwLock::new(BTreeMap::new())); macro_rules! build { ($tyname:ty) => {{ let built = <$tyname>::build(Args { @@ -193,7 +191,7 @@ impl Services { fn interrupt(&self) { debug!("Interrupting services..."); - for (name, (service, ..)) in self.service.read().expect("locked for reading").iter() { + for (name, (service, ..)) in self.service.read().iter() { if let Some(service) = service.upgrade() { trace!("Interrupting {name}"); service.interrupt(); @@ -205,7 +203,6 @@ impl Services { fn services(&self) -> impl Stream> + Send { self.service .read() - .expect("locked for reading") .values() .filter_map(|val| val.0.upgrade()) .collect::>() @@ -233,10 +230,9 @@ impl Services { #[allow(clippy::needless_pass_by_value)] fn add_service(map: &Arc, s: Arc, a: Arc) { let name = s.name(); - let len = map.read().expect("locked for reading").len(); + let len = map.read().len(); trace!("built service #{len}: {name:?}"); map.write() - .expect("locked for writing") .insert(name.to_owned(), (Arc::downgrade(&s), Arc::downgrade(&a))); } diff --git a/src/service/uiaa/mod.rs b/src/service/uiaa/mod.rs index 7735c87f..acd3dd86 100644 --- a/src/service/uiaa/mod.rs +++ b/src/service/uiaa/mod.rs @@ -1,10 +1,10 @@ use std::{ collections::{BTreeMap, HashSet}, - sync::{Arc, RwLock}, + sync::Arc, }; use conduwuit::{ - Err, Error, Result, err, error, implement, utils, + Err, Error, Result, SyncRwLock, err, error, implement, utils, utils::{hash, string::EMPTY}, }; use database::{Deserialized, Json, Map}; @@ -19,7 +19,7 @@ use ruma::{ use crate::{Dep, config, globals, users}; pub struct Service { - userdevicesessionid_uiaarequest: RwLock, + userdevicesessionid_uiaarequest: SyncRwLock, db: Data, services: Services, } @@ -42,7 +42,7 @@ pub const SESSION_ID_LENGTH: usize = 32; impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - userdevicesessionid_uiaarequest: RwLock::new(RequestMap::new()), + userdevicesessionid_uiaarequest: SyncRwLock::new(RequestMap::new()), db: Data { userdevicesessionid_uiaainfo: args.db["userdevicesessionid_uiaainfo"].clone(), }, @@ -268,7 +268,6 @@ fn set_uiaa_request( let key = (user_id.to_owned(), device_id.to_owned(), session.to_owned()); self.userdevicesessionid_uiaarequest .write() - .expect("locked for writing") .insert(key, request.to_owned()); } @@ -287,7 +286,6 @@ pub fn get_uiaa_request( self.userdevicesessionid_uiaarequest .read() - .expect("locked for reading") .get(&key) .cloned() } From 374fb2745c47e8b26e0b5daa435ace709f446498 Mon Sep 17 00:00:00 2001 From: Jade Ellis Date: Sat, 19 Jul 2025 22:05:43 +0100 Subject: [PATCH 15/31] refactor: Replace remaining std Mutexes --- src/database/engine/backup.rs | 2 +- src/database/engine/cf_opts.rs | 2 +- src/database/engine/context.rs | 15 +++--- src/database/engine/memory_usage.rs | 6 +-- src/database/engine/open.rs | 6 +-- src/database/pool.rs | 10 ++-- src/service/admin/console.rs | 56 ++++++++------------- src/service/rooms/auth_chain/data.rs | 24 +++------ src/service/rooms/auth_chain/mod.rs | 4 +- src/service/rooms/state_compressor/mod.rs | 14 +++--- src/service/sync/mod.rs | 61 +++++++++++------------ 11 files changed, 83 insertions(+), 117 deletions(-) diff --git a/src/database/engine/backup.rs b/src/database/engine/backup.rs index ac72e6d4..4cdb6172 100644 --- a/src/database/engine/backup.rs +++ b/src/database/engine/backup.rs @@ -71,7 +71,7 @@ pub fn backup_count(&self) -> Result { fn backup_engine(&self) -> Result { let path = self.backup_path()?; let options = BackupEngineOptions::new(path).map_err(map_err)?; - BackupEngine::open(&options, &*self.ctx.env.lock()?).map_err(map_err) + BackupEngine::open(&options, &self.ctx.env.lock()).map_err(map_err) } #[implement(Engine)] diff --git a/src/database/engine/cf_opts.rs b/src/database/engine/cf_opts.rs index cbbd1012..58358f02 100644 --- a/src/database/engine/cf_opts.rs +++ b/src/database/engine/cf_opts.rs @@ -232,7 +232,7 @@ fn get_cache(ctx: &Context, desc: &Descriptor) -> Option { cache_opts.set_num_shard_bits(shard_bits); cache_opts.set_capacity(size); - let mut caches = ctx.col_cache.lock().expect("locked"); + let mut caches = ctx.col_cache.lock(); match desc.cache_disp { | CacheDisp::Unique if desc.cache_size == 0 => None, | CacheDisp::Unique => { diff --git a/src/database/engine/context.rs b/src/database/engine/context.rs index 380e37af..3b9238bd 100644 --- a/src/database/engine/context.rs +++ b/src/database/engine/context.rs @@ -1,9 +1,6 @@ -use std::{ - collections::BTreeMap, - sync::{Arc, Mutex}, -}; +use std::{collections::BTreeMap, sync::Arc}; -use conduwuit::{Result, Server, debug, utils::math::usize_from_f64}; +use conduwuit::{Result, Server, SyncMutex, debug, utils::math::usize_from_f64}; use rocksdb::{Cache, Env, LruCacheOptions}; use crate::{or_else, pool::Pool}; @@ -14,9 +11,9 @@ use crate::{or_else, pool::Pool}; /// These assets are housed in the shared Context. pub(crate) struct Context { pub(crate) pool: Arc, - pub(crate) col_cache: Mutex>, - pub(crate) row_cache: Mutex, - pub(crate) env: Mutex, + pub(crate) col_cache: SyncMutex>, + pub(crate) row_cache: SyncMutex, + pub(crate) env: SyncMutex, pub(crate) server: Arc, } @@ -68,7 +65,7 @@ impl Drop for Context { debug!("Closing frontend pool"); self.pool.close(); - let mut env = self.env.lock().expect("locked"); + let mut env = self.env.lock(); debug!("Shutting down background threads"); env.set_high_priority_background_threads(0); diff --git a/src/database/engine/memory_usage.rs b/src/database/engine/memory_usage.rs index 9bb5c535..21af35c8 100644 --- a/src/database/engine/memory_usage.rs +++ b/src/database/engine/memory_usage.rs @@ -9,7 +9,7 @@ use crate::or_else; #[implement(Engine)] pub fn memory_usage(&self) -> Result { let mut res = String::new(); - let stats = get_memory_usage_stats(Some(&[&self.db]), Some(&[&*self.ctx.row_cache.lock()?])) + let stats = get_memory_usage_stats(Some(&[&self.db]), Some(&[&*self.ctx.row_cache.lock()])) .or_else(or_else)?; let mibs = |input| f64::from(u32::try_from(input / 1024).unwrap_or(0)) / 1024.0; writeln!( @@ -19,10 +19,10 @@ pub fn memory_usage(&self) -> Result { mibs(stats.mem_table_total), mibs(stats.mem_table_unflushed), mibs(stats.mem_table_readers_total), - mibs(u64::try_from(self.ctx.row_cache.lock()?.get_usage())?), + mibs(u64::try_from(self.ctx.row_cache.lock().get_usage())?), )?; - for (name, cache) in &*self.ctx.col_cache.lock()? { + for (name, cache) in &*self.ctx.col_cache.lock() { writeln!(res, "{name} cache: {:.2} MiB", mibs(u64::try_from(cache.get_usage())?))?; } diff --git a/src/database/engine/open.rs b/src/database/engine/open.rs index 84e59a6a..7b9d93c2 100644 --- a/src/database/engine/open.rs +++ b/src/database/engine/open.rs @@ -23,11 +23,7 @@ pub(crate) async fn open(ctx: Arc, desc: &[Descriptor]) -> Result, queues: Vec>, - workers: Mutex>>, + workers: SyncMutex>>, topology: Vec, busy: AtomicUsize, queued_max: AtomicUsize, @@ -115,7 +115,7 @@ impl Drop for Pool { #[implement(Pool)] #[tracing::instrument(skip_all)] pub(crate) fn close(&self) { - let workers = take(&mut *self.workers.lock().expect("locked")); + let workers = take(&mut *self.workers.lock()); let senders = self.queues.iter().map(Sender::sender_count).sum::(); @@ -154,7 +154,7 @@ pub(crate) fn close(&self) { #[implement(Pool)] fn spawn_until(self: &Arc, recv: &[Receiver], count: usize) -> Result { - let mut workers = self.workers.lock().expect("locked"); + let mut workers = self.workers.lock(); while workers.len() < count { self.clone().spawn_one(&mut workers, recv)?; } diff --git a/src/service/admin/console.rs b/src/service/admin/console.rs index 02f41303..931bb719 100644 --- a/src/service/admin/console.rs +++ b/src/service/admin/console.rs @@ -1,11 +1,8 @@ #![cfg(feature = "console")] -use std::{ - collections::VecDeque, - sync::{Arc, Mutex}, -}; +use std::{collections::VecDeque, sync::Arc}; -use conduwuit::{Server, debug, defer, error, log, log::is_systemd_mode}; +use conduwuit::{Server, SyncMutex, debug, defer, error, log, log::is_systemd_mode}; use futures::future::{AbortHandle, Abortable}; use ruma::events::room::message::RoomMessageEventContent; use rustyline_async::{Readline, ReadlineError, ReadlineEvent}; @@ -17,10 +14,10 @@ use crate::{Dep, admin}; pub struct Console { server: Arc, admin: Dep, - worker_join: Mutex>>, - input_abort: Mutex>, - command_abort: Mutex>, - history: Mutex>, + worker_join: SyncMutex>>, + input_abort: SyncMutex>, + command_abort: SyncMutex>, + history: SyncMutex>, output: MadSkin, } @@ -50,7 +47,7 @@ impl Console { } pub async fn start(self: &Arc) { - let mut worker_join = self.worker_join.lock().expect("locked"); + let mut worker_join = self.worker_join.lock(); if worker_join.is_none() { let self_ = Arc::clone(self); _ = worker_join.insert(self.server.runtime().spawn(self_.worker())); @@ -60,7 +57,7 @@ impl Console { pub async fn close(self: &Arc) { self.interrupt(); - let Some(worker_join) = self.worker_join.lock().expect("locked").take() else { + let Some(worker_join) = self.worker_join.lock().take() else { return; }; @@ -70,22 +67,18 @@ impl Console { pub fn interrupt(self: &Arc) { self.interrupt_command(); self.interrupt_readline(); - self.worker_join - .lock() - .expect("locked") - .as_ref() - .map(JoinHandle::abort); + self.worker_join.lock().as_ref().map(JoinHandle::abort); } pub fn interrupt_readline(self: &Arc) { - if let Some(input_abort) = self.input_abort.lock().expect("locked").take() { + if let Some(input_abort) = self.input_abort.lock().take() { debug!("Interrupting console readline..."); input_abort.abort(); } } pub fn interrupt_command(self: &Arc) { - if let Some(command_abort) = self.command_abort.lock().expect("locked").take() { + if let Some(command_abort) = self.command_abort.lock().take() { debug!("Interrupting console command..."); command_abort.abort(); } @@ -120,7 +113,7 @@ impl Console { } debug!("session ending"); - self.worker_join.lock().expect("locked").take(); + self.worker_join.lock().take(); } async fn readline(self: &Arc) -> Result { @@ -135,9 +128,9 @@ impl Console { let (abort, abort_reg) = AbortHandle::new_pair(); let future = Abortable::new(future, abort_reg); - _ = self.input_abort.lock().expect("locked").insert(abort); + _ = self.input_abort.lock().insert(abort); defer! {{ - _ = self.input_abort.lock().expect("locked").take(); + _ = self.input_abort.lock().take(); }} let Ok(result) = future.await else { @@ -158,9 +151,9 @@ impl Console { let (abort, abort_reg) = AbortHandle::new_pair(); let future = Abortable::new(future, abort_reg); - _ = self.command_abort.lock().expect("locked").insert(abort); + _ = self.command_abort.lock().insert(abort); defer! {{ - _ = self.command_abort.lock().expect("locked").take(); + _ = self.command_abort.lock().take(); }} _ = future.await; @@ -184,20 +177,15 @@ impl Console { } fn set_history(&self, readline: &mut Readline) { - self.history - .lock() - .expect("locked") - .iter() - .rev() - .for_each(|entry| { - readline - .add_history_entry(entry.clone()) - .expect("added history entry"); - }); + self.history.lock().iter().rev().for_each(|entry| { + readline + .add_history_entry(entry.clone()) + .expect("added history entry"); + }); } fn add_history(&self, line: String) { - let mut history = self.history.lock().expect("locked"); + let mut history = self.history.lock(); history.push_front(line); history.truncate(HISTORY_LIMIT); } diff --git a/src/service/rooms/auth_chain/data.rs b/src/service/rooms/auth_chain/data.rs index 8c3588cc..e9e40979 100644 --- a/src/service/rooms/auth_chain/data.rs +++ b/src/service/rooms/auth_chain/data.rs @@ -1,9 +1,6 @@ -use std::{ - mem::size_of, - sync::{Arc, Mutex}, -}; +use std::{mem::size_of, sync::Arc}; -use conduwuit::{Err, Result, err, utils, utils::math::usize_from_f64}; +use conduwuit::{Err, Result, SyncMutex, err, utils, utils::math::usize_from_f64}; use database::Map; use lru_cache::LruCache; @@ -11,7 +8,7 @@ use crate::rooms::short::ShortEventId; pub(super) struct Data { shorteventid_authchain: Arc, - pub(super) auth_chain_cache: Mutex, Arc<[ShortEventId]>>>, + pub(super) auth_chain_cache: SyncMutex, Arc<[ShortEventId]>>>, } impl Data { @@ -23,7 +20,7 @@ impl Data { .expect("valid cache size"); Self { shorteventid_authchain: db["shorteventid_authchain"].clone(), - auth_chain_cache: Mutex::new(LruCache::new(cache_size)), + auth_chain_cache: SyncMutex::new(LruCache::new(cache_size)), } } @@ -34,12 +31,7 @@ impl Data { debug_assert!(!key.is_empty(), "auth_chain key must not be empty"); // Check RAM cache - if let Some(result) = self - .auth_chain_cache - .lock() - .expect("cache locked") - .get_mut(key) - { + if let Some(result) = self.auth_chain_cache.lock().get_mut(key) { return Ok(Arc::clone(result)); } @@ -63,7 +55,6 @@ impl Data { // Cache in RAM self.auth_chain_cache .lock() - .expect("cache locked") .insert(vec![key[0]], Arc::clone(&chain)); Ok(chain) @@ -84,9 +75,6 @@ impl Data { } // Cache in RAM - self.auth_chain_cache - .lock() - .expect("cache locked") - .insert(key, auth_chain); + self.auth_chain_cache.lock().insert(key, auth_chain); } } diff --git a/src/service/rooms/auth_chain/mod.rs b/src/service/rooms/auth_chain/mod.rs index 0903ea75..79d4d070 100644 --- a/src/service/rooms/auth_chain/mod.rs +++ b/src/service/rooms/auth_chain/mod.rs @@ -248,10 +248,10 @@ pub fn cache_auth_chain_vec(&self, key: Vec, auth_chain: &[ShortEventId]) { #[implement(Service)] pub fn get_cache_usage(&self) -> (usize, usize) { - let cache = self.db.auth_chain_cache.lock().expect("locked"); + let cache = self.db.auth_chain_cache.lock(); (cache.len(), cache.capacity()) } #[implement(Service)] -pub fn clear_cache(&self) { self.db.auth_chain_cache.lock().expect("locked").clear(); } +pub fn clear_cache(&self) { self.db.auth_chain_cache.lock().clear(); } diff --git a/src/service/rooms/state_compressor/mod.rs b/src/service/rooms/state_compressor/mod.rs index a33fb342..f7f7d043 100644 --- a/src/service/rooms/state_compressor/mod.rs +++ b/src/service/rooms/state_compressor/mod.rs @@ -2,12 +2,12 @@ use std::{ collections::{BTreeSet, HashMap}, fmt::{Debug, Write}, mem::size_of, - sync::{Arc, Mutex}, + sync::Arc, }; use async_trait::async_trait; use conduwuit::{ - Result, + Result, SyncMutex, arrayvec::ArrayVec, at, checked, err, expected, implement, utils, utils::{bytes, math::usize_from_f64, stream::IterStream}, @@ -23,7 +23,7 @@ use crate::{ }; pub struct Service { - pub stateinfo_cache: Mutex, + pub stateinfo_cache: SyncMutex, db: Data, services: Services, } @@ -86,7 +86,7 @@ impl crate::Service for Service { async fn memory_usage(&self, out: &mut (dyn Write + Send)) -> Result { let (cache_len, ents) = { - let cache = self.stateinfo_cache.lock().expect("locked"); + let cache = self.stateinfo_cache.lock(); let ents = cache.iter().map(at!(1)).flat_map(|vec| vec.iter()).fold( HashMap::new(), |mut ents, ssi| { @@ -110,7 +110,7 @@ impl crate::Service for Service { Ok(()) } - async fn clear_cache(&self) { self.stateinfo_cache.lock().expect("locked").clear(); } + async fn clear_cache(&self) { self.stateinfo_cache.lock().clear(); } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } @@ -123,7 +123,7 @@ pub async fn load_shortstatehash_info( &self, shortstatehash: ShortStateHash, ) -> Result { - if let Some(r) = self.stateinfo_cache.lock()?.get_mut(&shortstatehash) { + if let Some(r) = self.stateinfo_cache.lock().get_mut(&shortstatehash) { return Ok(r.clone()); } @@ -152,7 +152,7 @@ async fn cache_shortstatehash_info( shortstatehash: ShortStateHash, stack: ShortStateInfoVec, ) -> Result { - self.stateinfo_cache.lock()?.insert(shortstatehash, stack); + self.stateinfo_cache.lock().insert(shortstatehash, stack); Ok(()) } diff --git a/src/service/sync/mod.rs b/src/service/sync/mod.rs index b095d2c1..6ac579f4 100644 --- a/src/service/sync/mod.rs +++ b/src/service/sync/mod.rs @@ -2,10 +2,10 @@ mod watch; use std::{ collections::{BTreeMap, BTreeSet}, - sync::{Arc, Mutex, Mutex as StdMutex}, + sync::Arc, }; -use conduwuit::{Result, Server}; +use conduwuit::{Result, Server, SyncMutex}; use database::Map; use ruma::{ OwnedDeviceId, OwnedRoomId, OwnedUserId, @@ -62,11 +62,11 @@ struct SnakeSyncCache { extensions: v5::request::Extensions, } -type DbConnections = Mutex>; +type DbConnections = SyncMutex>; type DbConnectionsKey = (OwnedUserId, OwnedDeviceId, String); -type DbConnectionsVal = Arc>; +type DbConnectionsVal = Arc>; type SnakeConnectionsKey = (OwnedUserId, OwnedDeviceId, Option); -type SnakeConnectionsVal = Arc>; +type SnakeConnectionsVal = Arc>; impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { @@ -90,8 +90,8 @@ impl crate::Service for Service { state_cache: args.depend::("rooms::state_cache"), typing: args.depend::("rooms::typing"), }, - connections: StdMutex::new(BTreeMap::new()), - snake_connections: StdMutex::new(BTreeMap::new()), + connections: SyncMutex::new(BTreeMap::new()), + snake_connections: SyncMutex::new(BTreeMap::new()), })) } @@ -100,22 +100,19 @@ impl crate::Service for Service { impl Service { pub fn snake_connection_cached(&self, key: &SnakeConnectionsKey) -> bool { - self.snake_connections - .lock() - .expect("locked") - .contains_key(key) + self.snake_connections.lock().contains_key(key) } pub fn forget_snake_sync_connection(&self, key: &SnakeConnectionsKey) { - self.snake_connections.lock().expect("locked").remove(key); + self.snake_connections.lock().remove(key); } pub fn remembered(&self, key: &DbConnectionsKey) -> bool { - self.connections.lock().expect("locked").contains_key(key) + self.connections.lock().contains_key(key) } pub fn forget_sync_request_connection(&self, key: &DbConnectionsKey) { - self.connections.lock().expect("locked").remove(key); + self.connections.lock().remove(key); } pub fn update_snake_sync_request_with_cache( @@ -123,13 +120,13 @@ impl Service { snake_key: &SnakeConnectionsKey, request: &mut v5::Request, ) -> BTreeMap> { - let mut cache = self.snake_connections.lock().expect("locked"); + let mut cache = self.snake_connections.lock(); let cached = Arc::clone( cache .entry(snake_key.clone()) - .or_insert_with(|| Arc::new(Mutex::new(SnakeSyncCache::default()))), + .or_insert_with(|| Arc::new(SyncMutex::new(SnakeSyncCache::default()))), ); - let cached = &mut cached.lock().expect("locked"); + let cached = &mut cached.lock(); drop(cache); //v5::Request::try_from_http_request(req, path_args); @@ -232,16 +229,16 @@ impl Service { }; let key = into_db_key(key.0.clone(), key.1.clone(), conn_id); - let mut cache = self.connections.lock().expect("locked"); + let mut cache = self.connections.lock(); let cached = Arc::clone(cache.entry(key).or_insert_with(|| { - Arc::new(Mutex::new(SlidingSyncCache { + Arc::new(SyncMutex::new(SlidingSyncCache { lists: BTreeMap::new(), subscriptions: BTreeMap::new(), known_rooms: BTreeMap::new(), extensions: ExtensionsConfig::default(), })) })); - let cached = &mut cached.lock().expect("locked"); + let cached = &mut cached.lock(); drop(cache); for (list_id, list) in &mut request.lists { @@ -328,16 +325,16 @@ impl Service { key: &DbConnectionsKey, subscriptions: BTreeMap, ) { - let mut cache = self.connections.lock().expect("locked"); + let mut cache = self.connections.lock(); let cached = Arc::clone(cache.entry(key.clone()).or_insert_with(|| { - Arc::new(Mutex::new(SlidingSyncCache { + Arc::new(SyncMutex::new(SlidingSyncCache { lists: BTreeMap::new(), subscriptions: BTreeMap::new(), known_rooms: BTreeMap::new(), extensions: ExtensionsConfig::default(), })) })); - let cached = &mut cached.lock().expect("locked"); + let cached = &mut cached.lock(); drop(cache); cached.subscriptions = subscriptions; @@ -350,16 +347,16 @@ impl Service { new_cached_rooms: BTreeSet, globalsince: u64, ) { - let mut cache = self.connections.lock().expect("locked"); + let mut cache = self.connections.lock(); let cached = Arc::clone(cache.entry(key.clone()).or_insert_with(|| { - Arc::new(Mutex::new(SlidingSyncCache { + Arc::new(SyncMutex::new(SlidingSyncCache { lists: BTreeMap::new(), subscriptions: BTreeMap::new(), known_rooms: BTreeMap::new(), extensions: ExtensionsConfig::default(), })) })); - let cached = &mut cached.lock().expect("locked"); + let cached = &mut cached.lock(); drop(cache); for (room_id, lastsince) in cached @@ -386,13 +383,13 @@ impl Service { globalsince: u64, ) { assert!(key.2.is_some(), "Some(conn_id) required for this call"); - let mut cache = self.snake_connections.lock().expect("locked"); + let mut cache = self.snake_connections.lock(); let cached = Arc::clone( cache .entry(key.clone()) - .or_insert_with(|| Arc::new(Mutex::new(SnakeSyncCache::default()))), + .or_insert_with(|| Arc::new(SyncMutex::new(SnakeSyncCache::default()))), ); - let cached = &mut cached.lock().expect("locked"); + let cached = &mut cached.lock(); drop(cache); for (room_id, lastsince) in cached @@ -416,13 +413,13 @@ impl Service { key: &SnakeConnectionsKey, subscriptions: BTreeMap, ) { - let mut cache = self.snake_connections.lock().expect("locked"); + let mut cache = self.snake_connections.lock(); let cached = Arc::clone( cache .entry(key.clone()) - .or_insert_with(|| Arc::new(Mutex::new(SnakeSyncCache::default()))), + .or_insert_with(|| Arc::new(SyncMutex::new(SnakeSyncCache::default()))), ); - let cached = &mut cached.lock().expect("locked"); + let cached = &mut cached.lock(); drop(cache); cached.subscriptions = subscriptions; From 6d29098d1af955d98ec57da4593836e67e8df090 Mon Sep 17 00:00:00 2001 From: Jade Ellis Date: Sat, 19 Jul 2025 22:20:26 +0100 Subject: [PATCH 16/31] refactor: Replace remaining std RwLocks --- src/admin/federation/commands.rs | 3 +-- src/admin/mod.rs | 12 ++------- src/core/alloc/je.rs | 7 +++--- src/service/admin/mod.rs | 25 +++++++------------ src/service/globals/mod.rs | 22 +++++----------- .../fetch_and_handle_outliers.rs | 3 --- .../event_handler/handle_incoming_pdu.rs | 3 --- .../rooms/event_handler/handle_prev_pdu.rs | 3 --- src/service/rooms/event_handler/mod.rs | 17 +++---------- src/service/rooms/state_cache/mod.rs | 22 +++++----------- src/service/rooms/state_cache/update.rs | 5 +--- 11 files changed, 32 insertions(+), 90 deletions(-) diff --git a/src/admin/federation/commands.rs b/src/admin/federation/commands.rs index 545dcbca..f77dadab 100644 --- a/src/admin/federation/commands.rs +++ b/src/admin/federation/commands.rs @@ -26,8 +26,7 @@ pub(super) async fn incoming_federation(&self) -> Result { .rooms .event_handler .federation_handletime - .read() - .expect("locked"); + .read(); let mut msg = format!("Handling {} incoming pdus:\n", map.len()); for (r, (e, i)) in map.iter() { diff --git a/src/admin/mod.rs b/src/admin/mod.rs index 732b8ce0..1d46590b 100644 --- a/src/admin/mod.rs +++ b/src/admin/mod.rs @@ -37,11 +37,7 @@ pub use crate::admin::AdminCommand; /// Install the admin command processor pub async fn init(admin_service: &service::admin::Service) { - _ = admin_service - .complete - .write() - .expect("locked for writing") - .insert(processor::complete); + _ = admin_service.complete.write().insert(processor::complete); _ = admin_service .handle .write() @@ -52,9 +48,5 @@ pub async fn init(admin_service: &service::admin::Service) { /// Uninstall the admin command handler pub async fn fini(admin_service: &service::admin::Service) { _ = admin_service.handle.write().await.take(); - _ = admin_service - .complete - .write() - .expect("locked for writing") - .take(); + _ = admin_service.complete.write().take(); } diff --git a/src/core/alloc/je.rs b/src/core/alloc/je.rs index e138233e..77deebc5 100644 --- a/src/core/alloc/je.rs +++ b/src/core/alloc/je.rs @@ -4,7 +4,6 @@ use std::{ cell::OnceCell, ffi::{CStr, c_char, c_void}, fmt::Debug, - sync::RwLock, }; use arrayvec::ArrayVec; @@ -13,7 +12,7 @@ use tikv_jemalloc_sys as ffi; use tikv_jemallocator as jemalloc; use crate::{ - Result, err, is_equal_to, is_nonzero, + Result, SyncRwLock, err, is_equal_to, is_nonzero, utils::{math, math::Tried}, }; @@ -40,7 +39,7 @@ const MALLOC_CONF_PROF: &str = ""; #[global_allocator] static JEMALLOC: jemalloc::Jemalloc = jemalloc::Jemalloc; -static CONTROL: RwLock<()> = RwLock::new(()); +static CONTROL: SyncRwLock<()> = SyncRwLock::new(()); type Name = ArrayVec; type Key = ArrayVec; @@ -332,7 +331,7 @@ fn set(key: &Key, val: T) -> Result where T: Copy + Debug, { - let _lock = CONTROL.write()?; + let _lock = CONTROL.write(); let res = xchg(key, val)?; inc_epoch()?; diff --git a/src/service/admin/mod.rs b/src/service/admin/mod.rs index f496c414..c052198c 100644 --- a/src/service/admin/mod.rs +++ b/src/service/admin/mod.rs @@ -5,11 +5,11 @@ mod grant; use std::{ pin::Pin, - sync::{Arc, RwLock as StdRwLock, Weak}, + sync::{Arc, Weak}, }; use async_trait::async_trait; -use conduwuit::{Err, utils}; +use conduwuit::{Err, SyncRwLock, utils}; use conduwuit_core::{ Error, Event, Result, Server, debug, err, error, error::default_log, pdu::PduBuilder, }; @@ -36,7 +36,7 @@ pub struct Service { services: Services, channel: (Sender, Receiver), pub handle: RwLock>, - pub complete: StdRwLock>, + pub complete: SyncRwLock>, #[cfg(feature = "console")] pub console: Arc, } @@ -50,7 +50,7 @@ struct Services { state_cache: Dep, state_accessor: Dep, account_data: Dep, - services: StdRwLock>>, + services: SyncRwLock>>, media: Dep, } @@ -105,7 +105,7 @@ impl crate::Service for Service { }, channel: loole::bounded(COMMAND_QUEUE_LIMIT), handle: RwLock::new(None), - complete: StdRwLock::new(None), + complete: SyncRwLock::new(None), #[cfg(feature = "console")] console: console::Console::new(&args), })) @@ -312,10 +312,7 @@ impl Service { /// Invokes the tab-completer to complete the command. When unavailable, /// None is returned. pub fn complete_command(&self, command: &str) -> Option { - self.complete - .read() - .expect("locked for reading") - .map(|complete| complete(command)) + self.complete.read().map(|complete| complete(command)) } async fn handle_signal(&self, sig: &'static str) { @@ -338,17 +335,13 @@ impl Service { } async fn process_command(&self, command: CommandInput) -> ProcessorResult { - let handle = &self - .handle - .read() - .await - .expect("Admin module is not loaded"); + let handle_guard = self.handle.read().await; + let handle = handle_guard.as_ref().expect("Admin module is not loaded"); let services = self .services .services .read() - .expect("locked") .as_ref() .and_then(Weak::upgrade) .expect("Services self-reference not initialized."); @@ -523,7 +516,7 @@ impl Service { /// Sets the self-reference to crate::Services which will provide context to /// the admin commands. pub(super) fn set_services(&self, services: Option<&Arc>) { - let receiver = &mut *self.services.services.write().expect("locked for writing"); + let receiver = &mut *self.services.services.write(); let weak = services.map(Arc::downgrade); *receiver = weak; } diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index a23a4c21..12f2ec78 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -1,14 +1,9 @@ mod data; -use std::{ - collections::HashMap, - fmt::Write, - sync::{Arc, RwLock}, - time::Instant, -}; +use std::{collections::HashMap, fmt::Write, sync::Arc, time::Instant}; use async_trait::async_trait; -use conduwuit::{Result, Server, error, utils::bytes::pretty}; +use conduwuit::{Result, Server, SyncRwLock, error, utils::bytes::pretty}; use data::Data; use regex::RegexSet; use ruma::{OwnedEventId, OwnedRoomAliasId, OwnedServerName, OwnedUserId, ServerName, UserId}; @@ -19,7 +14,7 @@ pub struct Service { pub db: Data, server: Arc, - pub bad_event_ratelimiter: Arc>>, + pub bad_event_ratelimiter: Arc>>, pub server_user: OwnedUserId, pub admin_alias: OwnedRoomAliasId, pub turn_secret: String, @@ -62,7 +57,7 @@ impl crate::Service for Service { Ok(Arc::new(Self { db, server: args.server.clone(), - bad_event_ratelimiter: Arc::new(RwLock::new(HashMap::new())), + bad_event_ratelimiter: Arc::new(SyncRwLock::new(HashMap::new())), admin_alias: OwnedRoomAliasId::try_from(format!("#admins:{}", &args.server.name)) .expect("#admins:server_name is valid alias name"), server_user: UserId::parse_with_server_name( @@ -76,7 +71,7 @@ impl crate::Service for Service { } async fn memory_usage(&self, out: &mut (dyn Write + Send)) -> Result { - let (ber_count, ber_bytes) = self.bad_event_ratelimiter.read()?.iter().fold( + let (ber_count, ber_bytes) = self.bad_event_ratelimiter.read().iter().fold( (0_usize, 0_usize), |(mut count, mut bytes), (event_id, _)| { bytes = bytes.saturating_add(event_id.capacity()); @@ -91,12 +86,7 @@ impl crate::Service for Service { Ok(()) } - async fn clear_cache(&self) { - self.bad_event_ratelimiter - .write() - .expect("locked for writing") - .clear(); - } + async fn clear_cache(&self) { self.bad_event_ratelimiter.write().clear(); } fn name(&self) -> &str { service::make_name(std::module_path!()) } } diff --git a/src/service/rooms/event_handler/fetch_and_handle_outliers.rs b/src/service/rooms/event_handler/fetch_and_handle_outliers.rs index 44027e04..59b768f2 100644 --- a/src/service/rooms/event_handler/fetch_and_handle_outliers.rs +++ b/src/service/rooms/event_handler/fetch_and_handle_outliers.rs @@ -41,7 +41,6 @@ where .globals .bad_event_ratelimiter .write() - .expect("locked") .entry(id) { | hash_map::Entry::Vacant(e) => { @@ -76,7 +75,6 @@ where .globals .bad_event_ratelimiter .read() - .expect("locked") .get(&*next_id) { // Exponential backoff @@ -187,7 +185,6 @@ where .globals .bad_event_ratelimiter .read() - .expect("locked") .get(&*next_id) { // Exponential backoff diff --git a/src/service/rooms/event_handler/handle_incoming_pdu.rs b/src/service/rooms/event_handler/handle_incoming_pdu.rs index 86a05e0a..5299e8d4 100644 --- a/src/service/rooms/event_handler/handle_incoming_pdu.rs +++ b/src/service/rooms/event_handler/handle_incoming_pdu.rs @@ -160,7 +160,6 @@ pub async fn handle_incoming_pdu<'a>( .globals .bad_event_ratelimiter .write() - .expect("locked") .entry(prev_id.into()) { | hash_map::Entry::Vacant(e) => { @@ -181,13 +180,11 @@ pub async fn handle_incoming_pdu<'a>( let start_time = Instant::now(); self.federation_handletime .write() - .expect("locked") .insert(room_id.into(), (event_id.to_owned(), start_time)); defer! {{ self.federation_handletime .write() - .expect("locked") .remove(room_id); }}; diff --git a/src/service/rooms/event_handler/handle_prev_pdu.rs b/src/service/rooms/event_handler/handle_prev_pdu.rs index cd46310a..cb4978d9 100644 --- a/src/service/rooms/event_handler/handle_prev_pdu.rs +++ b/src/service/rooms/event_handler/handle_prev_pdu.rs @@ -42,7 +42,6 @@ where .globals .bad_event_ratelimiter .read() - .expect("locked") .get(prev_id) { // Exponential backoff @@ -70,13 +69,11 @@ where let start_time = Instant::now(); self.federation_handletime .write() - .expect("locked") .insert(room_id.into(), ((*prev_id).to_owned(), start_time)); defer! {{ self.federation_handletime .write() - .expect("locked") .remove(room_id); }}; diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index aed38e1e..4e59c207 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -10,15 +10,10 @@ mod resolve_state; mod state_at_incoming; mod upgrade_outlier_pdu; -use std::{ - collections::HashMap, - fmt::Write, - sync::{Arc, RwLock as StdRwLock}, - time::Instant, -}; +use std::{collections::HashMap, fmt::Write, sync::Arc, time::Instant}; use async_trait::async_trait; -use conduwuit::{Err, Event, PduEvent, Result, RoomVersion, Server, utils::MutexMap}; +use conduwuit::{Err, Event, PduEvent, Result, RoomVersion, Server, SyncRwLock, utils::MutexMap}; use ruma::{ OwnedEventId, OwnedRoomId, RoomId, RoomVersionId, events::room::create::RoomCreateEventContent, @@ -28,7 +23,7 @@ use crate::{Dep, globals, rooms, sending, server_keys}; pub struct Service { pub mutex_federation: RoomMutexMap, - pub federation_handletime: StdRwLock, + pub federation_handletime: SyncRwLock, services: Services, } @@ -81,11 +76,7 @@ impl crate::Service for Service { let mutex_federation = self.mutex_federation.len(); writeln!(out, "federation_mutex: {mutex_federation}")?; - let federation_handletime = self - .federation_handletime - .read() - .expect("locked for reading") - .len(); + let federation_handletime = self.federation_handletime.read().len(); writeln!(out, "federation_handletime: {federation_handletime}")?; Ok(()) diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index 9429be79..e9845fbf 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -1,13 +1,10 @@ mod update; mod via; -use std::{ - collections::HashMap, - sync::{Arc, RwLock}, -}; +use std::{collections::HashMap, sync::Arc}; use conduwuit::{ - Result, implement, + Result, SyncRwLock, implement, result::LogErr, utils::{ReadyExt, stream::TryIgnore}, warn, @@ -54,14 +51,14 @@ struct Data { userroomid_knockedstate: Arc, } -type AppServiceInRoomCache = RwLock>>; +type AppServiceInRoomCache = SyncRwLock>>; type StrippedStateEventItem = (OwnedRoomId, Vec>); type SyncStateEventItem = (OwnedRoomId, Vec>); impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - appservice_in_room_cache: RwLock::new(HashMap::new()), + appservice_in_room_cache: SyncRwLock::new(HashMap::new()), services: Services { account_data: args.depend::("account_data"), config: args.depend::("config"), @@ -99,7 +96,6 @@ pub async fn appservice_in_room(&self, room_id: &RoomId, appservice: &Registrati if let Some(cached) = self .appservice_in_room_cache .read() - .expect("locked") .get(room_id) .and_then(|map| map.get(&appservice.registration.id)) .copied() @@ -124,7 +120,6 @@ pub async fn appservice_in_room(&self, room_id: &RoomId, appservice: &Registrati self.appservice_in_room_cache .write() - .expect("locked") .entry(room_id.into()) .or_default() .insert(appservice.registration.id.clone(), in_room); @@ -134,19 +129,14 @@ pub async fn appservice_in_room(&self, room_id: &RoomId, appservice: &Registrati #[implement(Service)] pub fn get_appservice_in_room_cache_usage(&self) -> (usize, usize) { - let cache = self.appservice_in_room_cache.read().expect("locked"); + let cache = self.appservice_in_room_cache.read(); (cache.len(), cache.capacity()) } #[implement(Service)] #[tracing::instrument(level = "debug", skip_all)] -pub fn clear_appservice_in_room_cache(&self) { - self.appservice_in_room_cache - .write() - .expect("locked") - .clear(); -} +pub fn clear_appservice_in_room_cache(&self) { self.appservice_in_room_cache.write().clear(); } /// Returns an iterator of all servers participating in this room. #[implement(Service)] diff --git a/src/service/rooms/state_cache/update.rs b/src/service/rooms/state_cache/update.rs index 02c6bec6..32c67947 100644 --- a/src/service/rooms/state_cache/update.rs +++ b/src/service/rooms/state_cache/update.rs @@ -211,10 +211,7 @@ pub async fn update_joined_count(&self, room_id: &RoomId) { self.db.serverroomids.put_raw(serverroom_id, []); } - self.appservice_in_room_cache - .write() - .expect("locked") - .remove(room_id); + self.appservice_in_room_cache.write().remove(room_id); } /// Direct DB function to directly mark a user as joined. It is not From b635e825d2ba002170ff1c25a26270f875699ca5 Mon Sep 17 00:00:00 2001 From: Jade Ellis Date: Sat, 19 Jul 2025 22:30:41 +0100 Subject: [PATCH 17/31] refactor: Implement with_lock for lock_api --- Cargo.lock | 1 + Cargo.toml | 4 ++-- src/core/Cargo.toml | 1 + src/core/utils/with_lock.rs | 26 +++++++++++++++++++++++++- 4 files changed, 29 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b084f72a..ed9be6d9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -963,6 +963,7 @@ dependencies = [ "itertools 0.14.0", "libc", "libloading", + "lock_api", "log", "maplit", "nix", diff --git a/Cargo.toml b/Cargo.toml index 3e52c4b2..54f7ae82 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -519,8 +519,8 @@ version = "1.0" version = "0.12.4" # Use this when extending with_lock::WithLock to parking_lot -# [workspace.dependencies.lock_api] -# version = "0.4.13" +[workspace.dependencies.lock_api] +version = "0.4.13" [workspace.dependencies.bytesize] version = "2.0" diff --git a/src/core/Cargo.toml b/src/core/Cargo.toml index 7a3721d6..462b8e54 100644 --- a/src/core/Cargo.toml +++ b/src/core/Cargo.toml @@ -111,6 +111,7 @@ tracing-subscriber.workspace = true tracing.workspace = true url.workspace = true parking_lot.workspace = true +lock_api.workspace = true [target.'cfg(unix)'.dependencies] nix.workspace = true diff --git a/src/core/utils/with_lock.rs b/src/core/utils/with_lock.rs index 76f014d1..914749de 100644 --- a/src/core/utils/with_lock.rs +++ b/src/core/utils/with_lock.rs @@ -2,7 +2,7 @@ use std::sync::{Arc, Mutex}; -pub trait WithLock { +pub trait WithLock { /// Acquires a lock and executes the given closure with the locked data. fn with_lock(&self, f: F) where @@ -33,6 +33,30 @@ impl WithLock for Arc> { } } +impl WithLock for lock_api::Mutex { + fn with_lock(&self, mut f: F) + where + F: FnMut(&mut T), + { + // The locking and unlocking logic is hidden inside this function. + let mut data_guard = self.lock(); + f(&mut data_guard); + // Lock is released here when `data_guard` goes out of scope. + } +} + +impl WithLock for Arc> { + fn with_lock(&self, mut f: F) + where + F: FnMut(&mut T), + { + // The locking and unlocking logic is hidden inside this function. + let mut data_guard = self.lock(); + f(&mut data_guard); + // Lock is released here when `data_guard` goes out of scope. + } +} + pub trait WithLockAsync { /// Acquires a lock and executes the given closure with the locked data. fn with_lock(&self, f: F) -> impl Future From 1c985c59f57579b0be1437e711bc939954da99c1 Mon Sep 17 00:00:00 2001 From: Jade Ellis Date: Sat, 19 Jul 2025 23:30:31 +0100 Subject: [PATCH 18/31] refactor: Allow with_lock to return data and take an async closure --- src/core/utils/with_lock.rs | 173 ++++++++++++++++++++++++++++++------ 1 file changed, 148 insertions(+), 25 deletions(-) diff --git a/src/core/utils/with_lock.rs b/src/core/utils/with_lock.rs index 914749de..91e8e8d1 100644 --- a/src/core/utils/with_lock.rs +++ b/src/core/utils/with_lock.rs @@ -1,89 +1,212 @@ //! Traits for explicitly scoping the lifetime of locks. -use std::sync::{Arc, Mutex}; +use std::{ + future::Future, + sync::{Arc, Mutex}, +}; pub trait WithLock { - /// Acquires a lock and executes the given closure with the locked data. - fn with_lock(&self, f: F) + /// Acquires a lock and executes the given closure with the locked data, + /// returning the result. + fn with_lock(&self, f: F) -> R where - F: FnMut(&mut T); + F: FnMut(&mut T) -> R; } impl WithLock for Mutex { - fn with_lock(&self, mut f: F) + fn with_lock(&self, mut f: F) -> R where - F: FnMut(&mut T), + F: FnMut(&mut T) -> R, { // The locking and unlocking logic is hidden inside this function. let mut data_guard = self.lock().unwrap(); - f(&mut data_guard); + f(&mut data_guard) // Lock is released here when `data_guard` goes out of scope. } } impl WithLock for Arc> { - fn with_lock(&self, mut f: F) + fn with_lock(&self, mut f: F) -> R where - F: FnMut(&mut T), + F: FnMut(&mut T) -> R, { // The locking and unlocking logic is hidden inside this function. let mut data_guard = self.lock().unwrap(); - f(&mut data_guard); + f(&mut data_guard) // Lock is released here when `data_guard` goes out of scope. } } impl WithLock for lock_api::Mutex { - fn with_lock(&self, mut f: F) + fn with_lock(&self, mut f: F) -> Ret where - F: FnMut(&mut T), + F: FnMut(&mut T) -> Ret, { // The locking and unlocking logic is hidden inside this function. let mut data_guard = self.lock(); - f(&mut data_guard); + f(&mut data_guard) // Lock is released here when `data_guard` goes out of scope. } } impl WithLock for Arc> { - fn with_lock(&self, mut f: F) + fn with_lock(&self, mut f: F) -> Ret where - F: FnMut(&mut T), + F: FnMut(&mut T) -> Ret, { // The locking and unlocking logic is hidden inside this function. let mut data_guard = self.lock(); - f(&mut data_guard); + f(&mut data_guard) // Lock is released here when `data_guard` goes out of scope. } } pub trait WithLockAsync { - /// Acquires a lock and executes the given closure with the locked data. - fn with_lock(&self, f: F) -> impl Future + /// Acquires a lock and executes the given closure with the locked data, + /// returning the result. + fn with_lock(&self, f: F) -> impl Future where - F: FnMut(&mut T); + F: FnMut(&mut T) -> R; + + /// Acquires a lock and executes the given async closure with the locked + /// data. + fn with_lock_async(&self, f: F) -> impl std::future::Future + where + F: AsyncFnMut(&mut T) -> R; } impl WithLockAsync for futures::lock::Mutex { - async fn with_lock(&self, mut f: F) + async fn with_lock(&self, mut f: F) -> R where - F: FnMut(&mut T), + F: FnMut(&mut T) -> R, { // The locking and unlocking logic is hidden inside this function. let mut data_guard = self.lock().await; - f(&mut data_guard); + f(&mut data_guard) + // Lock is released here when `data_guard` goes out of scope. + } + + async fn with_lock_async(&self, mut f: F) -> R + where + F: AsyncFnMut(&mut T) -> R, + { + // The locking and unlocking logic is hidden inside this function. + let mut data_guard = self.lock().await; + f(&mut data_guard).await // Lock is released here when `data_guard` goes out of scope. } } impl WithLockAsync for Arc> { - async fn with_lock(&self, mut f: F) + async fn with_lock(&self, mut f: F) -> R where - F: FnMut(&mut T), + F: FnMut(&mut T) -> R, { // The locking and unlocking logic is hidden inside this function. let mut data_guard = self.lock().await; - f(&mut data_guard); + f(&mut data_guard) + // Lock is released here when `data_guard` goes out of scope. + } + + async fn with_lock_async(&self, mut f: F) -> R + where + F: AsyncFnMut(&mut T) -> R, + { + // The locking and unlocking logic is hidden inside this function. + let mut data_guard = self.lock().await; + f(&mut data_guard).await // Lock is released here when `data_guard` goes out of scope. } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_with_lock_return_value() { + let mutex = Mutex::new(5); + let result = mutex.with_lock(|v| { + *v += 1; + *v * 2 + }); + assert_eq!(result, 12); + let value = mutex.lock().unwrap(); + assert_eq!(*value, 6); + } + + #[test] + fn test_with_lock_unit_return() { + let mutex = Mutex::new(10); + mutex.with_lock(|v| { + *v += 2; + }); + let value = mutex.lock().unwrap(); + assert_eq!(*value, 12); + } + + #[test] + fn test_with_lock_arc_mutex() { + let mutex = Arc::new(Mutex::new(1)); + let result = mutex.with_lock(|v| { + *v *= 10; + *v + }); + assert_eq!(result, 10); + assert_eq!(*mutex.lock().unwrap(), 10); + } + + #[tokio::test] + async fn test_with_lock_async_return_value() { + use futures::lock::Mutex as AsyncMutex; + let mutex = AsyncMutex::new(7); + let result = mutex + .with_lock(|v| { + *v += 3; + *v * 2 + }) + .await; + assert_eq!(result, 20); + let value = mutex.lock().await; + assert_eq!(*value, 10); + } + + #[tokio::test] + async fn test_with_lock_async_unit_return() { + use futures::lock::Mutex as AsyncMutex; + let mutex = AsyncMutex::new(100); + mutex + .with_lock(|v| { + *v -= 50; + }) + .await; + let value = mutex.lock().await; + assert_eq!(*value, 50); + } + + #[tokio::test] + async fn test_with_lock_async_closure() { + use futures::lock::Mutex as AsyncMutex; + let mutex = AsyncMutex::new(1); + mutex + .with_lock_async(async |v| { + *v += 9; + }) + .await; + let value = mutex.lock().await; + assert_eq!(*value, 10); + } + + #[tokio::test] + async fn test_with_lock_async_arc_mutex() { + use futures::lock::Mutex as AsyncMutex; + let mutex = Arc::new(AsyncMutex::new(2)); + mutex + .with_lock_async(async |v: &mut i32| { + *v *= 5; + }) + .await; + let value = mutex.lock().await; + assert_eq!(*value, 10); + } +} From f593cac58aa48a7be73a87d5f53a0b7a5e41dcd8 Mon Sep 17 00:00:00 2001 From: Jade Ellis Date: Sat, 19 Jul 2025 23:32:18 +0100 Subject: [PATCH 19/31] feat: Enable hardware-lock-elision and deadlock_detection --- Cargo.lock | 29 +++++++++++++++++++++++++++++ Cargo.toml | 1 + 2 files changed, 30 insertions(+) diff --git a/Cargo.lock b/Cargo.lock index ed9be6d9..5dce9c59 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1659,6 +1659,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "fixedbitset" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" + [[package]] name = "flate2" version = "1.1.2" @@ -3220,10 +3226,13 @@ version = "0.9.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bc838d2a56b5b1a6c25f55575dfc605fabb63bb2365f6c2353ef9159aa69e4a5" dependencies = [ + "backtrace", "cfg-if", "libc", + "petgraph", "redox_syscall", "smallvec", + "thread-id", "windows-targets 0.52.6", ] @@ -3273,6 +3282,16 @@ version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" +[[package]] +name = "petgraph" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db" +dependencies = [ + "fixedbitset", + "indexmap 2.9.0", +] + [[package]] name = "phf" version = "0.11.3" @@ -4894,6 +4913,16 @@ dependencies = [ "syn", ] +[[package]] +name = "thread-id" +version = "4.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfe8f25bbdd100db7e1d34acf7fd2dc59c4bf8f7483f505eaa7d4f12f76cc0ea" +dependencies = [ + "libc", + "winapi", +] + [[package]] name = "thread_local" version = "1.1.9" diff --git a/Cargo.toml b/Cargo.toml index 54f7ae82..ab6a9e8a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -517,6 +517,7 @@ version = "1.0" [workspace.dependencies.parking_lot] version = "0.12.4" +features = ["hardware-lock-elision", "deadlock_detection"] # TODO: Check if deadlock_detection has a perf impact, if it does only enable with debug_assertions # Use this when extending with_lock::WithLock to parking_lot [workspace.dependencies.lock_api] From 95610499c7df2d6d1ebe16650431673369975a54 Mon Sep 17 00:00:00 2001 From: Jade Ellis Date: Sat, 19 Jul 2025 23:32:53 +0100 Subject: [PATCH 20/31] chore: Disable direnv's nix flake interfering with cargo cache --- .envrc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.envrc b/.envrc index 952ec2f8..bad73b75 100644 --- a/.envrc +++ b/.envrc @@ -2,6 +2,6 @@ dotenv_if_exists -use flake ".#${DIRENV_DEVSHELL:-default}" +# use flake ".#${DIRENV_DEVSHELL:-default}" PATH_add bin From d74514f305f49637115b07a56b73f2520f97a3fe Mon Sep 17 00:00:00 2001 From: Jade Ellis Date: Sun, 20 Jul 2025 20:58:58 +0100 Subject: [PATCH 21/31] ci: Fix inverted latest tag --- .forgejo/workflows/release-image.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.forgejo/workflows/release-image.yml b/.forgejo/workflows/release-image.yml index 5ac5ddfa..04fc9de9 100644 --- a/.forgejo/workflows/release-image.yml +++ b/.forgejo/workflows/release-image.yml @@ -262,7 +262,7 @@ jobs: type=ref,event=branch,prefix=${{ format('refs/heads/{0}', github.event.repository.default_branch) != github.ref && 'branch-' || '' }} type=ref,event=pr type=sha,format=long - type=raw,value=latest,enable=${{ !startsWith(github.ref, 'refs/tags/v') }} + type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/v') }} images: ${{needs.define-variables.outputs.images}} # default labels & annotations: https://github.com/docker/metadata-action/blob/master/src/meta.ts#L509 env: From 195f0a7bba51d99640fb0418d7fa78ceb077f964 Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Sat, 19 Jul 2025 20:22:29 +0100 Subject: [PATCH 22/31] feat(policy-server): Policy server following --- src/core/matrix/state_res/event_auth.rs | 7 +- .../rooms/event_handler/call_policyserv.rs | 71 +++++++++++++++++++ src/service/rooms/event_handler/mod.rs | 1 + .../event_handler/upgrade_outlier_pdu.rs | 34 ++++++--- 4 files changed, 102 insertions(+), 11 deletions(-) create mode 100644 src/service/rooms/event_handler/call_policyserv.rs diff --git a/src/core/matrix/state_res/event_auth.rs b/src/core/matrix/state_res/event_auth.rs index 5c36ce03..819d05e2 100644 --- a/src/core/matrix/state_res/event_auth.rs +++ b/src/core/matrix/state_res/event_auth.rs @@ -5,7 +5,7 @@ use futures::{ future::{OptionFuture, join3}, }; use ruma::{ - Int, OwnedUserId, RoomVersionId, UserId, + EventId, Int, OwnedUserId, RoomVersionId, UserId, events::room::{ create::RoomCreateEventContent, join_rules::{JoinRule, RoomJoinRulesEventContent}, @@ -217,8 +217,9 @@ where } /* - // TODO: In the past this code caused problems federating with synapse, maybe this has been - // resolved already. Needs testing. + // TODO: In the past this code was commented as it caused problems with Synapse. This is no + // longer the case. This needs to be implemented. + // See also: https://github.com/ruma/ruma/pull/2064 // // 2. Reject if auth_events // a. auth_events cannot have duplicate keys since it's a BTree diff --git a/src/service/rooms/event_handler/call_policyserv.rs b/src/service/rooms/event_handler/call_policyserv.rs new file mode 100644 index 00000000..4a52227d --- /dev/null +++ b/src/service/rooms/event_handler/call_policyserv.rs @@ -0,0 +1,71 @@ +use conduwuit::{ + Err, Event, PduEvent, Result, debug, implement, utils::to_canonical_object, warn, +}; +use ruma::{ + RoomId, ServerName, + api::federation::room::policy::v1::Request as PolicyRequest, + canonical_json::to_canonical_value, + events::{StateEventType, room::policy::RoomPolicyEventContent}, +}; + +/// Returns Ok if the policy server allows the event +#[implement(super::Service)] +#[tracing::instrument(skip_all, level = "debug")] +pub async fn policyserv_check(&self, pdu: &PduEvent, room_id: &RoomId) -> Result { + let Ok(policyserver) = self + .services + .state_accessor + .room_state_get_content(room_id, &StateEventType::RoomPolicy, "") + .await + .map(|c: RoomPolicyEventContent| c) + else { + return Ok(()); + }; + + let via = match policyserver.via { + | Some(ref via) => ServerName::parse(via)?, + | None => { + debug!("No policy server configured for room {room_id}"); + return Ok(()); + }, + }; + // TODO: dont do *this* + let pdu_json = self.services.timeline.get_pdu_json(pdu.event_id()).await?; + let outgoing = self + .services + .sending + .convert_to_outgoing_federation_event(pdu_json) + .await; + // let s = match serde_json::to_string(outgoing.as_ref()) { + // | Ok(s) => s, + // | Err(e) => { + // warn!("Failed to convert pdu {} to outgoing federation event: {e}", + // pdu.event_id()); return Err!(Request(InvalidParam("Failed to convert PDU + // to outgoing event."))); }, + // }; + debug!("Checking pdu {outgoing:?} for spam with policy server {via} for room {room_id}"); + let response = self + .services + .sending + .send_federation_request(via, PolicyRequest { + event_id: pdu.event_id().to_owned(), + pdu: Some(outgoing), + }) + .await; + let response = match response { + | Ok(response) => response, + | Err(e) => { + warn!("Failed to contact policy server {via} for room {room_id}: {e}"); + return Ok(()); + }, + }; + if response.recommendation == "spam" { + warn!( + "Event {} in room {room_id} was marked as spam by policy server {via}", + pdu.event_id().to_owned() + ); + return Err!(Request(Forbidden("Event was marked as spam by policy server"))); + }; + + Ok(()) +} diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index 4e59c207..ba5ad7e2 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -1,4 +1,5 @@ mod acl_check; +mod call_policyserv; mod fetch_and_handle_outliers; mod fetch_prev; mod fetch_state; diff --git a/src/service/rooms/event_handler/upgrade_outlier_pdu.rs b/src/service/rooms/event_handler/upgrade_outlier_pdu.rs index 4093cb05..abb5c116 100644 --- a/src/service/rooms/event_handler/upgrade_outlier_pdu.rs +++ b/src/service/rooms/event_handler/upgrade_outlier_pdu.rs @@ -1,7 +1,7 @@ use std::{borrow::Borrow, collections::BTreeMap, iter::once, sync::Arc, time::Instant}; use conduwuit::{ - Err, Result, debug, debug_info, err, implement, is_equal_to, + Err, Result, debug, debug_info, err, implement, info, is_equal_to, matrix::{Event, EventTypeExt, PduEvent, StateKey, state_res}, trace, utils::stream::{BroadbandExt, ReadyExt}, @@ -47,7 +47,7 @@ where return Err!(Request(InvalidParam("Event has been soft failed"))); } - debug!("Upgrading to timeline pdu"); + debug!("Upgrading pdu {} from outlier to timeline pdu", incoming_pdu.event_id); let timer = Instant::now(); let room_version_id = get_room_version_id(create_event)?; @@ -55,7 +55,7 @@ where // backwards extremities doing all the checks in this list starting at 1. // These are not timeline events. - debug!("Resolving state at event"); + debug!("Resolving state at event {}", incoming_pdu.event_id); let mut state_at_incoming_event = if incoming_pdu.prev_events().count() == 1 { self.state_at_incoming_degree_one(&incoming_pdu).await? } else { @@ -74,7 +74,7 @@ where let room_version = to_room_version(&room_version_id); - debug!("Performing auth check"); + debug!("Performing auth check to upgrade {}", incoming_pdu.event_id); // 11. Check the auth of the event passes based on the state of the event let state_fetch_state = &state_at_incoming_event; let state_fetch = |k: StateEventType, s: StateKey| async move { @@ -84,6 +84,7 @@ where self.services.timeline.get_pdu(event_id).await.ok() }; + debug!("running auth check on {}", incoming_pdu.event_id); let auth_check = state_res::event_auth::auth_check( &room_version, &incoming_pdu, @@ -97,7 +98,7 @@ where return Err!(Request(Forbidden("Event has failed auth check with state at the event."))); } - debug!("Gathering auth events"); + debug!("Gathering auth events for {}", incoming_pdu.event_id); let auth_events = self .services .state @@ -115,6 +116,7 @@ where ready(auth_events.get(&key).map(ToOwned::to_owned)) }; + debug!("running auth check on {} with claimed state auth", incoming_pdu.event_id); let auth_check = state_res::event_auth::auth_check( &room_version, &incoming_pdu, @@ -125,8 +127,8 @@ where .map_err(|e| err!(Request(Forbidden("Auth check failed: {e:?}"))))?; // Soft fail check before doing state res - debug!("Performing soft-fail check"); - let soft_fail = match (auth_check, incoming_pdu.redacts_id(&room_version_id)) { + debug!("Performing soft-fail check on {}", incoming_pdu.event_id); + let mut soft_fail = match (auth_check, incoming_pdu.redacts_id(&room_version_id)) { | (false, _) => true, | (true, None) => false, | (true, Some(redact_id)) => @@ -219,10 +221,26 @@ where .await?; } + // 14-pre. If the event is not a state event, ask the policy server about it + if incoming_pdu.state_key.is_none() + && incoming_pdu.sender().server_name() != self.services.globals.server_name() + { + debug!("Checking policy server for event {}", incoming_pdu.event_id); + let policy = self.policyserv_check(&incoming_pdu, room_id); + if let Err(e) = policy.await { + warn!("Policy server check failed for event {}: {e}", incoming_pdu.event_id); + if !soft_fail { + soft_fail = true; + } + } + debug!("Policy server check passed for event {}", incoming_pdu.event_id); + } + // 14. Check if the event passes auth based on the "current state" of the room, // if not soft fail it if soft_fail { - debug!("Soft failing event"); + info!("Soft failing event {}", incoming_pdu.event_id); + // assert!(extremities.is_empty(), "soft_fail extremities empty"); let extremities = extremities.iter().map(Borrow::borrow); self.services From 80cf98210442a4cdd107b61f2a01540c52ec94d3 Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Sat, 19 Jul 2025 20:34:34 +0100 Subject: [PATCH 23/31] chore: Update ruwuma & fix lints --- Cargo.lock | 22 +++++++++---------- Cargo.toml | 2 +- src/core/matrix/state_res/event_auth.rs | 2 +- .../rooms/event_handler/call_policyserv.rs | 7 ++---- 4 files changed, 15 insertions(+), 18 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5dce9c59..22c90e17 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3921,7 +3921,7 @@ checksum = "88f8660c1ff60292143c98d08fc6e2f654d722db50410e3f3797d40baaf9d8f3" [[package]] name = "ruma" version = "0.10.1" -source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=a4b948b40417a65ab0282ae47cc50035dd455e02#a4b948b40417a65ab0282ae47cc50035dd455e02" +source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=b753738047d1f443aca870896ef27ecaacf027da#b753738047d1f443aca870896ef27ecaacf027da" dependencies = [ "assign", "js_int", @@ -3941,7 +3941,7 @@ dependencies = [ [[package]] name = "ruma-appservice-api" version = "0.10.0" -source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=a4b948b40417a65ab0282ae47cc50035dd455e02#a4b948b40417a65ab0282ae47cc50035dd455e02" +source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=b753738047d1f443aca870896ef27ecaacf027da#b753738047d1f443aca870896ef27ecaacf027da" dependencies = [ "js_int", "ruma-common", @@ -3953,7 +3953,7 @@ dependencies = [ [[package]] name = "ruma-client-api" version = "0.18.0" -source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=a4b948b40417a65ab0282ae47cc50035dd455e02#a4b948b40417a65ab0282ae47cc50035dd455e02" +source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=b753738047d1f443aca870896ef27ecaacf027da#b753738047d1f443aca870896ef27ecaacf027da" dependencies = [ "as_variant", "assign", @@ -3976,7 +3976,7 @@ dependencies = [ [[package]] name = "ruma-common" version = "0.13.0" -source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=a4b948b40417a65ab0282ae47cc50035dd455e02#a4b948b40417a65ab0282ae47cc50035dd455e02" +source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=b753738047d1f443aca870896ef27ecaacf027da#b753738047d1f443aca870896ef27ecaacf027da" dependencies = [ "as_variant", "base64 0.22.1", @@ -4008,7 +4008,7 @@ dependencies = [ [[package]] name = "ruma-events" version = "0.28.1" -source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=a4b948b40417a65ab0282ae47cc50035dd455e02#a4b948b40417a65ab0282ae47cc50035dd455e02" +source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=b753738047d1f443aca870896ef27ecaacf027da#b753738047d1f443aca870896ef27ecaacf027da" dependencies = [ "as_variant", "indexmap 2.9.0", @@ -4033,7 +4033,7 @@ dependencies = [ [[package]] name = "ruma-federation-api" version = "0.9.0" -source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=a4b948b40417a65ab0282ae47cc50035dd455e02#a4b948b40417a65ab0282ae47cc50035dd455e02" +source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=b753738047d1f443aca870896ef27ecaacf027da#b753738047d1f443aca870896ef27ecaacf027da" dependencies = [ "bytes", "headers", @@ -4055,7 +4055,7 @@ dependencies = [ [[package]] name = "ruma-identifiers-validation" version = "0.9.5" -source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=a4b948b40417a65ab0282ae47cc50035dd455e02#a4b948b40417a65ab0282ae47cc50035dd455e02" +source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=b753738047d1f443aca870896ef27ecaacf027da#b753738047d1f443aca870896ef27ecaacf027da" dependencies = [ "js_int", "thiserror 2.0.12", @@ -4064,7 +4064,7 @@ dependencies = [ [[package]] name = "ruma-identity-service-api" version = "0.9.0" -source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=a4b948b40417a65ab0282ae47cc50035dd455e02#a4b948b40417a65ab0282ae47cc50035dd455e02" +source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=b753738047d1f443aca870896ef27ecaacf027da#b753738047d1f443aca870896ef27ecaacf027da" dependencies = [ "js_int", "ruma-common", @@ -4074,7 +4074,7 @@ dependencies = [ [[package]] name = "ruma-macros" version = "0.13.0" -source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=a4b948b40417a65ab0282ae47cc50035dd455e02#a4b948b40417a65ab0282ae47cc50035dd455e02" +source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=b753738047d1f443aca870896ef27ecaacf027da#b753738047d1f443aca870896ef27ecaacf027da" dependencies = [ "cfg-if", "proc-macro-crate", @@ -4089,7 +4089,7 @@ dependencies = [ [[package]] name = "ruma-push-gateway-api" version = "0.9.0" -source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=a4b948b40417a65ab0282ae47cc50035dd455e02#a4b948b40417a65ab0282ae47cc50035dd455e02" +source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=b753738047d1f443aca870896ef27ecaacf027da#b753738047d1f443aca870896ef27ecaacf027da" dependencies = [ "js_int", "ruma-common", @@ -4101,7 +4101,7 @@ dependencies = [ [[package]] name = "ruma-signatures" version = "0.15.0" -source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=a4b948b40417a65ab0282ae47cc50035dd455e02#a4b948b40417a65ab0282ae47cc50035dd455e02" +source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=b753738047d1f443aca870896ef27ecaacf027da#b753738047d1f443aca870896ef27ecaacf027da" dependencies = [ "base64 0.22.1", "ed25519-dalek", diff --git a/Cargo.toml b/Cargo.toml index ab6a9e8a..9cb5ff84 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -352,7 +352,7 @@ version = "0.1.2" [workspace.dependencies.ruma] git = "https://forgejo.ellis.link/continuwuation/ruwuma" #branch = "conduwuit-changes" -rev = "a4b948b40417a65ab0282ae47cc50035dd455e02" +rev = "b753738047d1f443aca870896ef27ecaacf027da" features = [ "compat", "rand", diff --git a/src/core/matrix/state_res/event_auth.rs b/src/core/matrix/state_res/event_auth.rs index 819d05e2..81c83431 100644 --- a/src/core/matrix/state_res/event_auth.rs +++ b/src/core/matrix/state_res/event_auth.rs @@ -5,7 +5,7 @@ use futures::{ future::{OptionFuture, join3}, }; use ruma::{ - EventId, Int, OwnedUserId, RoomVersionId, UserId, + Int, OwnedUserId, RoomVersionId, UserId, events::room::{ create::RoomCreateEventContent, join_rules::{JoinRule, RoomJoinRulesEventContent}, diff --git a/src/service/rooms/event_handler/call_policyserv.rs b/src/service/rooms/event_handler/call_policyserv.rs index 4a52227d..804c77eb 100644 --- a/src/service/rooms/event_handler/call_policyserv.rs +++ b/src/service/rooms/event_handler/call_policyserv.rs @@ -1,10 +1,7 @@ -use conduwuit::{ - Err, Event, PduEvent, Result, debug, implement, utils::to_canonical_object, warn, -}; +use conduwuit::{Err, Event, PduEvent, Result, debug, implement, warn}; use ruma::{ RoomId, ServerName, api::federation::room::policy::v1::Request as PolicyRequest, - canonical_json::to_canonical_value, events::{StateEventType, room::policy::RoomPolicyEventContent}, }; @@ -65,7 +62,7 @@ pub async fn policyserv_check(&self, pdu: &PduEvent, room_id: &RoomId) -> Result pdu.event_id().to_owned() ); return Err!(Request(Forbidden("Event was marked as spam by policy server"))); - }; + } Ok(()) } From 05ea20d943c22f37a473fc05814ea3fbe2226cee Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Sat, 19 Jul 2025 20:47:02 +0100 Subject: [PATCH 24/31] fix(policy-server): Avoid unnecessary database lookup --- src/service/rooms/event_handler/call_policyserv.rs | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/src/service/rooms/event_handler/call_policyserv.rs b/src/service/rooms/event_handler/call_policyserv.rs index 804c77eb..e7ae1d0f 100644 --- a/src/service/rooms/event_handler/call_policyserv.rs +++ b/src/service/rooms/event_handler/call_policyserv.rs @@ -26,20 +26,11 @@ pub async fn policyserv_check(&self, pdu: &PduEvent, room_id: &RoomId) -> Result return Ok(()); }, }; - // TODO: dont do *this* - let pdu_json = self.services.timeline.get_pdu_json(pdu.event_id()).await?; let outgoing = self .services .sending - .convert_to_outgoing_federation_event(pdu_json) + .convert_to_outgoing_federation_event(pdu.to_canonical_object()) .await; - // let s = match serde_json::to_string(outgoing.as_ref()) { - // | Ok(s) => s, - // | Err(e) => { - // warn!("Failed to convert pdu {} to outgoing federation event: {e}", - // pdu.event_id()); return Err!(Request(InvalidParam("Failed to convert PDU - // to outgoing event."))); }, - // }; debug!("Checking pdu {outgoing:?} for spam with policy server {via} for room {room_id}"); let response = self .services From b64ba58163cdc7fc404e5a98862d9bbca6e865bb Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Sat, 19 Jul 2025 20:50:47 +0100 Subject: [PATCH 25/31] style(policy-server): Restructure logging --- src/service/rooms/event_handler/call_policyserv.rs | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/service/rooms/event_handler/call_policyserv.rs b/src/service/rooms/event_handler/call_policyserv.rs index e7ae1d0f..894e28af 100644 --- a/src/service/rooms/event_handler/call_policyserv.rs +++ b/src/service/rooms/event_handler/call_policyserv.rs @@ -43,14 +43,21 @@ pub async fn policyserv_check(&self, pdu: &PduEvent, room_id: &RoomId) -> Result let response = match response { | Ok(response) => response, | Err(e) => { - warn!("Failed to contact policy server {via} for room {room_id}: {e}"); + warn!( + via = %via, + event_id = %pdu.event_id(), + room_id = %room_id, + "Failed to contact policy server: {e}" + ); return Ok(()); }, }; if response.recommendation == "spam" { warn!( - "Event {} in room {room_id} was marked as spam by policy server {via}", - pdu.event_id().to_owned() + via = %via, + event_id = %pdu.event_id(), + room_id = %room_id, + "Event was marked as spam by policy server", ); return Err!(Request(Forbidden("Event was marked as spam by policy server"))); } From 0e19dce31c7f58af6a8945dbcbed6faf3d3dea3b Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Sat, 19 Jul 2025 20:54:06 +0100 Subject: [PATCH 26/31] feat(policy-server): Soft-fail redactions for failed events --- .../event_handler/upgrade_outlier_pdu.rs | 22 ++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/src/service/rooms/event_handler/upgrade_outlier_pdu.rs b/src/service/rooms/event_handler/upgrade_outlier_pdu.rs index abb5c116..e8e22fe9 100644 --- a/src/service/rooms/event_handler/upgrade_outlier_pdu.rs +++ b/src/service/rooms/event_handler/upgrade_outlier_pdu.rs @@ -222,9 +222,7 @@ where } // 14-pre. If the event is not a state event, ask the policy server about it - if incoming_pdu.state_key.is_none() - && incoming_pdu.sender().server_name() != self.services.globals.server_name() - { + if incoming_pdu.state_key.is_none() { debug!("Checking policy server for event {}", incoming_pdu.event_id); let policy = self.policyserv_check(&incoming_pdu, room_id); if let Err(e) = policy.await { @@ -236,6 +234,24 @@ where debug!("Policy server check passed for event {}", incoming_pdu.event_id); } + // Additionally, if this is a redaction for a soft-failed event, we soft-fail it + // also + if let Some(redact_id) = incoming_pdu.redacts_id(&room_version_id) { + debug!("Checking if redaction {} is for a soft-failed event", redact_id); + if self + .services + .pdu_metadata + .is_event_soft_failed(&redact_id) + .await + { + warn!( + "Redaction {} is for a soft-failed event, soft failing the redaction", + redact_id + ); + soft_fail = true; + } + } + // 14. Check if the event passes auth based on the "current state" of the room, // if not soft fail it if soft_fail { From 99eb50433e80e2416da14ea42c13fd976a645f7f Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Sat, 19 Jul 2025 21:09:23 +0100 Subject: [PATCH 27/31] feat(policy-server): Prevent local events that fail the policy check --- src/service/rooms/timeline/create.rs | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/service/rooms/timeline/create.rs b/src/service/rooms/timeline/create.rs index 20ccaf56..6301d785 100644 --- a/src/service/rooms/timeline/create.rs +++ b/src/service/rooms/timeline/create.rs @@ -165,6 +165,17 @@ pub async fn create_hash_and_sign_event( return Err!(Request(Forbidden("Event is not authorized."))); } + // Check with the policy server + if self + .services + .event_handler + .policyserv_check(&pdu, room_id) + .await + .is_err() + { + return Err!(Request(Forbidden(debug_warn!("Policy server marked this event as spam")))); + } + // Hash and sign let mut pdu_json = utils::to_canonical_object(&pdu).map_err(|e| { err!(Request(BadJson(warn!("Failed to convert PDU to canonical JSON: {e}")))) From 6e620a66bf349ebd0cdefb8a8d5537d51496c730 Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Sat, 19 Jul 2025 22:07:18 +0100 Subject: [PATCH 28/31] feat(policy-server): Limit policy server request timeout to 10 seconds --- .../rooms/event_handler/call_policyserv.rs | 33 +++++++++++++------ 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/src/service/rooms/event_handler/call_policyserv.rs b/src/service/rooms/event_handler/call_policyserv.rs index 894e28af..0592186a 100644 --- a/src/service/rooms/event_handler/call_policyserv.rs +++ b/src/service/rooms/event_handler/call_policyserv.rs @@ -1,3 +1,5 @@ +use std::time::Duration; + use conduwuit::{Err, Event, PduEvent, Result, debug, implement, warn}; use ruma::{ RoomId, ServerName, @@ -32,17 +34,19 @@ pub async fn policyserv_check(&self, pdu: &PduEvent, room_id: &RoomId) -> Result .convert_to_outgoing_federation_event(pdu.to_canonical_object()) .await; debug!("Checking pdu {outgoing:?} for spam with policy server {via} for room {room_id}"); - let response = self - .services - .sending - .send_federation_request(via, PolicyRequest { - event_id: pdu.event_id().to_owned(), - pdu: Some(outgoing), - }) - .await; + let response = tokio::time::timeout( + Duration::from_secs(10), + self.services + .sending + .send_federation_request(via, PolicyRequest { + event_id: pdu.event_id().to_owned(), + pdu: Some(outgoing), + }), + ) + .await; let response = match response { - | Ok(response) => response, - | Err(e) => { + | Ok(Ok(response)) => response, + | Ok(Err(e)) => { warn!( via = %via, event_id = %pdu.event_id(), @@ -51,6 +55,15 @@ pub async fn policyserv_check(&self, pdu: &PduEvent, room_id: &RoomId) -> Result ); return Ok(()); }, + | Err(_) => { + warn!( + via = %via, + event_id = %pdu.event_id(), + room_id = %room_id, + "Policy server request timed out after 10 seconds" + ); + return Ok(()); + }, }; if response.recommendation == "spam" { warn!( From bd7bb4a1305046415df9cefba599a460d0492579 Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Sat, 19 Jul 2025 23:50:32 +0100 Subject: [PATCH 29/31] feat(policy-server): Optimise policy server lookups --- src/service/rooms/event_handler/call_policyserv.rs | 12 ++++++++++++ src/service/rooms/event_handler/mod.rs | 2 ++ 2 files changed, 14 insertions(+) diff --git a/src/service/rooms/event_handler/call_policyserv.rs b/src/service/rooms/event_handler/call_policyserv.rs index 0592186a..331d4c8f 100644 --- a/src/service/rooms/event_handler/call_policyserv.rs +++ b/src/service/rooms/event_handler/call_policyserv.rs @@ -11,6 +11,10 @@ use ruma::{ #[implement(super::Service)] #[tracing::instrument(skip_all, level = "debug")] pub async fn policyserv_check(&self, pdu: &PduEvent, room_id: &RoomId) -> Result { + if pdu.event_type().to_owned() == StateEventType::RoomPolicy.into() { + debug!("Skipping spam check for policy server meta-event in room {room_id}"); + return Ok(()); + } let Ok(policyserver) = self .services .state_accessor @@ -28,6 +32,14 @@ pub async fn policyserv_check(&self, pdu: &PduEvent, room_id: &RoomId) -> Result return Ok(()); }, }; + if via.is_empty() { + debug!("Policy server is empty for room {room_id}, skipping spam check"); + return Ok(()); + } + if !self.services.state_cache.server_in_room(via, room_id).await { + debug!("Policy server {via} is not in the room {room_id}, skipping spam check"); + return Ok(()); + } let outgoing = self .services .sending diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index ba5ad7e2..a0a1b20b 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -38,6 +38,7 @@ struct Services { server_keys: Dep, short: Dep, state: Dep, + state_cache: Dep, state_accessor: Dep, state_compressor: Dep, timeline: Dep, @@ -63,6 +64,7 @@ impl crate::Service for Service { pdu_metadata: args.depend::("rooms::pdu_metadata"), short: args.depend::("rooms::short"), state: args.depend::("rooms::state"), + state_cache: args.depend::("rooms::state_cache"), state_accessor: args .depend::("rooms::state_accessor"), state_compressor: args From df06e5dd678f0e1e8141a4e1a9ffe0999ec42b89 Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Sat, 19 Jul 2025 23:54:07 +0100 Subject: [PATCH 30/31] style(policy-server): Run clippy --- src/service/rooms/event_handler/call_policyserv.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/service/rooms/event_handler/call_policyserv.rs b/src/service/rooms/event_handler/call_policyserv.rs index 331d4c8f..96e3f7cc 100644 --- a/src/service/rooms/event_handler/call_policyserv.rs +++ b/src/service/rooms/event_handler/call_policyserv.rs @@ -11,7 +11,7 @@ use ruma::{ #[implement(super::Service)] #[tracing::instrument(skip_all, level = "debug")] pub async fn policyserv_check(&self, pdu: &PduEvent, room_id: &RoomId) -> Result { - if pdu.event_type().to_owned() == StateEventType::RoomPolicy.into() { + if *pdu.event_type() == StateEventType::RoomPolicy.into() { debug!("Skipping spam check for policy server meta-event in room {room_id}"); return Ok(()); } From dd3ace92e94f09ea9eb182da71dc2a50afacfa82 Mon Sep 17 00:00:00 2001 From: Jade Ellis Date: Sun, 20 Jul 2025 01:03:18 +0100 Subject: [PATCH 31/31] style: Improve logging and comments --- src/core/matrix/state_res/event_auth.rs | 4 +- .../rooms/event_handler/call_policyserv.rs | 26 ++++++- .../event_handler/upgrade_outlier_pdu.rs | 73 +++++++++++++++---- 3 files changed, 82 insertions(+), 21 deletions(-) diff --git a/src/core/matrix/state_res/event_auth.rs b/src/core/matrix/state_res/event_auth.rs index 81c83431..77a4a95c 100644 --- a/src/core/matrix/state_res/event_auth.rs +++ b/src/core/matrix/state_res/event_auth.rs @@ -149,8 +149,8 @@ where for<'a> &'a E: Event + Send, { debug!( - event_id = format!("{}", incoming_event.event_id()), - event_type = format!("{}", incoming_event.event_type()), + event_id = %incoming_event.event_id(), + event_type = ?incoming_event.event_type(), "auth_check beginning" ); diff --git a/src/service/rooms/event_handler/call_policyserv.rs b/src/service/rooms/event_handler/call_policyserv.rs index 96e3f7cc..aef99dba 100644 --- a/src/service/rooms/event_handler/call_policyserv.rs +++ b/src/service/rooms/event_handler/call_policyserv.rs @@ -1,3 +1,8 @@ +//! Policy server integration for event spam checking in Matrix rooms. +//! +//! This module implements a check against a room-specific policy server, as +//! described in the relevant Matrix spec proposal (see: https://github.com/matrix-org/matrix-spec-proposals/pull/4284). + use std::time::Duration; use conduwuit::{Err, Event, PduEvent, Result, debug, implement, warn}; @@ -12,7 +17,11 @@ use ruma::{ #[tracing::instrument(skip_all, level = "debug")] pub async fn policyserv_check(&self, pdu: &PduEvent, room_id: &RoomId) -> Result { if *pdu.event_type() == StateEventType::RoomPolicy.into() { - debug!("Skipping spam check for policy server meta-event in room {room_id}"); + debug!( + room_id = %room_id, + event_type = ?pdu.event_type(), + "Skipping spam check for policy server meta-event" + ); return Ok(()); } let Ok(policyserver) = self @@ -37,7 +46,11 @@ pub async fn policyserv_check(&self, pdu: &PduEvent, room_id: &RoomId) -> Result return Ok(()); } if !self.services.state_cache.server_in_room(via, room_id).await { - debug!("Policy server {via} is not in the room {room_id}, skipping spam check"); + debug!( + room_id = %room_id, + via = %via, + "Policy server is not in the room, skipping spam check" + ); return Ok(()); } let outgoing = self @@ -45,7 +58,12 @@ pub async fn policyserv_check(&self, pdu: &PduEvent, room_id: &RoomId) -> Result .sending .convert_to_outgoing_federation_event(pdu.to_canonical_object()) .await; - debug!("Checking pdu {outgoing:?} for spam with policy server {via} for room {room_id}"); + debug!( + room_id = %room_id, + via = %via, + outgoing = ?outgoing, + "Checking event for spam with policy server" + ); let response = tokio::time::timeout( Duration::from_secs(10), self.services @@ -65,6 +83,8 @@ pub async fn policyserv_check(&self, pdu: &PduEvent, room_id: &RoomId) -> Result room_id = %room_id, "Failed to contact policy server: {e}" ); + // Network or policy server errors are treated as non-fatal: event is allowed by + // default. return Ok(()); }, | Err(_) => { diff --git a/src/service/rooms/event_handler/upgrade_outlier_pdu.rs b/src/service/rooms/event_handler/upgrade_outlier_pdu.rs index e8e22fe9..d3dc32fb 100644 --- a/src/service/rooms/event_handler/upgrade_outlier_pdu.rs +++ b/src/service/rooms/event_handler/upgrade_outlier_pdu.rs @@ -47,7 +47,10 @@ where return Err!(Request(InvalidParam("Event has been soft failed"))); } - debug!("Upgrading pdu {} from outlier to timeline pdu", incoming_pdu.event_id); + debug!( + event_id = %incoming_pdu.event_id, + "Upgrading PDU from outlier to timeline" + ); let timer = Instant::now(); let room_version_id = get_room_version_id(create_event)?; @@ -55,7 +58,10 @@ where // backwards extremities doing all the checks in this list starting at 1. // These are not timeline events. - debug!("Resolving state at event {}", incoming_pdu.event_id); + debug!( + event_id = %incoming_pdu.event_id, + "Resolving state at event" + ); let mut state_at_incoming_event = if incoming_pdu.prev_events().count() == 1 { self.state_at_incoming_degree_one(&incoming_pdu).await? } else { @@ -74,7 +80,10 @@ where let room_version = to_room_version(&room_version_id); - debug!("Performing auth check to upgrade {}", incoming_pdu.event_id); + debug!( + event_id = %incoming_pdu.event_id, + "Performing auth check to upgrade" + ); // 11. Check the auth of the event passes based on the state of the event let state_fetch_state = &state_at_incoming_event; let state_fetch = |k: StateEventType, s: StateKey| async move { @@ -84,7 +93,10 @@ where self.services.timeline.get_pdu(event_id).await.ok() }; - debug!("running auth check on {}", incoming_pdu.event_id); + debug!( + event_id = %incoming_pdu.event_id, + "Running initial auth check" + ); let auth_check = state_res::event_auth::auth_check( &room_version, &incoming_pdu, @@ -98,7 +110,10 @@ where return Err!(Request(Forbidden("Event has failed auth check with state at the event."))); } - debug!("Gathering auth events for {}", incoming_pdu.event_id); + debug!( + event_id = %incoming_pdu.event_id, + "Gathering auth events" + ); let auth_events = self .services .state @@ -116,7 +131,10 @@ where ready(auth_events.get(&key).map(ToOwned::to_owned)) }; - debug!("running auth check on {} with claimed state auth", incoming_pdu.event_id); + debug!( + event_id = %incoming_pdu.event_id, + "Running auth check with claimed state auth" + ); let auth_check = state_res::event_auth::auth_check( &room_version, &incoming_pdu, @@ -127,7 +145,10 @@ where .map_err(|e| err!(Request(Forbidden("Auth check failed: {e:?}"))))?; // Soft fail check before doing state res - debug!("Performing soft-fail check on {}", incoming_pdu.event_id); + debug!( + event_id = %incoming_pdu.event_id, + "Performing soft-fail check" + ); let mut soft_fail = match (auth_check, incoming_pdu.redacts_id(&room_version_id)) { | (false, _) => true, | (true, None) => false, @@ -142,7 +163,10 @@ where // 13. Use state resolution to find new room state // We start looking at current room state now, so lets lock the room - trace!("Locking the room"); + trace!( + room_id = %room_id, + "Locking the room" + ); let state_lock = self.services.state.mutex.lock(room_id).await; // Now we calculate the set of extremities this room has after the incoming @@ -223,21 +247,32 @@ where // 14-pre. If the event is not a state event, ask the policy server about it if incoming_pdu.state_key.is_none() { - debug!("Checking policy server for event {}", incoming_pdu.event_id); + debug!( + event_id = %incoming_pdu.event_id,"Checking policy server for event"); let policy = self.policyserv_check(&incoming_pdu, room_id); if let Err(e) = policy.await { - warn!("Policy server check failed for event {}: {e}", incoming_pdu.event_id); + warn!( + event_id = %incoming_pdu.event_id, + error = ?e, + "Policy server check failed for event" + ); if !soft_fail { soft_fail = true; } } - debug!("Policy server check passed for event {}", incoming_pdu.event_id); + debug!( + event_id = %incoming_pdu.event_id, + "Policy server check passed for event" + ); } // Additionally, if this is a redaction for a soft-failed event, we soft-fail it // also if let Some(redact_id) = incoming_pdu.redacts_id(&room_version_id) { - debug!("Checking if redaction {} is for a soft-failed event", redact_id); + debug!( + redact_id = %redact_id, + "Checking if redaction is for a soft-failed event" + ); if self .services .pdu_metadata @@ -245,8 +280,8 @@ where .await { warn!( - "Redaction {} is for a soft-failed event, soft failing the redaction", - redact_id + redact_id = %redact_id, + "Redaction is for a soft-failed event, soft failing the redaction" ); soft_fail = true; } @@ -255,7 +290,10 @@ where // 14. Check if the event passes auth based on the "current state" of the room, // if not soft fail it if soft_fail { - info!("Soft failing event {}", incoming_pdu.event_id); + info!( + event_id = %incoming_pdu.event_id, + "Soft failing event" + ); // assert!(extremities.is_empty(), "soft_fail extremities empty"); let extremities = extremities.iter().map(Borrow::borrow); @@ -276,7 +314,10 @@ where .pdu_metadata .mark_event_soft_failed(incoming_pdu.event_id()); - warn!("Event was soft failed: {:?}", incoming_pdu.event_id()); + warn!( + event_id = %incoming_pdu.event_id, + "Event was soft failed" + ); return Err!(Request(InvalidParam("Event has been soft failed"))); }