diff --git a/Cargo.lock b/Cargo.lock index 82e7a20d..92044b92 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3798,7 +3798,7 @@ dependencies = [ [[package]] name = "ruma" version = "0.10.1" -source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=a4b948b40417a65ab0282ae47cc50035dd455e02#a4b948b40417a65ab0282ae47cc50035dd455e02" +source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=9b65f83981f6f53d185ce77da37aaef9dfd764a9#9b65f83981f6f53d185ce77da37aaef9dfd764a9" dependencies = [ "assign", "js_int", @@ -3818,7 +3818,7 @@ dependencies = [ [[package]] name = "ruma-appservice-api" version = "0.10.0" -source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=a4b948b40417a65ab0282ae47cc50035dd455e02#a4b948b40417a65ab0282ae47cc50035dd455e02" +source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=9b65f83981f6f53d185ce77da37aaef9dfd764a9#9b65f83981f6f53d185ce77da37aaef9dfd764a9" dependencies = [ "js_int", "ruma-common", @@ -3830,7 +3830,7 @@ dependencies = [ [[package]] name = "ruma-client-api" version = "0.18.0" -source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=a4b948b40417a65ab0282ae47cc50035dd455e02#a4b948b40417a65ab0282ae47cc50035dd455e02" +source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=9b65f83981f6f53d185ce77da37aaef9dfd764a9#9b65f83981f6f53d185ce77da37aaef9dfd764a9" dependencies = [ "as_variant", "assign", @@ -3853,7 +3853,7 @@ dependencies = [ [[package]] name = "ruma-common" version = "0.13.0" -source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=a4b948b40417a65ab0282ae47cc50035dd455e02#a4b948b40417a65ab0282ae47cc50035dd455e02" +source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=9b65f83981f6f53d185ce77da37aaef9dfd764a9#9b65f83981f6f53d185ce77da37aaef9dfd764a9" dependencies = [ "as_variant", "base64 0.22.1", @@ -3885,7 +3885,7 @@ dependencies = [ [[package]] name = "ruma-events" version = "0.28.1" -source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=a4b948b40417a65ab0282ae47cc50035dd455e02#a4b948b40417a65ab0282ae47cc50035dd455e02" +source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=9b65f83981f6f53d185ce77da37aaef9dfd764a9#9b65f83981f6f53d185ce77da37aaef9dfd764a9" dependencies = [ "as_variant", "indexmap 2.9.0", @@ -3910,7 +3910,7 @@ dependencies = [ [[package]] name = "ruma-federation-api" version = "0.9.0" -source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=a4b948b40417a65ab0282ae47cc50035dd455e02#a4b948b40417a65ab0282ae47cc50035dd455e02" +source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=9b65f83981f6f53d185ce77da37aaef9dfd764a9#9b65f83981f6f53d185ce77da37aaef9dfd764a9" dependencies = [ "bytes", "headers", @@ -3932,7 +3932,7 @@ dependencies = [ [[package]] name = "ruma-identifiers-validation" version = "0.9.5" -source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=a4b948b40417a65ab0282ae47cc50035dd455e02#a4b948b40417a65ab0282ae47cc50035dd455e02" +source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=9b65f83981f6f53d185ce77da37aaef9dfd764a9#9b65f83981f6f53d185ce77da37aaef9dfd764a9" dependencies = [ "js_int", "thiserror 2.0.12", @@ -3941,7 +3941,7 @@ dependencies = [ [[package]] name = "ruma-identity-service-api" version = "0.9.0" -source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=a4b948b40417a65ab0282ae47cc50035dd455e02#a4b948b40417a65ab0282ae47cc50035dd455e02" +source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=9b65f83981f6f53d185ce77da37aaef9dfd764a9#9b65f83981f6f53d185ce77da37aaef9dfd764a9" dependencies = [ "js_int", "ruma-common", @@ -3951,7 +3951,7 @@ dependencies = [ [[package]] name = "ruma-macros" version = "0.13.0" -source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=a4b948b40417a65ab0282ae47cc50035dd455e02#a4b948b40417a65ab0282ae47cc50035dd455e02" +source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=9b65f83981f6f53d185ce77da37aaef9dfd764a9#9b65f83981f6f53d185ce77da37aaef9dfd764a9" dependencies = [ "cfg-if", "proc-macro-crate", @@ -3966,7 +3966,7 @@ dependencies = [ [[package]] name = "ruma-push-gateway-api" version = "0.9.0" -source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=a4b948b40417a65ab0282ae47cc50035dd455e02#a4b948b40417a65ab0282ae47cc50035dd455e02" +source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=9b65f83981f6f53d185ce77da37aaef9dfd764a9#9b65f83981f6f53d185ce77da37aaef9dfd764a9" dependencies = [ "js_int", "ruma-common", @@ -3978,7 +3978,7 @@ dependencies = [ [[package]] name = "ruma-signatures" version = "0.15.0" -source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=a4b948b40417a65ab0282ae47cc50035dd455e02#a4b948b40417a65ab0282ae47cc50035dd455e02" +source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=9b65f83981f6f53d185ce77da37aaef9dfd764a9#9b65f83981f6f53d185ce77da37aaef9dfd764a9" dependencies = [ "base64 0.22.1", "ed25519-dalek", diff --git a/Cargo.toml b/Cargo.toml index b815e2b8..5c289adf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -350,7 +350,7 @@ version = "0.1.2" [workspace.dependencies.ruma] git = "https://forgejo.ellis.link/continuwuation/ruwuma" #branch = "conduwuit-changes" -rev = "a4b948b40417a65ab0282ae47cc50035dd455e02" +rev = "9b65f83981f6f53d185ce77da37aaef9dfd764a9" features = [ "compat", "rand", diff --git a/committed.toml b/committed.toml index 59750fa5..64f7f18a 100644 --- a/committed.toml +++ b/committed.toml @@ -1,3 +1,2 @@ style = "conventional" -subject_length = 72 allowed_types = ["ci", "build", "fix", "feat", "chore", "docs", "style", "refactor", "perf", "test"] diff --git a/conduwuit-example.toml b/conduwuit-example.toml index f744a07d..287bf65f 100644 --- a/conduwuit-example.toml +++ b/conduwuit-example.toml @@ -398,17 +398,6 @@ # #allow_registration = false -# If registration is enabled, and this setting is true, new users -# registered after the first admin user will be automatically suspended -# and will require an admin to run `!admin users unsuspend `. -# -# Suspended users are still able to read messages, make profile updates, -# leave rooms, and deactivate their account, however cannot send messages, -# invites, or create/join or otherwise modify rooms. -# They are effectively read-only. -# -#suspend_on_register = false - # Enabling this setting opens registration to anyone without restrictions. # This makes your server vulnerable to abuse # diff --git a/src/admin/debug/commands.rs b/src/admin/debug/commands.rs index 81b0e9da..a397e0fc 100644 --- a/src/admin/debug/commands.rs +++ b/src/admin/debug/commands.rs @@ -7,10 +7,7 @@ use std::{ use conduwuit::{ Err, Result, debug_error, err, info, - matrix::{ - Event, - pdu::{PduEvent, PduId, RawPduId}, - }, + matrix::pdu::{PduEvent, PduId, RawPduId}, trace, utils, utils::{ stream::{IterStream, ReadyExt}, @@ -22,7 +19,7 @@ use futures::{FutureExt, StreamExt, TryStreamExt}; use ruma::{ CanonicalJsonObject, CanonicalJsonValue, EventId, OwnedEventId, OwnedRoomId, OwnedRoomOrAliasId, OwnedServerName, RoomId, RoomVersionId, - api::federation::event::get_room_state, events::AnyStateEvent, serde::Raw, + api::federation::event::get_room_state, }; use service::rooms::{ short::{ShortEventId, ShortRoomId}, @@ -299,12 +296,12 @@ pub(super) async fn get_remote_pdu( #[admin_command] pub(super) async fn get_room_state(&self, room: OwnedRoomOrAliasId) -> Result { let room_id = self.services.rooms.alias.resolve(&room).await?; - let room_state: Vec> = self + let room_state: Vec<_> = self .services .rooms .state_accessor .room_state_full_pdus(&room_id) - .map_ok(Event::into_format) + .map_ok(PduEvent::into_state_event) .try_collect() .await?; @@ -412,9 +409,7 @@ pub(super) async fn change_log_level(&self, filter: Option, reset: bool) .reload .reload(&new_filter_layer, Some(handles)) { - | Ok(()) => { - return self.write_str("Successfully changed log level").await; - }, + | Ok(()) => return self.write_str("Successfully changed log level").await, | Err(e) => { return Err!("Failed to modify and reload the global tracing log level: {e}"); }, @@ -558,8 +553,8 @@ pub(super) async fn force_set_room_state_from_server( .latest_pdu_in_room(&room_id) .await .map_err(|_| err!(Database("Failed to find the latest PDU in database")))? - .event_id() - .to_owned(), + .event_id + .clone(), }; let room_version = self.services.rooms.state.get_room_version(&room_id).await?; diff --git a/src/admin/processor.rs b/src/admin/processor.rs index e80000c1..8d1fe89c 100644 --- a/src/admin/processor.rs +++ b/src/admin/processor.rs @@ -94,7 +94,8 @@ async fn process_command(services: Arc, input: &CommandInput) -> Proce #[allow(clippy::result_large_err)] fn handle_panic(error: &Error, command: &CommandInput) -> ProcessorResult { - let link = "Please submit a [bug report](https://forgejo.ellis.link/continuwuation/continuwuity/issues/new). 🥺"; + let link = + "Please submit a [bug report](https://forgejo.ellis.link/continuwuation/continuwuity/issues/new). 🥺"; let msg = format!("Panic occurred while processing command:\n```\n{error:#?}\n```\n{link}"); let content = RoomMessageEventContent::notice_markdown(msg); error!("Panic while processing command: {error:?}"); diff --git a/src/admin/query/room_timeline.rs b/src/admin/query/room_timeline.rs index 0fd22ca7..58f75cb9 100644 --- a/src/admin/query/room_timeline.rs +++ b/src/admin/query/room_timeline.rs @@ -31,7 +31,7 @@ pub(super) async fn last(&self, room_id: OwnedRoomOrAliasId) -> Result { .services .rooms .timeline - .last_timeline_count(None, &room_id) + .last_timeline_count(&room_id) .await?; self.write_str(&format!("{result:#?}")).await @@ -52,7 +52,7 @@ pub(super) async fn pdus( .services .rooms .timeline - .pdus_rev(None, &room_id, from) + .pdus_rev(&room_id, from) .try_take(limit.unwrap_or(3)) .try_collect() .await?; diff --git a/src/admin/room/moderation.rs b/src/admin/room/moderation.rs index 5fb5bb3e..4c19da5c 100644 --- a/src/admin/room/moderation.rs +++ b/src/admin/room/moderation.rs @@ -1,11 +1,12 @@ use api::client::leave_room; use clap::Subcommand; use conduwuit::{ - Err, Result, debug, info, + Err, Result, debug, utils::{IterStream, ReadyExt}, warn, }; -use futures::{FutureExt, StreamExt}; +use futures::StreamExt; +use futures::FutureExt; use ruma::{OwnedRoomId, OwnedRoomOrAliasId, RoomAliasId, RoomId, RoomOrAliasId}; use crate::{admin_command, admin_command_dispatch, get_room_info}; @@ -70,6 +71,7 @@ async fn ban_room(&self, room: OwnedRoomOrAliasId) -> Result { }; debug!("Room specified is a room ID, banning room ID"); + self.services.rooms.metadata.ban_room(room_id, true); room_id.to_owned() } else if room.is_room_alias_id() { @@ -89,25 +91,47 @@ async fn ban_room(&self, room: OwnedRoomOrAliasId) -> Result { locally, if not using get_alias_helper to fetch room ID remotely" ); - match self + let room_id = match self .services .rooms .alias - .resolve_alias(room_alias, None) + .resolve_local_alias(room_alias) .await { - | Ok((room_id, servers)) => { + | Ok(room_id) => room_id, + | _ => { debug!( - ?room_id, - ?servers, - "Got federation response fetching room ID for room {room}" + "We don't have this room alias to a room ID locally, attempting to fetch \ + room ID over federation" ); - room_id + + match self + .services + .rooms + .alias + .resolve_alias(room_alias, None) + .await + { + | Ok((room_id, servers)) => { + debug!( + ?room_id, + ?servers, + "Got federation response fetching room ID for {room_id}" + ); + room_id + }, + | Err(e) => { + return Err!( + "Failed to resolve room alias {room_alias} to a room ID: {e}" + ); + }, + } }, - | Err(e) => { - return Err!("Failed to resolve room alias {room} to a room ID: {e}"); - }, - } + }; + + self.services.rooms.metadata.ban_room(&room_id, true); + + room_id } else { return Err!( "Room specified is not a room ID or room alias. Please note that this requires a \ @@ -116,7 +140,7 @@ async fn ban_room(&self, room: OwnedRoomOrAliasId) -> Result { ); }; - info!("Making all users leave the room {room_id} and forgetting it"); + debug!("Making all users leave the room {room_id} and forgetting it"); let mut users = self .services .rooms @@ -127,7 +151,7 @@ async fn ban_room(&self, room: OwnedRoomOrAliasId) -> Result { .boxed(); while let Some(ref user_id) = users.next().await { - info!( + debug!( "Attempting leave for user {user_id} in room {room_id} (ignoring all errors, \ evicting admins too)", ); @@ -157,9 +181,10 @@ async fn ban_room(&self, room: OwnedRoomOrAliasId) -> Result { }) .await; - self.services.rooms.directory.set_not_public(&room_id); // remove from the room directory - self.services.rooms.metadata.ban_room(&room_id, true); // prevent further joins - self.services.rooms.metadata.disable_room(&room_id, true); // disable federation + // unpublish from room directory + self.services.rooms.directory.set_not_public(&room_id); + + self.services.rooms.metadata.disable_room(&room_id, true); self.write_str( "Room banned, removed all our local users, and disabled incoming federation with room.", @@ -281,6 +306,8 @@ async fn ban_list_of_rooms(&self) -> Result { } for room_id in room_ids { + self.services.rooms.metadata.ban_room(&room_id, true); + debug!("Banned {room_id} successfully"); room_ban_count = room_ban_count.saturating_add(1); @@ -326,9 +353,9 @@ async fn ban_list_of_rooms(&self) -> Result { }) .await; - self.services.rooms.metadata.ban_room(&room_id, true); // unpublish from room directory, ignore errors self.services.rooms.directory.set_not_public(&room_id); + self.services.rooms.metadata.disable_room(&room_id, true); } diff --git a/src/admin/user/commands.rs b/src/admin/user/commands.rs index 86206c2b..5b02da54 100644 --- a/src/admin/user/commands.rs +++ b/src/admin/user/commands.rs @@ -1,16 +1,15 @@ use std::{collections::BTreeMap, fmt::Write as _}; -use api::client::{ - full_user_deactivate, join_room_by_id_helper, leave_all_rooms, leave_room, update_avatar_url, - update_displayname, -}; +use api::client::{full_user_deactivate, join_room_by_id_helper, leave_room}; use conduwuit::{ Err, Result, debug, debug_warn, error, info, is_equal_to, - matrix::{Event, pdu::PduBuilder}, + matrix::pdu::PduBuilder, utils::{self, ReadyExt}, warn, }; -use futures::{FutureExt, StreamExt}; +use conduwuit_api::client::{leave_all_rooms, update_avatar_url, update_displayname}; +use futures::StreamExt; +use futures::FutureExt; use ruma::{ OwnedEventId, OwnedRoomId, OwnedRoomOrAliasId, OwnedUserId, UserId, events::{ @@ -286,9 +285,8 @@ pub(super) async fn reset_password(&self, username: String, password: Option return Err!("Couldn't reset the password for user {user_id}: {e}"), - | Ok(()) => { - write!(self, "Successfully reset the password for user {user_id}: `{new_password}`") - }, + | Ok(()) => + write!(self, "Successfully reset the password for user {user_id}: `{new_password}`"), } .await } @@ -738,7 +736,7 @@ pub(super) async fn force_demote(&self, user_id: String, room_id: OwnedRoomOrAli .state_accessor .room_state_get(&room_id, &StateEventType::RoomCreate, "") .await - .is_ok_and(|event| event.sender() == user_id); + .is_ok_and(|event| event.sender == user_id); if !user_can_demote_self { return Err!("User is not allowed to modify their own power levels in the room.",); @@ -889,7 +887,10 @@ pub(super) async fn redact_event(&self, event_id: OwnedEventId) -> Result { return Err!("Event is already redacted."); } - if !self.services.globals.user_is_local(event.sender()) { + let room_id = event.room_id; + let sender_user = event.sender; + + if !self.services.globals.user_is_local(&sender_user) { return Err!("This command only works on local users."); } @@ -899,21 +900,21 @@ pub(super) async fn redact_event(&self, event_id: OwnedEventId) -> Result { ); let redaction_event_id = { - let state_lock = self.services.rooms.state.mutex.lock(event.room_id()).await; + let state_lock = self.services.rooms.state.mutex.lock(&room_id).await; self.services .rooms .timeline .build_and_append_pdu( PduBuilder { - redacts: Some(event.event_id().to_owned()), + redacts: Some(event.event_id.clone()), ..PduBuilder::timeline(&RoomRedactionEventContent { - redacts: Some(event.event_id().to_owned()), + redacts: Some(event.event_id.clone()), reason: Some(reason), }) }, - event.sender(), - event.room_id(), + &sender_user, + &room_id, &state_lock, ) .await? diff --git a/src/api/client/account.rs b/src/api/client/account.rs index 12801e7d..05dfa8b7 100644 --- a/src/api/client/account.rs +++ b/src/api/client/account.rs @@ -3,9 +3,10 @@ use std::fmt::Write; use axum::extract::State; use axum_client_ip::InsecureClientIp; use conduwuit::{ - Err, Error, Event, Result, debug_info, err, error, info, is_equal_to, + Err, Error, Result, debug_info, err, error, info, is_equal_to, matrix::pdu::PduBuilder, - utils::{self, ReadyExt, stream::BroadbandExt}, + utils, + utils::{ReadyExt, stream::BroadbandExt}, warn, }; use conduwuit_service::Services; @@ -150,32 +151,16 @@ pub(crate) async fn register_route( if !services.config.allow_registration && body.appservice_info.is_none() { match (body.username.as_ref(), body.initial_device_display_name.as_ref()) { | (Some(username), Some(device_display_name)) => { - info!( - %is_guest, - user = %username, - device_name = %device_display_name, - "Rejecting registration attempt as registration is disabled" - ); + info!(%is_guest, user = %username, device_name = %device_display_name, "Rejecting registration attempt as registration is disabled"); }, | (Some(username), _) => { - info!( - %is_guest, - user = %username, - "Rejecting registration attempt as registration is disabled" - ); + info!(%is_guest, user = %username, "Rejecting registration attempt as registration is disabled"); }, | (_, Some(device_display_name)) => { - info!( - %is_guest, - device_name = %device_display_name, - "Rejecting registration attempt as registration is disabled" - ); + info!(%is_guest, device_name = %device_display_name, "Rejecting registration attempt as registration is disabled"); }, | (None, _) => { - info!( - %is_guest, - "Rejecting registration attempt as registration is disabled" - ); + info!(%is_guest, "Rejecting registration attempt as registration is disabled"); }, } @@ -366,7 +351,8 @@ pub(crate) async fn register_route( if !services.globals.new_user_displayname_suffix().is_empty() && body.appservice_info.is_none() { - write!(displayname, " {}", services.server.config.new_user_displayname_suffix)?; + write!(displayname, " {}", services.server.config.new_user_displayname_suffix) + .expect("should be able to write to string buffer"); } services @@ -384,7 +370,8 @@ pub(crate) async fn register_route( content: ruma::events::push_rules::PushRulesEventContent { global: push::Ruleset::server_default(&user_id), }, - })?, + }) + .expect("to json always works"), ) .await?; @@ -429,21 +416,32 @@ pub(crate) async fn register_route( // log in conduit admin channel if a non-guest user registered if body.appservice_info.is_none() && !is_guest { if !device_display_name.is_empty() { - let notice = format!( - "New user \"{user_id}\" registered on this server from IP {client} and device \ - display name \"{device_display_name}\"" + info!( + "New user \"{user_id}\" registered on this server with device display name: \ + \"{device_display_name}\"" ); - info!("{notice}"); if services.server.config.admin_room_notices { - services.admin.notice(¬ice).await; + services + .admin + .send_message(RoomMessageEventContent::notice_plain(format!( + "New user \"{user_id}\" registered on this server from IP {client} and \ + device display name \"{device_display_name}\"" + ))) + .await + .ok(); } } else { - let notice = format!("New user \"{user_id}\" registered on this server."); + info!("New user \"{user_id}\" registered on this server."); - info!("{notice}"); if services.server.config.admin_room_notices { - services.admin.notice(¬ice).await; + services + .admin + .send_message(RoomMessageEventContent::notice_plain(format!( + "New user \"{user_id}\" registered on this server from IP {client}" + ))) + .await + .ok(); } } } @@ -456,22 +454,24 @@ pub(crate) async fn register_route( if services.server.config.admin_room_notices { services .admin - .notice(&format!( + .send_message(RoomMessageEventContent::notice_plain(format!( "Guest user \"{user_id}\" with device display name \ \"{device_display_name}\" registered on this server from IP {client}" - )) - .await; + ))) + .await + .ok(); } } else { #[allow(clippy::collapsible_else_if)] if services.server.config.admin_room_notices { services .admin - .notice(&format!( + .send_message(RoomMessageEventContent::notice_plain(format!( "Guest user \"{user_id}\" with no device display name registered on \ this server from IP {client}", - )) - .await; + ))) + .await + .ok(); } } } @@ -490,25 +490,6 @@ pub(crate) async fn register_route( { services.admin.make_user_admin(&user_id).await?; warn!("Granting {user_id} admin privileges as the first user"); - } else if services.config.suspend_on_register { - // This is not an admin, suspend them. - // Note that we can still do auto joins for suspended users - services - .users - .suspend_account(&user_id, &services.globals.server_user) - .await; - // And send an @room notice to the admin room, to prompt admins to review the - // new user and ideally unsuspend them if deemed appropriate. - if services.server.config.admin_room_notices { - services - .admin - .send_loud_message(RoomMessageEventContent::text_plain(format!( - "User {user_id} has been suspended as they are not the first user \ - on this server. Please review and unsuspend them if appropriate." - ))) - .await - .ok(); - } } } } @@ -603,6 +584,7 @@ pub(crate) async fn change_password_route( .sender_user .as_ref() .ok_or_else(|| err!(Request(MissingToken("Missing access token."))))?; + let sender_device = body.sender_device(); let mut uiaainfo = UiaaInfo { flows: vec![AuthFlow { stages: vec![AuthType::Password] }], @@ -616,7 +598,7 @@ pub(crate) async fn change_password_route( | Some(auth) => { let (worked, uiaainfo) = services .uiaa - .try_auth(sender_user, body.sender_device(), auth, &uiaainfo) + .try_auth(sender_user, sender_device, auth, &uiaainfo) .await?; if !worked { @@ -630,7 +612,7 @@ pub(crate) async fn change_password_route( uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); services .uiaa - .create(sender_user, body.sender_device(), &uiaainfo, json); + .create(sender_user, sender_device, &uiaainfo, json); return Err(Error::Uiaa(uiaainfo)); }, @@ -649,7 +631,7 @@ pub(crate) async fn change_password_route( services .users .all_device_ids(sender_user) - .ready_filter(|id| *id != body.sender_device()) + .ready_filter(|id| *id != sender_device) .for_each(|id| services.users.remove_device(sender_user, id)) .await; @@ -658,17 +640,17 @@ pub(crate) async fn change_password_route( .pusher .get_pushkeys(sender_user) .map(ToOwned::to_owned) - .broad_filter_map(async |pushkey| { + .broad_filter_map(|pushkey| async move { services .pusher .get_pusher_device(&pushkey) .await .ok() - .filter(|pusher_device| pusher_device != body.sender_device()) + .filter(|pusher_device| pusher_device != sender_device) .is_some() .then_some(pushkey) }) - .for_each(async |pushkey| { + .for_each(|pushkey| async move { services.pusher.delete_pusher(sender_user, &pushkey).await; }) .await; @@ -679,8 +661,11 @@ pub(crate) async fn change_password_route( if services.server.config.admin_room_notices { services .admin - .notice(&format!("User {sender_user} changed their password.")) - .await; + .send_message(RoomMessageEventContent::notice_plain(format!( + "User {sender_user} changed their password." + ))) + .await + .ok(); } Ok(change_password::v3::Response {}) @@ -695,10 +680,13 @@ pub(crate) async fn whoami_route( State(services): State, body: Ruma, ) -> Result { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let device_id = body.sender_device.clone(); + Ok(whoami::v3::Response { - user_id: body.sender_user().to_owned(), - device_id: body.sender_device.clone(), - is_guest: services.users.is_deactivated(body.sender_user()).await? + user_id: sender_user.clone(), + device_id, + is_guest: services.users.is_deactivated(sender_user).await? && body.appservice_info.is_none(), }) } @@ -726,6 +714,7 @@ pub(crate) async fn deactivate_route( .sender_user .as_ref() .ok_or_else(|| err!(Request(MissingToken("Missing access token."))))?; + let sender_device = body.sender_device(); let mut uiaainfo = UiaaInfo { flows: vec![AuthFlow { stages: vec![AuthType::Password] }], @@ -739,7 +728,7 @@ pub(crate) async fn deactivate_route( | Some(auth) => { let (worked, uiaainfo) = services .uiaa - .try_auth(sender_user, body.sender_device(), auth, &uiaainfo) + .try_auth(sender_user, sender_device, auth, &uiaainfo) .await?; if !worked { @@ -752,7 +741,7 @@ pub(crate) async fn deactivate_route( uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); services .uiaa - .create(sender_user, body.sender_device(), &uiaainfo, json); + .create(sender_user, sender_device, &uiaainfo, json); return Err(Error::Uiaa(uiaainfo)); }, @@ -783,8 +772,11 @@ pub(crate) async fn deactivate_route( if services.server.config.admin_room_notices { services .admin - .notice(&format!("User {sender_user} deactivated their account.")) - .await; + .send_message(RoomMessageEventContent::notice_plain(format!( + "User {sender_user} deactivated their account." + ))) + .await + .ok(); } Ok(deactivate::v3::Response { @@ -861,7 +853,6 @@ pub async fn full_user_deactivate( all_joined_rooms: &[OwnedRoomId], ) -> Result<()> { services.users.deactivate_account(user_id).await.ok(); - super::update_displayname(services, user_id, None, all_joined_rooms).await; super::update_avatar_url(services, user_id, None, None, all_joined_rooms).await; @@ -898,7 +889,7 @@ pub async fn full_user_deactivate( .state_accessor .room_state_get(room_id, &StateEventType::RoomCreate, "") .await - .is_ok_and(|event| event.sender() == user_id); + .is_ok_and(|event| event.sender == user_id); if user_can_demote_self { let mut power_levels_content = room_power_levels.unwrap_or_default(); @@ -926,7 +917,9 @@ pub async fn full_user_deactivate( } } - super::leave_all_rooms(services, user_id).boxed().await; + super::leave_all_rooms(services, user_id) + .boxed() + .await; Ok(()) } diff --git a/src/api/client/alias.rs b/src/api/client/alias.rs index 97c1a1bd..dc7aad44 100644 --- a/src/api/client/alias.rs +++ b/src/api/client/alias.rs @@ -17,7 +17,7 @@ pub(crate) async fn create_alias_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user(); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); if services.users.is_suspended(sender_user).await? { return Err!(Request(UserSuspended("You cannot perform this action while suspended."))); } @@ -65,7 +65,7 @@ pub(crate) async fn delete_alias_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user(); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); if services.users.is_suspended(sender_user).await? { return Err!(Request(UserSuspended("You cannot perform this action while suspended."))); } diff --git a/src/api/client/backup.rs b/src/api/client/backup.rs index a3038f26..2ad37cf3 100644 --- a/src/api/client/backup.rs +++ b/src/api/client/backup.rs @@ -2,10 +2,8 @@ use std::cmp::Ordering; use axum::extract::State; use conduwuit::{Err, Result, err}; -use conduwuit_service::Services; -use futures::{FutureExt, future::try_join}; use ruma::{ - UInt, UserId, + UInt, api::client::backup::{ add_backup_keys, add_backup_keys_for_room, add_backup_keys_for_session, create_backup_version, delete_backup_keys, delete_backup_keys_for_room, @@ -60,9 +58,21 @@ pub(crate) async fn get_latest_backup_info_route( .await .map_err(|_| err!(Request(NotFound("Key backup does not exist."))))?; - let (count, etag) = get_count_etag(&services, body.sender_user(), &version).await?; - - Ok(get_latest_backup_info::v3::Response { algorithm, count, etag, version }) + Ok(get_latest_backup_info::v3::Response { + algorithm, + count: (UInt::try_from( + services + .key_backups + .count_keys(body.sender_user(), &version) + .await, + ) + .expect("user backup keys count should not be that high")), + etag: services + .key_backups + .get_etag(body.sender_user(), &version) + .await, + version, + }) } /// # `GET /_matrix/client/v3/room_keys/version/{version}` @@ -80,12 +90,17 @@ pub(crate) async fn get_backup_info_route( err!(Request(NotFound("Key backup does not exist at version {:?}", body.version))) })?; - let (count, etag) = get_count_etag(&services, body.sender_user(), &body.version).await?; - Ok(get_backup_info::v3::Response { algorithm, - count, - etag, + count: services + .key_backups + .count_keys(body.sender_user(), &body.version) + .await + .try_into()?, + etag: services + .key_backups + .get_etag(body.sender_user(), &body.version) + .await, version: body.version.clone(), }) } @@ -140,9 +155,17 @@ pub(crate) async fn add_backup_keys_route( } } - let (count, etag) = get_count_etag(&services, body.sender_user(), &body.version).await?; - - Ok(add_backup_keys::v3::Response { count, etag }) + Ok(add_backup_keys::v3::Response { + count: services + .key_backups + .count_keys(body.sender_user(), &body.version) + .await + .try_into()?, + etag: services + .key_backups + .get_etag(body.sender_user(), &body.version) + .await, + }) } /// # `PUT /_matrix/client/r0/room_keys/keys/{roomId}` @@ -175,9 +198,17 @@ pub(crate) async fn add_backup_keys_for_room_route( .await?; } - let (count, etag) = get_count_etag(&services, body.sender_user(), &body.version).await?; - - Ok(add_backup_keys_for_room::v3::Response { count, etag }) + Ok(add_backup_keys_for_room::v3::Response { + count: services + .key_backups + .count_keys(body.sender_user(), &body.version) + .await + .try_into()?, + etag: services + .key_backups + .get_etag(body.sender_user(), &body.version) + .await, + }) } /// # `PUT /_matrix/client/r0/room_keys/keys/{roomId}/{sessionId}` @@ -275,9 +306,17 @@ pub(crate) async fn add_backup_keys_for_session_route( .await?; } - let (count, etag) = get_count_etag(&services, body.sender_user(), &body.version).await?; - - Ok(add_backup_keys_for_session::v3::Response { count, etag }) + Ok(add_backup_keys_for_session::v3::Response { + count: services + .key_backups + .count_keys(body.sender_user(), &body.version) + .await + .try_into()?, + etag: services + .key_backups + .get_etag(body.sender_user(), &body.version) + .await, + }) } /// # `GET /_matrix/client/r0/room_keys/keys` @@ -340,9 +379,17 @@ pub(crate) async fn delete_backup_keys_route( .delete_all_keys(body.sender_user(), &body.version) .await; - let (count, etag) = get_count_etag(&services, body.sender_user(), &body.version).await?; - - Ok(delete_backup_keys::v3::Response { count, etag }) + Ok(delete_backup_keys::v3::Response { + count: services + .key_backups + .count_keys(body.sender_user(), &body.version) + .await + .try_into()?, + etag: services + .key_backups + .get_etag(body.sender_user(), &body.version) + .await, + }) } /// # `DELETE /_matrix/client/r0/room_keys/keys/{roomId}` @@ -357,9 +404,17 @@ pub(crate) async fn delete_backup_keys_for_room_route( .delete_room_keys(body.sender_user(), &body.version, &body.room_id) .await; - let (count, etag) = get_count_etag(&services, body.sender_user(), &body.version).await?; - - Ok(delete_backup_keys_for_room::v3::Response { count, etag }) + Ok(delete_backup_keys_for_room::v3::Response { + count: services + .key_backups + .count_keys(body.sender_user(), &body.version) + .await + .try_into()?, + etag: services + .key_backups + .get_etag(body.sender_user(), &body.version) + .await, + }) } /// # `DELETE /_matrix/client/r0/room_keys/keys/{roomId}/{sessionId}` @@ -374,22 +429,15 @@ pub(crate) async fn delete_backup_keys_for_session_route( .delete_room_key(body.sender_user(), &body.version, &body.room_id, &body.session_id) .await; - let (count, etag) = get_count_etag(&services, body.sender_user(), &body.version).await?; - - Ok(delete_backup_keys_for_session::v3::Response { count, etag }) -} - -async fn get_count_etag( - services: &Services, - sender_user: &UserId, - version: &str, -) -> Result<(UInt, String)> { - let count = services - .key_backups - .count_keys(sender_user, version) - .map(TryInto::try_into); - - let etag = services.key_backups.get_etag(sender_user, version).map(Ok); - - Ok(try_join(count, etag).await?) + Ok(delete_backup_keys_for_session::v3::Response { + count: services + .key_backups + .count_keys(body.sender_user(), &body.version) + .await + .try_into()?, + etag: services + .key_backups + .get_etag(body.sender_user(), &body.version) + .await, + }) } diff --git a/src/api/client/capabilities.rs b/src/api/client/capabilities.rs index c42c6dfd..7362c4f9 100644 --- a/src/api/client/capabilities.rs +++ b/src/api/client/capabilities.rs @@ -26,8 +26,8 @@ pub(crate) async fn get_capabilities_route( let mut capabilities = Capabilities::default(); capabilities.room_versions = RoomVersionsCapability { - available, default: services.server.config.default_room_version.clone(), + available, }; // we do not implement 3PID stuff @@ -38,12 +38,16 @@ pub(crate) async fn get_capabilities_route( }; // MSC4133 capability - capabilities.set("uk.tcpip.msc4133.profile_fields", json!({"enabled": true}))?; + capabilities + .set("uk.tcpip.msc4133.profile_fields", json!({"enabled": true})) + .expect("this is valid JSON we created"); - capabilities.set( - "org.matrix.msc4267.forget_forced_upon_leave", - json!({"enabled": services.config.forget_forced_upon_leave}), - )?; + capabilities + .set( + "org.matrix.msc4267.forget_forced_upon_leave", + json!({"enabled": services.config.forget_forced_upon_leave}), + ) + .expect("valid JSON we created"); Ok(get_capabilities::v3::Response { capabilities }) } diff --git a/src/api/client/context.rs b/src/api/client/context.rs index 4a7d34d2..ee3a458c 100644 --- a/src/api/client/context.rs +++ b/src/api/client/context.rs @@ -1,6 +1,8 @@ use axum::extract::State; use conduwuit::{ - Err, Event, Result, at, debug_warn, err, ref_at, + Err, Result, at, debug_warn, err, + matrix::pdu::PduEvent, + ref_at, utils::{ IterStream, future::TryExtExt, @@ -82,11 +84,25 @@ pub(crate) async fn get_context_route( let base_event = ignored_filter(&services, (base_count, base_pdu), sender_user); + // PDUs are used to get seen user IDs and then returned in response. + let events_before = services .rooms .timeline - .pdus_rev(Some(sender_user), room_id, Some(base_count)) + .pdus_rev(room_id, Some(base_count)) .ignore_err() + .then(async |mut pdu| { + pdu.1.set_unsigned(Some(sender_user)); + if let Err(e) = services + .rooms + .pdu_metadata + .add_bundled_aggregations_to_pdu(sender_user, &mut pdu.1) + .await + { + debug_warn!("Failed to add bundled aggregations: {e}"); + } + pdu + }) .ready_filter_map(|item| event_filter(item, filter)) .wide_filter_map(|item| ignored_filter(&services, item, sender_user)) .wide_filter_map(|item| visibility_filter(&services, item, sender_user)) @@ -96,8 +112,20 @@ pub(crate) async fn get_context_route( let events_after = services .rooms .timeline - .pdus(Some(sender_user), room_id, Some(base_count)) + .pdus(room_id, Some(base_count)) .ignore_err() + .then(async |mut pdu| { + pdu.1.set_unsigned(Some(sender_user)); + if let Err(e) = services + .rooms + .pdu_metadata + .add_bundled_aggregations_to_pdu(sender_user, &mut pdu.1) + .await + { + debug_warn!("Failed to add bundled aggregations: {e}"); + } + pdu + }) .ready_filter_map(|item| event_filter(item, filter)) .wide_filter_map(|item| ignored_filter(&services, item, sender_user)) .wide_filter_map(|item| visibility_filter(&services, item, sender_user)) @@ -109,7 +137,7 @@ pub(crate) async fn get_context_route( let lazy_loading_context = lazy_loading::Context { user_id: sender_user, - device_id: Some(sender_device), + device_id: sender_device, room_id, token: Some(base_count.into_unsigned()), options: Some(&filter.lazy_load_options), @@ -177,12 +205,12 @@ pub(crate) async fn get_context_route( .broad_filter_map(|event_id: &OwnedEventId| { services.rooms.timeline.get_pdu(event_id.as_ref()).ok() }) - .map(Event::into_format) + .map(PduEvent::into_state_event) .collect() .await; Ok(get_context::v3::Response { - event: base_event.map(at!(1)).map(Event::into_format), + event: base_event.map(at!(1)).map(PduEvent::into_room_event), start: events_before .last() @@ -201,13 +229,13 @@ pub(crate) async fn get_context_route( events_before: events_before .into_iter() .map(at!(1)) - .map(Event::into_format) + .map(PduEvent::into_room_event) .collect(), events_after: events_after .into_iter() .map(at!(1)) - .map(Event::into_format) + .map(PduEvent::into_room_event) .collect(), state, diff --git a/src/api/client/device.rs b/src/api/client/device.rs index b0a7e142..5519a1a5 100644 --- a/src/api/client/device.rs +++ b/src/api/client/device.rs @@ -21,9 +21,11 @@ pub(crate) async fn get_devices_route( State(services): State, body: Ruma, ) -> Result { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let devices: Vec = services .users - .all_devices_metadata(body.sender_user()) + .all_devices_metadata(sender_user) .collect() .await; @@ -37,9 +39,11 @@ pub(crate) async fn get_device_route( State(services): State, body: Ruma, ) -> Result { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let device = services .users - .get_device_metadata(body.sender_user(), &body.body.device_id) + .get_device_metadata(sender_user, &body.body.device_id) .await .map_err(|_| err!(Request(NotFound("Device not found."))))?; diff --git a/src/api/client/directory.rs b/src/api/client/directory.rs index 00879274..2e219fd9 100644 --- a/src/api/client/directory.rs +++ b/src/api/client/directory.rs @@ -1,7 +1,7 @@ use axum::extract::State; use axum_client_ip::InsecureClientIp; use conduwuit::{ - Err, Event, Result, err, info, + Err, Result, err, info, utils::{ TryFutureExtExt, math::Expected, @@ -352,7 +352,7 @@ async fn user_can_publish_room( .room_state_get(room_id, &StateEventType::RoomPowerLevels, "") .await { - | Ok(event) => serde_json::from_str(event.content().get()) + | Ok(event) => serde_json::from_str(event.content.get()) .map_err(|_| err!(Database("Invalid event content for m.room.power_levels"))) .map(|content: RoomPowerLevelsEventContent| { RoomPowerLevels::from(content) @@ -365,7 +365,7 @@ async fn user_can_publish_room( .room_state_get(room_id, &StateEventType::RoomCreate, "") .await { - | Ok(event) => Ok(event.sender() == user_id), + | Ok(event) => Ok(event.sender == user_id), | _ => Err!(Request(Forbidden("User is not allowed to publish this room"))), } }, diff --git a/src/api/client/filter.rs b/src/api/client/filter.rs index 9814d366..97044ffc 100644 --- a/src/api/client/filter.rs +++ b/src/api/client/filter.rs @@ -13,9 +13,11 @@ pub(crate) async fn get_filter_route( State(services): State, body: Ruma, ) -> Result { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + services .users - .get_filter(body.sender_user(), &body.filter_id) + .get_filter(sender_user, &body.filter_id) .await .map(get_filter::v3::Response::new) .map_err(|_| err!(Request(NotFound("Filter not found.")))) @@ -28,9 +30,9 @@ pub(crate) async fn create_filter_route( State(services): State, body: Ruma, ) -> Result { - let filter_id = services - .users - .create_filter(body.sender_user(), &body.filter); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + + let filter_id = services.users.create_filter(sender_user, &body.filter); Ok(create_filter::v3::Response::new(filter_id)) } diff --git a/src/api/client/keys.rs b/src/api/client/keys.rs index d2bd46a0..650c573f 100644 --- a/src/api/client/keys.rs +++ b/src/api/client/keys.rs @@ -126,7 +126,7 @@ pub(crate) async fn get_keys_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user(); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); get_keys_helper( &services, @@ -157,7 +157,8 @@ pub(crate) async fn upload_signing_keys_route( State(services): State, body: Ruma, ) -> Result { - let (sender_user, sender_device) = body.sender(); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_device = body.sender_device.as_ref().expect("user is authenticated"); // UIAA let mut uiaainfo = UiaaInfo { @@ -202,12 +203,12 @@ pub(crate) async fn upload_signing_keys_route( } // Success! }, - | _ => match body.json_body.as_ref() { + | _ => match body.json_body { | Some(json) => { uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); services .uiaa - .create(sender_user, sender_device, &uiaainfo, json); + .create(sender_user, sender_device, &uiaainfo, &json); return Err(Error::Uiaa(uiaainfo)); }, @@ -372,7 +373,7 @@ pub(crate) async fn get_key_changes_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user(); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let mut device_list_updates = HashSet::new(); diff --git a/src/api/client/media.rs b/src/api/client/media.rs index 3f491d54..11d5450c 100644 --- a/src/api/client/media.rs +++ b/src/api/client/media.rs @@ -51,7 +51,7 @@ pub(crate) async fn create_content_route( InsecureClientIp(client): InsecureClientIp, body: Ruma, ) -> Result { - let user = body.sender_user(); + let user = body.sender_user.as_ref().expect("user is authenticated"); if services.users.is_suspended(user).await? { return Err!(Request(UserSuspended("You cannot perform this action while suspended."))); } @@ -97,7 +97,7 @@ pub(crate) async fn get_content_thumbnail_route( InsecureClientIp(client): InsecureClientIp, body: Ruma, ) -> Result { - let user = body.sender_user(); + let user = body.sender_user.as_ref().expect("user is authenticated"); let dim = Dim::from_ruma(body.width, body.height, body.method.clone())?; let mxc = Mxc { @@ -134,7 +134,7 @@ pub(crate) async fn get_content_route( InsecureClientIp(client): InsecureClientIp, body: Ruma, ) -> Result { - let user = body.sender_user(); + let user = body.sender_user.as_ref().expect("user is authenticated"); let mxc = Mxc { server_name: &body.server_name, @@ -170,7 +170,7 @@ pub(crate) async fn get_content_as_filename_route( InsecureClientIp(client): InsecureClientIp, body: Ruma, ) -> Result { - let user = body.sender_user(); + let user = body.sender_user.as_ref().expect("user is authenticated"); let mxc = Mxc { server_name: &body.server_name, @@ -206,7 +206,7 @@ pub(crate) async fn get_media_preview_route( InsecureClientIp(client): InsecureClientIp, body: Ruma, ) -> Result { - let sender_user = body.sender_user(); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let url = &body.url; let url = Url::parse(&body.url).map_err(|e| { diff --git a/src/api/client/media_legacy.rs b/src/api/client/media_legacy.rs index 930daab4..d9f24f77 100644 --- a/src/api/client/media_legacy.rs +++ b/src/api/client/media_legacy.rs @@ -55,7 +55,7 @@ pub(crate) async fn get_media_preview_legacy_route( InsecureClientIp(client): InsecureClientIp, body: Ruma, ) -> Result { - let sender_user = body.sender_user(); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let url = &body.url; let url = Url::parse(&body.url).map_err(|e| { diff --git a/src/api/client/membership.rs b/src/api/client/membership.rs index 28768fee..300b6b93 100644 --- a/src/api/client/membership.rs +++ b/src/api/client/membership.rs @@ -114,7 +114,9 @@ async fn banned_room_check( .collect() .await; - full_user_deactivate(services, user_id, &all_joined_rooms).await?; + full_user_deactivate(services, user_id, &all_joined_rooms) + .boxed() + .await?; } return Err!(Request(Forbidden("This room is banned on this homeserver."))); @@ -153,7 +155,9 @@ async fn banned_room_check( .collect() .await; - full_user_deactivate(services, user_id, &all_joined_rooms).await?; + full_user_deactivate(services, user_id, &all_joined_rooms) + .boxed() + .await?; } return Err!(Request(Forbidden("This remote server is banned on this homeserver."))); @@ -265,6 +269,7 @@ pub(crate) async fn join_room_by_id_or_alias_route( room_id.server_name(), client, ) + .boxed() .await?; let mut servers = body.via.clone(); @@ -487,6 +492,7 @@ pub(crate) async fn leave_room_route( body: Ruma, ) -> Result { leave_room(&services, body.sender_user(), &body.room_id, body.reason.clone()) + .boxed() .await .map(|()| leave_room::v3::Response::new()) } @@ -1826,7 +1832,10 @@ pub async fn leave_all_rooms(services: &Services, user_id: &UserId) { for room_id in all_rooms { // ignore errors - if let Err(e) = leave_room(services, user_id, &room_id, None).await { + if let Err(e) = leave_room(services, user_id, &room_id, None) + .boxed() + .await + { warn!(%user_id, "Failed to leave {room_id} remotely: {e}"); } diff --git a/src/api/client/membership/ban.rs b/src/api/client/membership/ban.rs deleted file mode 100644 index 339dcf2e..00000000 --- a/src/api/client/membership/ban.rs +++ /dev/null @@ -1,60 +0,0 @@ -use axum::extract::State; -use conduwuit::{Err, Result, matrix::pdu::PduBuilder}; -use ruma::{ - api::client::membership::ban_user, - events::room::member::{MembershipState, RoomMemberEventContent}, -}; - -use crate::Ruma; - -/// # `POST /_matrix/client/r0/rooms/{roomId}/ban` -/// -/// Tries to send a ban event into the room. -pub(crate) async fn ban_user_route( - State(services): State, - body: Ruma, -) -> Result { - let sender_user = body.sender_user(); - - if sender_user == body.user_id { - return Err!(Request(Forbidden("You cannot ban yourself."))); - } - - if services.users.is_suspended(sender_user).await? { - return Err!(Request(UserSuspended("You cannot perform this action while suspended."))); - } - - let state_lock = services.rooms.state.mutex.lock(&body.room_id).await; - - let current_member_content = services - .rooms - .state_accessor - .get_member(&body.room_id, &body.user_id) - .await - .unwrap_or_else(|_| RoomMemberEventContent::new(MembershipState::Ban)); - - services - .rooms - .timeline - .build_and_append_pdu( - PduBuilder::state(body.user_id.to_string(), &RoomMemberEventContent { - membership: MembershipState::Ban, - reason: body.reason.clone(), - displayname: None, // display name may be offensive - avatar_url: None, // avatar may be offensive - is_direct: None, - join_authorized_via_users_server: None, - third_party_invite: None, - redact_events: body.redact_events, - ..current_member_content - }), - sender_user, - &body.room_id, - &state_lock, - ) - .await?; - - drop(state_lock); - - Ok(ban_user::v3::Response::new()) -} diff --git a/src/api/client/membership/forget.rs b/src/api/client/membership/forget.rs deleted file mode 100644 index 7f3a1a57..00000000 --- a/src/api/client/membership/forget.rs +++ /dev/null @@ -1,52 +0,0 @@ -use axum::extract::State; -use conduwuit::{Err, Result, is_matching, result::NotFound, utils::FutureBoolExt}; -use futures::pin_mut; -use ruma::{api::client::membership::forget_room, events::room::member::MembershipState}; - -use crate::Ruma; - -/// # `POST /_matrix/client/v3/rooms/{roomId}/forget` -/// -/// Forgets about a room. -/// -/// - If the sender user currently left the room: Stops sender user from -/// receiving information about the room -/// -/// Note: Other devices of the user have no way of knowing the room was -/// forgotten, so this has to be called from every device -pub(crate) async fn forget_room_route( - State(services): State, - body: Ruma, -) -> Result { - let user_id = body.sender_user(); - let room_id = &body.room_id; - - let joined = services.rooms.state_cache.is_joined(user_id, room_id); - let knocked = services.rooms.state_cache.is_knocked(user_id, room_id); - let invited = services.rooms.state_cache.is_invited(user_id, room_id); - - pin_mut!(joined, knocked, invited); - if joined.or(knocked).or(invited).await { - return Err!(Request(Unknown("You must leave the room before forgetting it"))); - } - - let membership = services - .rooms - .state_accessor - .get_member(room_id, user_id) - .await; - - if membership.is_not_found() { - return Err!(Request(Unknown("No membership event was found, room was never joined"))); - } - - let non_membership = membership - .map(|member| member.membership) - .is_ok_and(is_matching!(MembershipState::Leave | MembershipState::Ban)); - - if non_membership || services.rooms.state_cache.is_left(user_id, room_id).await { - services.rooms.state_cache.forget(room_id, user_id); - } - - Ok(forget_room::v3::Response::new()) -} diff --git a/src/api/client/membership/invite.rs b/src/api/client/membership/invite.rs deleted file mode 100644 index 018fb774..00000000 --- a/src/api/client/membership/invite.rs +++ /dev/null @@ -1,238 +0,0 @@ -use axum::extract::State; -use axum_client_ip::InsecureClientIp; -use conduwuit::{ - Err, Result, debug_error, err, info, - matrix::{event::gen_event_id_canonical_json, pdu::PduBuilder}, -}; -use futures::{FutureExt, join}; -use ruma::{ - OwnedServerName, RoomId, UserId, - api::{client::membership::invite_user, federation::membership::create_invite}, - events::room::member::{MembershipState, RoomMemberEventContent}, -}; -use service::Services; - -use super::banned_room_check; -use crate::Ruma; - -/// # `POST /_matrix/client/r0/rooms/{roomId}/invite` -/// -/// Tries to send an invite event into the room. -#[tracing::instrument(skip_all, fields(%client), name = "invite")] -pub(crate) async fn invite_user_route( - State(services): State, - InsecureClientIp(client): InsecureClientIp, - body: Ruma, -) -> Result { - let sender_user = body.sender_user(); - if services.users.is_suspended(sender_user).await? { - return Err!(Request(UserSuspended("You cannot perform this action while suspended."))); - } - - if !services.users.is_admin(sender_user).await && services.config.block_non_admin_invites { - debug_error!( - "User {sender_user} is not an admin and attempted to send an invite to room {}", - &body.room_id - ); - return Err!(Request(Forbidden("Invites are not allowed on this server."))); - } - - banned_room_check( - &services, - sender_user, - Some(&body.room_id), - body.room_id.server_name(), - client, - ) - .await?; - - match &body.recipient { - | invite_user::v3::InvitationRecipient::UserId { user_id } => { - let sender_ignored_recipient = services.users.user_is_ignored(sender_user, user_id); - let recipient_ignored_by_sender = - services.users.user_is_ignored(user_id, sender_user); - - let (sender_ignored_recipient, recipient_ignored_by_sender) = - join!(sender_ignored_recipient, recipient_ignored_by_sender); - - if sender_ignored_recipient { - return Ok(invite_user::v3::Response {}); - } - - if let Ok(target_user_membership) = services - .rooms - .state_accessor - .get_member(&body.room_id, user_id) - .await - { - if target_user_membership.membership == MembershipState::Ban { - return Err!(Request(Forbidden("User is banned from this room."))); - } - } - - if recipient_ignored_by_sender { - // silently drop the invite to the recipient if they've been ignored by the - // sender, pretend it worked - return Ok(invite_user::v3::Response {}); - } - - invite_helper( - &services, - sender_user, - user_id, - &body.room_id, - body.reason.clone(), - false, - ) - .boxed() - .await?; - - Ok(invite_user::v3::Response {}) - }, - | _ => { - Err!(Request(NotFound("User not found."))) - }, - } -} - -pub(crate) async fn invite_helper( - services: &Services, - sender_user: &UserId, - user_id: &UserId, - room_id: &RoomId, - reason: Option, - is_direct: bool, -) -> Result { - if !services.users.is_admin(sender_user).await && services.config.block_non_admin_invites { - info!( - "User {sender_user} is not an admin and attempted to send an invite to room \ - {room_id}" - ); - return Err!(Request(Forbidden("Invites are not allowed on this server."))); - } - - if !services.globals.user_is_local(user_id) { - let (pdu, pdu_json, invite_room_state) = { - let state_lock = services.rooms.state.mutex.lock(room_id).await; - - let content = RoomMemberEventContent { - avatar_url: services.users.avatar_url(user_id).await.ok(), - is_direct: Some(is_direct), - reason, - ..RoomMemberEventContent::new(MembershipState::Invite) - }; - - let (pdu, pdu_json) = services - .rooms - .timeline - .create_hash_and_sign_event( - PduBuilder::state(user_id.to_string(), &content), - sender_user, - room_id, - &state_lock, - ) - .await?; - - let invite_room_state = services.rooms.state.summary_stripped(&pdu).await; - - drop(state_lock); - - (pdu, pdu_json, invite_room_state) - }; - - let room_version_id = services.rooms.state.get_room_version(room_id).await?; - - let response = services - .sending - .send_federation_request(user_id.server_name(), create_invite::v2::Request { - room_id: room_id.to_owned(), - event_id: (*pdu.event_id).to_owned(), - room_version: room_version_id.clone(), - event: services - .sending - .convert_to_outgoing_federation_event(pdu_json.clone()) - .await, - invite_room_state, - via: services - .rooms - .state_cache - .servers_route_via(room_id) - .await - .ok(), - }) - .await?; - - // We do not add the event_id field to the pdu here because of signature and - // hashes checks - let (event_id, value) = gen_event_id_canonical_json(&response.event, &room_version_id) - .map_err(|e| { - err!(Request(BadJson(warn!("Could not convert event to canonical JSON: {e}")))) - })?; - - if pdu.event_id != event_id { - return Err!(Request(BadJson(warn!( - %pdu.event_id, %event_id, - "Server {} sent event with wrong event ID", - user_id.server_name() - )))); - } - - let origin: OwnedServerName = serde_json::from_value(serde_json::to_value( - value - .get("origin") - .ok_or_else(|| err!(Request(BadJson("Event missing origin field."))))?, - )?) - .map_err(|e| { - err!(Request(BadJson(warn!("Origin field in event is not a valid server name: {e}")))) - })?; - - let pdu_id = services - .rooms - .event_handler - .handle_incoming_pdu(&origin, room_id, &event_id, value, true) - .boxed() - .await? - .ok_or_else(|| { - err!(Request(InvalidParam("Could not accept incoming PDU as timeline event."))) - })?; - - return services.sending.send_pdu_room(room_id, &pdu_id).await; - } - - if !services - .rooms - .state_cache - .is_joined(sender_user, room_id) - .await - { - return Err!(Request(Forbidden( - "You must be joined in the room you are trying to invite from." - ))); - } - - let state_lock = services.rooms.state.mutex.lock(room_id).await; - - let content = RoomMemberEventContent { - displayname: services.users.displayname(user_id).await.ok(), - avatar_url: services.users.avatar_url(user_id).await.ok(), - blurhash: services.users.blurhash(user_id).await.ok(), - is_direct: Some(is_direct), - reason, - ..RoomMemberEventContent::new(MembershipState::Invite) - }; - - services - .rooms - .timeline - .build_and_append_pdu( - PduBuilder::state(user_id.to_string(), &content), - sender_user, - room_id, - &state_lock, - ) - .await?; - - drop(state_lock); - - Ok(()) -} diff --git a/src/api/client/membership/join.rs b/src/api/client/membership/join.rs deleted file mode 100644 index dc170cbf..00000000 --- a/src/api/client/membership/join.rs +++ /dev/null @@ -1,989 +0,0 @@ -use std::{borrow::Borrow, collections::HashMap, iter::once, sync::Arc}; - -use axum::extract::State; -use axum_client_ip::InsecureClientIp; -use conduwuit::{ - Err, Result, debug, debug_info, debug_warn, err, error, info, - matrix::{ - StateKey, - event::{gen_event_id, gen_event_id_canonical_json}, - pdu::{PduBuilder, PduEvent}, - state_res, - }, - result::FlatOk, - trace, - utils::{ - self, shuffle, - stream::{IterStream, ReadyExt}, - }, - warn, -}; -use futures::{FutureExt, StreamExt}; -use ruma::{ - CanonicalJsonObject, CanonicalJsonValue, OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, - RoomVersionId, UserId, - api::{ - client::{ - error::ErrorKind, - membership::{ThirdPartySigned, join_room_by_id, join_room_by_id_or_alias}, - }, - federation::{self}, - }, - canonical_json::to_canonical_value, - events::{ - StateEventType, - room::{ - join_rules::{AllowRule, JoinRule, RoomJoinRulesEventContent}, - member::{MembershipState, RoomMemberEventContent}, - }, - }, -}; -use service::{ - Services, - appservice::RegistrationInfo, - rooms::{ - state::RoomMutexGuard, - state_compressor::{CompressedState, HashSetCompressStateEvent}, - }, -}; - -use super::banned_room_check; -use crate::Ruma; - -/// # `POST /_matrix/client/r0/rooms/{roomId}/join` -/// -/// Tries to join the sender user into a room. -/// -/// - If the server knowns about this room: creates the join event and does auth -/// rules locally -/// - If the server does not know about the room: asks other servers over -/// federation -#[tracing::instrument(skip_all, fields(%client), name = "join")] -pub(crate) async fn join_room_by_id_route( - State(services): State, - InsecureClientIp(client): InsecureClientIp, - body: Ruma, -) -> Result { - let sender_user = body.sender_user(); - if services.users.is_suspended(sender_user).await? { - return Err!(Request(UserSuspended("You cannot perform this action while suspended."))); - } - - banned_room_check( - &services, - sender_user, - Some(&body.room_id), - body.room_id.server_name(), - client, - ) - .await?; - - // There is no body.server_name for /roomId/join - let mut servers: Vec<_> = services - .rooms - .state_cache - .servers_invite_via(&body.room_id) - .map(ToOwned::to_owned) - .collect() - .await; - - servers.extend( - services - .rooms - .state_cache - .invite_state(sender_user, &body.room_id) - .await - .unwrap_or_default() - .iter() - .filter_map(|event| event.get_field("sender").ok().flatten()) - .filter_map(|sender: &str| UserId::parse(sender).ok()) - .map(|user| user.server_name().to_owned()), - ); - - if let Some(server) = body.room_id.server_name() { - servers.push(server.into()); - } - - servers.sort_unstable(); - servers.dedup(); - shuffle(&mut servers); - - join_room_by_id_helper( - &services, - sender_user, - &body.room_id, - body.reason.clone(), - &servers, - body.third_party_signed.as_ref(), - &body.appservice_info, - ) - .boxed() - .await -} - -/// # `POST /_matrix/client/r0/join/{roomIdOrAlias}` -/// -/// Tries to join the sender user into a room. -/// -/// - If the server knowns about this room: creates the join event and does auth -/// rules locally -/// - If the server does not know about the room: use the server name query -/// param if specified. if not specified, asks other servers over federation -/// via room alias server name and room ID server name -#[tracing::instrument(skip_all, fields(%client), name = "join")] -pub(crate) async fn join_room_by_id_or_alias_route( - State(services): State, - InsecureClientIp(client): InsecureClientIp, - body: Ruma, -) -> Result { - let sender_user = body.sender_user(); - let appservice_info = &body.appservice_info; - let body = &body.body; - if services.users.is_suspended(sender_user).await? { - return Err!(Request(UserSuspended("You cannot perform this action while suspended."))); - } - - let (servers, room_id) = match OwnedRoomId::try_from(body.room_id_or_alias.clone()) { - | Ok(room_id) => { - banned_room_check( - &services, - sender_user, - Some(&room_id), - room_id.server_name(), - client, - ) - .boxed() - .await?; - - let mut servers = body.via.clone(); - servers.extend( - services - .rooms - .state_cache - .servers_invite_via(&room_id) - .map(ToOwned::to_owned) - .collect::>() - .await, - ); - - servers.extend( - services - .rooms - .state_cache - .invite_state(sender_user, &room_id) - .await - .unwrap_or_default() - .iter() - .filter_map(|event| event.get_field("sender").ok().flatten()) - .filter_map(|sender: &str| UserId::parse(sender).ok()) - .map(|user| user.server_name().to_owned()), - ); - - if let Some(server) = room_id.server_name() { - servers.push(server.to_owned()); - } - - servers.sort_unstable(); - servers.dedup(); - shuffle(&mut servers); - - (servers, room_id) - }, - | Err(room_alias) => { - let (room_id, mut servers) = services - .rooms - .alias - .resolve_alias(&room_alias, Some(body.via.clone())) - .await?; - - banned_room_check( - &services, - sender_user, - Some(&room_id), - Some(room_alias.server_name()), - client, - ) - .await?; - - let addl_via_servers = services - .rooms - .state_cache - .servers_invite_via(&room_id) - .map(ToOwned::to_owned); - - let addl_state_servers = services - .rooms - .state_cache - .invite_state(sender_user, &room_id) - .await - .unwrap_or_default(); - - let mut addl_servers: Vec<_> = addl_state_servers - .iter() - .map(|event| event.get_field("sender")) - .filter_map(FlatOk::flat_ok) - .map(|user: &UserId| user.server_name().to_owned()) - .stream() - .chain(addl_via_servers) - .collect() - .await; - - addl_servers.sort_unstable(); - addl_servers.dedup(); - shuffle(&mut addl_servers); - servers.append(&mut addl_servers); - - (servers, room_id) - }, - }; - - let join_room_response = join_room_by_id_helper( - &services, - sender_user, - &room_id, - body.reason.clone(), - &servers, - body.third_party_signed.as_ref(), - appservice_info, - ) - .boxed() - .await?; - - Ok(join_room_by_id_or_alias::v3::Response { room_id: join_room_response.room_id }) -} - -pub async fn join_room_by_id_helper( - services: &Services, - sender_user: &UserId, - room_id: &RoomId, - reason: Option, - servers: &[OwnedServerName], - third_party_signed: Option<&ThirdPartySigned>, - appservice_info: &Option, -) -> Result { - let state_lock = services.rooms.state.mutex.lock(room_id).await; - - let user_is_guest = services - .users - .is_deactivated(sender_user) - .await - .unwrap_or(false) - && appservice_info.is_none(); - - if user_is_guest && !services.rooms.state_accessor.guest_can_join(room_id).await { - return Err!(Request(Forbidden("Guests are not allowed to join this room"))); - } - - if services - .rooms - .state_cache - .is_joined(sender_user, room_id) - .await - { - debug_warn!("{sender_user} is already joined in {room_id}"); - return Ok(join_room_by_id::v3::Response { room_id: room_id.into() }); - } - - let server_in_room = services - .rooms - .state_cache - .server_in_room(services.globals.server_name(), room_id) - .await; - - // Only check our known membership if we're already in the room. - // See: https://forgejo.ellis.link/continuwuation/continuwuity/issues/855 - let membership = if server_in_room { - services - .rooms - .state_accessor - .get_member(room_id, sender_user) - .await - } else { - debug!("Ignoring local state for join {room_id}, we aren't in the room yet."); - Ok(RoomMemberEventContent::new(MembershipState::Leave)) - }; - if let Ok(m) = membership { - if m.membership == MembershipState::Ban { - debug_warn!("{sender_user} is banned from {room_id} but attempted to join"); - // TODO: return reason - return Err!(Request(Forbidden("You are banned from the room."))); - } - } - - let local_join = server_in_room - || servers.is_empty() - || (servers.len() == 1 && services.globals.server_is_ours(&servers[0])); - - if local_join { - join_room_by_id_helper_local( - services, - sender_user, - room_id, - reason, - servers, - third_party_signed, - state_lock, - ) - .boxed() - .await?; - } else { - // Ask a remote server if we are not participating in this room - join_room_by_id_helper_remote( - services, - sender_user, - room_id, - reason, - servers, - third_party_signed, - state_lock, - ) - .boxed() - .await?; - } - - Ok(join_room_by_id::v3::Response::new(room_id.to_owned())) -} - -#[tracing::instrument(skip_all, fields(%sender_user, %room_id), name = "join_remote")] -async fn join_room_by_id_helper_remote( - services: &Services, - sender_user: &UserId, - room_id: &RoomId, - reason: Option, - servers: &[OwnedServerName], - _third_party_signed: Option<&ThirdPartySigned>, - state_lock: RoomMutexGuard, -) -> Result { - info!("Joining {room_id} over federation."); - - let (make_join_response, remote_server) = - make_join_request(services, sender_user, room_id, servers).await?; - - info!("make_join finished"); - - let Some(room_version_id) = make_join_response.room_version else { - return Err!(BadServerResponse("Remote room version is not supported by conduwuit")); - }; - - if !services.server.supported_room_version(&room_version_id) { - return Err!(BadServerResponse( - "Remote room version {room_version_id} is not supported by conduwuit" - )); - } - - let mut join_event_stub: CanonicalJsonObject = - serde_json::from_str(make_join_response.event.get()).map_err(|e| { - err!(BadServerResponse(warn!( - "Invalid make_join event json received from server: {e:?}" - ))) - })?; - - let join_authorized_via_users_server = { - use RoomVersionId::*; - if !matches!(room_version_id, V1 | V2 | V3 | V4 | V5 | V6 | V7) { - join_event_stub - .get("content") - .map(|s| { - s.as_object()? - .get("join_authorised_via_users_server")? - .as_str() - }) - .and_then(|s| OwnedUserId::try_from(s.unwrap_or_default()).ok()) - } else { - None - } - }; - - join_event_stub.insert( - "origin".to_owned(), - CanonicalJsonValue::String(services.globals.server_name().as_str().to_owned()), - ); - join_event_stub.insert( - "origin_server_ts".to_owned(), - CanonicalJsonValue::Integer( - utils::millis_since_unix_epoch() - .try_into() - .expect("Timestamp is valid js_int value"), - ), - ); - join_event_stub.insert( - "content".to_owned(), - to_canonical_value(RoomMemberEventContent { - displayname: services.users.displayname(sender_user).await.ok(), - avatar_url: services.users.avatar_url(sender_user).await.ok(), - blurhash: services.users.blurhash(sender_user).await.ok(), - reason, - join_authorized_via_users_server: join_authorized_via_users_server.clone(), - ..RoomMemberEventContent::new(MembershipState::Join) - }) - .expect("event is valid, we just created it"), - ); - - // We keep the "event_id" in the pdu only in v1 or - // v2 rooms - match room_version_id { - | RoomVersionId::V1 | RoomVersionId::V2 => {}, - | _ => { - join_event_stub.remove("event_id"); - }, - } - - // In order to create a compatible ref hash (EventID) the `hashes` field needs - // to be present - services - .server_keys - .hash_and_sign_event(&mut join_event_stub, &room_version_id)?; - - // Generate event id - let event_id = gen_event_id(&join_event_stub, &room_version_id)?; - - // Add event_id back - join_event_stub - .insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.clone().into())); - - // It has enough fields to be called a proper event now - let mut join_event = join_event_stub; - - info!("Asking {remote_server} for send_join in room {room_id}"); - let send_join_request = federation::membership::create_join_event::v2::Request { - room_id: room_id.to_owned(), - event_id: event_id.clone(), - omit_members: false, - pdu: services - .sending - .convert_to_outgoing_federation_event(join_event.clone()) - .await, - }; - - let send_join_response = match services - .sending - .send_synapse_request(&remote_server, send_join_request) - .await - { - | Ok(response) => response, - | Err(e) => { - error!("send_join failed: {e}"); - return Err(e); - }, - }; - - info!("send_join finished"); - - if join_authorized_via_users_server.is_some() { - if let Some(signed_raw) = &send_join_response.room_state.event { - debug_info!( - "There is a signed event with join_authorized_via_users_server. This room is \ - probably using restricted joins. Adding signature to our event" - ); - - let (signed_event_id, signed_value) = - gen_event_id_canonical_json(signed_raw, &room_version_id).map_err(|e| { - err!(Request(BadJson(warn!( - "Could not convert event to canonical JSON: {e}" - )))) - })?; - - if signed_event_id != event_id { - return Err!(Request(BadJson(warn!( - %signed_event_id, %event_id, - "Server {remote_server} sent event with wrong event ID" - )))); - } - - match signed_value["signatures"] - .as_object() - .ok_or_else(|| { - err!(BadServerResponse(warn!( - "Server {remote_server} sent invalid signatures type" - ))) - }) - .and_then(|e| { - e.get(remote_server.as_str()).ok_or_else(|| { - err!(BadServerResponse(warn!( - "Server {remote_server} did not send its signature for a restricted \ - room" - ))) - }) - }) { - | Ok(signature) => { - join_event - .get_mut("signatures") - .expect("we created a valid pdu") - .as_object_mut() - .expect("we created a valid pdu") - .insert(remote_server.to_string(), signature.clone()); - }, - | Err(e) => { - warn!( - "Server {remote_server} sent invalid signature in send_join signatures \ - for event {signed_value:?}: {e:?}", - ); - }, - } - } - } - - services - .rooms - .short - .get_or_create_shortroomid(room_id) - .await; - - info!("Parsing join event"); - let parsed_join_pdu = PduEvent::from_id_val(&event_id, join_event.clone()) - .map_err(|e| err!(BadServerResponse("Invalid join event PDU: {e:?}")))?; - - info!("Acquiring server signing keys for response events"); - let resp_events = &send_join_response.room_state; - let resp_state = &resp_events.state; - let resp_auth = &resp_events.auth_chain; - services - .server_keys - .acquire_events_pubkeys(resp_auth.iter().chain(resp_state.iter())) - .await; - - info!("Going through send_join response room_state"); - let cork = services.db.cork_and_flush(); - let state = send_join_response - .room_state - .state - .iter() - .stream() - .then(|pdu| { - services - .server_keys - .validate_and_add_event_id_no_fetch(pdu, &room_version_id) - }) - .ready_filter_map(Result::ok) - .fold(HashMap::new(), |mut state, (event_id, value)| async move { - let pdu = match PduEvent::from_id_val(&event_id, value.clone()) { - | Ok(pdu) => pdu, - | Err(e) => { - debug_warn!("Invalid PDU in send_join response: {e:?}: {value:#?}"); - return state; - }, - }; - - services.rooms.outlier.add_pdu_outlier(&event_id, &value); - if let Some(state_key) = &pdu.state_key { - let shortstatekey = services - .rooms - .short - .get_or_create_shortstatekey(&pdu.kind.to_string().into(), state_key) - .await; - - state.insert(shortstatekey, pdu.event_id.clone()); - } - - state - }) - .await; - - drop(cork); - - info!("Going through send_join response auth_chain"); - let cork = services.db.cork_and_flush(); - send_join_response - .room_state - .auth_chain - .iter() - .stream() - .then(|pdu| { - services - .server_keys - .validate_and_add_event_id_no_fetch(pdu, &room_version_id) - }) - .ready_filter_map(Result::ok) - .ready_for_each(|(event_id, value)| { - services.rooms.outlier.add_pdu_outlier(&event_id, &value); - }) - .await; - - drop(cork); - - debug!("Running send_join auth check"); - let fetch_state = &state; - let state_fetch = |k: StateEventType, s: StateKey| async move { - let shortstatekey = services.rooms.short.get_shortstatekey(&k, &s).await.ok()?; - - let event_id = fetch_state.get(&shortstatekey)?; - services.rooms.timeline.get_pdu(event_id).await.ok() - }; - - let auth_check = state_res::event_auth::auth_check( - &state_res::RoomVersion::new(&room_version_id)?, - &parsed_join_pdu, - None, // TODO: third party invite - |k, s| state_fetch(k.clone(), s.into()), - ) - .await - .map_err(|e| err!(Request(Forbidden(warn!("Auth check failed: {e:?}")))))?; - - if !auth_check { - return Err!(Request(Forbidden("Auth check failed"))); - } - - info!("Compressing state from send_join"); - let compressed: CompressedState = services - .rooms - .state_compressor - .compress_state_events(state.iter().map(|(ssk, eid)| (ssk, eid.borrow()))) - .collect() - .await; - - debug!("Saving compressed state"); - let HashSetCompressStateEvent { - shortstatehash: statehash_before_join, - added, - removed, - } = services - .rooms - .state_compressor - .save_state(room_id, Arc::new(compressed)) - .await?; - - debug!("Forcing state for new room"); - services - .rooms - .state - .force_state(room_id, statehash_before_join, added, removed, &state_lock) - .await?; - - info!("Updating joined counts for new room"); - services - .rooms - .state_cache - .update_joined_count(room_id) - .await; - - // We append to state before appending the pdu, so we don't have a moment in - // time with the pdu without it's state. This is okay because append_pdu can't - // fail. - let statehash_after_join = services - .rooms - .state - .append_to_state(&parsed_join_pdu) - .await?; - - info!("Appending new room join event"); - services - .rooms - .timeline - .append_pdu( - &parsed_join_pdu, - join_event, - once(parsed_join_pdu.event_id.borrow()), - &state_lock, - ) - .await?; - - info!("Setting final room state for new room"); - // We set the room state after inserting the pdu, so that we never have a moment - // in time where events in the current room state do not exist - services - .rooms - .state - .set_room_state(room_id, statehash_after_join, &state_lock); - - Ok(()) -} - -#[tracing::instrument(skip_all, fields(%sender_user, %room_id), name = "join_local")] -async fn join_room_by_id_helper_local( - services: &Services, - sender_user: &UserId, - room_id: &RoomId, - reason: Option, - servers: &[OwnedServerName], - _third_party_signed: Option<&ThirdPartySigned>, - state_lock: RoomMutexGuard, -) -> Result { - debug_info!("We can join locally"); - - let join_rules_event_content = services - .rooms - .state_accessor - .room_state_get_content::( - room_id, - &StateEventType::RoomJoinRules, - "", - ) - .await; - - let restriction_rooms = match join_rules_event_content { - | Ok(RoomJoinRulesEventContent { - join_rule: JoinRule::Restricted(restricted) | JoinRule::KnockRestricted(restricted), - }) => restricted - .allow - .into_iter() - .filter_map(|a| match a { - | AllowRule::RoomMembership(r) => Some(r.room_id), - | _ => None, - }) - .collect(), - | _ => Vec::new(), - }; - - let join_authorized_via_users_server: Option = { - if restriction_rooms - .iter() - .stream() - .any(|restriction_room_id| { - services - .rooms - .state_cache - .is_joined(sender_user, restriction_room_id) - }) - .await - { - services - .rooms - .state_cache - .local_users_in_room(room_id) - .filter(|user| { - services.rooms.state_accessor.user_can_invite( - room_id, - user, - sender_user, - &state_lock, - ) - }) - .boxed() - .next() - .await - .map(ToOwned::to_owned) - } else { - None - } - }; - - let content = RoomMemberEventContent { - displayname: services.users.displayname(sender_user).await.ok(), - avatar_url: services.users.avatar_url(sender_user).await.ok(), - blurhash: services.users.blurhash(sender_user).await.ok(), - reason: reason.clone(), - join_authorized_via_users_server, - ..RoomMemberEventContent::new(MembershipState::Join) - }; - - // Try normal join first - let Err(error) = services - .rooms - .timeline - .build_and_append_pdu( - PduBuilder::state(sender_user.to_string(), &content), - sender_user, - room_id, - &state_lock, - ) - .await - else { - return Ok(()); - }; - - if restriction_rooms.is_empty() - && (servers.is_empty() - || servers.len() == 1 && services.globals.server_is_ours(&servers[0])) - { - return Err(error); - } - - warn!( - "We couldn't do the join locally, maybe federation can help to satisfy the restricted \ - join requirements" - ); - let Ok((make_join_response, remote_server)) = - make_join_request(services, sender_user, room_id, servers).await - else { - return Err(error); - }; - - let Some(room_version_id) = make_join_response.room_version else { - return Err!(BadServerResponse("Remote room version is not supported by conduwuit")); - }; - - if !services.server.supported_room_version(&room_version_id) { - return Err!(BadServerResponse( - "Remote room version {room_version_id} is not supported by conduwuit" - )); - } - - let mut join_event_stub: CanonicalJsonObject = - serde_json::from_str(make_join_response.event.get()).map_err(|e| { - err!(BadServerResponse("Invalid make_join event json received from server: {e:?}")) - })?; - - let join_authorized_via_users_server = join_event_stub - .get("content") - .map(|s| { - s.as_object()? - .get("join_authorised_via_users_server")? - .as_str() - }) - .and_then(|s| OwnedUserId::try_from(s.unwrap_or_default()).ok()); - - join_event_stub.insert( - "origin".to_owned(), - CanonicalJsonValue::String(services.globals.server_name().as_str().to_owned()), - ); - join_event_stub.insert( - "origin_server_ts".to_owned(), - CanonicalJsonValue::Integer( - utils::millis_since_unix_epoch() - .try_into() - .expect("Timestamp is valid js_int value"), - ), - ); - join_event_stub.insert( - "content".to_owned(), - to_canonical_value(RoomMemberEventContent { - displayname: services.users.displayname(sender_user).await.ok(), - avatar_url: services.users.avatar_url(sender_user).await.ok(), - blurhash: services.users.blurhash(sender_user).await.ok(), - reason, - join_authorized_via_users_server, - ..RoomMemberEventContent::new(MembershipState::Join) - }) - .expect("event is valid, we just created it"), - ); - - // We keep the "event_id" in the pdu only in v1 or - // v2 rooms - match room_version_id { - | RoomVersionId::V1 | RoomVersionId::V2 => {}, - | _ => { - join_event_stub.remove("event_id"); - }, - } - - // In order to create a compatible ref hash (EventID) the `hashes` field needs - // to be present - services - .server_keys - .hash_and_sign_event(&mut join_event_stub, &room_version_id)?; - - // Generate event id - let event_id = gen_event_id(&join_event_stub, &room_version_id)?; - - // Add event_id back - join_event_stub - .insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.clone().into())); - - // It has enough fields to be called a proper event now - let join_event = join_event_stub; - - let send_join_response = services - .sending - .send_synapse_request( - &remote_server, - federation::membership::create_join_event::v2::Request { - room_id: room_id.to_owned(), - event_id: event_id.clone(), - omit_members: false, - pdu: services - .sending - .convert_to_outgoing_federation_event(join_event.clone()) - .await, - }, - ) - .await?; - - if let Some(signed_raw) = send_join_response.room_state.event { - let (signed_event_id, signed_value) = - gen_event_id_canonical_json(&signed_raw, &room_version_id).map_err(|e| { - err!(Request(BadJson(warn!("Could not convert event to canonical JSON: {e}")))) - })?; - - if signed_event_id != event_id { - return Err!(Request(BadJson( - warn!(%signed_event_id, %event_id, "Server {remote_server} sent event with wrong event ID") - ))); - } - - drop(state_lock); - services - .rooms - .event_handler - .handle_incoming_pdu(&remote_server, room_id, &signed_event_id, signed_value, true) - .boxed() - .await?; - } else { - return Err(error); - } - - Ok(()) -} - -async fn make_join_request( - services: &Services, - sender_user: &UserId, - room_id: &RoomId, - servers: &[OwnedServerName], -) -> Result<(federation::membership::prepare_join_event::v1::Response, OwnedServerName)> { - let mut make_join_response_and_server = - Err!(BadServerResponse("No server available to assist in joining.")); - - let mut make_join_counter: usize = 0; - let mut incompatible_room_version_count: usize = 0; - - for remote_server in servers { - if services.globals.server_is_ours(remote_server) { - continue; - } - info!("Asking {remote_server} for make_join ({make_join_counter})"); - let make_join_response = services - .sending - .send_federation_request( - remote_server, - federation::membership::prepare_join_event::v1::Request { - room_id: room_id.to_owned(), - user_id: sender_user.to_owned(), - ver: services.server.supported_room_versions().collect(), - }, - ) - .await; - - trace!("make_join response: {:?}", make_join_response); - make_join_counter = make_join_counter.saturating_add(1); - - if let Err(ref e) = make_join_response { - if matches!( - e.kind(), - ErrorKind::IncompatibleRoomVersion { .. } | ErrorKind::UnsupportedRoomVersion - ) { - incompatible_room_version_count = - incompatible_room_version_count.saturating_add(1); - } - - if incompatible_room_version_count > 15 { - info!( - "15 servers have responded with M_INCOMPATIBLE_ROOM_VERSION or \ - M_UNSUPPORTED_ROOM_VERSION, assuming that conduwuit does not support the \ - room version {room_id}: {e}" - ); - make_join_response_and_server = - Err!(BadServerResponse("Room version is not supported by Conduwuit")); - return make_join_response_and_server; - } - - if make_join_counter > 40 { - warn!( - "40 servers failed to provide valid make_join response, assuming no server \ - can assist in joining." - ); - make_join_response_and_server = - Err!(BadServerResponse("No server available to assist in joining.")); - - return make_join_response_and_server; - } - } - - make_join_response_and_server = make_join_response.map(|r| (r, remote_server.clone())); - - if make_join_response_and_server.is_ok() { - break; - } - } - - make_join_response_and_server -} diff --git a/src/api/client/membership/kick.rs b/src/api/client/membership/kick.rs deleted file mode 100644 index 5e0e86e2..00000000 --- a/src/api/client/membership/kick.rs +++ /dev/null @@ -1,65 +0,0 @@ -use axum::extract::State; -use conduwuit::{Err, Result, matrix::pdu::PduBuilder}; -use ruma::{ - api::client::membership::kick_user, - events::room::member::{MembershipState, RoomMemberEventContent}, -}; - -use crate::Ruma; - -/// # `POST /_matrix/client/r0/rooms/{roomId}/kick` -/// -/// Tries to send a kick event into the room. -pub(crate) async fn kick_user_route( - State(services): State, - body: Ruma, -) -> Result { - let sender_user = body.sender_user(); - if services.users.is_suspended(sender_user).await? { - return Err!(Request(UserSuspended("You cannot perform this action while suspended."))); - } - let state_lock = services.rooms.state.mutex.lock(&body.room_id).await; - - let Ok(event) = services - .rooms - .state_accessor - .get_member(&body.room_id, &body.user_id) - .await - else { - // copy synapse's behaviour of returning 200 without any change to the state - // instead of erroring on left users - return Ok(kick_user::v3::Response::new()); - }; - - if !matches!( - event.membership, - MembershipState::Invite | MembershipState::Knock | MembershipState::Join, - ) { - return Err!(Request(Forbidden( - "Cannot kick a user who is not apart of the room (current membership: {})", - event.membership - ))); - } - - services - .rooms - .timeline - .build_and_append_pdu( - PduBuilder::state(body.user_id.to_string(), &RoomMemberEventContent { - membership: MembershipState::Leave, - reason: body.reason.clone(), - is_direct: None, - join_authorized_via_users_server: None, - third_party_invite: None, - ..event - }), - sender_user, - &body.room_id, - &state_lock, - ) - .await?; - - drop(state_lock); - - Ok(kick_user::v3::Response::new()) -} diff --git a/src/api/client/membership/knock.rs b/src/api/client/membership/knock.rs deleted file mode 100644 index 79f16631..00000000 --- a/src/api/client/membership/knock.rs +++ /dev/null @@ -1,770 +0,0 @@ -use std::{borrow::Borrow, collections::HashMap, iter::once, sync::Arc}; - -use axum::extract::State; -use axum_client_ip::InsecureClientIp; -use conduwuit::{ - Err, Result, debug, debug_info, debug_warn, err, info, - matrix::{ - event::{Event, gen_event_id}, - pdu::{PduBuilder, PduEvent}, - }, - result::FlatOk, - trace, - utils::{self, shuffle, stream::IterStream}, - warn, -}; -use futures::{FutureExt, StreamExt}; -use ruma::{ - CanonicalJsonObject, CanonicalJsonValue, OwnedEventId, OwnedRoomId, OwnedServerName, RoomId, - RoomVersionId, UserId, - api::{ - client::knock::knock_room, - federation::{self}, - }, - canonical_json::to_canonical_value, - events::{ - StateEventType, - room::{ - join_rules::{AllowRule, JoinRule}, - member::{MembershipState, RoomMemberEventContent}, - }, - }, -}; -use service::{ - Services, - rooms::{ - state::RoomMutexGuard, - state_compressor::{CompressedState, HashSetCompressStateEvent}, - }, -}; - -use super::{banned_room_check, join::join_room_by_id_helper}; -use crate::Ruma; - -/// # `POST /_matrix/client/*/knock/{roomIdOrAlias}` -/// -/// Tries to knock the room to ask permission to join for the sender user. -#[tracing::instrument(skip_all, fields(%client), name = "knock")] -pub(crate) async fn knock_room_route( - State(services): State, - InsecureClientIp(client): InsecureClientIp, - body: Ruma, -) -> Result { - let sender_user = body.sender_user(); - let body = &body.body; - if services.users.is_suspended(sender_user).await? { - return Err!(Request(UserSuspended("You cannot perform this action while suspended."))); - } - - let (servers, room_id) = match OwnedRoomId::try_from(body.room_id_or_alias.clone()) { - | Ok(room_id) => { - banned_room_check( - &services, - sender_user, - Some(&room_id), - room_id.server_name(), - client, - ) - .await?; - - let mut servers = body.via.clone(); - servers.extend( - services - .rooms - .state_cache - .servers_invite_via(&room_id) - .map(ToOwned::to_owned) - .collect::>() - .await, - ); - - servers.extend( - services - .rooms - .state_cache - .invite_state(sender_user, &room_id) - .await - .unwrap_or_default() - .iter() - .filter_map(|event| event.get_field("sender").ok().flatten()) - .filter_map(|sender: &str| UserId::parse(sender).ok()) - .map(|user| user.server_name().to_owned()), - ); - - if let Some(server) = room_id.server_name() { - servers.push(server.to_owned()); - } - - servers.sort_unstable(); - servers.dedup(); - shuffle(&mut servers); - - (servers, room_id) - }, - | Err(room_alias) => { - let (room_id, mut servers) = services - .rooms - .alias - .resolve_alias(&room_alias, Some(body.via.clone())) - .await?; - - banned_room_check( - &services, - sender_user, - Some(&room_id), - Some(room_alias.server_name()), - client, - ) - .await?; - - let addl_via_servers = services - .rooms - .state_cache - .servers_invite_via(&room_id) - .map(ToOwned::to_owned); - - let addl_state_servers = services - .rooms - .state_cache - .invite_state(sender_user, &room_id) - .await - .unwrap_or_default(); - - let mut addl_servers: Vec<_> = addl_state_servers - .iter() - .map(|event| event.get_field("sender")) - .filter_map(FlatOk::flat_ok) - .map(|user: &UserId| user.server_name().to_owned()) - .stream() - .chain(addl_via_servers) - .collect() - .await; - - addl_servers.sort_unstable(); - addl_servers.dedup(); - shuffle(&mut addl_servers); - servers.append(&mut addl_servers); - - (servers, room_id) - }, - }; - - knock_room_by_id_helper(&services, sender_user, &room_id, body.reason.clone(), &servers) - .boxed() - .await -} - -async fn knock_room_by_id_helper( - services: &Services, - sender_user: &UserId, - room_id: &RoomId, - reason: Option, - servers: &[OwnedServerName], -) -> Result { - let state_lock = services.rooms.state.mutex.lock(room_id).await; - - if services - .rooms - .state_cache - .is_invited(sender_user, room_id) - .await - { - debug_warn!("{sender_user} is already invited in {room_id} but attempted to knock"); - return Err!(Request(Forbidden( - "You cannot knock on a room you are already invited/accepted to." - ))); - } - - if services - .rooms - .state_cache - .is_joined(sender_user, room_id) - .await - { - debug_warn!("{sender_user} is already joined in {room_id} but attempted to knock"); - return Err!(Request(Forbidden("You cannot knock on a room you are already joined in."))); - } - - if services - .rooms - .state_cache - .is_knocked(sender_user, room_id) - .await - { - debug_warn!("{sender_user} is already knocked in {room_id}"); - return Ok(knock_room::v3::Response { room_id: room_id.into() }); - } - - if let Ok(membership) = services - .rooms - .state_accessor - .get_member(room_id, sender_user) - .await - { - if membership.membership == MembershipState::Ban { - debug_warn!("{sender_user} is banned from {room_id} but attempted to knock"); - return Err!(Request(Forbidden("You cannot knock on a room you are banned from."))); - } - } - - // For knock_restricted rooms, check if the user meets the restricted conditions - // If they do, attempt to join instead of knock - // This is not mentioned in the spec, but should be allowable (we're allowed to - // auto-join invites to knocked rooms) - let join_rule = services.rooms.state_accessor.get_join_rules(room_id).await; - - if let JoinRule::KnockRestricted(restricted) = &join_rule { - let restriction_rooms: Vec<_> = restricted - .allow - .iter() - .filter_map(|a| match a { - | AllowRule::RoomMembership(r) => Some(&r.room_id), - | _ => None, - }) - .collect(); - - // Check if the user is in any of the allowed rooms - let mut user_meets_restrictions = false; - for restriction_room_id in &restriction_rooms { - if services - .rooms - .state_cache - .is_joined(sender_user, restriction_room_id) - .await - { - user_meets_restrictions = true; - break; - } - } - - // If the user meets the restrictions, try joining instead - if user_meets_restrictions { - debug_info!( - "{sender_user} meets the restricted criteria in knock_restricted room \ - {room_id}, attempting to join instead of knock" - ); - // For this case, we need to drop the state lock and get a new one in - // join_room_by_id_helper We need to release the lock here and let - // join_room_by_id_helper acquire it again - drop(state_lock); - match join_room_by_id_helper( - services, - sender_user, - room_id, - reason.clone(), - servers, - None, - &None, - ) - .await - { - | Ok(_) => return Ok(knock_room::v3::Response::new(room_id.to_owned())), - | Err(e) => { - debug_warn!( - "Failed to convert knock to join for {sender_user} in {room_id}: {e:?}" - ); - // Get a new state lock for the remaining knock logic - let new_state_lock = services.rooms.state.mutex.lock(room_id).await; - - let server_in_room = services - .rooms - .state_cache - .server_in_room(services.globals.server_name(), room_id) - .await; - - let local_knock = server_in_room - || servers.is_empty() - || (servers.len() == 1 && services.globals.server_is_ours(&servers[0])); - - if local_knock { - knock_room_helper_local( - services, - sender_user, - room_id, - reason, - servers, - new_state_lock, - ) - .boxed() - .await?; - } else { - knock_room_helper_remote( - services, - sender_user, - room_id, - reason, - servers, - new_state_lock, - ) - .boxed() - .await?; - } - - return Ok(knock_room::v3::Response::new(room_id.to_owned())); - }, - } - } - } else if !matches!(join_rule, JoinRule::Knock | JoinRule::KnockRestricted(_)) { - debug_warn!( - "{sender_user} attempted to knock on room {room_id} but its join rule is \ - {join_rule:?}, not knock or knock_restricted" - ); - } - - let server_in_room = services - .rooms - .state_cache - .server_in_room(services.globals.server_name(), room_id) - .await; - - let local_knock = server_in_room - || servers.is_empty() - || (servers.len() == 1 && services.globals.server_is_ours(&servers[0])); - - if local_knock { - knock_room_helper_local(services, sender_user, room_id, reason, servers, state_lock) - .boxed() - .await?; - } else { - knock_room_helper_remote(services, sender_user, room_id, reason, servers, state_lock) - .boxed() - .await?; - } - - Ok(knock_room::v3::Response::new(room_id.to_owned())) -} - -async fn knock_room_helper_local( - services: &Services, - sender_user: &UserId, - room_id: &RoomId, - reason: Option, - servers: &[OwnedServerName], - state_lock: RoomMutexGuard, -) -> Result { - debug_info!("We can knock locally"); - - let room_version_id = services.rooms.state.get_room_version(room_id).await?; - - if matches!( - room_version_id, - RoomVersionId::V1 - | RoomVersionId::V2 - | RoomVersionId::V3 - | RoomVersionId::V4 - | RoomVersionId::V5 - | RoomVersionId::V6 - ) { - return Err!(Request(Forbidden("This room does not support knocking."))); - } - - let content = RoomMemberEventContent { - displayname: services.users.displayname(sender_user).await.ok(), - avatar_url: services.users.avatar_url(sender_user).await.ok(), - blurhash: services.users.blurhash(sender_user).await.ok(), - reason: reason.clone(), - ..RoomMemberEventContent::new(MembershipState::Knock) - }; - - // Try normal knock first - let Err(error) = services - .rooms - .timeline - .build_and_append_pdu( - PduBuilder::state(sender_user.to_string(), &content), - sender_user, - room_id, - &state_lock, - ) - .await - else { - return Ok(()); - }; - - if servers.is_empty() || (servers.len() == 1 && services.globals.server_is_ours(&servers[0])) - { - return Err(error); - } - - warn!("We couldn't do the knock locally, maybe federation can help to satisfy the knock"); - - let (make_knock_response, remote_server) = - make_knock_request(services, sender_user, room_id, servers).await?; - - info!("make_knock finished"); - - let room_version_id = make_knock_response.room_version; - - if !services.server.supported_room_version(&room_version_id) { - return Err!(BadServerResponse( - "Remote room version {room_version_id} is not supported by conduwuit" - )); - } - - let mut knock_event_stub = serde_json::from_str::( - make_knock_response.event.get(), - ) - .map_err(|e| { - err!(BadServerResponse("Invalid make_knock event json received from server: {e:?}")) - })?; - - knock_event_stub.insert( - "origin".to_owned(), - CanonicalJsonValue::String(services.globals.server_name().as_str().to_owned()), - ); - knock_event_stub.insert( - "origin_server_ts".to_owned(), - CanonicalJsonValue::Integer( - utils::millis_since_unix_epoch() - .try_into() - .expect("Timestamp is valid js_int value"), - ), - ); - knock_event_stub.insert( - "content".to_owned(), - to_canonical_value(RoomMemberEventContent { - displayname: services.users.displayname(sender_user).await.ok(), - avatar_url: services.users.avatar_url(sender_user).await.ok(), - blurhash: services.users.blurhash(sender_user).await.ok(), - reason, - ..RoomMemberEventContent::new(MembershipState::Knock) - }) - .expect("event is valid, we just created it"), - ); - - // In order to create a compatible ref hash (EventID) the `hashes` field needs - // to be present - services - .server_keys - .hash_and_sign_event(&mut knock_event_stub, &room_version_id)?; - - // Generate event id - let event_id = gen_event_id(&knock_event_stub, &room_version_id)?; - - // Add event_id - knock_event_stub - .insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.clone().into())); - - // It has enough fields to be called a proper event now - let knock_event = knock_event_stub; - - info!("Asking {remote_server} for send_knock in room {room_id}"); - let send_knock_request = federation::knock::send_knock::v1::Request { - room_id: room_id.to_owned(), - event_id: event_id.clone(), - pdu: services - .sending - .convert_to_outgoing_federation_event(knock_event.clone()) - .await, - }; - - let send_knock_response = services - .sending - .send_federation_request(&remote_server, send_knock_request) - .await?; - - info!("send_knock finished"); - - services - .rooms - .short - .get_or_create_shortroomid(room_id) - .await; - - info!("Parsing knock event"); - - let parsed_knock_pdu = PduEvent::from_id_val(&event_id, knock_event.clone()) - .map_err(|e| err!(BadServerResponse("Invalid knock event PDU: {e:?}")))?; - - info!("Updating membership locally to knock state with provided stripped state events"); - services - .rooms - .state_cache - .update_membership( - room_id, - sender_user, - parsed_knock_pdu - .get_content::() - .expect("we just created this"), - sender_user, - Some(send_knock_response.knock_room_state), - None, - false, - ) - .await?; - - info!("Appending room knock event locally"); - services - .rooms - .timeline - .append_pdu( - &parsed_knock_pdu, - knock_event, - once(parsed_knock_pdu.event_id.borrow()), - &state_lock, - ) - .await?; - - Ok(()) -} - -async fn knock_room_helper_remote( - services: &Services, - sender_user: &UserId, - room_id: &RoomId, - reason: Option, - servers: &[OwnedServerName], - state_lock: RoomMutexGuard, -) -> Result { - info!("Knocking {room_id} over federation."); - - let (make_knock_response, remote_server) = - make_knock_request(services, sender_user, room_id, servers).await?; - - info!("make_knock finished"); - - let room_version_id = make_knock_response.room_version; - - if !services.server.supported_room_version(&room_version_id) { - return Err!(BadServerResponse( - "Remote room version {room_version_id} is not supported by conduwuit" - )); - } - - let mut knock_event_stub: CanonicalJsonObject = - serde_json::from_str(make_knock_response.event.get()).map_err(|e| { - err!(BadServerResponse("Invalid make_knock event json received from server: {e:?}")) - })?; - - knock_event_stub.insert( - "origin".to_owned(), - CanonicalJsonValue::String(services.globals.server_name().as_str().to_owned()), - ); - knock_event_stub.insert( - "origin_server_ts".to_owned(), - CanonicalJsonValue::Integer( - utils::millis_since_unix_epoch() - .try_into() - .expect("Timestamp is valid js_int value"), - ), - ); - knock_event_stub.insert( - "content".to_owned(), - to_canonical_value(RoomMemberEventContent { - displayname: services.users.displayname(sender_user).await.ok(), - avatar_url: services.users.avatar_url(sender_user).await.ok(), - blurhash: services.users.blurhash(sender_user).await.ok(), - reason, - ..RoomMemberEventContent::new(MembershipState::Knock) - }) - .expect("event is valid, we just created it"), - ); - - // In order to create a compatible ref hash (EventID) the `hashes` field needs - // to be present - services - .server_keys - .hash_and_sign_event(&mut knock_event_stub, &room_version_id)?; - - // Generate event id - let event_id = gen_event_id(&knock_event_stub, &room_version_id)?; - - // Add event_id - knock_event_stub - .insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.clone().into())); - - // It has enough fields to be called a proper event now - let knock_event = knock_event_stub; - - info!("Asking {remote_server} for send_knock in room {room_id}"); - let send_knock_request = federation::knock::send_knock::v1::Request { - room_id: room_id.to_owned(), - event_id: event_id.clone(), - pdu: services - .sending - .convert_to_outgoing_federation_event(knock_event.clone()) - .await, - }; - - let send_knock_response = services - .sending - .send_federation_request(&remote_server, send_knock_request) - .await?; - - info!("send_knock finished"); - - services - .rooms - .short - .get_or_create_shortroomid(room_id) - .await; - - info!("Parsing knock event"); - let parsed_knock_pdu = PduEvent::from_id_val(&event_id, knock_event.clone()) - .map_err(|e| err!(BadServerResponse("Invalid knock event PDU: {e:?}")))?; - - info!("Going through send_knock response knock state events"); - let state = send_knock_response - .knock_room_state - .iter() - .map(|event| serde_json::from_str::(event.clone().into_json().get())) - .filter_map(Result::ok); - - let mut state_map: HashMap = HashMap::new(); - - for event in state { - let Some(state_key) = event.get("state_key") else { - debug_warn!("send_knock stripped state event missing state_key: {event:?}"); - continue; - }; - let Some(event_type) = event.get("type") else { - debug_warn!("send_knock stripped state event missing event type: {event:?}"); - continue; - }; - - let Ok(state_key) = serde_json::from_value::(state_key.clone().into()) else { - debug_warn!("send_knock stripped state event has invalid state_key: {event:?}"); - continue; - }; - let Ok(event_type) = serde_json::from_value::(event_type.clone().into()) - else { - debug_warn!("send_knock stripped state event has invalid event type: {event:?}"); - continue; - }; - - let event_id = gen_event_id(&event, &room_version_id)?; - let shortstatekey = services - .rooms - .short - .get_or_create_shortstatekey(&event_type, &state_key) - .await; - - services.rooms.outlier.add_pdu_outlier(&event_id, &event); - state_map.insert(shortstatekey, event_id.clone()); - } - - info!("Compressing state from send_knock"); - let compressed: CompressedState = services - .rooms - .state_compressor - .compress_state_events(state_map.iter().map(|(ssk, eid)| (ssk, eid.borrow()))) - .collect() - .await; - - debug!("Saving compressed state"); - let HashSetCompressStateEvent { - shortstatehash: statehash_before_knock, - added, - removed, - } = services - .rooms - .state_compressor - .save_state(room_id, Arc::new(compressed)) - .await?; - - debug!("Forcing state for new room"); - services - .rooms - .state - .force_state(room_id, statehash_before_knock, added, removed, &state_lock) - .await?; - - let statehash_after_knock = services - .rooms - .state - .append_to_state(&parsed_knock_pdu) - .await?; - - info!("Updating membership locally to knock state with provided stripped state events"); - services - .rooms - .state_cache - .update_membership( - room_id, - sender_user, - parsed_knock_pdu - .get_content::() - .expect("we just created this"), - sender_user, - Some(send_knock_response.knock_room_state), - None, - false, - ) - .await?; - - info!("Appending room knock event locally"); - services - .rooms - .timeline - .append_pdu( - &parsed_knock_pdu, - knock_event, - once(parsed_knock_pdu.event_id.borrow()), - &state_lock, - ) - .await?; - - info!("Setting final room state for new room"); - // We set the room state after inserting the pdu, so that we never have a moment - // in time where events in the current room state do not exist - services - .rooms - .state - .set_room_state(room_id, statehash_after_knock, &state_lock); - - Ok(()) -} - -async fn make_knock_request( - services: &Services, - sender_user: &UserId, - room_id: &RoomId, - servers: &[OwnedServerName], -) -> Result<(federation::knock::create_knock_event_template::v1::Response, OwnedServerName)> { - let mut make_knock_response_and_server = - Err!(BadServerResponse("No server available to assist in knocking.")); - - let mut make_knock_counter: usize = 0; - - for remote_server in servers { - if services.globals.server_is_ours(remote_server) { - continue; - } - - info!("Asking {remote_server} for make_knock ({make_knock_counter})"); - - let make_knock_response = services - .sending - .send_federation_request( - remote_server, - federation::knock::create_knock_event_template::v1::Request { - room_id: room_id.to_owned(), - user_id: sender_user.to_owned(), - ver: services.server.supported_room_versions().collect(), - }, - ) - .await; - - trace!("make_knock response: {make_knock_response:?}"); - make_knock_counter = make_knock_counter.saturating_add(1); - - make_knock_response_and_server = make_knock_response.map(|r| (r, remote_server.clone())); - - if make_knock_response_and_server.is_ok() { - break; - } - - if make_knock_counter > 40 { - warn!( - "50 servers failed to provide valid make_knock response, assuming no server can \ - assist in knocking." - ); - make_knock_response_and_server = - Err!(BadServerResponse("No server available to assist in knocking.")); - - return make_knock_response_and_server; - } - } - - make_knock_response_and_server -} diff --git a/src/api/client/membership/leave.rs b/src/api/client/membership/leave.rs deleted file mode 100644 index f4f1666b..00000000 --- a/src/api/client/membership/leave.rs +++ /dev/null @@ -1,386 +0,0 @@ -use std::collections::HashSet; - -use axum::extract::State; -use conduwuit::{ - Err, Result, debug_info, debug_warn, err, - matrix::{event::gen_event_id, pdu::PduBuilder}, - utils::{self, FutureBoolExt, future::ReadyEqExt}, - warn, -}; -use futures::{FutureExt, StreamExt, TryFutureExt, pin_mut}; -use ruma::{ - CanonicalJsonObject, CanonicalJsonValue, OwnedServerName, RoomId, RoomVersionId, UserId, - api::{ - client::membership::leave_room, - federation::{self}, - }, - events::{ - StateEventType, - room::member::{MembershipState, RoomMemberEventContent}, - }, -}; -use service::Services; - -use crate::Ruma; - -/// # `POST /_matrix/client/v3/rooms/{roomId}/leave` -/// -/// Tries to leave the sender user from a room. -/// -/// - This should always work if the user is currently joined. -pub(crate) async fn leave_room_route( - State(services): State, - body: Ruma, -) -> Result { - leave_room(&services, body.sender_user(), &body.room_id, body.reason.clone()) - .boxed() - .await - .map(|()| leave_room::v3::Response::new()) -} - -// Make a user leave all their joined rooms, rescinds knocks, forgets all rooms, -// and ignores errors -pub async fn leave_all_rooms(services: &Services, user_id: &UserId) { - let rooms_joined = services - .rooms - .state_cache - .rooms_joined(user_id) - .map(ToOwned::to_owned); - - let rooms_invited = services - .rooms - .state_cache - .rooms_invited(user_id) - .map(|(r, _)| r); - - let rooms_knocked = services - .rooms - .state_cache - .rooms_knocked(user_id) - .map(|(r, _)| r); - - let all_rooms: Vec<_> = rooms_joined - .chain(rooms_invited) - .chain(rooms_knocked) - .collect() - .await; - - for room_id in all_rooms { - // ignore errors - if let Err(e) = leave_room(services, user_id, &room_id, None).boxed().await { - warn!(%user_id, "Failed to leave {room_id} remotely: {e}"); - } - - services.rooms.state_cache.forget(&room_id, user_id); - } -} - -pub async fn leave_room( - services: &Services, - user_id: &UserId, - room_id: &RoomId, - reason: Option, -) -> Result { - let default_member_content = RoomMemberEventContent { - membership: MembershipState::Leave, - reason: reason.clone(), - join_authorized_via_users_server: None, - is_direct: None, - avatar_url: None, - displayname: None, - third_party_invite: None, - blurhash: None, - redact_events: None, - }; - - let is_banned = services.rooms.metadata.is_banned(room_id); - let is_disabled = services.rooms.metadata.is_disabled(room_id); - - pin_mut!(is_banned, is_disabled); - if is_banned.or(is_disabled).await { - // the room is banned/disabled, the room must be rejected locally since we - // cant/dont want to federate with this server - services - .rooms - .state_cache - .update_membership( - room_id, - user_id, - default_member_content, - user_id, - None, - None, - true, - ) - .await?; - - return Ok(()); - } - - let dont_have_room = services - .rooms - .state_cache - .server_in_room(services.globals.server_name(), room_id) - .eq(&false); - - let not_knocked = services - .rooms - .state_cache - .is_knocked(user_id, room_id) - .eq(&false); - - // Ask a remote server if we don't have this room and are not knocking on it - if dont_have_room.and(not_knocked).await { - if let Err(e) = remote_leave_room(services, user_id, room_id, reason.clone()) - .boxed() - .await - { - warn!(%user_id, "Failed to leave room {room_id} remotely: {e}"); - // Don't tell the client about this error - } - - let last_state = services - .rooms - .state_cache - .invite_state(user_id, room_id) - .or_else(|_| services.rooms.state_cache.knock_state(user_id, room_id)) - .or_else(|_| services.rooms.state_cache.left_state(user_id, room_id)) - .await - .ok(); - - // We always drop the invite, we can't rely on other servers - services - .rooms - .state_cache - .update_membership( - room_id, - user_id, - default_member_content, - user_id, - last_state, - None, - true, - ) - .await?; - } else { - let state_lock = services.rooms.state.mutex.lock(room_id).await; - - let Ok(event) = services - .rooms - .state_accessor - .room_state_get_content::( - room_id, - &StateEventType::RoomMember, - user_id.as_str(), - ) - .await - else { - debug_warn!( - "Trying to leave a room you are not a member of, marking room as left locally." - ); - - return services - .rooms - .state_cache - .update_membership( - room_id, - user_id, - default_member_content, - user_id, - None, - None, - true, - ) - .await; - }; - - services - .rooms - .timeline - .build_and_append_pdu( - PduBuilder::state(user_id.to_string(), &RoomMemberEventContent { - membership: MembershipState::Leave, - reason, - join_authorized_via_users_server: None, - is_direct: None, - ..event - }), - user_id, - room_id, - &state_lock, - ) - .await?; - } - - Ok(()) -} - -async fn remote_leave_room( - services: &Services, - user_id: &UserId, - room_id: &RoomId, - reason: Option, -) -> Result<()> { - let mut make_leave_response_and_server = - Err!(BadServerResponse("No remote server available to assist in leaving {room_id}.")); - - let mut servers: HashSet = services - .rooms - .state_cache - .servers_invite_via(room_id) - .map(ToOwned::to_owned) - .collect() - .await; - - match services - .rooms - .state_cache - .invite_state(user_id, room_id) - .await - { - | Ok(invite_state) => { - servers.extend( - invite_state - .iter() - .filter_map(|event| event.get_field("sender").ok().flatten()) - .filter_map(|sender: &str| UserId::parse(sender).ok()) - .map(|user| user.server_name().to_owned()), - ); - }, - | _ => { - match services - .rooms - .state_cache - .knock_state(user_id, room_id) - .await - { - | Ok(knock_state) => { - servers.extend( - knock_state - .iter() - .filter_map(|event| event.get_field("sender").ok().flatten()) - .filter_map(|sender: &str| UserId::parse(sender).ok()) - .filter_map(|sender| { - if !services.globals.user_is_local(sender) { - Some(sender.server_name().to_owned()) - } else { - None - } - }), - ); - }, - | _ => {}, - } - }, - } - - if let Some(room_id_server_name) = room_id.server_name() { - servers.insert(room_id_server_name.to_owned()); - } - - debug_info!("servers in remote_leave_room: {servers:?}"); - - for remote_server in servers { - let make_leave_response = services - .sending - .send_federation_request( - &remote_server, - federation::membership::prepare_leave_event::v1::Request { - room_id: room_id.to_owned(), - user_id: user_id.to_owned(), - }, - ) - .await; - - make_leave_response_and_server = make_leave_response.map(|r| (r, remote_server)); - - if make_leave_response_and_server.is_ok() { - break; - } - } - - let (make_leave_response, remote_server) = make_leave_response_and_server?; - - let Some(room_version_id) = make_leave_response.room_version else { - return Err!(BadServerResponse(warn!( - "No room version was returned by {remote_server} for {room_id}, room version is \ - likely not supported by conduwuit" - ))); - }; - - if !services.server.supported_room_version(&room_version_id) { - return Err!(BadServerResponse(warn!( - "Remote room version {room_version_id} for {room_id} is not supported by conduwuit", - ))); - } - - let mut leave_event_stub = serde_json::from_str::( - make_leave_response.event.get(), - ) - .map_err(|e| { - err!(BadServerResponse(warn!( - "Invalid make_leave event json received from {remote_server} for {room_id}: {e:?}" - ))) - })?; - - // TODO: Is origin needed? - leave_event_stub.insert( - "origin".to_owned(), - CanonicalJsonValue::String(services.globals.server_name().as_str().to_owned()), - ); - leave_event_stub.insert( - "origin_server_ts".to_owned(), - CanonicalJsonValue::Integer( - utils::millis_since_unix_epoch() - .try_into() - .expect("Timestamp is valid js_int value"), - ), - ); - // Inject the reason key into the event content dict if it exists - if let Some(reason) = reason { - if let Some(CanonicalJsonValue::Object(content)) = leave_event_stub.get_mut("content") { - content.insert("reason".to_owned(), CanonicalJsonValue::String(reason)); - } - } - - // room v3 and above removed the "event_id" field from remote PDU format - match room_version_id { - | RoomVersionId::V1 | RoomVersionId::V2 => {}, - | _ => { - leave_event_stub.remove("event_id"); - }, - } - - // In order to create a compatible ref hash (EventID) the `hashes` field needs - // to be present - services - .server_keys - .hash_and_sign_event(&mut leave_event_stub, &room_version_id)?; - - // Generate event id - let event_id = gen_event_id(&leave_event_stub, &room_version_id)?; - - // Add event_id back - leave_event_stub - .insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.clone().into())); - - // It has enough fields to be called a proper event now - let leave_event = leave_event_stub; - - services - .sending - .send_federation_request( - &remote_server, - federation::membership::create_leave_event::v2::Request { - room_id: room_id.to_owned(), - event_id, - pdu: services - .sending - .convert_to_outgoing_federation_event(leave_event.clone()) - .await, - }, - ) - .await?; - - Ok(()) -} diff --git a/src/api/client/membership/members.rs b/src/api/client/membership/members.rs deleted file mode 100644 index 05ba1c43..00000000 --- a/src/api/client/membership/members.rs +++ /dev/null @@ -1,147 +0,0 @@ -use axum::extract::State; -use conduwuit::{ - Err, Event, Result, at, - utils::{ - future::TryExtExt, - stream::{BroadbandExt, ReadyExt}, - }, -}; -use futures::{FutureExt, StreamExt, future::join}; -use ruma::{ - api::client::membership::{ - get_member_events::{self, v3::MembershipEventFilter}, - joined_members::{self, v3::RoomMember}, - }, - events::{ - StateEventType, - room::member::{MembershipState, RoomMemberEventContent}, - }, -}; - -use crate::Ruma; - -/// # `POST /_matrix/client/r0/rooms/{roomId}/members` -/// -/// Lists all joined users in a room (TODO: at a specific point in time, with a -/// specific membership). -/// -/// - Only works if the user is currently joined -pub(crate) async fn get_member_events_route( - State(services): State, - body: Ruma, -) -> Result { - let sender_user = body.sender_user(); - let membership = body.membership.as_ref(); - let not_membership = body.not_membership.as_ref(); - - if !services - .rooms - .state_accessor - .user_can_see_state_events(sender_user, &body.room_id) - .await - { - return Err!(Request(Forbidden("You don't have permission to view this room."))); - } - - Ok(get_member_events::v3::Response { - chunk: services - .rooms - .state_accessor - .room_state_full(&body.room_id) - .ready_filter_map(Result::ok) - .ready_filter(|((ty, _), _)| *ty == StateEventType::RoomMember) - .map(at!(1)) - .ready_filter_map(|pdu| membership_filter(pdu, membership, not_membership)) - .map(Event::into_format) - .collect() - .boxed() - .await, - }) -} - -/// # `POST /_matrix/client/r0/rooms/{roomId}/joined_members` -/// -/// Lists all members of a room. -/// -/// - The sender user must be in the room -/// - TODO: An appservice just needs a puppet joined -pub(crate) async fn joined_members_route( - State(services): State, - body: Ruma, -) -> Result { - if !services - .rooms - .state_accessor - .user_can_see_state_events(body.sender_user(), &body.room_id) - .await - { - return Err!(Request(Forbidden("You don't have permission to view this room."))); - } - - Ok(joined_members::v3::Response { - joined: services - .rooms - .state_cache - .room_members(&body.room_id) - .map(ToOwned::to_owned) - .broad_then(|user_id| async move { - let (display_name, avatar_url) = join( - services.users.displayname(&user_id).ok(), - services.users.avatar_url(&user_id).ok(), - ) - .await; - - (user_id, RoomMember { display_name, avatar_url }) - }) - .collect() - .await, - }) -} - -fn membership_filter( - pdu: Pdu, - for_membership: Option<&MembershipEventFilter>, - not_membership: Option<&MembershipEventFilter>, -) -> Option { - let membership_state_filter = match for_membership { - | Some(MembershipEventFilter::Ban) => MembershipState::Ban, - | Some(MembershipEventFilter::Invite) => MembershipState::Invite, - | Some(MembershipEventFilter::Knock) => MembershipState::Knock, - | Some(MembershipEventFilter::Leave) => MembershipState::Leave, - | Some(_) | None => MembershipState::Join, - }; - - let not_membership_state_filter = match not_membership { - | Some(MembershipEventFilter::Ban) => MembershipState::Ban, - | Some(MembershipEventFilter::Invite) => MembershipState::Invite, - | Some(MembershipEventFilter::Join) => MembershipState::Join, - | Some(MembershipEventFilter::Knock) => MembershipState::Knock, - | Some(_) | None => MembershipState::Leave, - }; - - let evt_membership = pdu.get_content::().ok()?.membership; - - if for_membership.is_some() && not_membership.is_some() { - if membership_state_filter != evt_membership - || not_membership_state_filter == evt_membership - { - None - } else { - Some(pdu) - } - } else if for_membership.is_some() && not_membership.is_none() { - if membership_state_filter != evt_membership { - None - } else { - Some(pdu) - } - } else if not_membership.is_some() && for_membership.is_none() { - if not_membership_state_filter == evt_membership { - None - } else { - Some(pdu) - } - } else { - Some(pdu) - } -} diff --git a/src/api/client/membership/mod.rs b/src/api/client/membership/mod.rs deleted file mode 100644 index 7a6f19ad..00000000 --- a/src/api/client/membership/mod.rs +++ /dev/null @@ -1,156 +0,0 @@ -mod ban; -mod forget; -mod invite; -mod join; -mod kick; -mod knock; -mod leave; -mod members; -mod unban; - -use std::net::IpAddr; - -use axum::extract::State; -use conduwuit::{Err, Result, warn}; -use futures::{FutureExt, StreamExt}; -use ruma::{OwnedRoomId, RoomId, ServerName, UserId, api::client::membership::joined_rooms}; -use service::Services; - -pub(crate) use self::{ - ban::ban_user_route, - forget::forget_room_route, - invite::{invite_helper, invite_user_route}, - join::{join_room_by_id_or_alias_route, join_room_by_id_route}, - kick::kick_user_route, - knock::knock_room_route, - leave::leave_room_route, - members::{get_member_events_route, joined_members_route}, - unban::unban_user_route, -}; -pub use self::{ - join::join_room_by_id_helper, - leave::{leave_all_rooms, leave_room}, -}; -use crate::{Ruma, client::full_user_deactivate}; - -/// # `POST /_matrix/client/r0/joined_rooms` -/// -/// Lists all rooms the user has joined. -pub(crate) async fn joined_rooms_route( - State(services): State, - body: Ruma, -) -> Result { - Ok(joined_rooms::v3::Response { - joined_rooms: services - .rooms - .state_cache - .rooms_joined(body.sender_user()) - .map(ToOwned::to_owned) - .collect() - .await, - }) -} - -/// Checks if the room is banned in any way possible and the sender user is not -/// an admin. -/// -/// Performs automatic deactivation if `auto_deactivate_banned_room_attempts` is -/// enabled -#[tracing::instrument(skip(services))] -pub(crate) async fn banned_room_check( - services: &Services, - user_id: &UserId, - room_id: Option<&RoomId>, - server_name: Option<&ServerName>, - client_ip: IpAddr, -) -> Result { - if services.users.is_admin(user_id).await { - return Ok(()); - } - - if let Some(room_id) = room_id { - if services.rooms.metadata.is_banned(room_id).await - || services - .moderation - .is_remote_server_forbidden(room_id.server_name().expect("legacy room mxid")) - { - warn!( - "User {user_id} who is not an admin attempted to send an invite for or \ - attempted to join a banned room or banned room server name: {room_id}" - ); - - if services.server.config.auto_deactivate_banned_room_attempts { - warn!( - "Automatically deactivating user {user_id} due to attempted banned room join" - ); - - if services.server.config.admin_room_notices { - services - .admin - .send_text(&format!( - "Automatically deactivating user {user_id} due to attempted banned \ - room join from IP {client_ip}" - )) - .await; - } - - let all_joined_rooms: Vec = services - .rooms - .state_cache - .rooms_joined(user_id) - .map(Into::into) - .collect() - .await; - - full_user_deactivate(services, user_id, &all_joined_rooms) - .boxed() - .await?; - } - - return Err!(Request(Forbidden("This room is banned on this homeserver."))); - } - } else if let Some(server_name) = server_name { - if services - .config - .forbidden_remote_server_names - .is_match(server_name.host()) - { - warn!( - "User {user_id} who is not an admin tried joining a room which has the server \ - name {server_name} that is globally forbidden. Rejecting.", - ); - - if services.server.config.auto_deactivate_banned_room_attempts { - warn!( - "Automatically deactivating user {user_id} due to attempted banned room join" - ); - - if services.server.config.admin_room_notices { - services - .admin - .send_text(&format!( - "Automatically deactivating user {user_id} due to attempted banned \ - room join from IP {client_ip}" - )) - .await; - } - - let all_joined_rooms: Vec = services - .rooms - .state_cache - .rooms_joined(user_id) - .map(Into::into) - .collect() - .await; - - full_user_deactivate(services, user_id, &all_joined_rooms) - .boxed() - .await?; - } - - return Err!(Request(Forbidden("This remote server is banned on this homeserver."))); - } - } - - Ok(()) -} diff --git a/src/api/client/membership/unban.rs b/src/api/client/membership/unban.rs deleted file mode 100644 index 34c5eace..00000000 --- a/src/api/client/membership/unban.rs +++ /dev/null @@ -1,58 +0,0 @@ -use axum::extract::State; -use conduwuit::{Err, Result, matrix::pdu::PduBuilder}; -use ruma::{ - api::client::membership::unban_user, - events::room::member::{MembershipState, RoomMemberEventContent}, -}; - -use crate::Ruma; - -/// # `POST /_matrix/client/r0/rooms/{roomId}/unban` -/// -/// Tries to send an unban event into the room. -pub(crate) async fn unban_user_route( - State(services): State, - body: Ruma, -) -> Result { - let sender_user = body.sender_user(); - if services.users.is_suspended(sender_user).await? { - return Err!(Request(UserSuspended("You cannot perform this action while suspended."))); - } - let state_lock = services.rooms.state.mutex.lock(&body.room_id).await; - - let current_member_content = services - .rooms - .state_accessor - .get_member(&body.room_id, &body.user_id) - .await - .unwrap_or_else(|_| RoomMemberEventContent::new(MembershipState::Leave)); - - if current_member_content.membership != MembershipState::Ban { - return Err!(Request(Forbidden( - "Cannot unban a user who is not banned (current membership: {})", - current_member_content.membership - ))); - } - - services - .rooms - .timeline - .build_and_append_pdu( - PduBuilder::state(body.user_id.to_string(), &RoomMemberEventContent { - membership: MembershipState::Leave, - reason: body.reason.clone(), - join_authorized_via_users_server: None, - third_party_invite: None, - is_direct: None, - ..current_member_content - }), - sender_user, - &body.room_id, - &state_lock, - ) - .await?; - - drop(state_lock); - - Ok(unban_user::v3::Response::new()) -} diff --git a/src/api/client/message.rs b/src/api/client/message.rs index f8818ebb..6087478c 100644 --- a/src/api/client/message.rs +++ b/src/api/client/message.rs @@ -1,11 +1,12 @@ +use core::panic; + use axum::extract::State; use conduwuit::{ - Err, Result, at, + Err, Result, at, debug_warn, matrix::{ - event::{Event, Matches}, - pdu::PduCount, + Event, + pdu::{PduCount, PduEvent}, }, - ref_at, utils::{ IterStream, ReadyExt, result::{FlatOk, LogErr}, @@ -33,7 +34,6 @@ use ruma::{ }, serde::Raw, }; -use tracing::warn; use crate::Ruma; @@ -73,7 +73,7 @@ pub(crate) async fn get_message_events_route( ) -> Result { debug_assert!(IGNORED_MESSAGE_TYPES.is_sorted(), "IGNORED_MESSAGE_TYPES is not sorted"); let sender_user = body.sender_user(); - let sender_device = body.sender_device.as_deref(); + let sender_device = body.sender_device.as_ref(); let room_id = &body.room_id; let filter = &body.filter; @@ -114,14 +114,14 @@ pub(crate) async fn get_message_events_route( | Direction::Forward => services .rooms .timeline - .pdus(Some(sender_user), room_id, Some(from)) + .pdus(room_id, Some(from)) .ignore_err() .boxed(), | Direction::Backward => services .rooms .timeline - .pdus_rev(Some(sender_user), room_id, Some(from)) + .pdus_rev(room_id, Some(from)) .ignore_err() .boxed(), }; @@ -132,22 +132,35 @@ pub(crate) async fn get_message_events_route( .wide_filter_map(|item| ignored_filter(&services, item, sender_user)) .wide_filter_map(|item| visibility_filter(&services, item, sender_user)) .take(limit) + .then(async |mut pdu| { + pdu.1.set_unsigned(Some(sender_user)); + if let Err(e) = services + .rooms + .pdu_metadata + .add_bundled_aggregations_to_pdu(sender_user, &mut pdu.1) + .await + { + debug_warn!("Failed to add bundled aggregations: {e}"); + } + pdu + }) .collect() .await; let lazy_loading_context = lazy_loading::Context { user_id: sender_user, - device_id: sender_device.or_else(|| { - if let Some(registration) = body.appservice_info.as_ref() { - Some(<&DeviceId>::from(registration.registration.id.as_str())) - } else { - warn!( - "No device_id provided and no appservice registration found, this should be \ - unreachable" - ); - None - } - }), + device_id: match sender_device { + | Some(device_id) => device_id, + | None => + if let Some(registration) = body.appservice_info.as_ref() { + <&DeviceId>::from(registration.registration.id.as_str()) + } else { + panic!( + "No device_id provided and no appservice registration found, this \ + should be unreachable" + ); + }, + }, room_id, token: Some(from.into_unsigned()), options: Some(&filter.lazy_load_options), @@ -176,7 +189,7 @@ pub(crate) async fn get_message_events_route( let chunk = events .into_iter() .map(at!(1)) - .map(Event::into_format) + .map(PduEvent::into_room_event) .collect(); Ok(get_message_events::v3::Response { @@ -217,9 +230,7 @@ where pin_mut!(receipts); let witness: Witness = events .stream() - .map(ref_at!(1)) - .map(Event::sender) - .map(ToOwned::to_owned) + .map(|(_, pdu)| pdu.sender.clone()) .chain( receipts .ready_take_while(|(_, c, _)| *c <= newest.into_unsigned()) @@ -244,7 +255,7 @@ async fn get_member_event( .rooms .state_accessor .room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str()) - .map_ok(Event::into_format) + .map_ok(PduEvent::into_state_event) .await .ok() } @@ -264,33 +275,27 @@ pub(crate) async fn ignored_filter( } #[inline] -pub(crate) async fn is_ignored_pdu( +pub(crate) async fn is_ignored_pdu( services: &Services, - event: &Pdu, + pdu: &PduEvent, user_id: &UserId, -) -> bool -where - Pdu: Event + Send + Sync, -{ +) -> bool { // exclude Synapse's dummy events from bloating up response bodies. clients // don't need to see this. - if event.kind().to_cow_str() == "org.matrix.dummy_event" { + if pdu.kind.to_cow_str() == "org.matrix.dummy_event" { return true; } - let ignored_type = IGNORED_MESSAGE_TYPES.binary_search(event.kind()).is_ok(); + let ignored_type = IGNORED_MESSAGE_TYPES.binary_search(&pdu.kind).is_ok(); let ignored_server = services .moderation - .is_remote_server_ignored(event.sender().server_name()); + .is_remote_server_ignored(pdu.sender().server_name()); if ignored_type && (ignored_server || (!services.config.send_messages_from_ignored_users_to_client - && services - .users - .user_is_ignored(event.sender(), user_id) - .await)) + && services.users.user_is_ignored(&pdu.sender, user_id).await)) { return true; } @@ -309,7 +314,7 @@ pub(crate) async fn visibility_filter( services .rooms .state_accessor - .user_can_see_event(user_id, pdu.room_id(), pdu.event_id()) + .user_can_see_event(user_id, &pdu.room_id, &pdu.event_id) .await .then_some(item) } @@ -317,7 +322,7 @@ pub(crate) async fn visibility_filter( #[inline] pub(crate) fn event_filter(item: PdusIterItem, filter: &RoomEventFilter) -> Option { let (_, pdu) = &item; - filter.matches(pdu).then_some(item) + pdu.matches(filter).then_some(item) } #[cfg_attr(debug_assertions, conduwuit::ctor)] diff --git a/src/api/client/openid.rs b/src/api/client/openid.rs index 0390b4b3..8d2de68d 100644 --- a/src/api/client/openid.rs +++ b/src/api/client/openid.rs @@ -1,8 +1,11 @@ use std::time::Duration; use axum::extract::State; -use conduwuit::{Err, Result, utils}; -use ruma::{api::client::account, authentication::TokenType}; +use conduwuit::{Error, Result, utils}; +use ruma::{ + api::client::{account, error::ErrorKind}, + authentication::TokenType, +}; use super::TOKEN_LENGTH; use crate::Ruma; @@ -16,15 +19,17 @@ pub(crate) async fn create_openid_token_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user(); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if sender_user != body.user_id { - return Err!(Request(InvalidParam( + if sender_user != &body.user_id { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, "Not allowed to request OpenID tokens on behalf of other users", - ))); + )); } let access_token = utils::random_string(TOKEN_LENGTH); + let expires_in = services .users .create_openid_token(&body.user_id, &access_token)?; diff --git a/src/api/client/profile.rs b/src/api/client/profile.rs index 1882495c..bdba4078 100644 --- a/src/api/client/profile.rs +++ b/src/api/client/profile.rs @@ -2,21 +2,21 @@ use std::collections::BTreeMap; use axum::extract::State; use conduwuit::{ - Err, Result, + Err, Error, Result, matrix::pdu::PduBuilder, - utils::{IterStream, future::TryExtExt, stream::TryIgnore}, + utils::{IterStream, stream::TryIgnore}, warn, }; use conduwuit_service::Services; -use futures::{ - StreamExt, TryStreamExt, - future::{join, join3, join4}, -}; +use futures::{StreamExt, TryStreamExt, future::join3}; use ruma::{ OwnedMxcUri, OwnedRoomId, UserId, api::{ - client::profile::{ - get_avatar_url, get_display_name, get_profile, set_avatar_url, set_display_name, + client::{ + error::ErrorKind, + profile::{ + get_avatar_url, get_display_name, get_profile, set_avatar_url, set_display_name, + }, }, federation, }, @@ -35,7 +35,7 @@ pub(crate) async fn set_displayname_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user(); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); if services.users.is_suspended(sender_user).await? { return Err!(Request(UserSuspended("You cannot perform this action while suspended."))); } @@ -110,7 +110,7 @@ pub(crate) async fn get_displayname_route( if !services.users.exists(&body.user_id).await { // Return 404 if this user doesn't exist and we couldn't fetch it over // federation - return Err!(Request(NotFound("Profile was not found."))); + return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found.")); } Ok(get_display_name::v3::Response { @@ -127,7 +127,7 @@ pub(crate) async fn set_avatar_url_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user(); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); if services.users.is_suspended(sender_user).await? { return Err!(Request(UserSuspended("You cannot perform this action while suspended."))); } @@ -195,9 +195,11 @@ pub(crate) async fn get_avatar_url_route( services .users .set_displayname(&body.user_id, response.displayname.clone()); + services .users .set_avatar_url(&body.user_id, response.avatar_url.clone()); + services .users .set_blurhash(&body.user_id, response.blurhash.clone()); @@ -212,16 +214,13 @@ pub(crate) async fn get_avatar_url_route( if !services.users.exists(&body.user_id).await { // Return 404 if this user doesn't exist and we couldn't fetch it over // federation - return Err!(Request(NotFound("Profile was not found."))); + return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found.")); } - let (avatar_url, blurhash) = join( - services.users.avatar_url(&body.user_id).ok(), - services.users.blurhash(&body.user_id).ok(), - ) - .await; - - Ok(get_avatar_url::v3::Response { avatar_url, blurhash }) + Ok(get_avatar_url::v3::Response { + avatar_url: services.users.avatar_url(&body.user_id).await.ok(), + blurhash: services.users.blurhash(&body.user_id).await.ok(), + }) } /// # `GET /_matrix/client/v3/profile/{userId}` @@ -254,12 +253,15 @@ pub(crate) async fn get_profile_route( services .users .set_displayname(&body.user_id, response.displayname.clone()); + services .users .set_avatar_url(&body.user_id, response.avatar_url.clone()); + services .users .set_blurhash(&body.user_id, response.blurhash.clone()); + services .users .set_timezone(&body.user_id, response.tz.clone()); @@ -285,7 +287,7 @@ pub(crate) async fn get_profile_route( if !services.users.exists(&body.user_id).await { // Return 404 if this user doesn't exist and we couldn't fetch it over // federation - return Err!(Request(NotFound("Profile was not found."))); + return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found.")); } let mut custom_profile_fields: BTreeMap = services @@ -298,19 +300,11 @@ pub(crate) async fn get_profile_route( custom_profile_fields.remove("us.cloke.msc4175.tz"); custom_profile_fields.remove("m.tz"); - let (avatar_url, blurhash, displayname, tz) = join4( - services.users.avatar_url(&body.user_id).ok(), - services.users.blurhash(&body.user_id).ok(), - services.users.displayname(&body.user_id).ok(), - services.users.timezone(&body.user_id).ok(), - ) - .await; - Ok(get_profile::v3::Response { - avatar_url, - blurhash, - displayname, - tz, + avatar_url: services.users.avatar_url(&body.user_id).await.ok(), + blurhash: services.users.blurhash(&body.user_id).await.ok(), + displayname: services.users.displayname(&body.user_id).await.ok(), + tz: services.users.timezone(&body.user_id).await.ok(), custom_profile_fields, }) } @@ -322,12 +316,16 @@ pub async fn update_displayname( all_joined_rooms: &[OwnedRoomId], ) { let (current_avatar_url, current_blurhash, current_displayname) = join3( - services.users.avatar_url(user_id).ok(), - services.users.blurhash(user_id).ok(), - services.users.displayname(user_id).ok(), + services.users.avatar_url(user_id), + services.users.blurhash(user_id), + services.users.displayname(user_id), ) .await; + let current_avatar_url = current_avatar_url.ok(); + let current_blurhash = current_blurhash.ok(); + let current_displayname = current_displayname.ok(); + if displayname == current_displayname { return; } @@ -371,12 +369,16 @@ pub async fn update_avatar_url( all_joined_rooms: &[OwnedRoomId], ) { let (current_avatar_url, current_blurhash, current_displayname) = join3( - services.users.avatar_url(user_id).ok(), - services.users.blurhash(user_id).ok(), - services.users.displayname(user_id).ok(), + services.users.avatar_url(user_id), + services.users.blurhash(user_id), + services.users.displayname(user_id), ) .await; + let current_avatar_url = current_avatar_url.ok(); + let current_blurhash = current_blurhash.ok(); + let current_displayname = current_displayname.ok(); + if current_avatar_url == avatar_url && current_blurhash == blurhash { return; } diff --git a/src/api/client/push.rs b/src/api/client/push.rs index d8d84ec7..81020ffa 100644 --- a/src/api/client/push.rs +++ b/src/api/client/push.rs @@ -79,14 +79,17 @@ pub(crate) async fn get_pushrules_all_route( global_ruleset.update_with_server_default(Ruleset::server_default(sender_user)); - let ty = GlobalAccountDataEventType::PushRules; - let event = PushRulesEvent { - content: PushRulesEventContent { global: global_ruleset.clone() }, - }; - services .account_data - .update(None, sender_user, ty.to_string().into(), &serde_json::to_value(event)?) + .update( + None, + sender_user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(PushRulesEvent { + content: PushRulesEventContent { global: global_ruleset.clone() }, + }) + .expect("to json always works"), + ) .await?; } }; @@ -103,7 +106,7 @@ pub(crate) async fn get_pushrules_global_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user(); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let Some(content_value) = services .account_data @@ -115,17 +118,19 @@ pub(crate) async fn get_pushrules_global_route( else { // user somehow has non-existent push rule event. recreate it and return server // default silently - - let ty = GlobalAccountDataEventType::PushRules; - let event = PushRulesEvent { - content: PushRulesEventContent { - global: Ruleset::server_default(sender_user), - }, - }; - services .account_data - .update(None, sender_user, ty.to_string().into(), &serde_json::to_value(event)?) + .update( + None, + sender_user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(PushRulesEvent { + content: PushRulesEventContent { + global: Ruleset::server_default(sender_user), + }, + }) + .expect("to json always works"), + ) .await?; return Ok(get_pushrules_global_scope::v3::Response { @@ -218,7 +223,7 @@ pub(crate) async fn get_pushrule_route( if let Some(rule) = rule { Ok(get_pushrule::v3::Response { rule }) } else { - Err!(Request(NotFound("Push rule not found."))) + Err(Error::BadRequest(ErrorKind::NotFound, "Push rule not found.")) } } @@ -229,8 +234,9 @@ pub(crate) async fn set_pushrule_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user(); - let body = &body.body; + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let body = body.body; + let mut account_data: PushRulesEvent = services .account_data .get_global(sender_user, GlobalAccountDataEventType::PushRules) @@ -269,10 +275,14 @@ pub(crate) async fn set_pushrule_route( return Err(err); } - let ty = GlobalAccountDataEventType::PushRules; services .account_data - .update(None, sender_user, ty.to_string().into(), &serde_json::to_value(account_data)?) + .update( + None, + sender_user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(account_data).expect("to json value always works"), + ) .await?; Ok(set_pushrule::v3::Response {}) @@ -285,7 +295,7 @@ pub(crate) async fn get_pushrule_actions_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user(); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); // remove old deprecated mentions push rules as per MSC4210 #[allow(deprecated)] @@ -319,7 +329,7 @@ pub(crate) async fn set_pushrule_actions_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user(); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let mut account_data: PushRulesEvent = services .account_data @@ -333,13 +343,17 @@ pub(crate) async fn set_pushrule_actions_route( .set_actions(body.kind.clone(), &body.rule_id, body.actions.clone()) .is_err() { - return Err!(Request(NotFound("Push rule not found."))); + return Err(Error::BadRequest(ErrorKind::NotFound, "Push rule not found.")); } - let ty = GlobalAccountDataEventType::PushRules; services .account_data - .update(None, sender_user, ty.to_string().into(), &serde_json::to_value(account_data)?) + .update( + None, + sender_user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(account_data).expect("to json value always works"), + ) .await?; Ok(set_pushrule_actions::v3::Response {}) @@ -352,7 +366,7 @@ pub(crate) async fn get_pushrule_enabled_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user(); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); // remove old deprecated mentions push rules as per MSC4210 #[allow(deprecated)] @@ -386,7 +400,7 @@ pub(crate) async fn set_pushrule_enabled_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user(); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let mut account_data: PushRulesEvent = services .account_data @@ -400,13 +414,17 @@ pub(crate) async fn set_pushrule_enabled_route( .set_enabled(body.kind.clone(), &body.rule_id, body.enabled) .is_err() { - return Err!(Request(NotFound("Push rule not found."))); + return Err(Error::BadRequest(ErrorKind::NotFound, "Push rule not found.")); } - let ty = GlobalAccountDataEventType::PushRules; services .account_data - .update(None, sender_user, ty.to_string().into(), &serde_json::to_value(account_data)?) + .update( + None, + sender_user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(account_data).expect("to json value always works"), + ) .await?; Ok(set_pushrule_enabled::v3::Response {}) @@ -419,7 +437,7 @@ pub(crate) async fn delete_pushrule_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user(); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let mut account_data: PushRulesEvent = services .account_data @@ -445,10 +463,14 @@ pub(crate) async fn delete_pushrule_route( return Err(err); } - let ty = GlobalAccountDataEventType::PushRules; services .account_data - .update(None, sender_user, ty.to_string().into(), &serde_json::to_value(account_data)?) + .update( + None, + sender_user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(account_data).expect("to json value always works"), + ) .await?; Ok(delete_pushrule::v3::Response {}) @@ -461,7 +483,7 @@ pub(crate) async fn get_pushers_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user(); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); Ok(get_pushers::v3::Response { pushers: services.pusher.get_pushers(sender_user).await, @@ -477,7 +499,7 @@ pub(crate) async fn set_pushers_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user(); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); services .pusher @@ -493,16 +515,19 @@ async fn recreate_push_rules_and_return( services: &Services, sender_user: &ruma::UserId, ) -> Result { - let ty = GlobalAccountDataEventType::PushRules; - let event = PushRulesEvent { - content: PushRulesEventContent { - global: Ruleset::server_default(sender_user), - }, - }; - services .account_data - .update(None, sender_user, ty.to_string().into(), &serde_json::to_value(event)?) + .update( + None, + sender_user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(PushRulesEvent { + content: PushRulesEventContent { + global: Ruleset::server_default(sender_user), + }, + }) + .expect("to json always works"), + ) .await?; Ok(get_pushrules_all::v3::Response { diff --git a/src/api/client/read_marker.rs b/src/api/client/read_marker.rs index 9d813294..e152869c 100644 --- a/src/api/client/read_marker.rs +++ b/src/api/client/read_marker.rs @@ -37,7 +37,7 @@ pub(crate) async fn set_read_marker_route( Some(&body.room_id), sender_user, RoomAccountDataEventType::FullyRead, - &serde_json::to_value(fully_read_event)?, + &serde_json::to_value(fully_read_event).expect("to json value always works"), ) .await?; } @@ -151,7 +151,7 @@ pub(crate) async fn create_receipt_route( Some(&body.room_id), sender_user, RoomAccountDataEventType::FullyRead, - &serde_json::to_value(fully_read_event)?, + &serde_json::to_value(fully_read_event).expect("to json value always works"), ) .await?; }, diff --git a/src/api/client/redact.rs b/src/api/client/redact.rs index 86d871ff..a8eaf91d 100644 --- a/src/api/client/redact.rs +++ b/src/api/client/redact.rs @@ -15,8 +15,8 @@ pub(crate) async fn redact_event_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user(); - let body = &body.body; + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let body = body.body; if services.users.is_suspended(sender_user).await? { // TODO: Users can redact their own messages while suspended return Err!(Request(UserSuspended("You cannot perform this action while suspended."))); diff --git a/src/api/client/relations.rs b/src/api/client/relations.rs index 1aa34ada..377f0c71 100644 --- a/src/api/client/relations.rs +++ b/src/api/client/relations.rs @@ -1,10 +1,10 @@ use axum::extract::State; use conduwuit::{ - Result, at, - matrix::{Event, event::RelationTypeEqual, pdu::PduCount}, + Result, at, debug_warn, + matrix::pdu::PduCount, utils::{IterStream, ReadyExt, result::FlatOk, stream::WidebandExt}, }; -use conduwuit_service::Services; +use conduwuit_service::{Services, rooms::timeline::PdusIterItem}; use futures::StreamExt; use ruma::{ EventId, RoomId, UInt, UserId, @@ -129,7 +129,7 @@ async fn paginate_relations_with_filter( // Spec (v1.10) recommends depth of at least 3 let depth: u8 = if recurse { 3 } else { 1 }; - let events: Vec<_> = services + let events: Vec = services .rooms .pdu_metadata .get_relations(sender_user, room_id, target, start, limit, depth, dir) @@ -138,17 +138,28 @@ async fn paginate_relations_with_filter( .filter(|(_, pdu)| { filter_event_type .as_ref() - .is_none_or(|kind| kind == pdu.kind()) + .is_none_or(|kind| *kind == pdu.kind) }) .filter(|(_, pdu)| { filter_rel_type .as_ref() - .is_none_or(|rel_type| rel_type.relation_type_equal(pdu)) + .is_none_or(|rel_type| pdu.relation_type_equal(rel_type)) }) .stream() .ready_take_while(|(count, _)| Some(*count) != to) .wide_filter_map(|item| visibility_filter(services, sender_user, item)) .take(limit) + .then(async |mut pdu| { + if let Err(e) = services + .rooms + .pdu_metadata + .add_bundled_aggregations_to_pdu(sender_user, &mut pdu.1) + .await + { + debug_warn!("Failed to add bundled aggregations to relation: {e}"); + } + pdu + }) .collect() .await; @@ -167,22 +178,26 @@ async fn paginate_relations_with_filter( chunk: events .into_iter() .map(at!(1)) - .map(Event::into_format) + .map(|pdu| pdu.to_message_like_event()) .collect(), }) } -async fn visibility_filter( +// TODO: Can we move the visibility filter lower down, to avoid checking events +// that won't be sent? At the moment this also results in getting events that +// appear to have no relation because intermediaries are not visible to the +// user. +async fn visibility_filter( services: &Services, sender_user: &UserId, - item: (PduCount, Pdu), -) -> Option<(PduCount, Pdu)> { + item: PdusIterItem, +) -> Option { let (_, pdu) = &item; services .rooms .state_accessor - .user_can_see_event(sender_user, pdu.room_id(), pdu.event_id()) + .user_can_see_event(sender_user, &pdu.room_id, &pdu.event_id) .await .then_some(item) } diff --git a/src/api/client/report.rs b/src/api/client/report.rs index 60a16e1a..4ee8ebe5 100644 --- a/src/api/client/report.rs +++ b/src/api/client/report.rs @@ -1,33 +1,23 @@ -use std::{fmt::Write as _, ops::Mul, time::Duration}; +use std::time::Duration; use axum::extract::State; use axum_client_ip::InsecureClientIp; -use conduwuit::{Err, Result, debug_info, info, matrix::pdu::PduEvent, utils::ReadyExt}; +use conduwuit::{Err, Error, Result, debug_info, info, matrix::pdu::PduEvent, utils::ReadyExt}; use conduwuit_service::Services; use rand::Rng; use ruma::{ - EventId, OwnedEventId, OwnedRoomId, OwnedUserId, RoomId, UserId, + EventId, RoomId, UserId, api::client::{ - report_user, + error::ErrorKind, room::{report_content, report_room}, }, - events::{Mentions, room::message::RoomMessageEventContent}, + events::room::message, int, }; use tokio::time::sleep; use crate::Ruma; -struct Report { - sender: OwnedUserId, - room_id: Option, - event_id: Option, - user_id: Option, - report_type: String, - reason: Option, - score: Option, -} - /// # `POST /_matrix/client/v3/rooms/{roomId}/report` /// /// Reports an abusive room to homeserver admins @@ -37,14 +27,19 @@ pub(crate) async fn report_room_route( InsecureClientIp(client): InsecureClientIp, body: Ruma, ) -> Result { - let sender_user = body.sender_user(); - if services.users.is_suspended(sender_user).await? { - return Err!(Request(UserSuspended("You cannot perform this action while suspended."))); - } + // user authentication + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + + info!( + "Received room report by user {sender_user} for room {} with reason: \"{}\"", + body.room_id, + body.reason.as_deref().unwrap_or("") + ); if body.reason.as_ref().is_some_and(|s| s.len() > 750) { - return Err!(Request( - InvalidParam("Reason too long, should be 750 characters or fewer",) + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Reason too long, should be 750 characters or fewer", )); } @@ -60,23 +55,19 @@ pub(crate) async fn report_room_route( "Room does not exist to us, no local users have joined at all" ))); } - info!( - "Received room report by user {sender_user} for room {} with reason: \"{}\"", - body.room_id, - body.reason.as_deref().unwrap_or("") - ); - let report = Report { - sender: sender_user.to_owned(), - room_id: Some(body.room_id.clone()), - event_id: None, - user_id: None, - report_type: "room".to_owned(), - reason: body.reason.clone(), - score: None, - }; - - services.admin.send_message(build_report(report)).await.ok(); + // send admin room message that we received the report with an @room ping for + // urgency + services + .admin + .send_message(message::RoomMessageEventContent::text_markdown(format!( + "@room Room report received from {} -\n\nRoom ID: {}\n\nReport Reason: {}", + sender_user.to_owned(), + body.room_id, + body.reason.as_deref().unwrap_or("") + ))) + .await + .ok(); Ok(report_room::v3::Response {}) } @@ -91,10 +82,15 @@ pub(crate) async fn report_event_route( body: Ruma, ) -> Result { // user authentication - let sender_user = body.sender_user(); - if services.users.is_suspended(sender_user).await? { - return Err!(Request(UserSuspended("You cannot perform this action while suspended."))); - } + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + + info!( + "Received event report by user {sender_user} for room {} and event ID {}, with reason: \ + \"{}\"", + body.room_id, + body.event_id, + body.reason.as_deref().unwrap_or("") + ); delay_response().await; @@ -113,73 +109,27 @@ pub(crate) async fn report_event_route( &pdu, ) .await?; - info!( - "Received event report by user {sender_user} for room {} and event ID {}, with reason: \ - \"{}\"", - body.room_id, - body.event_id, - body.reason.as_deref().unwrap_or("") - ); - let report = Report { - sender: sender_user.to_owned(), - room_id: Some(body.room_id.clone()), - event_id: Some(body.event_id.clone()), - user_id: None, - report_type: "event".to_owned(), - reason: body.reason.clone(), - score: body.score, - }; - services.admin.send_message(build_report(report)).await.ok(); + + // send admin room message that we received the report with an @room ping for + // urgency + services + .admin + .send_message(message::RoomMessageEventContent::text_markdown(format!( + "@room Event report received from {} -\n\nEvent ID: {}\nRoom ID: {}\nSent By: \ + {}\n\nReport Score: {}\nReport Reason: {}", + sender_user.to_owned(), + pdu.event_id, + pdu.room_id, + pdu.sender, + body.score.unwrap_or_else(|| ruma::Int::from(0)), + body.reason.as_deref().unwrap_or("") + ))) + .await + .ok(); Ok(report_content::v3::Response {}) } -#[tracing::instrument(skip_all, fields(%client), name = "report_user")] -pub(crate) async fn report_user_route( - State(services): State, - InsecureClientIp(client): InsecureClientIp, - body: Ruma, -) -> Result { - // user authentication - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if services.users.is_suspended(sender_user).await? { - return Err!(Request(UserSuspended("You cannot perform this action while suspended."))); - } - - if body.reason.as_ref().is_some_and(|s| s.len() > 750) { - return Err!(Request( - InvalidParam("Reason too long, should be 750 characters or fewer",) - )); - } - - delay_response().await; - - if !services.users.is_active_local(&body.user_id).await { - // return 200 as to not reveal if the user exists. Recommended by spec. - return Ok(report_user::v3::Response {}); - } - - let report = Report { - sender: sender_user.to_owned(), - room_id: None, - event_id: None, - user_id: Some(body.user_id.clone()), - report_type: "user".to_owned(), - reason: body.reason.clone(), - score: None, - }; - - info!( - "Received room report from {sender_user} for user {} with reason: \"{}\"", - body.user_id, - body.reason.as_deref().unwrap_or("") - ); - - services.admin.send_message(build_report(report)).await.ok(); - - Ok(report_user::v3::Response {}) -} - /// in the following order: /// /// check if the room ID from the URI matches the PDU's room ID @@ -201,16 +151,23 @@ async fn is_event_report_valid( ); if room_id != pdu.room_id { - return Err!(Request(NotFound("Event ID does not belong to the reported room",))); + return Err(Error::BadRequest( + ErrorKind::NotFound, + "Event ID does not belong to the reported room", + )); } if score.is_some_and(|s| s > int!(0) || s < int!(-100)) { - return Err!(Request(InvalidParam("Invalid score, must be within 0 to -100",))); + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Invalid score, must be within 0 to -100", + )); } if reason.as_ref().is_some_and(|s| s.len() > 750) { - return Err!(Request( - InvalidParam("Reason too long, should be 750 characters or fewer",) + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Reason too long, should be 750 characters or fewer", )); } @@ -221,35 +178,15 @@ async fn is_event_report_valid( .ready_any(|user_id| user_id == sender_user) .await { - return Err!(Request(NotFound("You are not in the room you are reporting.",))); + return Err(Error::BadRequest( + ErrorKind::NotFound, + "You are not in the room you are reporting.", + )); } Ok(()) } -/// Builds a report message to be sent to the admin room. -fn build_report(report: Report) -> RoomMessageEventContent { - let mut text = - format!("@room New {} report received from {}:\n\n", report.report_type, report.sender); - if report.user_id.is_some() { - let _ = writeln!(text, "- Reported User ID: `{}`", report.user_id.unwrap()); - } - if report.room_id.is_some() { - let _ = writeln!(text, "- Reported Room ID: `{}`", report.room_id.unwrap()); - } - if report.event_id.is_some() { - let _ = writeln!(text, "- Reported Event ID: `{}`", report.event_id.unwrap()); - } - if let Some(score) = report.score { - let _ = writeln!(text, "- User-supplied offensiveness score: {}%", score.mul(int!(-1))); - } - if let Some(reason) = report.reason { - let _ = writeln!(text, "- Report Reason: {reason}"); - } - - RoomMessageEventContent::text_markdown(text).add_mentions(Mentions::with_room_mention()) -} - /// even though this is kinda security by obscurity, let's still make a small /// random delay sending a response per spec suggestion regarding /// enumerating for potential events existing in our server. @@ -259,6 +196,5 @@ async fn delay_response() { "Got successful /report request, waiting {time_to_wait} seconds before sending \ successful response." ); - sleep(Duration::from_secs(time_to_wait)).await; } diff --git a/src/api/client/room/aliases.rs b/src/api/client/room/aliases.rs index 0b072b74..3f0016af 100644 --- a/src/api/client/room/aliases.rs +++ b/src/api/client/room/aliases.rs @@ -1,7 +1,7 @@ use axum::extract::State; -use conduwuit::{Err, Result}; +use conduwuit::{Error, Result}; use futures::StreamExt; -use ruma::api::client::room::aliases; +use ruma::api::client::{error::ErrorKind, room::aliases}; use crate::Ruma; @@ -15,7 +15,7 @@ pub(crate) async fn get_room_aliases_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user(); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); if !services .rooms @@ -23,7 +23,10 @@ pub(crate) async fn get_room_aliases_route( .user_can_see_state_events(sender_user, &body.room_id) .await { - return Err!(Request(Forbidden("You don't have permission to view this room.",))); + return Err(Error::BadRequest( + ErrorKind::forbidden(), + "You don't have permission to view this room.", + )); } Ok(aliases::v3::Response { diff --git a/src/api/client/room/create.rs b/src/api/client/room/create.rs index 238691d1..d1dffc51 100644 --- a/src/api/client/room/create.rs +++ b/src/api/client/room/create.rs @@ -2,7 +2,7 @@ use std::collections::BTreeMap; use axum::extract::State; use conduwuit::{ - Err, Result, debug_info, debug_warn, err, info, + Err, Error, Result, debug_info, debug_warn, err, error, info, matrix::{StateKey, pdu::PduBuilder}, warn, }; @@ -10,7 +10,10 @@ use conduwuit_service::{Services, appservice::RegistrationInfo}; use futures::FutureExt; use ruma::{ CanonicalJsonObject, Int, OwnedRoomAliasId, OwnedRoomId, OwnedUserId, RoomId, RoomVersionId, - api::client::room::{self, create_room}, + api::client::{ + error::ErrorKind, + room::{self, create_room}, + }, events::{ TimelineEventType, room::{ @@ -55,13 +58,16 @@ pub(crate) async fn create_room_route( ) -> Result { use create_room::v3::RoomPreset; - let sender_user = body.sender_user(); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); if !services.globals.allow_room_creation() && body.appservice_info.is_none() && !services.users.is_admin(sender_user).await { - return Err!(Request(Forbidden("Room creation has been disabled.",))); + return Err(Error::BadRequest( + ErrorKind::forbidden(), + "Room creation has been disabled.", + )); } if services.users.is_suspended(sender_user).await? { @@ -75,7 +81,10 @@ pub(crate) async fn create_room_route( // check if room ID doesn't already exist instead of erroring on auth check if services.rooms.short.get_shortroomid(&room_id).await.is_ok() { - return Err!(Request(RoomInUse("Room with that custom room ID already exists",))); + return Err(Error::BadRequest( + ErrorKind::RoomInUse, + "Room with that custom room ID already exists", + )); } if body.visibility == room::Visibility::Public @@ -83,17 +92,19 @@ pub(crate) async fn create_room_route( && !services.users.is_admin(sender_user).await && body.appservice_info.is_none() { - warn!( - "Non-admin user {sender_user} tried to publish {room_id} to the room directory \ - while \"lockdown_public_room_directory\" is enabled" + info!( + "Non-admin user {sender_user} tried to publish {0} to the room directory while \ + \"lockdown_public_room_directory\" is enabled", + &room_id ); if services.server.config.admin_room_notices { services .admin - .notice(&format!( - "Non-admin user {sender_user} tried to publish {room_id} to the room \ - directory while \"lockdown_public_room_directory\" is enabled" + .send_text(&format!( + "Non-admin user {sender_user} tried to publish {0} to the room directory \ + while \"lockdown_public_room_directory\" is enabled", + &room_id )) .await; } @@ -118,9 +129,10 @@ pub(crate) async fn create_room_route( if services.server.supported_room_version(&room_version) { room_version } else { - return Err!(Request(UnsupportedRoomVersion( - "This server does not support that room version." - ))); + return Err(Error::BadRequest( + ErrorKind::UnsupportedRoomVersion, + "This server does not support that room version.", + )); }, | None => services.server.config.default_room_version.clone(), }; @@ -132,17 +144,16 @@ pub(crate) async fn create_room_route( let mut content = content .deserialize_as::() .map_err(|e| { - err!(Request(BadJson(error!( - "Failed to deserialise content as canonical JSON: {e}" - )))) + error!("Failed to deserialise content as canonical JSON: {}", e); + Error::bad_database("Failed to deserialise content as canonical JSON.") })?; - match room_version { | V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => { content.insert( "creator".into(), json!(&sender_user).try_into().map_err(|e| { - err!(Request(BadJson(debug_error!("Invalid creation content: {e}")))) + info!("Invalid creation content: {e}"); + Error::BadRequest(ErrorKind::BadJson, "Invalid creation content") })?, ); }, @@ -152,9 +163,9 @@ pub(crate) async fn create_room_route( } content.insert( "room_version".into(), - json!(room_version.as_str()) - .try_into() - .map_err(|e| err!(Request(BadJson("Invalid creation content: {e}"))))?, + json!(room_version.as_str()).try_into().map_err(|_| { + Error::BadRequest(ErrorKind::BadJson, "Invalid creation content") + })?, ); content }, @@ -163,13 +174,21 @@ pub(crate) async fn create_room_route( let content = match room_version { | V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => - RoomCreateEventContent::new_v1(sender_user.to_owned()), + RoomCreateEventContent::new_v1(sender_user.clone()), | _ => RoomCreateEventContent::new_v11(), }; - let mut content = - serde_json::from_str::(to_raw_value(&content)?.get()) - .unwrap(); - content.insert("room_version".into(), json!(room_version.as_str()).try_into()?); + let mut content = serde_json::from_str::( + to_raw_value(&content) + .expect("we just created this as content was None") + .get(), + ) + .unwrap(); + content.insert( + "room_version".into(), + json!(room_version.as_str()) + .try_into() + .expect("we just created this as content was None"), + ); content }, }; @@ -181,7 +200,8 @@ pub(crate) async fn create_room_route( .build_and_append_pdu( PduBuilder { event_type: TimelineEventType::RoomCreate, - content: to_raw_value(&create_content)?, + content: to_raw_value(&create_content) + .expect("create event content serialization"), state_key: Some(StateKey::new()), ..Default::default() }, @@ -219,7 +239,7 @@ pub(crate) async fn create_room_route( | _ => RoomPreset::PrivateChat, // Room visibility should not be custom }); - let mut users = BTreeMap::from_iter([(sender_user.to_owned(), int!(100))]); + let mut users = BTreeMap::from_iter([(sender_user.clone(), int!(100))]); if preset == RoomPreset::TrustedPrivateChat { for invite in &body.invite { @@ -247,7 +267,8 @@ pub(crate) async fn create_room_route( .build_and_append_pdu( PduBuilder { event_type: TimelineEventType::RoomPowerLevels, - content: to_raw_value(&power_levels_content)?, + content: to_raw_value(&power_levels_content) + .expect("serialized power_levels event content"), state_key: Some(StateKey::new()), ..Default::default() }, @@ -336,7 +357,8 @@ pub(crate) async fn create_room_route( // 6. Events listed in initial_state for event in &body.initial_state { let mut pdu_builder = event.deserialize_as::().map_err(|e| { - err!(Request(InvalidParam(warn!("Invalid initial state event: {e:?}")))) + warn!("Invalid initial state event: {:?}", e); + Error::BadRequest(ErrorKind::InvalidParam, "Invalid initial state event.") })?; debug_info!("Room creation initial state event: {event:?}"); @@ -345,7 +367,7 @@ pub(crate) async fn create_room_route( // state event in there with the content of literally `{}` (not null or empty // string), let's just skip it over and warn. if pdu_builder.content.get().eq("{}") { - debug_warn!("skipping empty initial state event with content of `{{}}`: {event:?}"); + info!("skipping empty initial state event with content of `{{}}`: {event:?}"); debug_warn!("content: {}", pdu_builder.content.get()); continue; } @@ -492,7 +514,9 @@ fn default_power_levels_content( if let Some(power_level_content_override) = power_level_content_override { let json: JsonObject = serde_json::from_str(power_level_content_override.json().get()) - .map_err(|e| err!(Request(BadJson("Invalid power_level_content_override: {e:?}"))))?; + .map_err(|_| { + Error::BadRequest(ErrorKind::BadJson, "Invalid power_level_content_override.") + })?; for (key, value) in json { power_levels_content[key] = value; @@ -510,14 +534,16 @@ async fn room_alias_check( ) -> Result { // Basic checks on the room alias validity if room_alias_name.contains(':') { - return Err!(Request(InvalidParam( + return Err(Error::BadRequest( + ErrorKind::InvalidParam, "Room alias contained `:` which is not allowed. Please note that this expects a \ localpart, not the full room alias.", - ))); + )); } else if room_alias_name.contains(char::is_whitespace) { - return Err!(Request(InvalidParam( + return Err(Error::BadRequest( + ErrorKind::InvalidParam, "Room alias contained spaces which is not a valid room alias.", - ))); + )); } // check if room alias is forbidden @@ -526,7 +552,7 @@ async fn room_alias_check( .forbidden_alias_names() .is_match(room_alias_name) { - return Err!(Request(Unknown("Room alias name is forbidden."))); + return Err(Error::BadRequest(ErrorKind::Unknown, "Room alias name is forbidden.")); } let server_name = services.globals.server_name(); @@ -546,19 +572,25 @@ async fn room_alias_check( .await .is_ok() { - return Err!(Request(RoomInUse("Room alias already exists."))); + return Err(Error::BadRequest(ErrorKind::RoomInUse, "Room alias already exists.")); } if let Some(info) = appservice_info { if !info.aliases.is_match(full_room_alias.as_str()) { - return Err!(Request(Exclusive("Room alias is not in namespace."))); + return Err(Error::BadRequest( + ErrorKind::Exclusive, + "Room alias is not in namespace.", + )); } } else if services .appservice .is_exclusive_alias(&full_room_alias) .await { - return Err!(Request(Exclusive("Room alias reserved by appservice.",))); + return Err(Error::BadRequest( + ErrorKind::Exclusive, + "Room alias reserved by appservice.", + )); } debug_info!("Full room alias: {full_room_alias}"); @@ -574,33 +606,24 @@ fn custom_room_id_check(services: &Services, custom_room_id: &str) -> Result Result = services .rooms - .state_cache - .user_membership(body.sender_user(), room_id) - .map(Ok); + .timeline + .pdus_rev(room_id, None) + .try_take(limit) + .and_then(async |mut pdu| { + pdu.1.set_unsigned(body.sender_user.as_deref()); + if let Some(sender_user) = body.sender_user.as_deref() { + if let Err(e) = services + .rooms + .pdu_metadata + .add_bundled_aggregations_to_pdu(sender_user, &mut pdu.1) + .await + { + debug_warn!("Failed to add bundled aggregations: {e}"); + } + } + Ok(pdu) + }) + .try_collect() + .await?; - let visibility = services.rooms.directory.visibility(room_id).map(Ok); - - let state = services + let state: Vec<_> = services .rooms .state_accessor .room_state_full_pdus(room_id) - .map_ok(Event::into_format) - .try_collect::>(); - - let limit = LIMIT_MAX; - let events = services - .rooms - .timeline - .pdus_rev(None, room_id, None) - .try_take(limit) - .try_collect::>(); - - let (membership, visibility, state, events) = - try_join4(membership, visibility, state, events) - .boxed() - .await?; + .map_ok(PduEvent::into_state_event) + .try_collect() + .await?; let messages = PaginationChunk { start: events.last().map(at!(0)).as_ref().map(ToString::to_string), @@ -66,7 +71,7 @@ pub(crate) async fn room_initial_sync_route( chunk: events .into_iter() .map(at!(1)) - .map(Event::into_format) + .map(PduEvent::into_room_event) .collect(), }; @@ -75,7 +80,11 @@ pub(crate) async fn room_initial_sync_route( account_data: None, state: state.into(), messages: messages.chunk.is_empty().or_some(messages), - visibility: visibility.into(), - membership, + visibility: services.rooms.directory.visibility(room_id).await.into(), + membership: services + .rooms + .state_cache + .user_membership(body.sender_user(), room_id) + .await, }) } diff --git a/src/api/client/room/summary.rs b/src/api/client/room/summary.rs index 635f5a8a..67d2e2ad 100644 --- a/src/api/client/room/summary.rs +++ b/src/api/client/room/summary.rs @@ -43,9 +43,10 @@ pub(crate) async fn get_room_summary_legacy( } /// # `GET /_matrix/client/unstable/im.nheko.summary/summary/{roomIdOrAlias}` -/// # `GET /_matrix/client/v1/room_summary/{roomIdOrAlias}` /// /// Returns a short description of the state of a room. +/// +/// An implementation of [MSC3266](https://github.com/matrix-org/matrix-spec-proposals/pull/3266) #[tracing::instrument(skip_all, fields(%client), name = "room_summary")] pub(crate) async fn get_room_summary( State(services): State, @@ -112,15 +113,13 @@ async fn local_room_summary_response( ) -> Result { trace!(?sender_user, "Sending local room summary response for {room_id:?}"); let join_rule = services.rooms.state_accessor.get_join_rules(room_id); - let world_readable = services.rooms.state_accessor.is_world_readable(room_id); - let guest_can_join = services.rooms.state_accessor.guest_can_join(room_id); let (join_rule, world_readable, guest_can_join) = join3(join_rule, world_readable, guest_can_join).await; - trace!("{join_rule:?}, {world_readable:?}, {guest_can_join:?}"); + user_can_see_summary( services, room_id, diff --git a/src/api/client/room/upgrade.rs b/src/api/client/room/upgrade.rs index ae632235..d8f5ea83 100644 --- a/src/api/client/room/upgrade.rs +++ b/src/api/client/room/upgrade.rs @@ -2,7 +2,7 @@ use std::cmp::max; use axum::extract::State; use conduwuit::{ - Err, Error, Event, Result, err, info, + Err, Error, Result, err, info, matrix::{StateKey, pdu::PduBuilder}, }; use futures::StreamExt; @@ -215,7 +215,7 @@ pub(crate) async fn upgrade_room_route( .room_state_get(&body.room_id, event_type, "") .await { - | Ok(v) => v.content().to_owned(), + | Ok(v) => v.content.clone(), | Err(_) => continue, // Skipping missing events. }; diff --git a/src/api/client/search.rs b/src/api/client/search.rs index cc745694..af5fccec 100644 --- a/src/api/client/search.rs +++ b/src/api/client/search.rs @@ -2,8 +2,8 @@ use std::collections::BTreeMap; use axum::extract::State; use conduwuit::{ - Err, Result, at, is_true, - matrix::Event, + Err, Result, at, debug_warn, is_true, + matrix::pdu::PduEvent, result::FlatOk, utils::{IterStream, stream::ReadyExt}, }; @@ -144,7 +144,18 @@ async fn category_room_events( .map(at!(2)) .flatten() .stream() - .map(Event::into_format) + .then(|mut pdu| async { + if let Err(e) = services + .rooms + .pdu_metadata + .add_bundled_aggregations_to_pdu(sender_user, &mut pdu) + .await + { + debug_warn!("Failed to add bundled aggregations to search result: {e}"); + } + pdu + }) + .map(PduEvent::into_room_event) .map(|result| SearchResult { rank: None, result: Some(result), @@ -185,7 +196,7 @@ async fn procure_room_state(services: &Services, room_id: &RoomId) -> Result, ) -> Result { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_device = body.sender_device.as_ref().expect("user is authenticated"); + services .users - .remove_device(body.sender_user(), body.sender_device()) + .remove_device(sender_user, sender_device) .await; Ok(logout::v3::Response::new()) @@ -360,10 +365,12 @@ pub(crate) async fn logout_all_route( InsecureClientIp(client): InsecureClientIp, body: Ruma, ) -> Result { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + services .users - .all_device_ids(body.sender_user()) - .for_each(|device_id| services.users.remove_device(body.sender_user(), device_id)) + .all_device_ids(sender_user) + .for_each(|device_id| services.users.remove_device(sender_user, device_id)) .await; Ok(logout_all::v3::Response::new()) diff --git a/src/api/client/state.rs b/src/api/client/state.rs index c0f5fe7c..d0b210c8 100644 --- a/src/api/client/state.rs +++ b/src/api/client/state.rs @@ -1,11 +1,12 @@ use axum::extract::State; use conduwuit::{ Err, Result, err, - matrix::{Event, pdu::PduBuilder}, + matrix::pdu::{PduBuilder, PduEvent}, utils::BoolExt, }; use conduwuit_service::Services; -use futures::{FutureExt, TryStreamExt}; +use futures::TryStreamExt; +use futures::FutureExt; use ruma::{ OwnedEventId, RoomId, UserId, api::client::state::{get_state_events, get_state_events_for_key, send_state_event}, @@ -21,7 +22,6 @@ use ruma::{ }, serde::Raw, }; -use serde_json::json; use crate::{Ruma, RumaResponse}; @@ -79,7 +79,7 @@ pub(crate) async fn get_state_events_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user(); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); if !services .rooms @@ -95,7 +95,7 @@ pub(crate) async fn get_state_events_route( .rooms .state_accessor .room_state_full_pdus(&body.room_id) - .map_ok(Event::into_format) + .map_ok(PduEvent::into_state_event) .try_collect() .await?, }) @@ -146,18 +146,7 @@ pub(crate) async fn get_state_events_for_key_route( Ok(get_state_events_for_key::v3::Response { content: event_format.or(|| event.get_content_as_value()), - event: event_format.then(|| { - json!({ - "content": event.content(), - "event_id": event.event_id(), - "origin_server_ts": event.origin_server_ts(), - "room_id": event.room_id(), - "sender": event.sender(), - "state_key": event.state_key(), - "type": event.kind(), - "unsigned": event.unsigned(), - }) - }), + event: event_format.then(|| event.into_state_event_value()), }) } diff --git a/src/api/client/sync/mod.rs b/src/api/client/sync/mod.rs index 40370160..1ea62883 100644 --- a/src/api/client/sync/mod.rs +++ b/src/api/client/sync/mod.rs @@ -3,7 +3,7 @@ mod v4; mod v5; use conduwuit::{ - Error, PduCount, Result, + Error, PduCount, Result, debug_warn, matrix::pdu::PduEvent, utils::stream::{BroadbandExt, ReadyExt, TryIgnore}, }; @@ -31,11 +31,7 @@ async fn load_timeline( next_batch: Option, limit: usize, ) -> Result<(Vec<(PduCount, PduEvent)>, bool), Error> { - let last_timeline_count = services - .rooms - .timeline - .last_timeline_count(Some(sender_user), room_id) - .await?; + let last_timeline_count = services.rooms.timeline.last_timeline_count(room_id).await?; if last_timeline_count <= roomsincecount { return Ok((Vec::new(), false)); @@ -44,10 +40,25 @@ async fn load_timeline( let non_timeline_pdus = services .rooms .timeline - .pdus_rev(Some(sender_user), room_id, None) + .pdus_rev(room_id, None) .ignore_err() .ready_skip_while(|&(pducount, _)| pducount > next_batch.unwrap_or_else(PduCount::max)) - .ready_take_while(|&(pducount, _)| pducount > roomsincecount); + .ready_take_while(|&(pducount, _)| pducount > roomsincecount) + .map(move |mut pdu| { + pdu.1.set_unsigned(Some(sender_user)); + pdu + }) + .then(async move |mut pdu| { + if let Err(e) = services + .rooms + .pdu_metadata + .add_bundled_aggregations_to_pdu(sender_user, &mut pdu.1) + .await + { + debug_warn!("Failed to add bundled aggregations: {e}"); + } + pdu + }); // Take the last events for the timeline pin_mut!(non_timeline_pdus); diff --git a/src/api/client/sync/v3.rs b/src/api/client/sync/v3.rs index 01428c08..7eb8c7e1 100644 --- a/src/api/client/sync/v3.rs +++ b/src/api/client/sync/v3.rs @@ -473,7 +473,9 @@ async fn handle_left_room( prev_batch: Some(next_batch.to_string()), events: Vec::new(), }, - state: RoomState { events: vec![event.into_format()] }, + state: RoomState { + events: vec![event.into_sync_state_event()], + }, })); } @@ -557,7 +559,7 @@ async fn handle_left_room( continue; } - left_state_events.push(pdu.into_format()); + left_state_events.push(pdu.into_sync_state_event()); } } @@ -643,7 +645,7 @@ async fn load_joined_room( let lazy_loading_context = &lazy_loading::Context { user_id: sender_user, - device_id: Some(sender_device), + device_id: sender_device, room_id, token: Some(since), options: Some(&filter.room.state.lazy_load_options), @@ -753,7 +755,7 @@ async fn load_joined_room( .wide_filter_map(|item| ignored_filter(services, item, sender_user)) .map(at!(1)) .chain(joined_sender_member.into_iter().stream()) - .map(Event::into_format) + .map(|pdu| pdu.to_sync_room_event()) .collect::>(); let account_data_events = services @@ -875,7 +877,10 @@ async fn load_joined_room( events: room_events, }, state: RoomState { - events: state_events.into_iter().map(Event::into_format).collect(), + events: state_events + .into_iter() + .map(PduEvent::into_sync_state_event) + .collect(), }, ephemeral: Ephemeral { events: edus }, unread_thread_notifications: BTreeMap::new(), @@ -1184,7 +1189,7 @@ async fn calculate_heroes( services .rooms .timeline - .all_pdus(sender_user, room_id) + .all_pdus(room_id) .ready_filter(|(_, pdu)| pdu.kind == RoomMember) .fold_default(|heroes: Vec<_>, (_, pdu)| { fold_hero(heroes, services, room_id, sender_user, pdu) diff --git a/src/api/client/sync/v4.rs b/src/api/client/sync/v4.rs index 14cd50d8..f153b2da 100644 --- a/src/api/client/sync/v4.rs +++ b/src/api/client/sync/v4.rs @@ -6,7 +6,7 @@ use std::{ use axum::extract::State; use conduwuit::{ - Err, Error, Event, PduCount, Result, at, debug, error, extract_variant, + Err, Error, PduCount, PduEvent, Result, debug, error, extract_variant, matrix::TypeStateKey, utils::{ BoolExt, IterStream, ReadyExt, TryFutureExtExt, @@ -604,8 +604,7 @@ pub(crate) async fn sync_events_v4_route( .iter() .stream() .filter_map(|item| ignored_filter(&services, item.clone(), sender_user)) - .map(at!(1)) - .map(Event::into_format) + .map(|(_, pdu)| pdu.to_sync_room_event()) .collect() .await; @@ -627,7 +626,7 @@ pub(crate) async fn sync_events_v4_route( .state_accessor .room_state_get(room_id, &state.0, &state.1) .await - .map(Event::into_format) + .map(PduEvent::into_sync_state_event) .ok() }) .collect() diff --git a/src/api/client/sync/v5.rs b/src/api/client/sync/v5.rs index e4cefba0..f3fc0f44 100644 --- a/src/api/client/sync/v5.rs +++ b/src/api/client/sync/v5.rs @@ -7,8 +7,11 @@ use std::{ use axum::extract::State; use conduwuit::{ - Err, Error, Result, at, error, extract_variant, is_equal_to, - matrix::{Event, TypeStateKey, pdu::PduCount}, + Err, Error, Result, error, extract_variant, is_equal_to, + matrix::{ + TypeStateKey, + pdu::{PduCount, PduEvent}, + }, trace, utils::{ BoolExt, FutureBoolExt, IterStream, ReadyExt, TryFutureExtExt, @@ -512,8 +515,7 @@ where .iter() .stream() .filter_map(|item| ignored_filter(services, item.clone(), sender_user)) - .map(at!(1)) - .map(Event::into_format) + .map(|(_, pdu)| pdu.to_sync_room_event()) .collect() .await; @@ -535,7 +537,7 @@ where .state_accessor .room_state_get(room_id, &state.0, &state.1) .await - .map(Event::into_format) + .map(PduEvent::into_sync_state_event) .ok() }) .collect() diff --git a/src/api/client/tag.rs b/src/api/client/tag.rs index 68105e4f..caafe10d 100644 --- a/src/api/client/tag.rs +++ b/src/api/client/tag.rs @@ -21,7 +21,7 @@ pub(crate) async fn update_tag_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user(); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let mut tags_event = services .account_data @@ -42,7 +42,7 @@ pub(crate) async fn update_tag_route( Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag, - &serde_json::to_value(tags_event)?, + &serde_json::to_value(tags_event).expect("to json value always works"), ) .await?; @@ -58,7 +58,7 @@ pub(crate) async fn delete_tag_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user(); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let mut tags_event = services .account_data @@ -76,7 +76,7 @@ pub(crate) async fn delete_tag_route( Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag, - &serde_json::to_value(tags_event)?, + &serde_json::to_value(tags_event).expect("to json value always works"), ) .await?; @@ -92,7 +92,7 @@ pub(crate) async fn get_tags_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user(); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let tags_event = services .account_data diff --git a/src/api/client/threads.rs b/src/api/client/threads.rs index ca176eda..09fb75d6 100644 --- a/src/api/client/threads.rs +++ b/src/api/client/threads.rs @@ -1,10 +1,7 @@ use axum::extract::State; use conduwuit::{ - Result, at, - matrix::{ - Event, - pdu::{PduCount, PduEvent}, - }, + Result, at, debug_warn, + matrix::pdu::{PduCount, PduEvent}, }; use futures::StreamExt; use ruma::{api::client::threads::get_threads, uint}; @@ -31,6 +28,8 @@ pub(crate) async fn get_threads_route( .transpose()? .unwrap_or_else(PduCount::max); + // TODO: user_can_see_event and set_unsigned should be at the same level / + // function, so unsigned is only set for seen events. let threads: Vec<(PduCount, PduEvent)> = services .rooms .threads @@ -45,6 +44,17 @@ pub(crate) async fn get_threads_route( .await .then_some((count, pdu)) }) + .then(|(count, mut pdu)| async move { + if let Err(e) = services + .rooms + .pdu_metadata + .add_bundled_aggregations_to_pdu(body.sender_user(), &mut pdu) + .await + { + debug_warn!("Failed to add bundled aggregations to thread: {e}"); + } + (count, pdu) + }) .collect() .await; @@ -59,7 +69,7 @@ pub(crate) async fn get_threads_route( chunk: threads .into_iter() .map(at!(1)) - .map(Event::into_format) + .map(PduEvent::into_room_event) .collect(), }) } diff --git a/src/api/client/to_device.rs b/src/api/client/to_device.rs index 581f4a72..8ad9dc99 100644 --- a/src/api/client/to_device.rs +++ b/src/api/client/to_device.rs @@ -21,7 +21,7 @@ pub(crate) async fn send_event_to_device_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user(); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_deref(); // Check if this is a new transaction id @@ -47,7 +47,7 @@ pub(crate) async fn send_event_to_device_route( serde_json::to_writer( &mut buf, &federation::transactions::edu::Edu::DirectToDevice(DirectDeviceContent { - sender: sender_user.to_owned(), + sender: sender_user.clone(), ev_type: body.event_type.clone(), message_id: count.to_string().into(), messages, diff --git a/src/api/client/unstable.rs b/src/api/client/unstable.rs index 08f70975..e21eaf21 100644 --- a/src/api/client/unstable.rs +++ b/src/api/client/unstable.rs @@ -69,7 +69,7 @@ pub(crate) async fn delete_timezone_key_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user(); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); if *sender_user != body.user_id && body.appservice_info.is_none() { return Err!(Request(Forbidden("You cannot update the profile of another user"))); @@ -97,7 +97,7 @@ pub(crate) async fn set_timezone_key_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user(); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); if *sender_user != body.user_id && body.appservice_info.is_none() { return Err!(Request(Forbidden("You cannot update the profile of another user"))); @@ -125,7 +125,7 @@ pub(crate) async fn set_profile_key_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user(); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); if *sender_user != body.user_id && body.appservice_info.is_none() { return Err!(Request(Forbidden("You cannot update the profile of another user"))); @@ -218,7 +218,7 @@ pub(crate) async fn delete_profile_key_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user(); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); if *sender_user != body.user_id && body.appservice_info.is_none() { return Err!(Request(Forbidden("You cannot update the profile of another user"))); diff --git a/src/api/client/unversioned.rs b/src/api/client/unversioned.rs index a4136d1a..232d5b28 100644 --- a/src/api/client/unversioned.rs +++ b/src/api/client/unversioned.rs @@ -37,11 +37,7 @@ pub(crate) async fn get_supported_versions_route( "v1.3".to_owned(), "v1.4".to_owned(), "v1.5".to_owned(), - "v1.8".to_owned(), "v1.11".to_owned(), - "v1.12".to_owned(), - "v1.13".to_owned(), - "v1.14".to_owned(), ], unstable_features: BTreeMap::from_iter([ ("org.matrix.e2e_cross_signing".to_owned(), true), diff --git a/src/api/client/user_directory.rs b/src/api/client/user_directory.rs index 9a1f86b8..748fc049 100644 --- a/src/api/client/user_directory.rs +++ b/src/api/client/user_directory.rs @@ -1,7 +1,10 @@ use axum::extract::State; use conduwuit::{ Result, - utils::{future::BoolExt, stream::BroadbandExt}, + utils::{ + future::BoolExt, + stream::{BroadbandExt, ReadyExt}, + }, }; use futures::{FutureExt, StreamExt, pin_mut}; use ruma::{ @@ -34,18 +37,17 @@ pub(crate) async fn search_users_route( let mut users = services .users .stream() + .ready_filter(|user_id| user_id.as_str().to_lowercase().contains(&search_term)) .map(ToOwned::to_owned) .broad_filter_map(async |user_id| { let display_name = services.users.displayname(&user_id).await.ok(); - let user_id_matches = user_id.as_str().to_lowercase().contains(&search_term); - let display_name_matches = display_name .as_deref() .map(str::to_lowercase) .is_some_and(|display_name| display_name.contains(&search_term)); - if !user_id_matches && !display_name_matches { + if !display_name_matches { return None; } diff --git a/src/api/router.rs b/src/api/router.rs index d1b05a91..5416e9e9 100644 --- a/src/api/router.rs +++ b/src/api/router.rs @@ -94,7 +94,6 @@ pub fn build(router: Router, server: &Server) -> Router { .ruma_route(&client::redact_event_route) .ruma_route(&client::report_event_route) .ruma_route(&client::report_room_route) - .ruma_route(&client::report_user_route) .ruma_route(&client::create_alias_route) .ruma_route(&client::delete_alias_route) .ruma_route(&client::get_alias_route) diff --git a/src/api/server/backfill.rs b/src/api/server/backfill.rs index 3cfbcedc..058fc273 100644 --- a/src/api/server/backfill.rs +++ b/src/api/server/backfill.rs @@ -3,6 +3,7 @@ use std::cmp; use axum::extract::State; use conduwuit::{ PduCount, Result, + result::LogErr, utils::{IterStream, ReadyExt, stream::TryTools}, }; use futures::{FutureExt, StreamExt, TryStreamExt}; @@ -62,7 +63,7 @@ pub(crate) async fn get_backfill_route( pdus: services .rooms .timeline - .pdus_rev(None, &body.room_id, Some(from.saturating_add(1))) + .pdus_rev(&body.room_id, Some(from.saturating_add(1))) .try_take(limit) .try_filter_map(|(_, pdu)| async move { Ok(services @@ -72,6 +73,15 @@ pub(crate) async fn get_backfill_route( .await .then_some(pdu)) }) + .and_then(async |mut pdu| { + // Strip the transaction ID, as that is private + pdu.remove_transaction_id().log_err().ok(); + // Add age, as this is specified + pdu.add_age().log_err().ok(); + // It's not clear if we should strip or add any more data, leave as is. + // In particular: Redaction? + Ok(pdu) + }) .try_filter_map(|pdu| async move { Ok(services .rooms diff --git a/src/api/server/invite.rs b/src/api/server/invite.rs index 0a9b2e10..f53e1a15 100644 --- a/src/api/server/invite.rs +++ b/src/api/server/invite.rs @@ -2,10 +2,7 @@ use axum::extract::State; use axum_client_ip::InsecureClientIp; use base64::{Engine as _, engine::general_purpose}; use conduwuit::{ - Err, Error, PduEvent, Result, err, - matrix::{Event, event::gen_event_id}, - utils::{self, hash::sha256}, - warn, + Err, Error, PduEvent, Result, err, pdu::gen_event_id, utils, utils::hash::sha256, warn, }; use ruma::{ CanonicalJsonValue, OwnedUserId, UserId, @@ -59,7 +56,7 @@ pub(crate) async fn create_invite_route( } let mut signed_event = utils::to_canonical_object(&body.event) - .map_err(|_| err!(Request(InvalidParam("Invite event is invalid."))))?; + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invite event is invalid."))?; let invited_user: OwnedUserId = signed_event .get("state_key") @@ -114,7 +111,7 @@ pub(crate) async fn create_invite_route( let pdu: PduEvent = serde_json::from_value(event.into()) .map_err(|e| err!(Request(BadJson("Invalid invite event PDU: {e}"))))?; - invite_state.push(pdu.to_format()); + invite_state.push(pdu.to_stripped_state_event()); // If we are active in the room, the remote server will notify us about the // join/invite through /send. If we are not in the room, we need to manually @@ -147,7 +144,7 @@ pub(crate) async fn create_invite_route( .send_appservice_request( appservice.registration.clone(), ruma::api::appservice::event::push_events::v1::Request { - events: vec![pdu.to_format()], + events: vec![pdu.to_room_event()], txn_id: general_purpose::URL_SAFE_NO_PAD .encode(sha256::hash(pdu.event_id.as_bytes())) .into(), diff --git a/src/api/server/send_join.rs b/src/api/server/send_join.rs index 652451c7..895eca81 100644 --- a/src/api/server/send_join.rs +++ b/src/api/server/send_join.rs @@ -5,7 +5,7 @@ use std::borrow::Borrow; use axum::extract::State; use conduwuit::{ Err, Result, at, err, - matrix::event::gen_event_id_canonical_json, + pdu::gen_event_id_canonical_json, utils::stream::{IterStream, TryBroadbandExt}, warn, }; diff --git a/src/api/server/send_knock.rs b/src/api/server/send_knock.rs index ffd41ada..8d3697d2 100644 --- a/src/api/server/send_knock.rs +++ b/src/api/server/send_knock.rs @@ -1,7 +1,7 @@ use axum::extract::State; use conduwuit::{ Err, Result, err, - matrix::{event::gen_event_id_canonical_json, pdu::PduEvent}, + matrix::pdu::{PduEvent, gen_event_id_canonical_json}, warn, }; use futures::FutureExt; diff --git a/src/api/server/send_leave.rs b/src/api/server/send_leave.rs index b6336e1a..d3dc994c 100644 --- a/src/api/server/send_leave.rs +++ b/src/api/server/send_leave.rs @@ -1,7 +1,7 @@ #![allow(deprecated)] use axum::extract::State; -use conduwuit::{Err, Result, err, matrix::event::gen_event_id_canonical_json}; +use conduwuit::{Err, Result, err, matrix::pdu::gen_event_id_canonical_json}; use conduwuit_service::Services; use futures::FutureExt; use ruma::{ diff --git a/src/core/config/mod.rs b/src/core/config/mod.rs index 8ef580f1..e19a5974 100644 --- a/src/core/config/mod.rs +++ b/src/core/config/mod.rs @@ -513,22 +513,6 @@ pub struct Config { #[serde(default)] pub allow_registration: bool, - /// If registration is enabled, and this setting is true, new users - /// registered after the first admin user will be automatically suspended - /// and will require an admin to run `!admin users unsuspend `. - /// - /// Suspended users are still able to read messages, make profile updates, - /// leave rooms, and deactivate their account, however cannot send messages, - /// invites, or create/join or otherwise modify rooms. - /// They are effectively read-only. - /// - /// If you want to use this to screen people who register on your server, - /// you should add a room to `auto_join_rooms` that is public, and contains - /// information that new users can read (since they won't be able to DM - /// anyone, or send a message, and may be confused). - #[serde(default)] - pub suspend_on_register: bool, - /// Enabling this setting opens registration to anyone without restrictions. /// This makes your server vulnerable to abuse #[serde(default)] diff --git a/src/core/config/proxy.rs b/src/core/config/proxy.rs index 77c4531a..ea388f24 100644 --- a/src/core/config/proxy.rs +++ b/src/core/config/proxy.rs @@ -88,7 +88,10 @@ impl PartialProxyConfig { } } match (included_because, excluded_because) { - | (Some(a), Some(b)) if a.more_specific_than(b) => Some(&self.url), + | (Some(a), Some(b)) if a.more_specific_than(b) => Some(&self.url), /* included for + * a more specific + * reason */ + // than excluded | (Some(_), None) => Some(&self.url), | _ => None, } diff --git a/src/core/info/cargo.rs b/src/core/info/cargo.rs index 61a97508..e70bdcd5 100644 --- a/src/core/info/cargo.rs +++ b/src/core/info/cargo.rs @@ -84,12 +84,10 @@ fn append_features(features: &mut Vec, manifest: &str) -> Result<()> { fn init_dependencies() -> Result { let manifest = Manifest::from_str(WORKSPACE_MANIFEST)?; - let deps_set = manifest + Ok(manifest .workspace .as_ref() .expect("manifest has workspace section") .dependencies - .clone(); - - Ok(deps_set) + .clone()) } diff --git a/src/core/matrix/event.rs b/src/core/matrix/event.rs index a1d1339e..e4c478cd 100644 --- a/src/core/matrix/event.rs +++ b/src/core/matrix/event.rs @@ -1,188 +1,63 @@ -mod content; -mod filter; -mod format; -mod id; -mod redact; -mod relation; -mod type_ext; -mod unsigned; - -use std::fmt::Debug; - -use ruma::{ - CanonicalJsonObject, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, RoomId, - RoomVersionId, UserId, events::TimelineEventType, -}; -use serde::Deserialize; -use serde_json::{Value as JsonValue, value::RawValue as RawJsonValue}; - -pub use self::{filter::Matches, id::*, relation::RelationTypeEqual, type_ext::TypeExt}; -use super::{pdu::Pdu, state_key::StateKey}; -use crate::{Result, utils}; +use ruma::{EventId, MilliSecondsSinceUnixEpoch, RoomId, UserId, events::TimelineEventType}; +use serde_json::value::RawValue as RawJsonValue; /// Abstraction of a PDU so users can have their own PDU types. -pub trait Event: Clone + Debug { - /// Serialize into a Ruma JSON format, consuming. - #[inline] - fn into_format(self) -> T - where - T: From>, - Self: Sized, - { - format::Owned(self).into() - } - - /// Serialize into a Ruma JSON format - #[inline] - fn to_format<'a, T>(&'a self) -> T - where - T: From>, - Self: Sized + 'a, - { - format::Ref(self).into() - } - - #[inline] - fn contains_unsigned_property(&self, property: &str, is_type: T) -> bool - where - T: FnOnce(&JsonValue) -> bool, - Self: Sized, - { - unsigned::contains_unsigned_property::(self, property, is_type) - } - - #[inline] - fn get_unsigned_property(&self, property: &str) -> Result - where - T: for<'de> Deserialize<'de>, - Self: Sized, - { - unsigned::get_unsigned_property::(self, property) - } - - #[inline] - fn get_unsigned_as_value(&self) -> JsonValue - where - Self: Sized, - { - unsigned::get_unsigned_as_value(self) - } - - #[inline] - fn get_unsigned(&self) -> Result - where - T: for<'de> Deserialize<'de>, - Self: Sized, - { - unsigned::get_unsigned::(self) - } - - #[inline] - fn get_content_as_value(&self) -> JsonValue - where - Self: Sized, - { - content::as_value(self) - } - - #[inline] - fn get_content(&self) -> Result - where - for<'de> T: Deserialize<'de>, - Self: Sized, - { - content::get::(self) - } - - #[inline] - fn redacts_id(&self, room_version: &RoomVersionId) -> Option - where - Self: Sized, - { - redact::redacts_id(self, room_version) - } - - #[inline] - fn is_redacted(&self) -> bool - where - Self: Sized, - { - redact::is_redacted(self) - } - - #[inline] - fn into_canonical_object(self) -> CanonicalJsonObject - where - Self: Sized, - { - utils::to_canonical_object(self.into_pdu()).expect("failed to create Value::Object") - } - - #[inline] - fn to_canonical_object(&self) -> CanonicalJsonObject { - utils::to_canonical_object(self.as_pdu()).expect("failed to create Value::Object") - } - - #[inline] - fn into_value(self) -> JsonValue - where - Self: Sized, - { - serde_json::to_value(self.into_pdu()).expect("failed to create JSON Value") - } - - #[inline] - fn to_value(&self) -> JsonValue { - serde_json::to_value(self.as_pdu()).expect("failed to create JSON Value") - } - - #[inline] - fn as_mut_pdu(&mut self) -> &mut Pdu { unimplemented!("not a mutable Pdu") } - - fn as_pdu(&self) -> &Pdu; - - fn into_pdu(self) -> Pdu; - - fn is_owned(&self) -> bool; - - // - // Canonical properties - // - - /// All the authenticating events for this event. - fn auth_events(&self) -> impl DoubleEndedIterator + Clone + Send + '_; - - /// The event's content. - fn content(&self) -> &RawJsonValue; - +pub trait Event { /// The `EventId` of this event. fn event_id(&self) -> &EventId; - /// The time of creation on the originating server. - fn origin_server_ts(&self) -> MilliSecondsSinceUnixEpoch; - - /// The events before this event. - fn prev_events(&self) -> impl DoubleEndedIterator + Clone + Send + '_; - - /// If this event is a redaction event this is the event it redacts. - fn redacts(&self) -> Option<&EventId>; - /// The `RoomId` of this event. fn room_id(&self) -> &RoomId; /// The `UserId` of this event. fn sender(&self) -> &UserId; + /// The time of creation on the originating server. + fn origin_server_ts(&self) -> MilliSecondsSinceUnixEpoch; + + /// The event type. + fn event_type(&self) -> &TimelineEventType; + + /// The event's content. + fn content(&self) -> &RawJsonValue; + /// The state key for this event. fn state_key(&self) -> Option<&str>; - /// The event type. - fn kind(&self) -> &TimelineEventType; + /// The events before this event. + // Requires GATs to avoid boxing (and TAIT for making it convenient). + fn prev_events(&self) -> impl DoubleEndedIterator + Send + '_; - /// Metadata container; peer-trusted only. - fn unsigned(&self) -> Option<&RawJsonValue>; + /// All the authenticating events for this event. + // Requires GATs to avoid boxing (and TAIT for making it convenient). + fn auth_events(&self) -> impl DoubleEndedIterator + Send + '_; - //#[deprecated] - #[inline] - fn event_type(&self) -> &TimelineEventType { self.kind() } + /// If this event is a redaction event this is the event it redacts. + fn redacts(&self) -> Option<&EventId>; +} + +impl Event for &T { + fn event_id(&self) -> &EventId { (*self).event_id() } + + fn room_id(&self) -> &RoomId { (*self).room_id() } + + fn sender(&self) -> &UserId { (*self).sender() } + + fn origin_server_ts(&self) -> MilliSecondsSinceUnixEpoch { (*self).origin_server_ts() } + + fn event_type(&self) -> &TimelineEventType { (*self).event_type() } + + fn content(&self) -> &RawJsonValue { (*self).content() } + + fn state_key(&self) -> Option<&str> { (*self).state_key() } + + fn prev_events(&self) -> impl DoubleEndedIterator + Send + '_ { + (*self).prev_events() + } + + fn auth_events(&self) -> impl DoubleEndedIterator + Send + '_ { + (*self).auth_events() + } + + fn redacts(&self) -> Option<&EventId> { (*self).redacts() } } diff --git a/src/core/matrix/event/content.rs b/src/core/matrix/event/content.rs deleted file mode 100644 index 1ee7ebd2..00000000 --- a/src/core/matrix/event/content.rs +++ /dev/null @@ -1,21 +0,0 @@ -use serde::Deserialize; -use serde_json::value::Value as JsonValue; - -use super::Event; -use crate::{Result, err}; - -#[inline] -#[must_use] -pub(super) fn as_value(event: &E) -> JsonValue { - get(event).expect("Failed to represent Event content as JsonValue") -} - -#[inline] -pub(super) fn get(event: &E) -> Result -where - T: for<'de> Deserialize<'de>, - E: Event, -{ - serde_json::from_str(event.content().get()) - .map_err(|e| err!(Request(BadJson("Failed to deserialize content into type: {e}")))) -} diff --git a/src/core/matrix/event/filter.rs b/src/core/matrix/event/filter.rs deleted file mode 100644 index d3a225b6..00000000 --- a/src/core/matrix/event/filter.rs +++ /dev/null @@ -1,93 +0,0 @@ -use ruma::api::client::filter::{RoomEventFilter, UrlFilter}; -use serde_json::Value; - -use super::Event; -use crate::is_equal_to; - -pub trait Matches { - fn matches(&self, event: &E) -> bool; -} - -impl Matches for &RoomEventFilter { - #[inline] - fn matches(&self, event: &E) -> bool { - if !matches_sender(event, self) { - return false; - } - - if !matches_room(event, self) { - return false; - } - - if !matches_type(event, self) { - return false; - } - - if !matches_url(event, self) { - return false; - } - - true - } -} - -fn matches_room(event: &E, filter: &RoomEventFilter) -> bool { - if filter.not_rooms.iter().any(is_equal_to!(event.room_id())) { - return false; - } - - if let Some(rooms) = filter.rooms.as_ref() { - if !rooms.iter().any(is_equal_to!(event.room_id())) { - return false; - } - } - - true -} - -fn matches_sender(event: &E, filter: &RoomEventFilter) -> bool { - if filter.not_senders.iter().any(is_equal_to!(event.sender())) { - return false; - } - - if let Some(senders) = filter.senders.as_ref() { - if !senders.iter().any(is_equal_to!(event.sender())) { - return false; - } - } - - true -} - -fn matches_type(event: &E, filter: &RoomEventFilter) -> bool { - let kind = event.kind().to_cow_str(); - - if filter.not_types.iter().any(is_equal_to!(&kind)) { - return false; - } - - if let Some(types) = filter.types.as_ref() { - if !types.iter().any(is_equal_to!(&kind)) { - return false; - } - } - - true -} - -fn matches_url(event: &E, filter: &RoomEventFilter) -> bool { - let Some(url_filter) = filter.url_filter.as_ref() else { - return true; - }; - - //TODO: might be better to use Ruma's Raw rather than serde here - let url = event - .get_content_as_value() - .get("url") - .is_some_and(Value::is_string); - - match url_filter { - | UrlFilter::EventsWithUrl => url, - | UrlFilter::EventsWithoutUrl => !url, - } -} diff --git a/src/core/matrix/event/format.rs b/src/core/matrix/event/format.rs deleted file mode 100644 index 988cf4f0..00000000 --- a/src/core/matrix/event/format.rs +++ /dev/null @@ -1,219 +0,0 @@ -use ruma::{ - events::{ - AnyMessageLikeEvent, AnyStateEvent, AnyStrippedStateEvent, AnySyncStateEvent, - AnySyncTimelineEvent, AnyTimelineEvent, StateEvent, room::member::RoomMemberEventContent, - space::child::HierarchySpaceChildEvent, - }, - serde::Raw, -}; -use serde_json::json; - -use super::{Event, redact}; - -pub struct Owned(pub(super) E); - -pub struct Ref<'a, E: Event>(pub(super) &'a E); - -impl From> for Raw { - fn from(event: Owned) -> Self { Ref(&event.0).into() } -} - -impl<'a, E: Event> From> for Raw { - fn from(event: Ref<'a, E>) -> Self { - let event = event.0; - let (redacts, content) = redact::copy(event); - let mut json = json!({ - "content": content, - "event_id": event.event_id(), - "origin_server_ts": event.origin_server_ts(), - "sender": event.sender(), - "type": event.event_type(), - }); - - if let Some(redacts) = redacts { - json["redacts"] = json!(redacts); - } - if let Some(state_key) = event.state_key() { - json["state_key"] = json!(state_key); - } - if let Some(unsigned) = event.unsigned() { - json["unsigned"] = json!(unsigned); - } - - serde_json::from_value(json).expect("Failed to serialize Event value") - } -} - -impl From> for Raw { - fn from(event: Owned) -> Self { Ref(&event.0).into() } -} - -impl<'a, E: Event> From> for Raw { - fn from(event: Ref<'a, E>) -> Self { - let event = event.0; - let (redacts, content) = redact::copy(event); - let mut json = json!({ - "content": content, - "event_id": event.event_id(), - "origin_server_ts": event.origin_server_ts(), - "room_id": event.room_id(), - "sender": event.sender(), - "type": event.kind(), - }); - - if let Some(redacts) = redacts { - json["redacts"] = json!(redacts); - } - if let Some(state_key) = event.state_key() { - json["state_key"] = json!(state_key); - } - if let Some(unsigned) = event.unsigned() { - json["unsigned"] = json!(unsigned); - } - - serde_json::from_value(json).expect("Failed to serialize Event value") - } -} - -impl From> for Raw { - fn from(event: Owned) -> Self { Ref(&event.0).into() } -} - -impl<'a, E: Event> From> for Raw { - fn from(event: Ref<'a, E>) -> Self { - let event = event.0; - let (redacts, content) = redact::copy(event); - let mut json = json!({ - "content": content, - "event_id": event.event_id(), - "origin_server_ts": event.origin_server_ts(), - "room_id": event.room_id(), - "sender": event.sender(), - "type": event.kind(), - }); - - if let Some(redacts) = &redacts { - json["redacts"] = json!(redacts); - } - if let Some(state_key) = event.state_key() { - json["state_key"] = json!(state_key); - } - if let Some(unsigned) = event.unsigned() { - json["unsigned"] = json!(unsigned); - } - - serde_json::from_value(json).expect("Failed to serialize Event value") - } -} - -impl From> for Raw { - fn from(event: Owned) -> Self { Ref(&event.0).into() } -} - -impl<'a, E: Event> From> for Raw { - fn from(event: Ref<'a, E>) -> Self { - let event = event.0; - let mut json = json!({ - "content": event.content(), - "event_id": event.event_id(), - "origin_server_ts": event.origin_server_ts(), - "room_id": event.room_id(), - "sender": event.sender(), - "state_key": event.state_key(), - "type": event.kind(), - }); - - if let Some(unsigned) = event.unsigned() { - json["unsigned"] = json!(unsigned); - } - - serde_json::from_value(json).expect("Failed to serialize Event value") - } -} - -impl From> for Raw { - fn from(event: Owned) -> Self { Ref(&event.0).into() } -} - -impl<'a, E: Event> From> for Raw { - fn from(event: Ref<'a, E>) -> Self { - let event = event.0; - let mut json = json!({ - "content": event.content(), - "event_id": event.event_id(), - "origin_server_ts": event.origin_server_ts(), - "sender": event.sender(), - "state_key": event.state_key(), - "type": event.kind(), - }); - - if let Some(unsigned) = event.unsigned() { - json["unsigned"] = json!(unsigned); - } - - serde_json::from_value(json).expect("Failed to serialize Event value") - } -} - -impl From> for Raw { - fn from(event: Owned) -> Self { Ref(&event.0).into() } -} - -impl<'a, E: Event> From> for Raw { - fn from(event: Ref<'a, E>) -> Self { - let event = event.0; - let json = json!({ - "content": event.content(), - "sender": event.sender(), - "state_key": event.state_key(), - "type": event.kind(), - }); - - serde_json::from_value(json).expect("Failed to serialize Event value") - } -} - -impl From> for Raw { - fn from(event: Owned) -> Self { Ref(&event.0).into() } -} - -impl<'a, E: Event> From> for Raw { - fn from(event: Ref<'a, E>) -> Self { - let event = event.0; - let json = json!({ - "content": event.content(), - "origin_server_ts": event.origin_server_ts(), - "sender": event.sender(), - "state_key": event.state_key(), - "type": event.kind(), - }); - - serde_json::from_value(json).expect("Failed to serialize Event value") - } -} - -impl From> for Raw> { - fn from(event: Owned) -> Self { Ref(&event.0).into() } -} - -impl<'a, E: Event> From> for Raw> { - fn from(event: Ref<'a, E>) -> Self { - let event = event.0; - let mut json = json!({ - "content": event.content(), - "event_id": event.event_id(), - "origin_server_ts": event.origin_server_ts(), - "redacts": event.redacts(), - "room_id": event.room_id(), - "sender": event.sender(), - "state_key": event.state_key(), - "type": event.kind(), - }); - - if let Some(unsigned) = event.unsigned() { - json["unsigned"] = json!(unsigned); - } - - serde_json::from_value(json).expect("Failed to serialize Event value") - } -} diff --git a/src/core/matrix/event/redact.rs b/src/core/matrix/event/redact.rs deleted file mode 100644 index 5deac874..00000000 --- a/src/core/matrix/event/redact.rs +++ /dev/null @@ -1,86 +0,0 @@ -use ruma::{ - OwnedEventId, RoomVersionId, - events::{TimelineEventType, room::redaction::RoomRedactionEventContent}, -}; -use serde::Deserialize; -use serde_json::value::{RawValue as RawJsonValue, to_raw_value}; - -use super::Event; - -/// Copies the `redacts` property of the event to the `content` dict and -/// vice-versa. -/// -/// This follows the specification's -/// [recommendation](https://spec.matrix.org/v1.10/rooms/v11/#moving-the-redacts-property-of-mroomredaction-events-to-a-content-property): -/// -/// > For backwards-compatibility with older clients, servers should add a -/// > redacts property to the top level of m.room.redaction events in when -/// > serving such events over the Client-Server API. -/// -/// > For improved compatibility with newer clients, servers should add a -/// > redacts property to the content of m.room.redaction events in older -/// > room versions when serving such events over the Client-Server API. -#[must_use] -pub(super) fn copy(event: &E) -> (Option, Box) { - if *event.event_type() != TimelineEventType::RoomRedaction { - return (event.redacts().map(ToOwned::to_owned), event.content().to_owned()); - } - - let Ok(mut content) = event.get_content::() else { - return (event.redacts().map(ToOwned::to_owned), event.content().to_owned()); - }; - - if let Some(redacts) = content.redacts { - return (Some(redacts), event.content().to_owned()); - } - - if let Some(redacts) = event.redacts().map(ToOwned::to_owned) { - content.redacts = Some(redacts); - return ( - event.redacts().map(ToOwned::to_owned), - to_raw_value(&content).expect("Must be valid, we only added redacts field"), - ); - } - - (event.redacts().map(ToOwned::to_owned), event.content().to_owned()) -} - -#[must_use] -pub(super) fn is_redacted(event: &E) -> bool { - let Some(unsigned) = event.unsigned() else { - return false; - }; - - let Ok(unsigned) = ExtractRedactedBecause::deserialize(unsigned) else { - return false; - }; - - unsigned.redacted_because.is_some() -} - -#[must_use] -pub(super) fn redacts_id( - event: &E, - room_version: &RoomVersionId, -) -> Option { - use RoomVersionId::*; - - if *event.kind() != TimelineEventType::RoomRedaction { - return None; - } - - match *room_version { - | V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => - event.redacts().map(ToOwned::to_owned), - | _ => - event - .get_content::() - .ok()? - .redacts, - } -} - -#[derive(Deserialize)] -struct ExtractRedactedBecause { - redacted_because: Option, -} diff --git a/src/core/matrix/event/relation.rs b/src/core/matrix/event/relation.rs deleted file mode 100644 index 58324e86..00000000 --- a/src/core/matrix/event/relation.rs +++ /dev/null @@ -1,28 +0,0 @@ -use ruma::events::relation::RelationType; -use serde::Deserialize; - -use super::Event; - -pub trait RelationTypeEqual { - fn relation_type_equal(&self, event: &E) -> bool; -} - -#[derive(Clone, Debug, Deserialize)] -struct ExtractRelatesToEventId { - #[serde(rename = "m.relates_to")] - relates_to: ExtractRelType, -} - -#[derive(Clone, Debug, Deserialize)] -struct ExtractRelType { - rel_type: RelationType, -} - -impl RelationTypeEqual for RelationType { - fn relation_type_equal(&self, event: &E) -> bool { - event - .get_content() - .map(|c: ExtractRelatesToEventId| c.relates_to.rel_type) - .is_ok_and(|r| r == *self) - } -} diff --git a/src/core/matrix/event/type_ext.rs b/src/core/matrix/event/type_ext.rs deleted file mode 100644 index 9b824d41..00000000 --- a/src/core/matrix/event/type_ext.rs +++ /dev/null @@ -1,32 +0,0 @@ -use ruma::events::{StateEventType, TimelineEventType}; - -use super::StateKey; - -/// Convenience trait for adding event type plus state key to state maps. -pub trait TypeExt { - fn with_state_key(self, state_key: impl Into) -> (StateEventType, StateKey); -} - -impl TypeExt for StateEventType { - fn with_state_key(self, state_key: impl Into) -> (StateEventType, StateKey) { - (self, state_key.into()) - } -} - -impl TypeExt for &StateEventType { - fn with_state_key(self, state_key: impl Into) -> (StateEventType, StateKey) { - (self.clone(), state_key.into()) - } -} - -impl TypeExt for TimelineEventType { - fn with_state_key(self, state_key: impl Into) -> (StateEventType, StateKey) { - (self.into(), state_key.into()) - } -} - -impl TypeExt for &TimelineEventType { - fn with_state_key(self, state_key: impl Into) -> (StateEventType, StateKey) { - (self.clone().into(), state_key.into()) - } -} diff --git a/src/core/matrix/event/unsigned.rs b/src/core/matrix/event/unsigned.rs deleted file mode 100644 index 42928af4..00000000 --- a/src/core/matrix/event/unsigned.rs +++ /dev/null @@ -1,51 +0,0 @@ -use serde::Deserialize; -use serde_json::value::Value as JsonValue; - -use super::Event; -use crate::{Result, err, is_true}; - -pub(super) fn contains_unsigned_property(event: &E, property: &str, is_type: F) -> bool -where - F: FnOnce(&JsonValue) -> bool, - E: Event, -{ - get_unsigned_as_value(event) - .get(property) - .map(is_type) - .is_some_and(is_true!()) -} - -pub(super) fn get_unsigned_property(event: &E, property: &str) -> Result -where - T: for<'de> Deserialize<'de>, - E: Event, -{ - get_unsigned_as_value(event) - .get_mut(property) - .map(JsonValue::take) - .map(serde_json::from_value) - .ok_or(err!(Request(NotFound("property not found in unsigned object"))))? - .map_err(|e| err!(Database("Failed to deserialize unsigned.{property} into type: {e}"))) -} - -#[must_use] -pub(super) fn get_unsigned_as_value(event: &E) -> JsonValue -where - E: Event, -{ - get_unsigned::(event).unwrap_or_default() -} - -pub(super) fn get_unsigned(event: &E) -> Result -where - T: for<'de> Deserialize<'de>, - E: Event, -{ - event - .unsigned() - .as_ref() - .map(|raw| raw.get()) - .map(serde_json::from_str) - .ok_or(err!(Request(NotFound("\"unsigned\" property not found in pdu"))))? - .map_err(|e| err!(Database("Failed to deserialize \"unsigned\" into value: {e}"))) -} diff --git a/src/core/matrix/mod.rs b/src/core/matrix/mod.rs index b38d4c9a..8c978173 100644 --- a/src/core/matrix/mod.rs +++ b/src/core/matrix/mod.rs @@ -2,10 +2,8 @@ pub mod event; pub mod pdu; -pub mod state_key; pub mod state_res; -pub use event::{Event, TypeExt as EventTypeExt}; -pub use pdu::{Pdu, PduBuilder, PduCount, PduEvent, PduId, RawPduId, ShortId}; -pub use state_key::StateKey; -pub use state_res::{RoomVersion, StateMap, TypeStateKey}; +pub use event::Event; +pub use pdu::{PduBuilder, PduCount, PduEvent, PduId, RawPduId, StateKey}; +pub use state_res::{EventTypeExt, RoomVersion, StateMap, TypeStateKey}; diff --git a/src/core/matrix/pdu.rs b/src/core/matrix/pdu.rs index bff0c203..188586bd 100644 --- a/src/core/matrix/pdu.rs +++ b/src/core/matrix/pdu.rs @@ -1,8 +1,14 @@ mod builder; +mod content; mod count; +mod event_id; +mod filter; mod id; mod raw_id; mod redact; +mod relation; +mod state_key; +mod strip; #[cfg(test)] mod tests; mod unsigned; @@ -20,50 +26,38 @@ pub use self::{ Count as PduCount, Id as PduId, Pdu as PduEvent, RawId as RawPduId, builder::{Builder, Builder as PduBuilder}, count::Count, - id::{ShortId, *}, + event_id::*, + id::*, raw_id::*, + state_key::{ShortStateKey, StateKey}, }; -use super::{Event, StateKey}; +use super::Event; use crate::Result; /// Persistent Data Unit (Event) #[derive(Clone, Deserialize, Serialize, Debug)] pub struct Pdu { pub event_id: OwnedEventId, - pub room_id: OwnedRoomId, - pub sender: OwnedUserId, - #[serde(skip_serializing_if = "Option::is_none")] pub origin: Option, - pub origin_server_ts: UInt, - #[serde(rename = "type")] pub kind: TimelineEventType, - pub content: Box, - #[serde(skip_serializing_if = "Option::is_none")] pub state_key: Option, - pub prev_events: Vec, - pub depth: UInt, - pub auth_events: Vec, - #[serde(skip_serializing_if = "Option::is_none")] pub redacts: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] pub unsigned: Option>, - pub hashes: EventHash, - - // BTreeMap, BTreeMap> #[serde(default, skip_serializing_if = "Option::is_none")] + // BTreeMap, BTreeMap> pub signatures: Option>, } @@ -85,106 +79,31 @@ impl Pdu { } impl Event for Pdu { - #[inline] - fn auth_events(&self) -> impl DoubleEndedIterator + Clone + Send + '_ { - self.auth_events.iter().map(AsRef::as_ref) - } - - #[inline] - fn content(&self) -> &RawJsonValue { &self.content } - - #[inline] fn event_id(&self) -> &EventId { &self.event_id } - #[inline] + fn room_id(&self) -> &RoomId { &self.room_id } + + fn sender(&self) -> &UserId { &self.sender } + + fn event_type(&self) -> &TimelineEventType { &self.kind } + + fn content(&self) -> &RawJsonValue { &self.content } + fn origin_server_ts(&self) -> MilliSecondsSinceUnixEpoch { MilliSecondsSinceUnixEpoch(self.origin_server_ts) } - #[inline] - fn prev_events(&self) -> impl DoubleEndedIterator + Clone + Send + '_ { + fn state_key(&self) -> Option<&str> { self.state_key.as_deref() } + + fn prev_events(&self) -> impl DoubleEndedIterator + Send + '_ { self.prev_events.iter().map(AsRef::as_ref) } - #[inline] - fn redacts(&self) -> Option<&EventId> { self.redacts.as_deref() } - - #[inline] - fn room_id(&self) -> &RoomId { &self.room_id } - - #[inline] - fn sender(&self) -> &UserId { &self.sender } - - #[inline] - fn state_key(&self) -> Option<&str> { self.state_key.as_deref() } - - #[inline] - fn kind(&self) -> &TimelineEventType { &self.kind } - - #[inline] - fn unsigned(&self) -> Option<&RawJsonValue> { self.unsigned.as_deref() } - - #[inline] - fn as_mut_pdu(&mut self) -> &mut Pdu { self } - - #[inline] - fn as_pdu(&self) -> &Pdu { self } - - #[inline] - fn into_pdu(self) -> Pdu { self } - - #[inline] - fn is_owned(&self) -> bool { true } -} - -impl Event for &Pdu { - #[inline] - fn auth_events(&self) -> impl DoubleEndedIterator + Clone + Send + '_ { + fn auth_events(&self) -> impl DoubleEndedIterator + Send + '_ { self.auth_events.iter().map(AsRef::as_ref) } - #[inline] - fn content(&self) -> &RawJsonValue { &self.content } - - #[inline] - fn event_id(&self) -> &EventId { &self.event_id } - - #[inline] - fn origin_server_ts(&self) -> MilliSecondsSinceUnixEpoch { - MilliSecondsSinceUnixEpoch(self.origin_server_ts) - } - - #[inline] - fn prev_events(&self) -> impl DoubleEndedIterator + Clone + Send + '_ { - self.prev_events.iter().map(AsRef::as_ref) - } - - #[inline] fn redacts(&self) -> Option<&EventId> { self.redacts.as_deref() } - - #[inline] - fn room_id(&self) -> &RoomId { &self.room_id } - - #[inline] - fn sender(&self) -> &UserId { &self.sender } - - #[inline] - fn state_key(&self) -> Option<&str> { self.state_key.as_deref() } - - #[inline] - fn kind(&self) -> &TimelineEventType { &self.kind } - - #[inline] - fn unsigned(&self) -> Option<&RawJsonValue> { self.unsigned.as_deref() } - - #[inline] - fn as_pdu(&self) -> &Pdu { self } - - #[inline] - fn into_pdu(self) -> Pdu { self.clone() } - - #[inline] - fn is_owned(&self) -> bool { false } } /// Prevent derived equality which wouldn't limit itself to event_id diff --git a/src/core/matrix/pdu/content.rs b/src/core/matrix/pdu/content.rs new file mode 100644 index 00000000..4e60ce6e --- /dev/null +++ b/src/core/matrix/pdu/content.rs @@ -0,0 +1,20 @@ +use serde::Deserialize; +use serde_json::value::Value as JsonValue; + +use crate::{Result, err, implement}; + +#[must_use] +#[implement(super::Pdu)] +pub fn get_content_as_value(&self) -> JsonValue { + self.get_content() + .expect("pdu content must be a valid JSON value") +} + +#[implement(super::Pdu)] +pub fn get_content(&self) -> Result +where + T: for<'de> Deserialize<'de>, +{ + serde_json::from_str(self.content.get()) + .map_err(|e| err!(Database("Failed to deserialize pdu content into type: {e}"))) +} diff --git a/src/core/matrix/event/id.rs b/src/core/matrix/pdu/event_id.rs similarity index 100% rename from src/core/matrix/event/id.rs rename to src/core/matrix/pdu/event_id.rs diff --git a/src/core/matrix/pdu/filter.rs b/src/core/matrix/pdu/filter.rs new file mode 100644 index 00000000..aabf13db --- /dev/null +++ b/src/core/matrix/pdu/filter.rs @@ -0,0 +1,90 @@ +use ruma::api::client::filter::{RoomEventFilter, UrlFilter}; +use serde_json::Value; + +use crate::{implement, is_equal_to}; + +#[implement(super::Pdu)] +#[must_use] +pub fn matches(&self, filter: &RoomEventFilter) -> bool { + if !self.matches_sender(filter) { + return false; + } + + if !self.matches_room(filter) { + return false; + } + + if !self.matches_type(filter) { + return false; + } + + if !self.matches_url(filter) { + return false; + } + + true +} + +#[implement(super::Pdu)] +fn matches_room(&self, filter: &RoomEventFilter) -> bool { + if filter.not_rooms.contains(&self.room_id) { + return false; + } + + if let Some(rooms) = filter.rooms.as_ref() { + if !rooms.contains(&self.room_id) { + return false; + } + } + + true +} + +#[implement(super::Pdu)] +fn matches_sender(&self, filter: &RoomEventFilter) -> bool { + if filter.not_senders.contains(&self.sender) { + return false; + } + + if let Some(senders) = filter.senders.as_ref() { + if !senders.contains(&self.sender) { + return false; + } + } + + true +} + +#[implement(super::Pdu)] +fn matches_type(&self, filter: &RoomEventFilter) -> bool { + let event_type = &self.kind.to_cow_str(); + if filter.not_types.iter().any(is_equal_to!(event_type)) { + return false; + } + + if let Some(types) = filter.types.as_ref() { + if !types.iter().any(is_equal_to!(event_type)) { + return false; + } + } + + true +} + +#[implement(super::Pdu)] +fn matches_url(&self, filter: &RoomEventFilter) -> bool { + let Some(url_filter) = filter.url_filter.as_ref() else { + return true; + }; + + //TODO: might be better to use Ruma's Raw rather than serde here + let url = serde_json::from_str::(self.content.get()) + .expect("parsing content JSON failed") + .get("url") + .is_some_and(Value::is_string); + + match url_filter { + | UrlFilter::EventsWithUrl => url, + | UrlFilter::EventsWithoutUrl => !url, + } +} diff --git a/src/core/matrix/pdu/id.rs b/src/core/matrix/pdu/id.rs index 896d677b..0b23a29f 100644 --- a/src/core/matrix/pdu/id.rs +++ b/src/core/matrix/pdu/id.rs @@ -3,7 +3,6 @@ use crate::utils::u64_from_u8x8; pub type ShortRoomId = ShortId; pub type ShortEventId = ShortId; -pub type ShortStateKey = ShortId; pub type ShortId = u64; #[derive(Clone, Copy, Debug, Eq, PartialEq)] diff --git a/src/core/matrix/pdu/redact.rs b/src/core/matrix/pdu/redact.rs index 896e03f8..409debfe 100644 --- a/src/core/matrix/pdu/redact.rs +++ b/src/core/matrix/pdu/redact.rs @@ -1,29 +1,117 @@ -use ruma::{RoomVersionId, canonical_json::redact_content_in_place}; -use serde_json::{Value as JsonValue, json, value::to_raw_value}; +use ruma::{ + OwnedEventId, RoomVersionId, + canonical_json::redact_content_in_place, + events::{TimelineEventType, room::redaction::RoomRedactionEventContent}, +}; +use serde::Deserialize; +use serde_json::{ + json, + value::{RawValue as RawJsonValue, to_raw_value}, +}; -use crate::{Error, Result, err, implement}; +use crate::{Error, Result, implement}; + +#[derive(Deserialize)] +struct ExtractRedactedBecause { + redacted_because: Option, +} #[implement(super::Pdu)] -pub fn redact(&mut self, room_version_id: &RoomVersionId, reason: JsonValue) -> Result { +pub fn redact(&mut self, room_version_id: &RoomVersionId, reason: &Self) -> Result { self.unsigned = None; let mut content = serde_json::from_str(self.content.get()) - .map_err(|e| err!(Request(BadJson("Failed to deserialize content into type: {e}"))))?; + .map_err(|_| Error::bad_database("PDU in db has invalid content."))?; redact_content_in_place(&mut content, room_version_id, self.kind.to_string()) .map_err(|e| Error::Redaction(self.sender.server_name().to_owned(), e))?; - let reason = serde_json::to_value(reason).expect("Failed to preserialize reason"); + self.unsigned = Some( + to_raw_value(&json!({ + "redacted_because": serde_json::to_value(reason).expect("to_value(Pdu) always works") + })) + .expect("to string always works"), + ); - let redacted_because = json!({ - "redacted_because": reason, - }); - - self.unsigned = to_raw_value(&redacted_because) - .expect("Failed to serialize unsigned") - .into(); - - self.content = to_raw_value(&content).expect("Failed to serialize content"); + self.content = to_raw_value(&content).expect("to string always works"); Ok(()) } + +#[implement(super::Pdu)] +#[must_use] +pub fn is_redacted(&self) -> bool { + let Some(unsigned) = &self.unsigned else { + return false; + }; + + let Ok(unsigned) = ExtractRedactedBecause::deserialize(&**unsigned) else { + return false; + }; + + unsigned.redacted_because.is_some() +} + +/// Copies the `redacts` property of the event to the `content` dict and +/// vice-versa. +/// +/// This follows the specification's +/// [recommendation](https://spec.matrix.org/v1.10/rooms/v11/#moving-the-redacts-property-of-mroomredaction-events-to-a-content-property): +/// +/// > For backwards-compatibility with older clients, servers should add a +/// > redacts +/// > property to the top level of m.room.redaction events in when serving +/// > such events +/// > over the Client-Server API. +/// +/// > For improved compatibility with newer clients, servers should add a +/// > redacts property +/// > to the content of m.room.redaction events in older room versions when +/// > serving +/// > such events over the Client-Server API. +#[implement(super::Pdu)] +#[must_use] +pub fn copy_redacts(&self) -> (Option, Box) { + if self.kind == TimelineEventType::RoomRedaction { + if let Ok(mut content) = + serde_json::from_str::(self.content.get()) + { + match content.redacts { + | Some(redacts) => { + return (Some(redacts), self.content.clone()); + }, + | _ => match self.redacts.clone() { + | Some(redacts) => { + content.redacts = Some(redacts); + return ( + self.redacts.clone(), + to_raw_value(&content) + .expect("Must be valid, we only added redacts field"), + ); + }, + | _ => {}, + }, + } + } + } + + (self.redacts.clone(), self.content.clone()) +} + +#[implement(super::Pdu)] +#[must_use] +pub fn redacts_id(&self, room_version: &RoomVersionId) -> Option { + use RoomVersionId::*; + + if self.kind != TimelineEventType::RoomRedaction { + return None; + } + + match *room_version { + | V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => self.redacts.clone(), + | _ => + self.get_content::() + .ok()? + .redacts, + } +} diff --git a/src/core/matrix/pdu/relation.rs b/src/core/matrix/pdu/relation.rs new file mode 100644 index 00000000..2968171e --- /dev/null +++ b/src/core/matrix/pdu/relation.rs @@ -0,0 +1,22 @@ +use ruma::events::relation::RelationType; +use serde::Deserialize; + +use crate::implement; + +#[derive(Clone, Debug, Deserialize)] +struct ExtractRelType { + rel_type: RelationType, +} +#[derive(Clone, Debug, Deserialize)] +struct ExtractRelatesToEventId { + #[serde(rename = "m.relates_to")] + relates_to: ExtractRelType, +} + +#[implement(super::Pdu)] +#[must_use] +pub fn relation_type_equal(&self, rel_type: &RelationType) -> bool { + self.get_content() + .map(|c: ExtractRelatesToEventId| c.relates_to.rel_type) + .is_ok_and(|r| r == *rel_type) +} diff --git a/src/core/matrix/state_key.rs b/src/core/matrix/pdu/state_key.rs similarity index 67% rename from src/core/matrix/state_key.rs rename to src/core/matrix/pdu/state_key.rs index 06d614f8..4af4fcf7 100644 --- a/src/core/matrix/state_key.rs +++ b/src/core/matrix/pdu/state_key.rs @@ -1,5 +1,8 @@ use smallstr::SmallString; +use super::ShortId; + pub type StateKey = SmallString<[u8; INLINE_SIZE]>; +pub type ShortStateKey = ShortId; const INLINE_SIZE: usize = 48; diff --git a/src/core/matrix/pdu/strip.rs b/src/core/matrix/pdu/strip.rs new file mode 100644 index 00000000..a39e7d35 --- /dev/null +++ b/src/core/matrix/pdu/strip.rs @@ -0,0 +1,257 @@ +use ruma::{ + events::{ + AnyMessageLikeEvent, AnyStateEvent, AnyStrippedStateEvent, AnySyncStateEvent, + AnySyncTimelineEvent, AnyTimelineEvent, StateEvent, room::member::RoomMemberEventContent, + space::child::HierarchySpaceChildEvent, + }, + serde::Raw, +}; +use serde_json::{json, value::Value as JsonValue}; + +use crate::implement; + +#[implement(super::Pdu)] +#[must_use] +#[inline] +pub fn into_room_event(self) -> Raw { self.to_room_event() } + +#[implement(super::Pdu)] +#[must_use] +pub fn to_room_event(&self) -> Raw { + let value = self.to_room_event_value(); + serde_json::from_value(value).expect("Failed to serialize Event value") +} + +#[implement(super::Pdu)] +#[must_use] +#[inline] +pub fn to_room_event_value(&self) -> JsonValue { + let (redacts, content) = self.copy_redacts(); + let mut json = json!({ + "content": content, + "type": self.kind, + "event_id": self.event_id, + "sender": self.sender, + "origin_server_ts": self.origin_server_ts, + "room_id": self.room_id, + }); + + if let Some(unsigned) = &self.unsigned { + json["unsigned"] = json!(unsigned); + } + if let Some(state_key) = &self.state_key { + json["state_key"] = json!(state_key); + } + if let Some(redacts) = &redacts { + json["redacts"] = json!(redacts); + } + + json +} + +#[implement(super::Pdu)] +#[must_use] +#[inline] +pub fn into_message_like_event(self) -> Raw { self.to_message_like_event() } + +#[implement(super::Pdu)] +#[must_use] +pub fn to_message_like_event(&self) -> Raw { + let value = self.to_message_like_event_value(); + serde_json::from_value(value).expect("Failed to serialize Event value") +} + +#[implement(super::Pdu)] +#[must_use] +#[inline] +pub fn to_message_like_event_value(&self) -> JsonValue { + let (redacts, content) = self.copy_redacts(); + let mut json = json!({ + "content": content, + "type": self.kind, + "event_id": self.event_id, + "sender": self.sender, + "origin_server_ts": self.origin_server_ts, + "room_id": self.room_id, + }); + + if let Some(unsigned) = &self.unsigned { + json["unsigned"] = json!(unsigned); + } + if let Some(state_key) = &self.state_key { + json["state_key"] = json!(state_key); + } + if let Some(redacts) = &redacts { + json["redacts"] = json!(redacts); + } + + json +} + +#[implement(super::Pdu)] +#[must_use] +#[inline] +pub fn into_sync_room_event(self) -> Raw { self.to_sync_room_event() } + +#[implement(super::Pdu)] +#[must_use] +pub fn to_sync_room_event(&self) -> Raw { + let value = self.to_sync_room_event_value(); + serde_json::from_value(value).expect("Failed to serialize Event value") +} + +#[implement(super::Pdu)] +#[must_use] +#[inline] +pub fn to_sync_room_event_value(&self) -> JsonValue { + let (redacts, content) = self.copy_redacts(); + let mut json = json!({ + "content": content, + "type": self.kind, + "event_id": self.event_id, + "sender": self.sender, + "origin_server_ts": self.origin_server_ts, + }); + + if let Some(unsigned) = &self.unsigned { + json["unsigned"] = json!(unsigned); + } + if let Some(state_key) = &self.state_key { + json["state_key"] = json!(state_key); + } + if let Some(redacts) = &redacts { + json["redacts"] = json!(redacts); + } + + json +} + +#[implement(super::Pdu)] +#[must_use] +pub fn into_state_event(self) -> Raw { + let value = self.into_state_event_value(); + serde_json::from_value(value).expect("Failed to serialize Event value") +} + +#[implement(super::Pdu)] +#[must_use] +#[inline] +pub fn into_state_event_value(self) -> JsonValue { + let mut json = json!({ + "content": self.content, + "type": self.kind, + "event_id": self.event_id, + "sender": self.sender, + "origin_server_ts": self.origin_server_ts, + "room_id": self.room_id, + "state_key": self.state_key, + }); + + if let Some(unsigned) = self.unsigned { + json["unsigned"] = json!(unsigned); + } + + json +} + +#[implement(super::Pdu)] +#[must_use] +pub fn into_sync_state_event(self) -> Raw { + let value = self.into_sync_state_event_value(); + serde_json::from_value(value).expect("Failed to serialize Event value") +} + +#[implement(super::Pdu)] +#[must_use] +#[inline] +pub fn into_sync_state_event_value(self) -> JsonValue { + let mut json = json!({ + "content": self.content, + "type": self.kind, + "event_id": self.event_id, + "sender": self.sender, + "origin_server_ts": self.origin_server_ts, + "state_key": self.state_key, + }); + + if let Some(unsigned) = &self.unsigned { + json["unsigned"] = json!(unsigned); + } + + json +} + +#[implement(super::Pdu)] +#[must_use] +#[inline] +pub fn into_stripped_state_event(self) -> Raw { + self.to_stripped_state_event() +} + +#[implement(super::Pdu)] +#[must_use] +pub fn to_stripped_state_event(&self) -> Raw { + let value = self.to_stripped_state_event_value(); + serde_json::from_value(value).expect("Failed to serialize Event value") +} + +#[implement(super::Pdu)] +#[must_use] +#[inline] +pub fn to_stripped_state_event_value(&self) -> JsonValue { + json!({ + "content": self.content, + "type": self.kind, + "sender": self.sender, + "state_key": self.state_key, + }) +} + +#[implement(super::Pdu)] +#[must_use] +pub fn into_stripped_spacechild_state_event(self) -> Raw { + let value = self.into_stripped_spacechild_state_event_value(); + serde_json::from_value(value).expect("Failed to serialize Event value") +} + +#[implement(super::Pdu)] +#[must_use] +#[inline] +pub fn into_stripped_spacechild_state_event_value(self) -> JsonValue { + json!({ + "content": self.content, + "type": self.kind, + "sender": self.sender, + "state_key": self.state_key, + "origin_server_ts": self.origin_server_ts, + }) +} + +#[implement(super::Pdu)] +#[must_use] +pub fn into_member_event(self) -> Raw> { + let value = self.into_member_event_value(); + serde_json::from_value(value).expect("Failed to serialize Event value") +} + +#[implement(super::Pdu)] +#[must_use] +#[inline] +pub fn into_member_event_value(self) -> JsonValue { + let mut json = json!({ + "content": self.content, + "type": self.kind, + "event_id": self.event_id, + "sender": self.sender, + "origin_server_ts": self.origin_server_ts, + "redacts": self.redacts, + "room_id": self.room_id, + "state_key": self.state_key, + }); + + if let Some(unsigned) = self.unsigned { + json["unsigned"] = json!(unsigned); + } + + json +} diff --git a/src/core/matrix/pdu/unsigned.rs b/src/core/matrix/pdu/unsigned.rs index 0c58bb68..2726a292 100644 --- a/src/core/matrix/pdu/unsigned.rs +++ b/src/core/matrix/pdu/unsigned.rs @@ -1,10 +1,24 @@ -use std::collections::BTreeMap; +use std::{borrow::Borrow, collections::BTreeMap}; use ruma::MilliSecondsSinceUnixEpoch; +use serde::Deserialize; use serde_json::value::{RawValue as RawJsonValue, Value as JsonValue, to_raw_value}; use super::Pdu; -use crate::{Result, err, implement}; +use crate::{Result, err, implement, is_true, result::LogErr}; + +/// Set the `unsigned` field of the PDU using only information in the PDU. +/// Some unsigned data is already set within the database (eg. prev events, +/// threads). Once this is done, other data must be calculated from the database +/// (eg. relations) This is for server-to-client events. +/// Backfill handles this itself. +#[implement(Pdu)] +pub fn set_unsigned(&mut self, user_id: Option<&ruma::UserId>) { + if Some(self.sender.borrow()) != user_id { + self.remove_transaction_id().log_err().ok(); + } + self.add_age().log_err().ok(); +} #[implement(Pdu)] pub fn remove_transaction_id(&mut self) -> Result { @@ -73,3 +87,43 @@ pub fn add_relation(&mut self, name: &str, pdu: Option<&Pdu>) -> Result { Ok(()) } + +#[implement(Pdu)] +pub fn contains_unsigned_property(&self, property: &str, is_type: F) -> bool +where + F: FnOnce(&JsonValue) -> bool, +{ + self.get_unsigned_as_value() + .get(property) + .map(is_type) + .is_some_and(is_true!()) +} + +#[implement(Pdu)] +pub fn get_unsigned_property(&self, property: &str) -> Result +where + T: for<'de> Deserialize<'de>, +{ + self.get_unsigned_as_value() + .get_mut(property) + .map(JsonValue::take) + .map(serde_json::from_value) + .ok_or(err!(Request(NotFound("property not found in unsigned object"))))? + .map_err(|e| err!(Database("Failed to deserialize unsigned.{property} into type: {e}"))) +} + +#[implement(Pdu)] +#[must_use] +pub fn get_unsigned_as_value(&self) -> JsonValue { + self.get_unsigned::().unwrap_or_default() +} + +#[implement(Pdu)] +pub fn get_unsigned(&self) -> Result { + self.unsigned + .as_ref() + .map(|raw| raw.get()) + .map(serde_json::from_str) + .ok_or(err!(Request(NotFound("\"unsigned\" property not found in pdu"))))? + .map_err(|e| err!(Database("Failed to deserialize \"unsigned\" into value: {e}"))) +} diff --git a/src/core/matrix/state_res/benches.rs b/src/core/matrix/state_res/benches.rs index 69088369..12eeab9d 100644 --- a/src/core/matrix/state_res/benches.rs +++ b/src/core/matrix/state_res/benches.rs @@ -13,6 +13,7 @@ use ruma::{ EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, RoomId, RoomVersionId, Signatures, UserId, events::{ StateEventType, TimelineEventType, + pdu::{EventHash, Pdu, RoomV3Pdu}, room::{ join_rules::{JoinRule, RoomJoinRulesEventContent}, member::{MembershipState, RoomMemberEventContent}, @@ -25,10 +26,8 @@ use serde_json::{ value::{RawValue as RawJsonValue, to_raw_value as to_raw_json_value}, }; -use crate::{ - matrix::{Event, Pdu, pdu::EventHash}, - state_res::{self as state_res, Error, Result, StateMap}, -}; +use self::event::PduEvent; +use crate::state_res::{self as state_res, Error, Event, Result, StateMap}; static SERVER_TIMESTAMP: AtomicU64 = AtomicU64::new(0); @@ -61,7 +60,7 @@ fn resolution_shallow_auth_chain(c: &mut test::Bencher) { c.iter(|| async { let ev_map = store.0.clone(); let state_sets = [&state_at_bob, &state_at_charlie]; - let fetch = |id: OwnedEventId| ready(ev_map.get(&id).map(ToOwned::to_owned)); + let fetch = |id: OwnedEventId| ready(ev_map.get(&id).clone()); let exists = |id: OwnedEventId| ready(ev_map.get(&id).is_some()); let auth_chain_sets: Vec> = state_sets .iter() @@ -143,7 +142,7 @@ fn resolve_deeper_event_set(c: &mut test::Bencher) { }) .collect(); - let fetch = |id: OwnedEventId| ready(inner.get(&id).map(ToOwned::to_owned)); + let fetch = |id: OwnedEventId| ready(inner.get(&id).clone()); let exists = |id: OwnedEventId| ready(inner.get(&id).is_some()); let _ = match state_res::resolve( &RoomVersionId::V6, @@ -247,7 +246,7 @@ impl TestStore { } } -impl TestStore { +impl TestStore { #[allow(clippy::type_complexity)] fn set_up( &mut self, @@ -381,7 +380,7 @@ fn to_pdu_event( content: Box, auth_events: &[S], prev_events: &[S], -) -> Pdu +) -> PduEvent where S: AsRef, { @@ -404,28 +403,30 @@ where .map(event_id) .collect::>(); - Pdu { + let state_key = state_key.map(ToOwned::to_owned); + PduEvent { event_id: id.try_into().unwrap(), - room_id: room_id().to_owned(), - sender: sender.to_owned(), - origin_server_ts: ts.try_into().unwrap(), - state_key: state_key.map(Into::into), - kind: ev_type, - content, - origin: None, - redacts: None, - unsigned: None, - auth_events, - prev_events, - depth: uint!(0), - hashes: EventHash { sha256: String::new() }, - signatures: None, + rest: Pdu::RoomV3Pdu(RoomV3Pdu { + room_id: room_id().to_owned(), + sender: sender.to_owned(), + origin_server_ts: MilliSecondsSinceUnixEpoch(ts.try_into().unwrap()), + state_key, + kind: ev_type, + content, + redacts: None, + unsigned: btreemap! {}, + auth_events, + prev_events, + depth: uint!(0), + hashes: EventHash::new(String::new()), + signatures: Signatures::new(), + }), } } // all graphs start with these input events #[allow(non_snake_case)] -fn INITIAL_EVENTS() -> HashMap { +fn INITIAL_EVENTS() -> HashMap { vec![ to_pdu_event::<&EventId>( "CREATE", @@ -507,7 +508,7 @@ fn INITIAL_EVENTS() -> HashMap { // all graphs start with these input events #[allow(non_snake_case)] -fn BAN_STATE_SET() -> HashMap { +fn BAN_STATE_SET() -> HashMap { vec![ to_pdu_event( "PA", @@ -550,3 +551,119 @@ fn BAN_STATE_SET() -> HashMap { .map(|ev| (ev.event_id().to_owned(), ev)) .collect() } + +/// Convenience trait for adding event type plus state key to state maps. +trait EventTypeExt { + fn with_state_key(self, state_key: impl Into) -> (StateEventType, String); +} + +impl EventTypeExt for &TimelineEventType { + fn with_state_key(self, state_key: impl Into) -> (StateEventType, String) { + (self.to_string().into(), state_key.into()) + } +} + +mod event { + use ruma::{ + EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, RoomId, UserId, + events::{TimelineEventType, pdu::Pdu}, + }; + use serde::{Deserialize, Serialize}; + use serde_json::value::RawValue as RawJsonValue; + + use super::Event; + + impl Event for PduEvent { + fn event_id(&self) -> &EventId { &self.event_id } + + fn room_id(&self) -> &RoomId { + match &self.rest { + | Pdu::RoomV1Pdu(ev) => &ev.room_id, + | Pdu::RoomV3Pdu(ev) => &ev.room_id, + #[cfg(not(feature = "unstable-exhaustive-types"))] + | _ => unreachable!("new PDU version"), + } + } + + fn sender(&self) -> &UserId { + match &self.rest { + | Pdu::RoomV1Pdu(ev) => &ev.sender, + | Pdu::RoomV3Pdu(ev) => &ev.sender, + #[cfg(not(feature = "unstable-exhaustive-types"))] + | _ => unreachable!("new PDU version"), + } + } + + fn event_type(&self) -> &TimelineEventType { + match &self.rest { + | Pdu::RoomV1Pdu(ev) => &ev.kind, + | Pdu::RoomV3Pdu(ev) => &ev.kind, + #[cfg(not(feature = "unstable-exhaustive-types"))] + | _ => unreachable!("new PDU version"), + } + } + + fn content(&self) -> &RawJsonValue { + match &self.rest { + | Pdu::RoomV1Pdu(ev) => &ev.content, + | Pdu::RoomV3Pdu(ev) => &ev.content, + #[cfg(not(feature = "unstable-exhaustive-types"))] + | _ => unreachable!("new PDU version"), + } + } + + fn origin_server_ts(&self) -> MilliSecondsSinceUnixEpoch { + match &self.rest { + | Pdu::RoomV1Pdu(ev) => ev.origin_server_ts, + | Pdu::RoomV3Pdu(ev) => ev.origin_server_ts, + #[cfg(not(feature = "unstable-exhaustive-types"))] + | _ => unreachable!("new PDU version"), + } + } + + fn state_key(&self) -> Option<&str> { + match &self.rest { + | Pdu::RoomV1Pdu(ev) => ev.state_key.as_deref(), + | Pdu::RoomV3Pdu(ev) => ev.state_key.as_deref(), + #[cfg(not(feature = "unstable-exhaustive-types"))] + | _ => unreachable!("new PDU version"), + } + } + + fn prev_events(&self) -> Box + Send + '_> { + match &self.rest { + | Pdu::RoomV1Pdu(ev) => + Box::new(ev.prev_events.iter().map(|(id, _)| id.as_ref())), + | Pdu::RoomV3Pdu(ev) => Box::new(ev.prev_events.iter().map(AsRef::as_ref)), + #[cfg(not(feature = "unstable-exhaustive-types"))] + | _ => unreachable!("new PDU version"), + } + } + + fn auth_events(&self) -> Box + Send + '_> { + match &self.rest { + | Pdu::RoomV1Pdu(ev) => + Box::new(ev.auth_events.iter().map(|(id, _)| id.as_ref())), + | Pdu::RoomV3Pdu(ev) => Box::new(ev.auth_events.iter().map(AsRef::as_ref)), + #[cfg(not(feature = "unstable-exhaustive-types"))] + | _ => unreachable!("new PDU version"), + } + } + + fn redacts(&self) -> Option<&EventId> { + match &self.rest { + | Pdu::RoomV1Pdu(ev) => ev.redacts.as_deref(), + | Pdu::RoomV3Pdu(ev) => ev.redacts.as_deref(), + #[cfg(not(feature = "unstable-exhaustive-types"))] + | _ => unreachable!("new PDU version"), + } + } + } + + #[derive(Clone, Debug, Deserialize, Serialize)] + pub(crate) struct PduEvent { + pub(crate) event_id: OwnedEventId, + #[serde(flatten)] + pub(crate) rest: Pdu, + } +} diff --git a/src/core/matrix/state_res/event_auth.rs b/src/core/matrix/state_res/event_auth.rs index 8833cbfb..67283b6a 100644 --- a/src/core/matrix/state_res/event_auth.rs +++ b/src/core/matrix/state_res/event_auth.rs @@ -136,17 +136,17 @@ pub fn auth_types_for_event( event_id = incoming_event.event_id().as_str(), ) )] -pub async fn auth_check( +pub async fn auth_check( room_version: &RoomVersion, - incoming_event: &E, - current_third_party_invite: Option<&E>, + incoming_event: &Incoming, + current_third_party_invite: Option<&Incoming>, fetch_state: F, ) -> Result where F: Fn(&StateEventType, &str) -> Fut + Send, - Fut: Future> + Send, - E: Event + Send + Sync, - for<'a> &'a E: Event + Send, + Fut: Future> + Send, + Fetched: Event + Send, + Incoming: Event + Send + Sync, { debug!( event_id = format!("{}", incoming_event.event_id()), @@ -541,24 +541,20 @@ where /// event and the current State. #[allow(clippy::too_many_arguments)] #[allow(clippy::cognitive_complexity)] -fn valid_membership_change( +fn valid_membership_change( room_version: &RoomVersion, target_user: &UserId, - target_user_membership_event: Option<&E>, + target_user_membership_event: Option<&impl Event>, sender: &UserId, - sender_membership_event: Option<&E>, - current_event: &E, - current_third_party_invite: Option<&E>, - power_levels_event: Option<&E>, - join_rules_event: Option<&E>, + sender_membership_event: Option<&impl Event>, + current_event: impl Event, + current_third_party_invite: Option<&impl Event>, + power_levels_event: Option<&impl Event>, + join_rules_event: Option<&impl Event>, user_for_join_auth: Option<&UserId>, user_for_join_auth_membership: &MembershipState, - create_room: &E, -) -> Result -where - E: Event + Send + Sync, - for<'a> &'a E: Event + Send, -{ + create_room: &impl Event, +) -> Result { #[derive(Deserialize)] struct GetThirdPartyInvite { third_party_invite: Option>, @@ -851,7 +847,7 @@ where /// /// Does the event have the correct userId as its state_key if it's not the "" /// state_key. -fn can_send_event(event: &impl Event, ple: Option<&impl Event>, user_level: Int) -> bool { +fn can_send_event(event: impl Event, ple: Option, user_level: Int) -> bool { let event_type_power_level = get_send_level(event.event_type(), event.state_key(), ple); debug!( @@ -877,8 +873,8 @@ fn can_send_event(event: &impl Event, ple: Option<&impl Event>, user_level: Int) /// Confirm that the event sender has the required power levels. fn check_power_levels( room_version: &RoomVersion, - power_event: &impl Event, - previous_power_event: Option<&impl Event>, + power_event: impl Event, + previous_power_event: Option, user_level: Int, ) -> Option { match power_event.state_key() { @@ -1041,7 +1037,7 @@ fn get_deserialize_levels( /// given event. fn check_redaction( _room_version: &RoomVersion, - redaction_event: &impl Event, + redaction_event: impl Event, user_level: Int, redact_level: Int, ) -> Result { @@ -1070,7 +1066,7 @@ fn check_redaction( fn get_send_level( e_type: &TimelineEventType, state_key: Option<&str>, - power_lvl: Option<&impl Event>, + power_lvl: Option, ) -> Int { power_lvl .and_then(|ple| { @@ -1093,7 +1089,7 @@ fn verify_third_party_invite( target_user: Option<&UserId>, sender: &UserId, tp_id: &ThirdPartyInvite, - current_third_party_invite: Option<&impl Event>, + current_third_party_invite: Option, ) -> bool { // 1. Check for user being banned happens before this is called // checking for mxid and token keys is done by ruma when deserializing @@ -1159,15 +1155,12 @@ mod tests { }; use serde_json::value::to_raw_value as to_raw_json_value; - use crate::{ - matrix::{Event, EventTypeExt, Pdu as PduEvent}, - state_res::{ - RoomVersion, StateMap, - event_auth::valid_membership_change, - test_utils::{ - INITIAL_EVENTS, INITIAL_EVENTS_CREATE_ROOM, alice, charlie, ella, event_id, - member_content_ban, member_content_join, room_id, to_pdu_event, - }, + use crate::state_res::{ + Event, EventTypeExt, RoomVersion, StateMap, + event_auth::valid_membership_change, + test_utils::{ + INITIAL_EVENTS, INITIAL_EVENTS_CREATE_ROOM, PduEvent, alice, charlie, ella, event_id, + member_content_ban, member_content_join, room_id, to_pdu_event, }, }; diff --git a/src/core/matrix/state_res/mod.rs b/src/core/matrix/state_res/mod.rs index 771e364f..ea49b29e 100644 --- a/src/core/matrix/state_res/mod.rs +++ b/src/core/matrix/state_res/mod.rs @@ -37,7 +37,7 @@ pub use self::{ }; use crate::{ debug, debug_error, - matrix::{Event, StateKey}, + matrix::{event::Event, pdu::StateKey}, trace, utils::stream::{BroadbandExt, IterStream, ReadyExt, TryBroadbandExt, WidebandExt}, warn, @@ -74,7 +74,7 @@ type Result = crate::Result; /// event is part of the same room. //#[tracing::instrument(level = "debug", skip(state_sets, auth_chain_sets, //#[tracing::instrument(level event_fetch))] -pub async fn resolve<'a, Pdu, Sets, SetIter, Hasher, Fetch, FetchFut, Exists, ExistsFut>( +pub async fn resolve<'a, E, Sets, SetIter, Hasher, Fetch, FetchFut, Exists, ExistsFut>( room_version: &RoomVersionId, state_sets: Sets, auth_chain_sets: &'a [HashSet], @@ -83,14 +83,14 @@ pub async fn resolve<'a, Pdu, Sets, SetIter, Hasher, Fetch, FetchFut, Exists, Ex ) -> Result> where Fetch: Fn(OwnedEventId) -> FetchFut + Sync, - FetchFut: Future> + Send, + FetchFut: Future> + Send, Exists: Fn(OwnedEventId) -> ExistsFut + Sync, ExistsFut: Future + Send, Sets: IntoIterator + Send, SetIter: Iterator> + Clone + Send, Hasher: BuildHasher + Send + Sync, - Pdu: Event + Clone + Send + Sync, - for<'b> &'b Pdu: Event + Send, + E: Event + Clone + Send + Sync, + for<'b> &'b E: Send, { debug!("State resolution starting"); @@ -221,7 +221,6 @@ where let state_sets_iter = state_sets_iter.inspect(|_| state_set_count = state_set_count.saturating_add(1)); - for (k, v) in state_sets_iter.flatten() { occurrences .entry(k) @@ -306,7 +305,6 @@ where let pl = get_power_level_for_sender(&event_id, fetch_event) .await .ok()?; - Some((event_id, pl)) }) .inspect(|(event_id, pl)| { @@ -524,7 +522,6 @@ where Fut: Future> + Send, S: Stream + Send + 'a, E: Event + Clone + Send + Sync, - for<'b> &'b E: Event + Send, { debug!("starting iterative auth check"); @@ -555,7 +552,7 @@ where let auth_events = &auth_events; let mut resolved_state = unconflicted_state; - for event in events_to_check { + for event in &events_to_check { let state_key = event .state_key() .ok_or_else(|| Error::InvalidPdu("State event had no state key".to_owned()))?; @@ -610,15 +607,11 @@ where }); let fetch_state = |ty: &StateEventType, key: &str| { - future::ready( - auth_state - .get(&ty.with_state_key(key)) - .map(ToOwned::to_owned), - ) + future::ready(auth_state.get(&ty.with_state_key(key))) }; debug!("running auth check on {:?}", event.event_id()); let auth_result = - auth_check(room_version, &event, current_third_party, fetch_state).await; + auth_check(room_version, &event, current_third_party.as_ref(), fetch_state).await; match auth_result { | Ok(true) => { @@ -805,11 +798,11 @@ where } } -fn is_type_and_key(ev: &impl Event, ev_type: &TimelineEventType, state_key: &str) -> bool { +fn is_type_and_key(ev: impl Event, ev_type: &TimelineEventType, state_key: &str) -> bool { ev.event_type() == ev_type && ev.state_key() == Some(state_key) } -fn is_power_event(event: &impl Event) -> bool { +fn is_power_event(event: impl Event) -> bool { match event.event_type() { | TimelineEventType::RoomPowerLevels | TimelineEventType::RoomJoinRules @@ -870,19 +863,15 @@ mod tests { use serde_json::{json, value::to_raw_value as to_raw_json_value}; use super::{ - StateMap, is_power_event, + Event, EventTypeExt, StateMap, is_power_event, room_version::RoomVersion, test_utils::{ - INITIAL_EVENTS, TestStore, alice, bob, charlie, do_check, ella, event_id, + INITIAL_EVENTS, PduEvent, TestStore, alice, bob, charlie, do_check, ella, event_id, member_content_ban, member_content_join, room_id, to_init_pdu_event, to_pdu_event, zara, }, }; - use crate::{ - debug, - matrix::{Event, EventTypeExt, Pdu as PduEvent}, - utils::stream::IterStream, - }; + use crate::{debug, utils::stream::IterStream}; async fn test_event_sort() { use futures::future::ready; diff --git a/src/core/matrix/state_res/test_utils.rs b/src/core/matrix/state_res/test_utils.rs index 9f24c51b..c6945f66 100644 --- a/src/core/matrix/state_res/test_utils.rs +++ b/src/core/matrix/state_res/test_utils.rs @@ -10,6 +10,7 @@ use ruma::{ UserId, event_id, events::{ TimelineEventType, + pdu::{EventHash, Pdu, RoomV3Pdu}, room::{ join_rules::{JoinRule, RoomJoinRulesEventContent}, member::{MembershipState, RoomMemberEventContent}, @@ -22,16 +23,17 @@ use serde_json::{ value::{RawValue as RawJsonValue, to_raw_value as to_raw_json_value}, }; +pub(crate) use self::event::PduEvent; use super::auth_types_for_event; use crate::{ Result, info, - matrix::{Event, EventTypeExt, Pdu, StateMap, pdu::EventHash}, + matrix::{Event, EventTypeExt, StateMap}, }; static SERVER_TIMESTAMP: AtomicU64 = AtomicU64::new(0); pub(crate) async fn do_check( - events: &[Pdu], + events: &[PduEvent], edges: Vec>, expected_state_ids: Vec, ) { @@ -79,8 +81,8 @@ pub(crate) async fn do_check( } } - // event_id -> Pdu - let mut event_map: HashMap = HashMap::new(); + // event_id -> PduEvent + let mut event_map: HashMap = HashMap::new(); // event_id -> StateMap let mut state_at_event: HashMap> = HashMap::new(); @@ -263,7 +265,7 @@ impl TestStore { // A StateStore implementation for testing #[allow(clippy::type_complexity)] -impl TestStore { +impl TestStore { pub(crate) fn set_up( &mut self, ) -> (StateMap, StateMap, StateMap) { @@ -388,7 +390,7 @@ pub(crate) fn to_init_pdu_event( ev_type: TimelineEventType, state_key: Option<&str>, content: Box, -) -> Pdu { +) -> PduEvent { let ts = SERVER_TIMESTAMP.fetch_add(1, SeqCst); let id = if id.contains('$') { id.to_owned() @@ -396,22 +398,24 @@ pub(crate) fn to_init_pdu_event( format!("${id}:foo") }; - Pdu { + let state_key = state_key.map(ToOwned::to_owned); + PduEvent { event_id: id.try_into().unwrap(), - room_id: room_id().to_owned(), - sender: sender.to_owned(), - origin_server_ts: ts.try_into().unwrap(), - state_key: state_key.map(Into::into), - kind: ev_type, - content, - origin: None, - redacts: None, - unsigned: None, - auth_events: vec![], - prev_events: vec![], - depth: uint!(0), - hashes: EventHash { sha256: "".to_owned() }, - signatures: None, + rest: Pdu::RoomV3Pdu(RoomV3Pdu { + room_id: room_id().to_owned(), + sender: sender.to_owned(), + origin_server_ts: MilliSecondsSinceUnixEpoch(ts.try_into().unwrap()), + state_key, + kind: ev_type, + content, + redacts: None, + unsigned: BTreeMap::new(), + auth_events: vec![], + prev_events: vec![], + depth: uint!(0), + hashes: EventHash::new("".to_owned()), + signatures: ServerSignatures::default(), + }), } } @@ -423,7 +427,7 @@ pub(crate) fn to_pdu_event( content: Box, auth_events: &[S], prev_events: &[S], -) -> Pdu +) -> PduEvent where S: AsRef, { @@ -444,28 +448,30 @@ where .map(event_id) .collect::>(); - Pdu { + let state_key = state_key.map(ToOwned::to_owned); + PduEvent { event_id: id.try_into().unwrap(), - room_id: room_id().to_owned(), - sender: sender.to_owned(), - origin_server_ts: ts.try_into().unwrap(), - state_key: state_key.map(Into::into), - kind: ev_type, - content, - origin: None, - redacts: None, - unsigned: None, - auth_events, - prev_events, - depth: uint!(0), - hashes: EventHash { sha256: "".to_owned() }, - signatures: None, + rest: Pdu::RoomV3Pdu(RoomV3Pdu { + room_id: room_id().to_owned(), + sender: sender.to_owned(), + origin_server_ts: MilliSecondsSinceUnixEpoch(ts.try_into().unwrap()), + state_key, + kind: ev_type, + content, + redacts: None, + unsigned: BTreeMap::new(), + auth_events, + prev_events, + depth: uint!(0), + hashes: EventHash::new("".to_owned()), + signatures: ServerSignatures::default(), + }), } } // all graphs start with these input events #[allow(non_snake_case)] -pub(crate) fn INITIAL_EVENTS() -> HashMap { +pub(crate) fn INITIAL_EVENTS() -> HashMap { vec![ to_pdu_event::<&EventId>( "CREATE", @@ -547,7 +553,7 @@ pub(crate) fn INITIAL_EVENTS() -> HashMap { // all graphs start with these input events #[allow(non_snake_case)] -pub(crate) fn INITIAL_EVENTS_CREATE_ROOM() -> HashMap { +pub(crate) fn INITIAL_EVENTS_CREATE_ROOM() -> HashMap { vec![to_pdu_event::<&EventId>( "CREATE", alice(), @@ -569,3 +575,111 @@ pub(crate) fn INITIAL_EDGES() -> Vec { .map(event_id) .collect::>() } + +pub(crate) mod event { + use ruma::{ + EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, RoomId, UserId, + events::{TimelineEventType, pdu::Pdu}, + }; + use serde::{Deserialize, Serialize}; + use serde_json::value::RawValue as RawJsonValue; + + use crate::Event; + + impl Event for PduEvent { + fn event_id(&self) -> &EventId { &self.event_id } + + fn room_id(&self) -> &RoomId { + match &self.rest { + | Pdu::RoomV1Pdu(ev) => &ev.room_id, + | Pdu::RoomV3Pdu(ev) => &ev.room_id, + #[allow(unreachable_patterns)] + | _ => unreachable!("new PDU version"), + } + } + + fn sender(&self) -> &UserId { + match &self.rest { + | Pdu::RoomV1Pdu(ev) => &ev.sender, + | Pdu::RoomV3Pdu(ev) => &ev.sender, + #[allow(unreachable_patterns)] + | _ => unreachable!("new PDU version"), + } + } + + fn event_type(&self) -> &TimelineEventType { + match &self.rest { + | Pdu::RoomV1Pdu(ev) => &ev.kind, + | Pdu::RoomV3Pdu(ev) => &ev.kind, + #[allow(unreachable_patterns)] + | _ => unreachable!("new PDU version"), + } + } + + fn content(&self) -> &RawJsonValue { + match &self.rest { + | Pdu::RoomV1Pdu(ev) => &ev.content, + | Pdu::RoomV3Pdu(ev) => &ev.content, + #[allow(unreachable_patterns)] + | _ => unreachable!("new PDU version"), + } + } + + fn origin_server_ts(&self) -> MilliSecondsSinceUnixEpoch { + match &self.rest { + | Pdu::RoomV1Pdu(ev) => ev.origin_server_ts, + | Pdu::RoomV3Pdu(ev) => ev.origin_server_ts, + #[allow(unreachable_patterns)] + | _ => unreachable!("new PDU version"), + } + } + + fn state_key(&self) -> Option<&str> { + match &self.rest { + | Pdu::RoomV1Pdu(ev) => ev.state_key.as_deref(), + | Pdu::RoomV3Pdu(ev) => ev.state_key.as_deref(), + #[allow(unreachable_patterns)] + | _ => unreachable!("new PDU version"), + } + } + + #[allow(refining_impl_trait)] + fn prev_events(&self) -> Box + Send + '_> { + match &self.rest { + | Pdu::RoomV1Pdu(ev) => + Box::new(ev.prev_events.iter().map(|(id, _)| id.as_ref())), + | Pdu::RoomV3Pdu(ev) => Box::new(ev.prev_events.iter().map(AsRef::as_ref)), + #[allow(unreachable_patterns)] + | _ => unreachable!("new PDU version"), + } + } + + #[allow(refining_impl_trait)] + fn auth_events(&self) -> Box + Send + '_> { + match &self.rest { + | Pdu::RoomV1Pdu(ev) => + Box::new(ev.auth_events.iter().map(|(id, _)| id.as_ref())), + | Pdu::RoomV3Pdu(ev) => Box::new(ev.auth_events.iter().map(AsRef::as_ref)), + #[allow(unreachable_patterns)] + | _ => unreachable!("new PDU version"), + } + } + + fn redacts(&self) -> Option<&EventId> { + match &self.rest { + | Pdu::RoomV1Pdu(ev) => ev.redacts.as_deref(), + | Pdu::RoomV3Pdu(ev) => ev.redacts.as_deref(), + #[allow(unreachable_patterns)] + | _ => unreachable!("new PDU version"), + } + } + } + + #[derive(Clone, Debug, Deserialize, Serialize)] + #[allow(clippy::exhaustive_structs)] + pub(crate) struct PduEvent { + pub(crate) event_id: OwnedEventId, + #[serde(flatten)] + pub(crate) rest: Pdu, + } +} diff --git a/src/core/mod.rs b/src/core/mod.rs index d99139be..aaacd4d8 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -25,9 +25,7 @@ pub use info::{ rustc_flags_capture, version, version::{name, version}, }; -pub use matrix::{ - Event, EventTypeExt, Pdu, PduCount, PduEvent, PduId, RoomVersion, pdu, state_res, -}; +pub use matrix::{Event, EventTypeExt, PduCount, PduEvent, PduId, RoomVersion, pdu, state_res}; pub use server::Server; pub use utils::{ctor, dtor, implement, result, result::Result}; diff --git a/src/core/mods/module.rs b/src/core/mods/module.rs index bcadf5aa..b65bbca2 100644 --- a/src/core/mods/module.rs +++ b/src/core/mods/module.rs @@ -44,7 +44,6 @@ impl Module { .handle .as_ref() .expect("backing library loaded by this instance"); - // SAFETY: Calls dlsym(3) on unix platforms. This might not have to be unsafe // if wrapped in libloading with_dlerror(). let sym = unsafe { handle.get::(cname.as_bytes()) }; diff --git a/src/core/mods/path.rs b/src/core/mods/path.rs index b792890b..cde251b3 100644 --- a/src/core/mods/path.rs +++ b/src/core/mods/path.rs @@ -27,7 +27,6 @@ pub fn to_name(path: &OsStr) -> Result { .expect("path file stem") .to_str() .expect("name string"); - let name = name.strip_prefix("lib").unwrap_or(name).to_owned(); Ok(name) diff --git a/src/core/utils/html.rs b/src/core/utils/html.rs index eac4c47f..f2b6d861 100644 --- a/src/core/utils/html.rs +++ b/src/core/utils/html.rs @@ -23,10 +23,8 @@ impl fmt::Display for Escape<'_> { | '"' => """, | _ => continue, }; - fmt.write_str(&pile_o_bits[last..i])?; fmt.write_str(s)?; - // NOTE: we only expect single byte characters here - which is fine as long as // we only match single byte characters last = i.saturating_add(1); diff --git a/src/core/utils/json.rs b/src/core/utils/json.rs index df4ccd13..3f2f225e 100644 --- a/src/core/utils/json.rs +++ b/src/core/utils/json.rs @@ -1,4 +1,4 @@ -use std::{fmt, marker::PhantomData, str::FromStr}; +use std::{fmt, str::FromStr}; use ruma::{CanonicalJsonError, CanonicalJsonObject, canonical_json::try_from_json_map}; @@ -11,28 +11,25 @@ use crate::Result; pub fn to_canonical_object( value: T, ) -> Result { - use CanonicalJsonError::SerDe; use serde::ser::Error; - match serde_json::to_value(value).map_err(SerDe)? { + match serde_json::to_value(value).map_err(CanonicalJsonError::SerDe)? { | serde_json::Value::Object(map) => try_from_json_map(map), - | _ => Err(SerDe(serde_json::Error::custom("Value must be an object"))), + | _ => + Err(CanonicalJsonError::SerDe(serde_json::Error::custom("Value must be an object"))), } } -pub fn deserialize_from_str<'de, D, T, E>(deserializer: D) -> Result -where +pub fn deserialize_from_str< + 'de, D: serde::de::Deserializer<'de>, T: FromStr, E: fmt::Display, -{ - struct Visitor, E>(PhantomData); - - impl serde::de::Visitor<'_> for Visitor - where - T: FromStr, - Err: fmt::Display, - { +>( + deserializer: D, +) -> Result { + struct Visitor, E>(std::marker::PhantomData); + impl, Err: fmt::Display> serde::de::Visitor<'_> for Visitor { type Value = T; fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -46,6 +43,5 @@ where v.parse().map_err(serde::de::Error::custom) } } - - deserializer.deserialize_str(Visitor(PhantomData)) + deserializer.deserialize_str(Visitor(std::marker::PhantomData)) } diff --git a/src/core/utils/time.rs b/src/core/utils/time.rs index 394e08cb..73f73971 100644 --- a/src/core/utils/time.rs +++ b/src/core/utils/time.rs @@ -105,11 +105,14 @@ pub fn whole_unit(d: Duration) -> Unit { | 86_400.. => Days(d.as_secs() / 86_400), | 3_600..=86_399 => Hours(d.as_secs() / 3_600), | 60..=3_599 => Mins(d.as_secs() / 60), + | _ => match d.as_micros() { | 1_000_000.. => Secs(d.as_secs()), | 1_000..=999_999 => Millis(d.subsec_millis().into()), + | _ => match d.as_nanos() { | 1_000.. => Micros(d.subsec_micros().into()), + | _ => Nanos(d.subsec_nanos().into()), }, }, diff --git a/src/database/watchers.rs b/src/database/watchers.rs index efb939d7..b3907833 100644 --- a/src/database/watchers.rs +++ b/src/database/watchers.rs @@ -37,6 +37,7 @@ impl Watchers { pub(crate) fn wake(&self, key: &[u8]) { let watchers = self.watchers.read().unwrap(); let mut triggered = Vec::new(); + for length in 0..=key.len() { if watchers.contains_key(&key[..length]) { triggered.push(&key[..length]); diff --git a/src/main/logging.rs b/src/main/logging.rs index 36a8896c..aec50bd4 100644 --- a/src/main/logging.rs +++ b/src/main/logging.rs @@ -22,12 +22,10 @@ pub(crate) fn init( let reload_handles = LogLevelReloadHandles::default(); let console_span_events = fmt_span::from_str(&config.log_span_events).unwrap_or_err(); - let console_filter = EnvFilter::builder() .with_regex(config.log_filter_regex) .parse(&config.log) .map_err(|e| err!(Config("log", "{e}.")))?; - let console_layer = fmt::Layer::new() .with_span_events(console_span_events) .event_format(ConsoleFormat::new(config)) @@ -36,7 +34,6 @@ pub(crate) fn init( let (console_reload_filter, console_reload_handle) = reload::Layer::new(console_filter.clone()); - reload_handles.add("console", Box::new(console_reload_handle)); let cap_state = Arc::new(capture::State::new()); @@ -50,10 +47,8 @@ pub(crate) fn init( let subscriber = { let sentry_filter = EnvFilter::try_new(&config.sentry_filter) .map_err(|e| err!(Config("sentry_filter", "{e}.")))?; - let sentry_layer = sentry_tracing::layer(); let (sentry_reload_filter, sentry_reload_handle) = reload::Layer::new(sentry_filter); - reload_handles.add("sentry", Box::new(sentry_reload_handle)); subscriber.with(sentry_layer.with_filter(sentry_reload_filter)) }; @@ -63,15 +58,12 @@ pub(crate) fn init( let (flame_layer, flame_guard) = if config.tracing_flame { let flame_filter = EnvFilter::try_new(&config.tracing_flame_filter) .map_err(|e| err!(Config("tracing_flame_filter", "{e}.")))?; - let (flame_layer, flame_guard) = tracing_flame::FlameLayer::with_file(&config.tracing_flame_output_path) .map_err(|e| err!(Config("tracing_flame_output_path", "{e}.")))?; - let flame_layer = flame_layer .with_empty_samples(false) .with_filter(flame_filter); - (Some(flame_layer), Some(flame_guard)) } else { (None, None) @@ -79,24 +71,19 @@ pub(crate) fn init( let jaeger_filter = EnvFilter::try_new(&config.jaeger_filter) .map_err(|e| err!(Config("jaeger_filter", "{e}.")))?; - let jaeger_layer = config.allow_jaeger.then(|| { opentelemetry::global::set_text_map_propagator( opentelemetry_jaeger::Propagator::new(), ); - let tracer = opentelemetry_jaeger::new_agent_pipeline() .with_auto_split_batch(true) .with_service_name(conduwuit_core::name()) .install_batch(opentelemetry_sdk::runtime::Tokio) .expect("jaeger agent pipeline"); - let telemetry = tracing_opentelemetry::layer().with_tracer(tracer); - let (jaeger_reload_filter, jaeger_reload_handle) = reload::Layer::new(jaeger_filter.clone()); reload_handles.add("jaeger", Box::new(jaeger_reload_handle)); - Some(telemetry.with_filter(jaeger_reload_filter)) }); diff --git a/src/main/mods.rs b/src/main/mods.rs index 6140cc6e..d585a381 100644 --- a/src/main/mods.rs +++ b/src/main/mods.rs @@ -51,9 +51,7 @@ pub(crate) async fn run(server: &Arc, starts: bool) -> Result<(bool, boo }, }; } - server.server.stopping.store(false, Ordering::Release); - let run = main_mod.get::("run")?; if let Err(error) = run(server .services @@ -66,9 +64,7 @@ pub(crate) async fn run(server: &Arc, starts: bool) -> Result<(bool, boo error!("Running server: {error}"); return Err(error); } - let reloads = server.server.reloading.swap(false, Ordering::AcqRel); - let stops = !reloads || stale(server).await? <= restart_thresh(); let starts = reloads && stops; if stops { diff --git a/src/main/sentry.rs b/src/main/sentry.rs index 2a09f415..68f12eb7 100644 --- a/src/main/sentry.rs +++ b/src/main/sentry.rs @@ -35,13 +35,11 @@ fn options(config: &Config) -> ClientOptions { .expect("init_sentry should only be called if sentry is enabled and this is not None") .as_str(); - let server_name = config - .sentry_send_server_name - .then(|| config.server_name.to_string().into()); - ClientOptions { dsn: Some(Dsn::from_str(dsn).expect("sentry_endpoint must be a valid URL")), - server_name, + server_name: config + .sentry_send_server_name + .then(|| config.server_name.to_string().into()), traces_sample_rate: config.sentry_traces_sample_rate, debug: cfg!(debug_assertions), release: sentry::release_name!(), diff --git a/src/router/request.rs b/src/router/request.rs index 3bbeae03..dba90324 100644 --- a/src/router/request.rs +++ b/src/router/request.rs @@ -98,8 +98,8 @@ async fn execute( fn handle_result(method: &Method, uri: &Uri, result: Response) -> Result { let status = result.status(); - let code = status.as_u16(); let reason = status.canonical_reason().unwrap_or("Unknown Reason"); + let code = status.as_u16(); if status.is_server_error() { error!(method = ?method, uri = ?uri, "{code} {reason}"); diff --git a/src/service/admin/grant.rs b/src/service/admin/grant.rs index 0d0e3fc1..2d90ea52 100644 --- a/src/service/admin/grant.rs +++ b/src/service/admin/grant.rs @@ -170,56 +170,3 @@ async fn set_room_tag(&self, room_id: &RoomId, user_id: &UserId, tag: &str) -> R ) .await } - -/// Demote an admin, removing its rights. -#[implement(super::Service)] -pub async fn revoke_admin(&self, user_id: &UserId) -> Result { - use MembershipState::{Invite, Join, Knock, Leave}; - - let Ok(room_id) = self.get_admin_room().await else { - return Err!(error!("No admin room available or created.")); - }; - - let state_lock = self.services.state.mutex.lock(&room_id).await; - - let event = match self - .services - .state_accessor - .get_member(&room_id, user_id) - .await - { - | Err(e) if e.is_not_found() => return Err!("{user_id} was never an admin."), - - | Err(e) => return Err!(error!(?e, "Failure occurred while attempting revoke.")), - - | Ok(event) if !matches!(event.membership, Invite | Knock | Join) => - return Err!("Cannot revoke {user_id} in membership state {:?}.", event.membership), - - | Ok(event) => { - assert!( - matches!(event.membership, Invite | Knock | Join), - "Incorrect membership state to remove user." - ); - - event - }, - }; - - self.services - .timeline - .build_and_append_pdu( - PduBuilder::state(user_id.to_string(), &RoomMemberEventContent { - membership: Leave, - reason: Some("Admin Revoked".into()), - is_direct: None, - join_authorized_via_users_server: None, - third_party_invite: None, - ..event - }), - self.services.globals.server_user.as_ref(), - &room_id, - &state_lock, - ) - .await - .map(|_| ()) -} diff --git a/src/service/admin/mod.rs b/src/service/admin/mod.rs index d971ce95..c8d8b3b8 100644 --- a/src/service/admin/mod.rs +++ b/src/service/admin/mod.rs @@ -9,18 +9,15 @@ use std::{ }; use async_trait::async_trait; -use conduwuit_core::{ - Error, Event, Result, Server, debug, err, error, error::default_log, pdu::PduBuilder, +use conduwuit::{ + Error, PduEvent, Result, Server, debug, err, error, error::default_log, pdu::PduBuilder, }; pub use create::create_admin_room; use futures::{Future, FutureExt, TryFutureExt}; use loole::{Receiver, Sender}; use ruma::{ OwnedEventId, OwnedRoomId, RoomId, UserId, - events::{ - Mentions, - room::message::{Relation, RoomMessageEventContent}, - }, + events::room::message::{Relation, RoomMessageEventContent}, }; use tokio::sync::RwLock; @@ -142,13 +139,6 @@ impl crate::Service for Service { } impl Service { - /// Sends markdown notice to the admin room as the admin user. - pub async fn notice(&self, body: &str) { - self.send_message(RoomMessageEventContent::notice_markdown(body)) - .await - .ok(); - } - /// Sends markdown message (not an m.notice for notification reasons) to the /// admin room as the admin user. pub async fn send_text(&self, body: &str) { @@ -167,17 +157,6 @@ impl Service { .await } - /// Sends a message, the same as send_message() but with an @room ping to - /// notify all users in the room. - pub async fn send_loud_message( - &self, - mut message_content: RoomMessageEventContent, - ) -> Result<()> { - // Add @room ping - message_content = message_content.add_mentions(Mentions::with_room_mention()); - self.send_message(message_content).await - } - /// Posts a command to the command processor queue and returns. Processing /// will take place on the service worker's task asynchronously. Errors if /// the queue is full. @@ -305,13 +284,13 @@ impl Service { return Ok(()); }; - let response_sender = if self.is_admin_room(pdu.room_id()).await { + let response_sender = if self.is_admin_room(&pdu.room_id).await { &self.services.globals.server_user } else { - pdu.sender() + &pdu.sender }; - self.respond_to_room(content, pdu.room_id(), response_sender) + self.respond_to_room(content, &pdu.room_id, response_sender) .boxed() .await } @@ -361,10 +340,7 @@ impl Service { Ok(()) } - pub async fn is_admin_command(&self, event: &E, body: &str) -> bool - where - E: Event + Send + Sync, - { + pub async fn is_admin_command(&self, pdu: &PduEvent, body: &str) -> bool { // Server-side command-escape with public echo let is_escape = body.starts_with('\\'); let is_public_escape = is_escape && body.trim_start_matches('\\').starts_with("!admin"); @@ -379,10 +355,8 @@ impl Service { return false; } - let user_is_local = self.services.globals.user_is_local(event.sender()); - // only allow public escaped commands by local admins - if is_public_escape && !user_is_local { + if is_public_escape && !self.services.globals.user_is_local(&pdu.sender) { return false; } @@ -392,20 +366,20 @@ impl Service { } // Prevent unescaped !admin from being used outside of the admin room - if is_public_prefix && !self.is_admin_room(event.room_id()).await { + if is_public_prefix && !self.is_admin_room(&pdu.room_id).await { return false; } // Only senders who are admin can proceed - if !self.user_is_admin(event.sender()).await { + if !self.user_is_admin(&pdu.sender).await { return false; } // This will evaluate to false if the emergency password is set up so that // the administrator can execute commands as the server user let emergency_password_set = self.services.server.config.emergency_password.is_some(); - let from_server = event.sender() == server_user && !emergency_password_set; - if from_server && self.is_admin_room(event.room_id()).await { + let from_server = pdu.sender == *server_user && !emergency_password_set; + if from_server && self.is_admin_room(&pdu.room_id).await { return false; } diff --git a/src/service/migrations.rs b/src/service/migrations.rs index cee638ba..512a7867 100644 --- a/src/service/migrations.rs +++ b/src/service/migrations.rs @@ -242,14 +242,12 @@ async fn db_lt_12(services: &Services) -> Result<()> { [".m.rules.contains_user_name", ".m.rule.contains_user_name"]; let rule = rules_list.content.get(content_rule_transformation[0]); - - if let Some(rule) = rule { - let mut rule = rule.clone(); + if rule.is_some() { + let mut rule = rule.unwrap().clone(); content_rule_transformation[1].clone_into(&mut rule.rule_id); rules_list .content .shift_remove(content_rule_transformation[0]); - rules_list.content.insert(rule); } } diff --git a/src/service/pusher/mod.rs b/src/service/pusher/mod.rs index baa7a72e..27490fb8 100644 --- a/src/service/pusher/mod.rs +++ b/src/service/pusher/mod.rs @@ -1,12 +1,12 @@ use std::{fmt::Debug, mem, sync::Arc}; use bytes::BytesMut; -use conduwuit_core::{ - Err, Event, Result, debug_warn, err, trace, +use conduwuit::{ + Err, PduEvent, Result, debug_warn, err, trace, utils::{stream::TryIgnore, string_from_bytes}, warn, }; -use conduwuit_database::{Deserialized, Ignore, Interfix, Json, Map}; +use database::{Deserialized, Ignore, Interfix, Json, Map}; use futures::{Stream, StreamExt}; use ipaddress::IPAddress; use ruma::{ @@ -272,33 +272,32 @@ impl Service { } } - #[tracing::instrument(skip(self, user, unread, pusher, ruleset, event))] - pub async fn send_push_notice( + #[tracing::instrument(skip(self, user, unread, pusher, ruleset, pdu))] + pub async fn send_push_notice( &self, user: &UserId, unread: UInt, pusher: &Pusher, ruleset: Ruleset, - event: &E, - ) -> Result - where - E: Event + Send + Sync, - for<'a> &'a E: Event + Send, - { + pdu: &PduEvent, + ) -> Result<()> { let mut notify = None; let mut tweaks = Vec::new(); let power_levels: RoomPowerLevelsEventContent = self .services .state_accessor - .room_state_get(event.room_id(), &StateEventType::RoomPowerLevels, "") + .room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "") .await - .and_then(|event| event.get_content()) + .and_then(|ev| { + serde_json::from_str(ev.content.get()).map_err(|e| { + err!(Database(error!("invalid m.room.power_levels event: {e:?}"))) + }) + }) .unwrap_or_default(); - let serialized = event.to_format(); for action in self - .get_actions(user, &ruleset, &power_levels, &serialized, event.room_id()) + .get_actions(user, &ruleset, &power_levels, &pdu.to_sync_room_event(), &pdu.room_id) .await { let n = match action { @@ -320,7 +319,7 @@ impl Service { } if notify == Some(true) { - self.send_notice(unread, pusher, tweaks, event).await?; + self.send_notice(unread, pusher, tweaks, pdu).await?; } // Else the event triggered no actions @@ -370,16 +369,13 @@ impl Service { } #[tracing::instrument(skip(self, unread, pusher, tweaks, event))] - async fn send_notice( + async fn send_notice( &self, unread: UInt, pusher: &Pusher, tweaks: Vec, - event: &E, - ) -> Result - where - E: Event + Send + Sync, - { + event: &PduEvent, + ) -> Result { // TODO: email match &pusher.kind { | PusherKind::Http(http) => { @@ -425,8 +421,8 @@ impl Service { let d = vec![device]; let mut notifi = Notification::new(d); - notifi.event_id = Some(event.event_id().to_owned()); - notifi.room_id = Some(event.room_id().to_owned()); + notifi.event_id = Some((*event.event_id).to_owned()); + notifi.room_id = Some((*event.room_id).to_owned()); if http .data .get("org.matrix.msc4076.disable_badge_count") @@ -446,7 +442,7 @@ impl Service { ) .await?; } else { - if *event.kind() == TimelineEventType::RoomEncrypted + if event.kind == TimelineEventType::RoomEncrypted || tweaks .iter() .any(|t| matches!(t, Tweak::Highlight(true) | Tweak::Sound(_))) @@ -455,29 +451,29 @@ impl Service { } else { notifi.prio = NotificationPriority::Low; } - notifi.sender = Some(event.sender().to_owned()); - notifi.event_type = Some(event.kind().to_owned()); - notifi.content = serde_json::value::to_raw_value(event.content()).ok(); + notifi.sender = Some(event.sender.clone()); + notifi.event_type = Some(event.kind.clone()); + notifi.content = serde_json::value::to_raw_value(&event.content).ok(); - if *event.kind() == TimelineEventType::RoomMember { + if event.kind == TimelineEventType::RoomMember { notifi.user_is_target = - event.state_key() == Some(event.sender().as_str()); + event.state_key.as_deref() == Some(event.sender.as_str()); } notifi.sender_display_name = - self.services.users.displayname(event.sender()).await.ok(); + self.services.users.displayname(&event.sender).await.ok(); notifi.room_name = self .services .state_accessor - .get_name(event.room_id()) + .get_name(&event.room_id) .await .ok(); notifi.room_alias = self .services .state_accessor - .get_canonical_alias(event.room_id()) + .get_canonical_alias(&event.room_id) .await .ok(); diff --git a/src/service/rooms/alias/mod.rs b/src/service/rooms/alias/mod.rs index 7675efd4..866e45a9 100644 --- a/src/service/rooms/alias/mod.rs +++ b/src/service/rooms/alias/mod.rs @@ -3,7 +3,7 @@ mod remote; use std::sync::Arc; use conduwuit::{ - Err, Event, Result, Server, err, + Err, Result, Server, err, utils::{ReadyExt, stream::TryIgnore}, }; use database::{Deserialized, Ignore, Interfix, Map}; @@ -241,7 +241,7 @@ impl Service { .room_state_get(&room_id, &StateEventType::RoomCreate, "") .await { - return Ok(event.sender() == user_id); + return Ok(event.sender == user_id); } Err!(Database("Room has no m.room.create event")) diff --git a/src/service/rooms/event_handler/fetch_and_handle_outliers.rs b/src/service/rooms/event_handler/fetch_and_handle_outliers.rs index 44027e04..b0a7d827 100644 --- a/src/service/rooms/event_handler/fetch_and_handle_outliers.rs +++ b/src/service/rooms/event_handler/fetch_and_handle_outliers.rs @@ -4,13 +4,11 @@ use std::{ }; use conduwuit::{ - Event, PduEvent, debug, debug_error, debug_warn, implement, - matrix::event::gen_event_id_canonical_json, trace, utils::continue_exponential_backoff_secs, - warn, + PduEvent, debug, debug_error, debug_warn, implement, pdu, trace, + utils::continue_exponential_backoff_secs, warn, }; use ruma::{ - CanonicalJsonValue, EventId, OwnedEventId, RoomId, ServerName, - api::federation::event::get_event, + CanonicalJsonValue, OwnedEventId, RoomId, ServerName, api::federation::event::get_event, }; use super::get_room_version_id; @@ -25,17 +23,13 @@ use super::get_room_version_id; /// c. Ask origin server over federation /// d. TODO: Ask other servers over federation? #[implement(super::Service)] -pub(super) async fn fetch_and_handle_outliers<'a, Pdu, Events>( +pub(super) async fn fetch_and_handle_outliers<'a>( &self, origin: &'a ServerName, - events: Events, - create_event: &'a Pdu, + events: &'a [OwnedEventId], + create_event: &'a PduEvent, room_id: &'a RoomId, -) -> Vec<(PduEvent, Option>)> -where - Pdu: Event + Send + Sync, - Events: Iterator + Clone + Send, -{ +) -> Vec<(PduEvent, Option>)> { let back_off = |id| match self .services .globals @@ -52,23 +46,22 @@ where }, }; - let mut events_with_auth_events = Vec::with_capacity(events.clone().count()); - + let mut events_with_auth_events = Vec::with_capacity(events.len()); for id in events { // a. Look in the main timeline (pduid_pdu tree) // b. Look at outlier pdu tree // (get_pdu_json checks both) if let Ok(local_pdu) = self.services.timeline.get_pdu(id).await { - events_with_auth_events.push((id.to_owned(), Some(local_pdu), vec![])); + trace!("Found {id} in db"); + events_with_auth_events.push((id, Some(local_pdu), vec![])); continue; } // c. Ask origin server over federation // We also handle its auth chain here so we don't get a stack overflow in // handle_outlier_pdu. - let mut todo_auth_events: VecDeque<_> = [id.to_owned()].into(); + let mut todo_auth_events: VecDeque<_> = [id.clone()].into(); let mut events_in_reverse_order = Vec::with_capacity(todo_auth_events.len()); - let mut events_all = HashSet::with_capacity(todo_auth_events.len()); while let Some(next_id) = todo_auth_events.pop_front() { if let Some((time, tries)) = self @@ -124,7 +117,7 @@ where }; let Ok((calculated_event_id, value)) = - gen_event_id_canonical_json(&res.pdu, &room_version_id) + pdu::gen_event_id_canonical_json(&res.pdu, &room_version_id) else { back_off((*next_id).to_owned()); continue; @@ -167,8 +160,7 @@ where }, } } - - events_with_auth_events.push((id.to_owned(), None, events_in_reverse_order)); + events_with_auth_events.push((id, None, events_in_reverse_order)); } let mut pdus = Vec::with_capacity(events_with_auth_events.len()); @@ -225,6 +217,5 @@ where } } } - pdus } diff --git a/src/service/rooms/event_handler/fetch_prev.rs b/src/service/rooms/event_handler/fetch_prev.rs index efc7a434..0f92d6e6 100644 --- a/src/service/rooms/event_handler/fetch_prev.rs +++ b/src/service/rooms/event_handler/fetch_prev.rs @@ -1,16 +1,13 @@ -use std::{ - collections::{BTreeMap, HashMap, HashSet, VecDeque}, - iter::once, -}; +use std::collections::{BTreeMap, HashMap, HashSet, VecDeque}; use conduwuit::{ - Event, PduEvent, Result, debug_warn, err, implement, + PduEvent, Result, debug_warn, err, implement, state_res::{self}, }; use futures::{FutureExt, future}; use ruma::{ - CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, RoomId, ServerName, - int, uint, + CanonicalJsonValue, MilliSecondsSinceUnixEpoch, OwnedEventId, RoomId, ServerName, UInt, int, + uint, }; use super::check_room_id; @@ -22,26 +19,20 @@ use super::check_room_id; fields(%origin), )] #[allow(clippy::type_complexity)] -pub(super) async fn fetch_prev<'a, Pdu, Events>( +pub(super) async fn fetch_prev( &self, origin: &ServerName, - create_event: &Pdu, + create_event: &PduEvent, room_id: &RoomId, - first_ts_in_room: MilliSecondsSinceUnixEpoch, - initial_set: Events, + first_ts_in_room: UInt, + initial_set: Vec, ) -> Result<( Vec, HashMap)>, -)> -where - Pdu: Event + Send + Sync, - Events: Iterator + Clone + Send, -{ - let num_ids = initial_set.clone().count(); +)> { + let mut graph: HashMap = HashMap::with_capacity(initial_set.len()); let mut eventid_info = HashMap::new(); - let mut graph: HashMap = HashMap::with_capacity(num_ids); - let mut todo_outlier_stack: VecDeque = - initial_set.map(ToOwned::to_owned).collect(); + let mut todo_outlier_stack: VecDeque = initial_set.into(); let mut amount = 0; @@ -49,12 +40,7 @@ where self.services.server.check_running()?; match self - .fetch_and_handle_outliers( - origin, - once(prev_event_id.as_ref()), - create_event, - room_id, - ) + .fetch_and_handle_outliers(origin, &[prev_event_id.clone()], create_event, room_id) .boxed() .await .pop() @@ -79,17 +65,17 @@ where } if let Some(json) = json_opt { - if pdu.origin_server_ts() > first_ts_in_room { + if pdu.origin_server_ts > first_ts_in_room { amount = amount.saturating_add(1); - for prev_prev in pdu.prev_events() { + for prev_prev in &pdu.prev_events { if !graph.contains_key(prev_prev) { - todo_outlier_stack.push_back(prev_prev.to_owned()); + todo_outlier_stack.push_back(prev_prev.clone()); } } graph.insert( prev_event_id.clone(), - pdu.prev_events().map(ToOwned::to_owned).collect(), + pdu.prev_events.iter().cloned().collect(), ); } else { // Time based check failed @@ -112,7 +98,8 @@ where let event_fetch = |event_id| { let origin_server_ts = eventid_info .get(&event_id) - .map_or_else(|| uint!(0), |info| info.0.origin_server_ts().get()); + .cloned() + .map_or_else(|| uint!(0), |info| info.0.origin_server_ts); // This return value is the key used for sorting events, // events are then sorted by power level, time, diff --git a/src/service/rooms/event_handler/fetch_state.rs b/src/service/rooms/event_handler/fetch_state.rs index d68a3542..0f9e093b 100644 --- a/src/service/rooms/event_handler/fetch_state.rs +++ b/src/service/rooms/event_handler/fetch_state.rs @@ -1,6 +1,6 @@ use std::collections::{HashMap, hash_map}; -use conduwuit::{Err, Event, Result, debug, debug_warn, err, implement}; +use conduwuit::{Err, Error, PduEvent, Result, debug, debug_warn, implement}; use futures::FutureExt; use ruma::{ EventId, OwnedEventId, RoomId, ServerName, api::federation::event::get_room_state_ids, @@ -18,16 +18,13 @@ use crate::rooms::short::ShortStateKey; skip_all, fields(%origin), )] -pub(super) async fn fetch_state( +pub(super) async fn fetch_state( &self, origin: &ServerName, - create_event: &Pdu, + create_event: &PduEvent, room_id: &RoomId, event_id: &EventId, -) -> Result>> -where - Pdu: Event + Send + Sync, -{ +) -> Result>> { let res = self .services .sending @@ -39,27 +36,27 @@ where .inspect_err(|e| debug_warn!("Fetching state for event failed: {e}"))?; debug!("Fetching state events"); - let state_ids = res.pdu_ids.iter().map(AsRef::as_ref); let state_vec = self - .fetch_and_handle_outliers(origin, state_ids, create_event, room_id) + .fetch_and_handle_outliers(origin, &res.pdu_ids, create_event, room_id) .boxed() .await; let mut state: HashMap = HashMap::with_capacity(state_vec.len()); for (pdu, _) in state_vec { let state_key = pdu - .state_key() - .ok_or_else(|| err!(Database("Found non-state pdu in state events.")))?; + .state_key + .clone() + .ok_or_else(|| Error::bad_database("Found non-state pdu in state events."))?; let shortstatekey = self .services .short - .get_or_create_shortstatekey(&pdu.kind().to_string().into(), state_key) + .get_or_create_shortstatekey(&pdu.kind.to_string().into(), &state_key) .await; match state.entry(shortstatekey) { | hash_map::Entry::Vacant(v) => { - v.insert(pdu.event_id().to_owned()); + v.insert(pdu.event_id.clone()); }, | hash_map::Entry::Occupied(_) => { return Err!(Database( @@ -76,7 +73,7 @@ where .get_shortstatekey(&StateEventType::RoomCreate, "") .await?; - if state.get(&create_shortstatekey).map(AsRef::as_ref) != Some(create_event.event_id()) { + if state.get(&create_shortstatekey) != Some(&create_event.event_id) { return Err!(Database("Incoming event refers to wrong create event.")); } diff --git a/src/service/rooms/event_handler/handle_incoming_pdu.rs b/src/service/rooms/event_handler/handle_incoming_pdu.rs index 86a05e0a..77cae41d 100644 --- a/src/service/rooms/event_handler/handle_incoming_pdu.rs +++ b/src/service/rooms/event_handler/handle_incoming_pdu.rs @@ -4,7 +4,7 @@ use std::{ }; use conduwuit::{ - Err, Event, Result, debug::INFO_SPAN_LEVEL, defer, err, implement, utils::stream::IterStream, + Err, Result, debug, debug::INFO_SPAN_LEVEL, defer, err, implement, utils::stream::IterStream, warn, }; use futures::{ @@ -12,7 +12,6 @@ use futures::{ future::{OptionFuture, try_join5}, }; use ruma::{CanonicalJsonValue, EventId, RoomId, ServerName, UserId, events::StateEventType}; -use tracing::debug; use crate::rooms::timeline::RawPduId; @@ -122,16 +121,22 @@ pub async fn handle_incoming_pdu<'a>( .timeline .first_pdu_in_room(room_id) .await? - .origin_server_ts(); + .origin_server_ts; - if incoming_pdu.origin_server_ts() < first_ts_in_room { + if incoming_pdu.origin_server_ts < first_ts_in_room { return Ok(None); } // 9. Fetch any missing prev events doing all checks listed here starting at 1. // These are timeline events let (sorted_prev_events, mut eventid_info) = self - .fetch_prev(origin, create_event, room_id, first_ts_in_room, incoming_pdu.prev_events()) + .fetch_prev( + origin, + create_event, + room_id, + first_ts_in_room, + incoming_pdu.prev_events.clone(), + ) .await?; debug!( diff --git a/src/service/rooms/event_handler/handle_outlier_pdu.rs b/src/service/rooms/event_handler/handle_outlier_pdu.rs index 0930c96f..87b76222 100644 --- a/src/service/rooms/event_handler/handle_outlier_pdu.rs +++ b/src/service/rooms/event_handler/handle_outlier_pdu.rs @@ -1,29 +1,27 @@ use std::collections::{BTreeMap, HashMap, hash_map}; use conduwuit::{ - Err, Event, PduEvent, Result, debug, debug_info, err, implement, state_res, trace, warn, + Err, Error, PduEvent, Result, debug, debug_info, err, implement, state_res, trace, warn, }; use futures::future::ready; use ruma::{ - CanonicalJsonObject, CanonicalJsonValue, EventId, RoomId, ServerName, events::StateEventType, + CanonicalJsonObject, CanonicalJsonValue, EventId, RoomId, ServerName, + api::client::error::ErrorKind, events::StateEventType, }; use super::{check_room_id, get_room_version_id, to_room_version}; #[implement(super::Service)] #[allow(clippy::too_many_arguments)] -pub(super) async fn handle_outlier_pdu<'a, Pdu>( +pub(super) async fn handle_outlier_pdu<'a>( &self, origin: &'a ServerName, - create_event: &'a Pdu, + create_event: &'a PduEvent, event_id: &'a EventId, room_id: &'a RoomId, mut value: CanonicalJsonObject, auth_events_known: bool, -) -> Result<(PduEvent, BTreeMap)> -where - Pdu: Event + Send + Sync, -{ +) -> Result<(PduEvent, BTreeMap)> { // 1. Remove unsigned field value.remove("unsigned"); @@ -32,7 +30,7 @@ where // 2. Check signatures, otherwise drop // 3. check content hash, redact if doesn't match let room_version_id = get_room_version_id(create_event)?; - let mut incoming_pdu = match self + let mut val = match self .services .server_keys .verify_event(&value, Some(&room_version_id)) @@ -64,15 +62,13 @@ where // Now that we have checked the signature and hashes we can add the eventID and // convert to our PduEvent type - incoming_pdu - .insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.as_str().to_owned())); - - let pdu_event = serde_json::from_value::( - serde_json::to_value(&incoming_pdu).expect("CanonicalJsonObj is a valid JsonValue"), + val.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.as_str().to_owned())); + let incoming_pdu = serde_json::from_value::( + serde_json::to_value(&val).expect("CanonicalJsonObj is a valid JsonValue"), ) .map_err(|e| err!(Request(BadJson(debug_warn!("Event is not a valid PDU: {e}")))))?; - check_room_id(room_id, &pdu_event)?; + check_room_id(room_id, &incoming_pdu)?; if !auth_events_known { // 4. fetch any missing auth events doing all checks listed here starting at 1. @@ -83,7 +79,7 @@ where debug!("Fetching auth events for {}", incoming_pdu.event_id); Box::pin(self.fetch_and_handle_outliers( origin, - pdu_event.auth_events(), + &incoming_pdu.auth_events, create_event, room_id, )) @@ -94,8 +90,8 @@ where // auth events debug!("Checking {} based on auth events", incoming_pdu.event_id); // Build map of auth events - let mut auth_events = HashMap::with_capacity(pdu_event.auth_events().count()); - for id in pdu_event.auth_events() { + let mut auth_events = HashMap::with_capacity(incoming_pdu.auth_events.len()); + for id in &incoming_pdu.auth_events { let Ok(auth_event) = self.services.timeline.get_pdu(id).await else { warn!("Could not find auth event {id} for {}", incoming_pdu.event_id); continue; @@ -114,9 +110,10 @@ where v.insert(auth_event); }, | hash_map::Entry::Occupied(_) => { - return Err!(Request(InvalidParam( + return Err(Error::BadRequest( + ErrorKind::InvalidParam, "Auth event's type and state_key combination exists multiple times.", - ))); + )); }, } } @@ -128,13 +125,13 @@ where let state_fetch = |ty: &StateEventType, sk: &str| { let key = (ty.to_owned(), sk.into()); - ready(auth_events.get(&key).map(ToOwned::to_owned)) + ready(auth_events.get(&key)) }; debug!("running auth check to handle outlier pdu {:?}", incoming_pdu.event_id); let auth_check = state_res::event_auth::auth_check( &to_room_version(&room_version_id), - &pdu_event, + &incoming_pdu, None, // TODO: third party invite state_fetch, ) @@ -150,9 +147,9 @@ where // 7. Persist the event as an outlier. self.services .outlier - .add_pdu_outlier(pdu_event.event_id(), &incoming_pdu); + .add_pdu_outlier(&incoming_pdu.event_id, &val); trace!("Added pdu as outlier."); - Ok((pdu_event, incoming_pdu)) + Ok((incoming_pdu, val)) } diff --git a/src/service/rooms/event_handler/handle_prev_pdu.rs b/src/service/rooms/event_handler/handle_prev_pdu.rs index cd46310a..d612b2bf 100644 --- a/src/service/rooms/event_handler/handle_prev_pdu.rs +++ b/src/service/rooms/event_handler/handle_prev_pdu.rs @@ -1,11 +1,10 @@ use std::{collections::BTreeMap, time::Instant}; use conduwuit::{ - Err, Event, PduEvent, Result, debug::INFO_SPAN_LEVEL, defer, implement, + Err, PduEvent, Result, debug, debug::INFO_SPAN_LEVEL, defer, implement, utils::continue_exponential_backoff_secs, }; -use ruma::{CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, RoomId, ServerName}; -use tracing::debug; +use ruma::{CanonicalJsonValue, EventId, RoomId, ServerName, UInt}; #[implement(super::Service)] #[allow(clippy::type_complexity)] @@ -16,19 +15,16 @@ use tracing::debug; skip_all, fields(%prev_id), )] -pub(super) async fn handle_prev_pdu<'a, Pdu>( +pub(super) async fn handle_prev_pdu<'a>( &self, origin: &'a ServerName, event_id: &'a EventId, room_id: &'a RoomId, eventid_info: Option<(PduEvent, BTreeMap)>, - create_event: &'a Pdu, - first_ts_in_room: MilliSecondsSinceUnixEpoch, + create_event: &'a PduEvent, + first_ts_in_room: UInt, prev_id: &'a EventId, -) -> Result -where - Pdu: Event + Send + Sync, -{ +) -> Result { // Check for disabled again because it might have changed if self.services.metadata.is_disabled(room_id).await { return Err!(Request(Forbidden(debug_warn!( @@ -63,7 +59,7 @@ where }; // Skip old events - if pdu.origin_server_ts() < first_ts_in_room { + if pdu.origin_server_ts < first_ts_in_room { return Ok(()); } diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index aed38e1e..45675da8 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -18,7 +18,7 @@ use std::{ }; use async_trait::async_trait; -use conduwuit::{Err, Event, PduEvent, Result, RoomVersion, Server, utils::MutexMap}; +use conduwuit::{Err, PduEvent, Result, RoomVersion, Server, utils::MutexMap}; use ruma::{ OwnedEventId, OwnedRoomId, RoomId, RoomVersionId, events::room::create::RoomCreateEventContent, @@ -104,11 +104,11 @@ impl Service { } } -fn check_room_id(room_id: &RoomId, pdu: &Pdu) -> Result { - if pdu.room_id() != room_id { +fn check_room_id(room_id: &RoomId, pdu: &PduEvent) -> Result { + if pdu.room_id != room_id { return Err!(Request(InvalidParam(error!( - pdu_event_id = ?pdu.event_id(), - pdu_room_id = ?pdu.room_id(), + pdu_event_id = ?pdu.event_id, + pdu_room_id = ?pdu.room_id, ?room_id, "Found event from room in room", )))); @@ -117,7 +117,7 @@ fn check_room_id(room_id: &RoomId, pdu: &Pdu) -> Result { Ok(()) } -fn get_room_version_id(create_event: &Pdu) -> Result { +fn get_room_version_id(create_event: &PduEvent) -> Result { let content: RoomCreateEventContent = create_event.get_content()?; let room_version = content.room_version; diff --git a/src/service/rooms/event_handler/parse_incoming_pdu.rs b/src/service/rooms/event_handler/parse_incoming_pdu.rs index 65cf1752..a49fc541 100644 --- a/src/service/rooms/event_handler/parse_incoming_pdu.rs +++ b/src/service/rooms/event_handler/parse_incoming_pdu.rs @@ -1,6 +1,4 @@ -use conduwuit::{ - Result, err, implement, matrix::event::gen_event_id_canonical_json, result::FlatOk, -}; +use conduwuit::{Result, err, implement, pdu::gen_event_id_canonical_json, result::FlatOk}; use ruma::{CanonicalJsonObject, CanonicalJsonValue, OwnedEventId, OwnedRoomId}; use serde_json::value::RawValue as RawJsonValue; diff --git a/src/service/rooms/event_handler/state_at_incoming.rs b/src/service/rooms/event_handler/state_at_incoming.rs index d3bb8f79..eb38c2c3 100644 --- a/src/service/rooms/event_handler/state_at_incoming.rs +++ b/src/service/rooms/event_handler/state_at_incoming.rs @@ -6,7 +6,7 @@ use std::{ use conduwuit::{ Result, debug, err, implement, - matrix::{Event, StateMap}, + matrix::{PduEvent, StateMap}, trace, utils::stream::{BroadbandExt, IterStream, ReadyExt, TryBroadbandExt, TryWidebandExt}, }; @@ -19,18 +19,11 @@ use crate::rooms::short::ShortStateHash; #[implement(super::Service)] // request and build the state from a known point and resolve if > 1 prev_event #[tracing::instrument(name = "state", level = "debug", skip_all)] -pub(super) async fn state_at_incoming_degree_one( +pub(super) async fn state_at_incoming_degree_one( &self, - incoming_pdu: &Pdu, -) -> Result>> -where - Pdu: Event + Send + Sync, -{ - let prev_event = incoming_pdu - .prev_events() - .next() - .expect("at least one prev_event"); - + incoming_pdu: &PduEvent, +) -> Result>> { + let prev_event = &incoming_pdu.prev_events[0]; let Ok(prev_event_sstatehash) = self .services .state_accessor @@ -62,7 +55,7 @@ where .get_or_create_shortstatekey(&prev_pdu.kind.to_string().into(), state_key) .await; - state.insert(shortstatekey, prev_event.to_owned()); + state.insert(shortstatekey, prev_event.clone()); // Now it's the state after the pdu } @@ -73,18 +66,16 @@ where #[implement(super::Service)] #[tracing::instrument(name = "state", level = "debug", skip_all)] -pub(super) async fn state_at_incoming_resolved( +pub(super) async fn state_at_incoming_resolved( &self, - incoming_pdu: &Pdu, + incoming_pdu: &PduEvent, room_id: &RoomId, room_version_id: &RoomVersionId, -) -> Result>> -where - Pdu: Event + Send + Sync, -{ +) -> Result>> { trace!("Calculating extremity statehashes..."); let Ok(extremity_sstatehashes) = incoming_pdu - .prev_events() + .prev_events + .iter() .try_stream() .broad_and_then(|prev_eventid| { self.services @@ -142,15 +133,12 @@ where } #[implement(super::Service)] -async fn state_at_incoming_fork( +async fn state_at_incoming_fork( &self, room_id: &RoomId, sstatehash: ShortStateHash, - prev_event: Pdu, -) -> Result<(StateMap, HashSet)> -where - Pdu: Event, -{ + prev_event: PduEvent, +) -> Result<(StateMap, HashSet)> { let mut leaf_state: HashMap<_, _> = self .services .state_accessor @@ -158,15 +146,15 @@ where .collect() .await; - if let Some(state_key) = prev_event.state_key() { + if let Some(state_key) = &prev_event.state_key { let shortstatekey = self .services .short - .get_or_create_shortstatekey(&prev_event.kind().to_string().into(), state_key) + .get_or_create_shortstatekey(&prev_event.kind.to_string().into(), state_key) .await; - let event_id = prev_event.event_id(); - leaf_state.insert(shortstatekey, event_id.to_owned()); + let event_id = &prev_event.event_id; + leaf_state.insert(shortstatekey, event_id.clone()); // Now it's the state after the pdu } diff --git a/src/service/rooms/lazy_loading/mod.rs b/src/service/rooms/lazy_loading/mod.rs index 61f081a9..346314d1 100644 --- a/src/service/rooms/lazy_loading/mod.rs +++ b/src/service/rooms/lazy_loading/mod.rs @@ -27,7 +27,7 @@ pub trait Options: Send + Sync { #[derive(Clone, Debug)] pub struct Context<'a> { pub user_id: &'a UserId, - pub device_id: Option<&'a DeviceId>, + pub device_id: &'a DeviceId, pub room_id: &'a RoomId, pub token: Option, pub options: Option<&'a LazyLoadOptions>, @@ -40,7 +40,7 @@ pub enum Status { } pub type Witness = HashSet; -type Key<'a> = (&'a UserId, Option<&'a DeviceId>, &'a RoomId, &'a UserId); +type Key<'a> = (&'a UserId, &'a DeviceId, &'a RoomId, &'a UserId); impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { diff --git a/src/service/rooms/outlier/mod.rs b/src/service/rooms/outlier/mod.rs index 6ab2c026..12b56935 100644 --- a/src/service/rooms/outlier/mod.rs +++ b/src/service/rooms/outlier/mod.rs @@ -1,7 +1,7 @@ use std::sync::Arc; -use conduwuit::{Result, implement, matrix::PduEvent}; -use database::{Deserialized, Json, Map}; +use conduwuit::{Result, implement, matrix::pdu::PduEvent}; +use conduwuit_database::{Deserialized, Json, Map}; use ruma::{CanonicalJsonObject, EventId}; pub struct Service { diff --git a/src/service/rooms/pdu_metadata/bundled_aggregations.rs b/src/service/rooms/pdu_metadata/bundled_aggregations.rs new file mode 100644 index 00000000..c47f637a --- /dev/null +++ b/src/service/rooms/pdu_metadata/bundled_aggregations.rs @@ -0,0 +1,765 @@ +use conduwuit::{Event, PduEvent, Result, err}; +use ruma::{ + EventId, RoomId, UserId, + api::Direction, + events::relation::{BundledMessageLikeRelations, BundledReference, ReferenceChunk}, +}; + +use super::PdusIterItem; + +const MAX_BUNDLED_RELATIONS: usize = 50; + +impl super::Service { + /// Gets bundled aggregations for an event according to the Matrix + /// specification. + /// - m.replace relations are bundled to include the most recent replacement + /// event. + /// - m.reference relations are bundled to include a chunk of event IDs. + #[tracing::instrument(skip(self), level = "debug")] + pub async fn get_bundled_aggregations( + &self, + user_id: &UserId, + room_id: &RoomId, + event_id: &EventId, + ) -> Result>>> { + let relations = self + .get_relations( + user_id, + room_id, + event_id, + conduwuit::PduCount::max(), + MAX_BUNDLED_RELATIONS, + 0, + Direction::Backward, + ) + .await; + // The relations database code still handles the basic unsigned data + // We don't want to recursively fetch relations + + // TODO: Event visibility check + // TODO: ignored users? + + if relations.is_empty() { + return Ok(None); + } + + // Get the original event for validation of replacement events + let original_event = self.services.timeline.get_pdu(event_id).await?; + + let mut replace_events = Vec::with_capacity(relations.len()); + let mut reference_events = Vec::with_capacity(relations.len()); + + for relation in &relations { + let pdu = &relation.1; + + let content = pdu.get_content_as_value(); + if let Some(relates_to) = content.get("m.relates_to") { + // We don't check that the event relates back, because we assume the database is + // good. + if let Some(rel_type) = relates_to.get("rel_type") { + match rel_type.as_str() { + | Some("m.replace") => { + // Only consider valid replacements + if Self::is_valid_replacement_event(&original_event, pdu).await? { + replace_events.push(relation); + } + }, + | Some("m.reference") => { + reference_events.push(relation); + }, + | _ => { + // Ignore other relation types for now + // Threads are in the database but not handled here + // Other types are not specified AFAICT. + }, + } + } + } + } + + // If no relations to bundle, return None + if replace_events.is_empty() && reference_events.is_empty() { + return Ok(None); + } + + let mut bundled = BundledMessageLikeRelations::new(); + + // Handle m.replace relations - find the most recent one + if !replace_events.is_empty() { + let most_recent_replacement = Self::find_most_recent_replacement(&replace_events)?; + + // Convert the replacement event to the bundled format + if let Some(replacement_pdu) = most_recent_replacement { + // According to the Matrix spec, we should include the full event as raw JSON + let replacement_json = serde_json::to_string(replacement_pdu) + .map_err(|e| err!(Database("Failed to serialize replacement event: {e}")))?; + let raw_value = serde_json::value::RawValue::from_string(replacement_json) + .map_err(|e| err!(Database("Failed to create RawValue: {e}")))?; + bundled.replace = Some(Box::new(raw_value)); + } + } + + // Handle m.reference relations - collect event IDs + if !reference_events.is_empty() { + let reference_chunk = Self::build_reference_chunk(&reference_events)?; + if !reference_chunk.is_empty() { + bundled.reference = Some(Box::new(ReferenceChunk::new(reference_chunk))); + } + } + + // TODO: Handle other relation types (m.annotation, etc.) when specified + + Ok(Some(bundled)) + } + + /// Build reference chunk for m.reference bundled aggregations + fn build_reference_chunk( + reference_events: &[&PdusIterItem], + ) -> Result> { + let mut chunk = Vec::with_capacity(reference_events.len()); + + for relation in reference_events { + let pdu = &relation.1; + + let reference_entry = BundledReference::new(pdu.event_id().to_owned()); + chunk.push(reference_entry); + } + + // Don't sort, order is unspecified + + Ok(chunk) + } + + /// Find the most recent replacement event based on origin_server_ts and + /// lexicographic event_id ordering + fn find_most_recent_replacement<'a>( + replacement_events: &'a [&'a PdusIterItem], + ) -> Result> { + if replacement_events.is_empty() { + return Ok(None); + } + + let mut most_recent: Option<&PduEvent> = None; + + // Jank, is there a better way to do this? + for relation in replacement_events { + let pdu = &relation.1; + + match most_recent { + | None => { + most_recent = Some(pdu); + }, + | Some(current_most_recent) => { + // Compare by origin_server_ts first + match pdu + .origin_server_ts() + .cmp(¤t_most_recent.origin_server_ts()) + { + | std::cmp::Ordering::Greater => { + most_recent = Some(pdu); + }, + | std::cmp::Ordering::Equal => { + // If timestamps are equal, use lexicographic ordering of event_id + if pdu.event_id() > current_most_recent.event_id() { + most_recent = Some(pdu); + } + }, + | std::cmp::Ordering::Less => { + // Keep current most recent + }, + } + }, + } + } + + Ok(most_recent) + } + + /// Adds bundled aggregations to a PDU's unsigned field + #[tracing::instrument(skip(self, pdu), level = "debug")] + pub async fn add_bundled_aggregations_to_pdu( + &self, + user_id: &UserId, + pdu: &mut PduEvent, + ) -> Result<()> { + if pdu.is_redacted() { + return Ok(()); + } + + let bundled_aggregations = self + .get_bundled_aggregations(user_id, pdu.room_id(), pdu.event_id()) + .await?; + + if let Some(aggregations) = bundled_aggregations { + let aggregations_json = serde_json::to_value(aggregations) + .map_err(|e| err!(Database("Failed to serialize bundled aggregations: {e}")))?; + + Self::add_bundled_aggregations_to_unsigned(pdu, aggregations_json)?; + } + + Ok(()) + } + + /// Helper method to add bundled aggregations to a PDU's unsigned + /// field + fn add_bundled_aggregations_to_unsigned( + pdu: &mut PduEvent, + aggregations_json: serde_json::Value, + ) -> Result<()> { + use serde_json::{ + Map, Value as JsonValue, + value::{RawValue as RawJsonValue, to_raw_value}, + }; + + let mut unsigned: Map = pdu + .unsigned + .as_deref() + .map(RawJsonValue::get) + .map_or_else(|| Ok(Map::new()), serde_json::from_str) + .map_err(|e| err!(Database("Invalid unsigned in pdu event: {e}")))?; + + let relations = unsigned + .entry("m.relations") + .or_insert_with(|| JsonValue::Object(Map::new())) + .as_object_mut() + .ok_or_else(|| err!(Database("m.relations is not an object")))?; + + if let JsonValue::Object(aggregations_map) = aggregations_json { + for (rel_type, aggregation) in aggregations_map { + relations.insert(rel_type, aggregation); + } + } + + pdu.unsigned = Some(to_raw_value(&unsigned)?); + + Ok(()) + } + + /// Validates that an event is acceptable as a replacement for another event + /// See C/S spec "Validity of replacement events" + #[tracing::instrument(level = "debug")] + async fn is_valid_replacement_event( + original_event: &PduEvent, + replacement_event: &PduEvent, + ) -> Result { + // 1. Same room_id + if original_event.room_id() != replacement_event.room_id() { + return Ok(false); + } + + // 2. Same sender + if original_event.sender() != replacement_event.sender() { + return Ok(false); + } + + // 3. Same type + if original_event.event_type() != replacement_event.event_type() { + return Ok(false); + } + + // 4. Neither event should have a state_key property + if original_event.state_key().is_some() || replacement_event.state_key().is_some() { + return Ok(false); + } + + // 5. Original event must not have rel_type of m.replace + let original_content = original_event.get_content_as_value(); + if let Some(relates_to) = original_content.get("m.relates_to") { + if let Some(rel_type) = relates_to.get("rel_type") { + if rel_type.as_str() == Some("m.replace") { + return Ok(false); + } + } + } + + // 6. Replacement event must have m.new_content property + // Skip this check for encrypted events, as m.new_content would be inside the + // encrypted payload + if replacement_event.event_type() != &ruma::events::TimelineEventType::RoomEncrypted { + let replacement_content = replacement_event.get_content_as_value(); + if replacement_content.get("m.new_content").is_none() { + return Ok(false); + } + } + + Ok(true) + } +} + +#[cfg(test)] +mod tests { + use conduwuit_core::pdu::{EventHash, PduEvent}; + use ruma::{UInt, events::TimelineEventType, owned_event_id, owned_room_id, owned_user_id}; + use serde_json::{Value as JsonValue, json, value::to_raw_value}; + + fn create_test_pdu(unsigned_content: Option) -> PduEvent { + PduEvent { + event_id: owned_event_id!("$test:example.com"), + room_id: owned_room_id!("!test:example.com"), + sender: owned_user_id!("@test:example.com"), + origin_server_ts: UInt::try_from(1_234_567_890_u64).unwrap(), + kind: TimelineEventType::RoomMessage, + content: to_raw_value(&json!({"msgtype": "m.text", "body": "test"})).unwrap(), + state_key: None, + prev_events: vec![], + depth: UInt::from(1_u32), + auth_events: vec![], + redacts: None, + unsigned: unsigned_content.map(|content| to_raw_value(&content).unwrap()), + hashes: EventHash { sha256: "test_hash".to_owned() }, + signatures: None, + origin: None, + } + } + + fn create_bundled_aggregations() -> JsonValue { + json!({ + "m.replace": { + "event_id": "$replace:example.com", + "origin_server_ts": 1_234_567_890, + "sender": "@replacer:example.com" + }, + "m.reference": { + "count": 5, + "chunk": [ + "$ref1:example.com", + "$ref2:example.com" + ] + } + }) + } + + #[test] + fn test_add_bundled_aggregations_to_unsigned_no_existing_unsigned() { + let mut pdu = create_test_pdu(None); + let aggregations = create_bundled_aggregations(); + + let result = super::super::Service::add_bundled_aggregations_to_unsigned( + &mut pdu, + aggregations.clone(), + ); + assert!(result.is_ok(), "Should succeed when no unsigned field exists"); + + assert!(pdu.unsigned.is_some(), "Unsigned field should be created"); + + let unsigned_str = pdu.unsigned.as_ref().unwrap().get(); + let unsigned: JsonValue = serde_json::from_str(unsigned_str).unwrap(); + + assert!(unsigned.get("m.relations").is_some(), "m.relations should exist"); + assert_eq!( + unsigned["m.relations"], aggregations, + "Relations should match the aggregations" + ); + } + + #[test] + fn test_add_bundled_aggregations_to_unsigned_overwrite_same_relation_type() { + let existing_unsigned = json!({ + "m.relations": { + "m.replace": { + "event_id": "$old_replace:example.com", + "origin_server_ts": 1_111_111_111, + "sender": "@old_replacer:example.com" + } + } + }); + + let mut pdu = create_test_pdu(Some(existing_unsigned)); + let new_aggregations = create_bundled_aggregations(); + + let result = super::super::Service::add_bundled_aggregations_to_unsigned( + &mut pdu, + new_aggregations.clone(), + ); + assert!(result.is_ok(), "Should succeed when overwriting same relation type"); + + let unsigned_str = pdu.unsigned.as_ref().unwrap().get(); + let unsigned: JsonValue = serde_json::from_str(unsigned_str).unwrap(); + + let relations = &unsigned["m.relations"]; + + assert_eq!( + relations["m.replace"], new_aggregations["m.replace"], + "m.replace should be updated" + ); + assert_eq!( + relations["m.replace"]["event_id"], "$replace:example.com", + "Should have new event_id" + ); + + assert!(relations.get("m.reference").is_some(), "New m.reference should be added"); + } + + #[test] + fn test_add_bundled_aggregations_to_unsigned_preserve_other_unsigned_fields() { + // Test case: Other unsigned fields should be preserved + let existing_unsigned = json!({ + "age": 98765, + "prev_content": {"msgtype": "m.text", "body": "old message"}, + "redacted_because": {"event_id": "$redaction:example.com"}, + "m.relations": { + "m.annotation": {"count": 1} + } + }); + + let mut pdu = create_test_pdu(Some(existing_unsigned)); + let new_aggregations = json!({ + "m.replace": {"event_id": "$new:example.com"} + }); + + let result = super::super::Service::add_bundled_aggregations_to_unsigned( + &mut pdu, + new_aggregations, + ); + assert!(result.is_ok(), "Should succeed while preserving other fields"); + + let unsigned_str = pdu.unsigned.as_ref().unwrap().get(); + let unsigned: JsonValue = serde_json::from_str(unsigned_str).unwrap(); + + // Verify all existing fields are preserved + assert_eq!(unsigned["age"], 98765, "age should be preserved"); + assert!(unsigned.get("prev_content").is_some(), "prev_content should be preserved"); + assert!( + unsigned.get("redacted_because").is_some(), + "redacted_because should be preserved" + ); + + // Verify relations were merged correctly + let relations = &unsigned["m.relations"]; + assert!( + relations.get("m.annotation").is_some(), + "Existing m.annotation should be preserved" + ); + assert!(relations.get("m.replace").is_some(), "New m.replace should be added"); + } + + #[test] + fn test_add_bundled_aggregations_to_unsigned_invalid_existing_unsigned() { + // Test case: Invalid JSON in existing unsigned should result in error + let mut pdu = create_test_pdu(None); + // Manually set invalid unsigned data + pdu.unsigned = Some(to_raw_value(&"invalid json").unwrap()); + + let aggregations = create_bundled_aggregations(); + let result = + super::super::Service::add_bundled_aggregations_to_unsigned(&mut pdu, aggregations); + + assert!(result.is_err(), "fails when existing unsigned is invalid"); + // Should we ignore the error and overwrite anyway? + } + + // Test helper function to create test PDU events + fn create_test_event( + event_id: &str, + room_id: &str, + sender: &str, + event_type: TimelineEventType, + content: &JsonValue, + state_key: Option<&str>, + ) -> PduEvent { + PduEvent { + event_id: event_id.try_into().unwrap(), + room_id: room_id.try_into().unwrap(), + sender: sender.try_into().unwrap(), + origin_server_ts: UInt::try_from(1_234_567_890_u64).unwrap(), + kind: event_type, + content: to_raw_value(&content).unwrap(), + state_key: state_key.map(Into::into), + prev_events: vec![], + depth: UInt::from(1_u32), + auth_events: vec![], + redacts: None, + unsigned: None, + hashes: EventHash { sha256: "test_hash".to_owned() }, + signatures: None, + origin: None, + } + } + + /// Test that a valid replacement event passes validation + #[tokio::test] + async fn test_valid_replacement_event() { + let original = create_test_event( + "$original:example.com", + "!room:example.com", + "@user:example.com", + TimelineEventType::RoomMessage, + &json!({"msgtype": "m.text", "body": "original message"}), + None, + ); + + let replacement = create_test_event( + "$replacement:example.com", + "!room:example.com", + "@user:example.com", + TimelineEventType::RoomMessage, + &json!({ + "msgtype": "m.text", + "body": "* edited message", + "m.new_content": { + "msgtype": "m.text", + "body": "edited message" + }, + "m.relates_to": { + "rel_type": "m.replace", + "event_id": "$original:example.com" + } + }), + None, + ); + + let result = + super::super::Service::is_valid_replacement_event(&original, &replacement).await; + assert!(result.is_ok(), "Validation should succeed"); + assert!(result.unwrap(), "Valid replacement event should be accepted"); + } + + /// Test replacement event with different room ID is rejected + #[tokio::test] + async fn test_replacement_event_different_room() { + let original = create_test_event( + "$original:example.com", + "!room1:example.com", + "@user:example.com", + TimelineEventType::RoomMessage, + &json!({"msgtype": "m.text", "body": "original message"}), + None, + ); + + let replacement = create_test_event( + "$replacement:example.com", + "!room2:example.com", // Different room + "@user:example.com", + TimelineEventType::RoomMessage, + &json!({ + "msgtype": "m.text", + "body": "* edited message", + "m.new_content": { + "msgtype": "m.text", + "body": "edited message" + } + }), + None, + ); + + let result = + super::super::Service::is_valid_replacement_event(&original, &replacement).await; + assert!(result.is_ok(), "Validation should succeed"); + assert!(!result.unwrap(), "Different room ID should be rejected"); + } + + /// Test replacement event with different sender is rejected + #[tokio::test] + async fn test_replacement_event_different_sender() { + let original = create_test_event( + "$original:example.com", + "!room:example.com", + "@user1:example.com", + TimelineEventType::RoomMessage, + &json!({"msgtype": "m.text", "body": "original message"}), + None, + ); + + let replacement = create_test_event( + "$replacement:example.com", + "!room:example.com", + "@user2:example.com", // Different sender + TimelineEventType::RoomMessage, + &json!({ + "msgtype": "m.text", + "body": "* edited message", + "m.new_content": { + "msgtype": "m.text", + "body": "edited message" + } + }), + None, + ); + + let result = + super::super::Service::is_valid_replacement_event(&original, &replacement).await; + assert!(result.is_ok(), "Validation should succeed"); + assert!(!result.unwrap(), "Different sender should be rejected"); + } + + /// Test replacement event with different type is rejected + #[tokio::test] + async fn test_replacement_event_different_type() { + let original = create_test_event( + "$original:example.com", + "!room:example.com", + "@user:example.com", + TimelineEventType::RoomMessage, + &json!({"msgtype": "m.text", "body": "original message"}), + None, + ); + + let replacement = create_test_event( + "$replacement:example.com", + "!room:example.com", + "@user:example.com", + TimelineEventType::RoomTopic, // Different event type + &json!({ + "topic": "new topic", + "m.new_content": { + "topic": "new topic" + } + }), + None, + ); + + let result = + super::super::Service::is_valid_replacement_event(&original, &replacement).await; + assert!(result.is_ok(), "Validation should succeed"); + assert!(!result.unwrap(), "Different event type should be rejected"); + } + + /// Test replacement event with state key is rejected + #[tokio::test] + async fn test_replacement_event_with_state_key() { + let original = create_test_event( + "$original:example.com", + "!room:example.com", + "@user:example.com", + TimelineEventType::RoomName, + &json!({"name": "room name"}), + Some(""), // Has state key + ); + + let replacement = create_test_event( + "$replacement:example.com", + "!room:example.com", + "@user:example.com", + TimelineEventType::RoomName, + &json!({ + "name": "new room name", + "m.new_content": { + "name": "new room name" + } + }), + None, + ); + + let result = + super::super::Service::is_valid_replacement_event(&original, &replacement).await; + assert!(result.is_ok(), "Validation should succeed"); + assert!(!result.unwrap(), "Event with state key should be rejected"); + } + + /// Test replacement of an event that is already a replacement is rejected + #[tokio::test] + async fn test_replacement_event_original_is_replacement() { + let original = create_test_event( + "$original:example.com", + "!room:example.com", + "@user:example.com", + TimelineEventType::RoomMessage, + &json!({ + "msgtype": "m.text", + "body": "* edited message", + "m.relates_to": { + "rel_type": "m.replace", // Original is already a replacement + "event_id": "$some_other:example.com" + } + }), + None, + ); + + let replacement = create_test_event( + "$replacement:example.com", + "!room:example.com", + "@user:example.com", + TimelineEventType::RoomMessage, + &json!({ + "msgtype": "m.text", + "body": "* edited again", + "m.new_content": { + "msgtype": "m.text", + "body": "edited again" + } + }), + None, + ); + + let result = + super::super::Service::is_valid_replacement_event(&original, &replacement).await; + assert!(result.is_ok(), "Validation should succeed"); + assert!(!result.unwrap(), "Replacement of replacement should be rejected"); + } + + /// Test replacement event missing m.new_content is rejected + #[tokio::test] + async fn test_replacement_event_missing_new_content() { + let original = create_test_event( + "$original:example.com", + "!room:example.com", + "@user:example.com", + TimelineEventType::RoomMessage, + &json!({"msgtype": "m.text", "body": "original message"}), + None, + ); + + let replacement = create_test_event( + "$replacement:example.com", + "!room:example.com", + "@user:example.com", + TimelineEventType::RoomMessage, + &json!({ + "msgtype": "m.text", + "body": "* edited message" + // Missing m.new_content + }), + None, + ); + + let result = + super::super::Service::is_valid_replacement_event(&original, &replacement).await; + assert!(result.is_ok(), "Validation should succeed"); + assert!(!result.unwrap(), "Missing m.new_content should be rejected"); + } + + /// Test encrypted replacement event without m.new_content is accepted + #[tokio::test] + async fn test_replacement_event_encrypted_missing_new_content_is_valid() { + let original = create_test_event( + "$original:example.com", + "!room:example.com", + "@user:example.com", + TimelineEventType::RoomEncrypted, + &json!({ + "algorithm": "m.megolm.v1.aes-sha2", + "ciphertext": "encrypted_payload_base64", + "sender_key": "sender_key", + "session_id": "session_id" + }), + None, + ); + + let replacement = create_test_event( + "$replacement:example.com", + "!room:example.com", + "@user:example.com", + TimelineEventType::RoomEncrypted, + &json!({ + "algorithm": "m.megolm.v1.aes-sha2", + "ciphertext": "encrypted_replacement_payload_base64", + "sender_key": "sender_key", + "session_id": "session_id", + "m.relates_to": { + "rel_type": "m.replace", + "event_id": "$original:example.com" + } + // No m.new_content in cleartext - this is valid for encrypted events + }), + None, + ); + + let result = + super::super::Service::is_valid_replacement_event(&original, &replacement).await; + assert!(result.is_ok(), "Validation should succeed"); + assert!( + result.unwrap(), + "Encrypted replacement without cleartext m.new_content should be accepted" + ); + } +} diff --git a/src/service/rooms/pdu_metadata/data.rs b/src/service/rooms/pdu_metadata/data.rs index c1376cb0..c4b37b99 100644 --- a/src/service/rooms/pdu_metadata/data.rs +++ b/src/service/rooms/pdu_metadata/data.rs @@ -1,9 +1,8 @@ use std::{mem::size_of, sync::Arc}; use conduwuit::{ + PduCount, PduEvent, arrayvec::ArrayVec, - matrix::{Event, PduCount}, - result::LogErr, utils::{ ReadyExt, stream::{TryIgnore, WidebandExt}, @@ -33,6 +32,8 @@ struct Services { timeline: Dep, } +pub(super) type PdusIterItem = (PduCount, PduEvent); + impl Data { pub(super) fn new(args: &crate::Args<'_>) -> Self { let db = &args.db; @@ -60,7 +61,7 @@ impl Data { target: ShortEventId, from: PduCount, dir: Direction, - ) -> impl Stream + Send + '_ { + ) -> impl Stream + Send + '_ { let mut current = ArrayVec::::new(); current.extend(target.to_be_bytes()); current.extend(from.saturating_inc(dir).into_unsigned().to_be_bytes()); @@ -78,9 +79,7 @@ impl Data { let mut pdu = self.services.timeline.get_pdu_from_id(&pdu_id).await.ok()?; - if pdu.sender() != user_id { - pdu.as_mut_pdu().remove_transaction_id().log_err().ok(); - } + pdu.set_unsigned(Some(user_id)); Some((shorteventid, pdu)) }) diff --git a/src/service/rooms/pdu_metadata/mod.rs b/src/service/rooms/pdu_metadata/mod.rs index c8e863fa..2dff54d8 100644 --- a/src/service/rooms/pdu_metadata/mod.rs +++ b/src/service/rooms/pdu_metadata/mod.rs @@ -1,14 +1,12 @@ +mod bundled_aggregations; mod data; use std::sync::Arc; -use conduwuit::{ - Result, - matrix::{Event, PduCount}, -}; +use conduwuit::{PduCount, Result}; use futures::{StreamExt, future::try_join}; use ruma::{EventId, RoomId, UserId, api::Direction}; -use self::data::Data; +use self::data::{Data, PdusIterItem}; use crate::{Dep, rooms}; pub struct Service { @@ -47,16 +45,16 @@ impl Service { } #[allow(clippy::too_many_arguments)] - pub async fn get_relations<'a>( - &'a self, - user_id: &'a UserId, - room_id: &'a RoomId, - target: &'a EventId, + pub async fn get_relations( + &self, + user_id: &UserId, + room_id: &RoomId, + target: &EventId, from: PduCount, limit: usize, max_depth: u8, dir: Direction, - ) -> Vec<(PduCount, impl Event)> { + ) -> Vec { let room_id = self.services.short.get_shortroomid(room_id); let target = self.services.timeline.get_pdu_count(target); diff --git a/src/service/rooms/read_receipt/mod.rs b/src/service/rooms/read_receipt/mod.rs index 68ce9b7f..69e859c4 100644 --- a/src/service/rooms/read_receipt/mod.rs +++ b/src/service/rooms/read_receipt/mod.rs @@ -4,10 +4,7 @@ use std::{collections::BTreeMap, sync::Arc}; use conduwuit::{ Result, debug, err, - matrix::{ - Event, - pdu::{PduCount, PduId, RawPduId}, - }, + matrix::pdu::{PduCount, PduId, RawPduId}, warn, }; use futures::{Stream, TryFutureExt, try_join}; @@ -77,13 +74,14 @@ impl Service { let shortroomid = self.services.short.get_shortroomid(room_id).map_err(|e| { err!(Database(warn!("Short room ID does not exist in database for {room_id}: {e}"))) }); - let (pdu_count, shortroomid) = try_join!(pdu_count, shortroomid)?; + let shorteventid = PduCount::Normal(pdu_count); let pdu_id: RawPduId = PduId { shortroomid, shorteventid }.into(); + let pdu = self.services.timeline.get_pdu_from_id(&pdu_id).await?; - let event_id: OwnedEventId = pdu.event_id().to_owned(); + let event_id: OwnedEventId = pdu.event_id; let user_id: OwnedUserId = user_id.to_owned(); let content: BTreeMap = BTreeMap::from_iter([( event_id, diff --git a/src/service/rooms/search/mod.rs b/src/service/rooms/search/mod.rs index afe3061b..7cef5dbf 100644 --- a/src/service/rooms/search/mod.rs +++ b/src/service/rooms/search/mod.rs @@ -1,10 +1,9 @@ use std::sync::Arc; use conduwuit::{ - PduCount, Result, + PduCount, PduEvent, Result, arrayvec::ArrayVec, implement, - matrix::event::{Event, Matches}, utils::{ ArrayVecExt, IterStream, ReadyExt, set, stream::{TryIgnore, WidebandExt}, @@ -104,10 +103,9 @@ pub fn deindex_pdu(&self, shortroomid: ShortRoomId, pdu_id: &RawPduId, message_b pub async fn search_pdus<'a>( &'a self, query: &'a RoomQuery<'a>, -) -> Result<(usize, impl Stream> + Send + '_)> { +) -> Result<(usize, impl Stream + Send + 'a)> { let pdu_ids: Vec<_> = self.search_pdu_ids(query).await?.collect().await; - let filter = &query.criteria.filter; let count = pdu_ids.len(); let pdus = pdu_ids .into_iter() @@ -120,16 +118,21 @@ pub async fn search_pdus<'a>( .ok() }) .ready_filter(|pdu| !pdu.is_redacted()) - .ready_filter(move |pdu| filter.matches(pdu)) + .ready_filter(|pdu| pdu.matches(&query.criteria.filter)) .wide_filter_map(move |pdu| async move { self.services .state_accessor - .user_can_see_event(query.user_id?, pdu.room_id(), pdu.event_id()) + .user_can_see_event(query.user_id?, &pdu.room_id, &pdu.event_id) .await .then_some(pdu) }) .skip(query.skip) - .take(query.limit); + .take(query.limit) + .map(move |mut pdu| { + pdu.set_unsigned(query.user_id); + // TODO: bundled aggregation + pdu + }); Ok((count, pdus)) } diff --git a/src/service/rooms/spaces/mod.rs b/src/service/rooms/spaces/mod.rs index de2647ca..53d2b742 100644 --- a/src/service/rooms/spaces/mod.rs +++ b/src/service/rooms/spaces/mod.rs @@ -5,8 +5,8 @@ mod tests; use std::{fmt::Write, sync::Arc}; use async_trait::async_trait; -use conduwuit_core::{ - Err, Error, Event, PduEvent, Result, implement, +use conduwuit::{ + Err, Error, PduEvent, Result, implement, utils::{ IterStream, future::{BoolExt, TryExtExt}, @@ -142,7 +142,7 @@ pub async fn get_summary_and_children_local( let children_pdus: Vec<_> = self .get_space_child_events(current_room) - .map(Event::into_format) + .map(PduEvent::into_stripped_spacechild_state_event) .collect() .await; @@ -511,7 +511,7 @@ async fn cache_insert( room_id: room_id.clone(), children_state: self .get_space_child_events(&room_id) - .map(Event::into_format) + .map(PduEvent::into_stripped_spacechild_state_event) .collect() .await, encryption, diff --git a/src/service/rooms/state/mod.rs b/src/service/rooms/state/mod.rs index 92881126..056da5e6 100644 --- a/src/service/rooms/state/mod.rs +++ b/src/service/rooms/state/mod.rs @@ -1,8 +1,8 @@ use std::{collections::HashMap, fmt::Write, iter::once, sync::Arc}; use async_trait::async_trait; -use conduwuit_core::{ - Event, PduEvent, Result, err, +use conduwuit::{ + PduEvent, Result, err, result::FlatOk, state_res::{self, StateMap}, utils::{ @@ -11,7 +11,7 @@ use conduwuit_core::{ }, warn, }; -use conduwuit_database::{Deserialized, Ignore, Interfix, Map}; +use database::{Deserialized, Ignore, Interfix, Map}; use futures::{ FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt, future::join_all, pin_mut, }; @@ -319,34 +319,30 @@ impl Service { } #[tracing::instrument(skip_all, level = "debug")] - pub async fn summary_stripped<'a, E>(&self, event: &'a E) -> Vec> - where - E: Event + Send + Sync, - &'a E: Event + Send, - { + pub async fn summary_stripped(&self, event: &PduEvent) -> Vec> { let cells = [ (&StateEventType::RoomCreate, ""), (&StateEventType::RoomJoinRules, ""), (&StateEventType::RoomCanonicalAlias, ""), (&StateEventType::RoomName, ""), (&StateEventType::RoomAvatar, ""), - (&StateEventType::RoomMember, event.sender().as_str()), // Add recommended events + (&StateEventType::RoomMember, event.sender.as_str()), // Add recommended events (&StateEventType::RoomEncryption, ""), (&StateEventType::RoomTopic, ""), ]; - let fetches = cells.into_iter().map(|(event_type, state_key)| { + let fetches = cells.iter().map(|(event_type, state_key)| { self.services .state_accessor - .room_state_get(event.room_id(), event_type, state_key) + .room_state_get(&event.room_id, event_type, state_key) }); join_all(fetches) .await .into_iter() .filter_map(Result::ok) - .map(Event::into_format) - .chain(once(event.to_format())) + .map(PduEvent::into_stripped_state_event) + .chain(once(event.to_stripped_state_event())) .collect() } @@ -356,8 +352,8 @@ impl Service { &self, room_id: &RoomId, shortstatehash: u64, - // Take mutex guard to make sure users get the room state mutex - _mutex_lock: &RoomMutexGuard, + _mutex_lock: &RoomMutexGuard, /* Take mutex guard to make sure users get the room + * state mutex */ ) { const BUFSIZE: usize = size_of::(); diff --git a/src/service/rooms/state_accessor/room_state.rs b/src/service/rooms/state_accessor/room_state.rs index 89a66f0c..89fa2a83 100644 --- a/src/service/rooms/state_accessor/room_state.rs +++ b/src/service/rooms/state_accessor/room_state.rs @@ -2,7 +2,7 @@ use std::borrow::Borrow; use conduwuit::{ Result, err, implement, - matrix::{Event, StateKey}, + matrix::{PduEvent, StateKey}, }; use futures::{Stream, StreamExt, TryFutureExt}; use ruma::{EventId, RoomId, events::StateEventType}; @@ -30,7 +30,7 @@ where pub fn room_state_full<'a>( &'a self, room_id: &'a RoomId, -) -> impl Stream> + Send + 'a { +) -> impl Stream> + Send + 'a { self.services .state .get_room_shortstatehash(room_id) @@ -45,7 +45,7 @@ pub fn room_state_full<'a>( pub fn room_state_full_pdus<'a>( &'a self, room_id: &'a RoomId, -) -> impl Stream> + Send + 'a { +) -> impl Stream> + Send + 'a { self.services .state .get_room_shortstatehash(room_id) @@ -84,7 +84,7 @@ pub async fn room_state_get( room_id: &RoomId, event_type: &StateEventType, state_key: &str, -) -> Result { +) -> Result { self.services .state .get_room_shortstatehash(room_id) diff --git a/src/service/rooms/state_accessor/state.rs b/src/service/rooms/state_accessor/state.rs index a46ce380..169e69e9 100644 --- a/src/service/rooms/state_accessor/state.rs +++ b/src/service/rooms/state_accessor/state.rs @@ -2,14 +2,14 @@ use std::{borrow::Borrow, ops::Deref, sync::Arc}; use conduwuit::{ Result, at, err, implement, - matrix::{Event, StateKey}, + matrix::{PduEvent, StateKey}, pair_of, utils::{ result::FlatOk, stream::{BroadbandExt, IterStream, ReadyExt, TryIgnore}, }, }; -use database::Deserialized; +use conduwuit_database::Deserialized; use futures::{FutureExt, Stream, StreamExt, TryFutureExt, future::try_join, pin_mut}; use ruma::{ EventId, OwnedEventId, UserId, @@ -125,9 +125,11 @@ pub async fn state_get( shortstatehash: ShortStateHash, event_type: &StateEventType, state_key: &str, -) -> Result { +) -> Result { self.state_get_id(shortstatehash, event_type, state_key) - .and_then(async |event_id: OwnedEventId| self.services.timeline.get_pdu(&event_id).await) + .and_then(|event_id: OwnedEventId| async move { + self.services.timeline.get_pdu(&event_id).await + }) .await } @@ -314,16 +316,18 @@ pub fn state_added( pub fn state_full( &self, shortstatehash: ShortStateHash, -) -> impl Stream + Send + '_ { +) -> impl Stream + Send + '_ { self.state_full_pdus(shortstatehash) - .ready_filter_map(|pdu| Some(((pdu.kind().clone().into(), pdu.state_key()?.into()), pdu))) + .ready_filter_map(|pdu| { + Some(((pdu.kind.to_string().into(), pdu.state_key.clone()?), pdu)) + }) } #[implement(super::Service)] pub fn state_full_pdus( &self, shortstatehash: ShortStateHash, -) -> impl Stream + Send + '_ { +) -> impl Stream + Send + '_ { let short_ids = self .state_full_shortids(shortstatehash) .ignore_err() diff --git a/src/service/rooms/state_accessor/user_can.rs b/src/service/rooms/state_accessor/user_can.rs index 221263a8..67e0b52b 100644 --- a/src/service/rooms/state_accessor/user_can.rs +++ b/src/service/rooms/state_accessor/user_can.rs @@ -1,4 +1,4 @@ -use conduwuit::{Err, Result, implement, matrix::Event, pdu::PduBuilder}; +use conduwuit::{Err, Result, implement, pdu::PduBuilder}; use ruma::{ EventId, RoomId, UserId, events::{ @@ -29,14 +29,14 @@ pub async fn user_can_redact( if redacting_event .as_ref() - .is_ok_and(|pdu| *pdu.kind() == TimelineEventType::RoomCreate) + .is_ok_and(|pdu| pdu.kind == TimelineEventType::RoomCreate) { return Err!(Request(Forbidden("Redacting m.room.create is not safe, forbidding."))); } if redacting_event .as_ref() - .is_ok_and(|pdu| *pdu.kind() == TimelineEventType::RoomServerAcl) + .is_ok_and(|pdu| pdu.kind == TimelineEventType::RoomServerAcl) { return Err!(Request(Forbidden( "Redacting m.room.server_acl will result in the room being inaccessible for \ @@ -59,9 +59,9 @@ pub async fn user_can_redact( && match redacting_event { | Ok(redacting_event) => if federation { - redacting_event.sender().server_name() == sender.server_name() + redacting_event.sender.server_name() == sender.server_name() } else { - redacting_event.sender() == sender + redacting_event.sender == sender }, | _ => false, }) @@ -72,10 +72,10 @@ pub async fn user_can_redact( .room_state_get(room_id, &StateEventType::RoomCreate, "") .await { - | Ok(room_create) => Ok(room_create.sender() == sender + | Ok(room_create) => Ok(room_create.sender == sender || redacting_event .as_ref() - .is_ok_and(|redacting_event| redacting_event.sender() == sender)), + .is_ok_and(|redacting_event| redacting_event.sender == sender)), | _ => Err!(Database( "No m.room.power_levels or m.room.create events in database for room" )), diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index 9429be79..d3dbc143 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -1,22 +1,30 @@ -mod update; -mod via; - use std::{ - collections::HashMap, + collections::{HashMap, HashSet}, sync::{Arc, RwLock}, }; use conduwuit::{ - Result, implement, + Result, is_not_empty, result::LogErr, - utils::{ReadyExt, stream::TryIgnore}, + utils::{ReadyExt, StreamTools, stream::TryIgnore}, warn, }; -use database::{Deserialized, Ignore, Interfix, Map}; -use futures::{Stream, StreamExt, future::join5, pin_mut}; +use database::{Deserialized, Ignore, Interfix, Json, Map, serialize_key}; +use futures::{Stream, StreamExt, future::join5, pin_mut, stream::iter}; +use itertools::Itertools; use ruma::{ - OwnedRoomId, RoomId, ServerName, UserId, - events::{AnyStrippedStateEvent, AnySyncStateEvent, room::member::MembershipState}, + OwnedRoomId, OwnedServerName, RoomId, ServerName, UserId, + events::{ + AnyStrippedStateEvent, AnySyncStateEvent, GlobalAccountDataEventType, + RoomAccountDataEventType, StateEventType, + direct::DirectEvent, + room::{ + create::RoomCreateEventContent, + member::{MembershipState, RoomMemberEventContent}, + power_levels::RoomPowerLevelsEventContent, + }, + }, + int, serde::Raw, }; @@ -93,443 +101,901 @@ impl crate::Service for Service { fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } -#[implement(Service)] -#[tracing::instrument(level = "trace", skip_all)] -pub async fn appservice_in_room(&self, room_id: &RoomId, appservice: &RegistrationInfo) -> bool { - if let Some(cached) = self - .appservice_in_room_cache - .read() - .expect("locked") - .get(room_id) - .and_then(|map| map.get(&appservice.registration.id)) - .copied() - { - return cached; +impl Service { + /// Update current membership data. + #[tracing::instrument( + level = "debug", + skip_all, + fields( + %room_id, + %user_id, + %sender, + ?membership_event, + ), + )] + #[allow(clippy::too_many_arguments)] + pub async fn update_membership( + &self, + room_id: &RoomId, + user_id: &UserId, + membership_event: RoomMemberEventContent, + sender: &UserId, + last_state: Option>>, + invite_via: Option>, + update_joined_count: bool, + ) -> Result<()> { + let membership = membership_event.membership; + + // Keep track what remote users exist by adding them as "deactivated" users + // + // TODO: use futures to update remote profiles without blocking the membership + // update + #[allow(clippy::collapsible_if)] + if !self.services.globals.user_is_local(user_id) { + if !self.services.users.exists(user_id).await { + self.services.users.create(user_id, None)?; + } + + /* + // Try to update our local copy of the user if ours does not match + if ((self.services.users.displayname(user_id)? != membership_event.displayname) + || (self.services.users.avatar_url(user_id)? != membership_event.avatar_url) + || (self.services.users.blurhash(user_id)? != membership_event.blurhash)) + && (membership != MembershipState::Leave) + { + let response = self.services + .sending + .send_federation_request( + user_id.server_name(), + federation::query::get_profile_information::v1::Request { + user_id: user_id.into(), + field: None, // we want the full user's profile to update locally too + }, + ) + .await; + + self.services.users.set_displayname(user_id, response.displayname.clone()).await?; + self.services.users.set_avatar_url(user_id, response.avatar_url).await?; + self.services.users.set_blurhash(user_id, response.blurhash).await?; + }; + */ + } + + match &membership { + | MembershipState::Join => { + // Check if the user never joined this room + if !self.once_joined(user_id, room_id).await { + // Add the user ID to the join list then + self.mark_as_once_joined(user_id, room_id); + + // Check if the room has a predecessor + if let Ok(Some(predecessor)) = self + .services + .state_accessor + .room_state_get_content(room_id, &StateEventType::RoomCreate, "") + .await + .map(|content: RoomCreateEventContent| content.predecessor) + { + // Copy user settings from predecessor to the current room: + // - Push rules + // + // TODO: finish this once push rules are implemented. + // + // let mut push_rules_event_content: PushRulesEvent = account_data + // .get( + // None, + // user_id, + // EventType::PushRules, + // )?; + // + // NOTE: find where `predecessor.room_id` match + // and update to `room_id`. + // + // account_data + // .update( + // None, + // user_id, + // EventType::PushRules, + // &push_rules_event_content, + // globals, + // ) + // .ok(); + + // Copy old tags to new room + if let Ok(tag_event) = self + .services + .account_data + .get_room( + &predecessor.room_id, + user_id, + RoomAccountDataEventType::Tag, + ) + .await + { + self.services + .account_data + .update( + Some(room_id), + user_id, + RoomAccountDataEventType::Tag, + &tag_event, + ) + .await + .ok(); + } + + // Copy direct chat flag + if let Ok(mut direct_event) = self + .services + .account_data + .get_global::( + user_id, + GlobalAccountDataEventType::Direct, + ) + .await + { + let mut room_ids_updated = false; + for room_ids in direct_event.content.0.values_mut() { + if room_ids.iter().any(|r| r == &predecessor.room_id) { + room_ids.push(room_id.to_owned()); + room_ids_updated = true; + } + } + + if room_ids_updated { + self.services + .account_data + .update( + None, + user_id, + GlobalAccountDataEventType::Direct.to_string().into(), + &serde_json::to_value(&direct_event) + .expect("to json always works"), + ) + .await?; + } + } + } + } + + self.mark_as_joined(user_id, room_id); + }, + | MembershipState::Invite => { + // We want to know if the sender is ignored by the receiver + if self.services.users.user_is_ignored(sender, user_id).await { + return Ok(()); + } + + self.mark_as_invited(user_id, room_id, last_state, invite_via) + .await; + }, + | MembershipState::Leave | MembershipState::Ban => { + self.mark_as_left(user_id, room_id); + + if self.services.globals.user_is_local(user_id) + && (self.services.config.forget_forced_upon_leave + || self.services.metadata.is_banned(room_id).await + || self.services.metadata.is_disabled(room_id).await) + { + self.forget(room_id, user_id); + } + }, + | _ => {}, + } + + if update_joined_count { + self.update_joined_count(room_id).await; + } + + Ok(()) } - let bridge_user_id = UserId::parse_with_server_name( - appservice.registration.sender_localpart.as_str(), - self.services.globals.server_name(), - ); + #[tracing::instrument(level = "trace", skip_all)] + pub async fn appservice_in_room( + &self, + room_id: &RoomId, + appservice: &RegistrationInfo, + ) -> bool { + if let Some(cached) = self + .appservice_in_room_cache + .read() + .expect("locked") + .get(room_id) + .and_then(|map| map.get(&appservice.registration.id)) + .copied() + { + return cached; + } - let Ok(bridge_user_id) = bridge_user_id.log_err() else { - return false; - }; + let bridge_user_id = UserId::parse_with_server_name( + appservice.registration.sender_localpart.as_str(), + self.services.globals.server_name(), + ); - let in_room = self.is_joined(&bridge_user_id, room_id).await - || self + let Ok(bridge_user_id) = bridge_user_id.log_err() else { + return false; + }; + + let in_room = self.is_joined(&bridge_user_id, room_id).await + || self + .room_members(room_id) + .ready_any(|user_id| appservice.users.is_match(user_id.as_str())) + .await; + + self.appservice_in_room_cache + .write() + .expect("locked") + .entry(room_id.into()) + .or_default() + .insert(appservice.registration.id.clone(), in_room); + + in_room + } + + /// Direct DB function to directly mark a user as joined. It is not + /// recommended to use this directly. You most likely should use + /// `update_membership` instead + #[tracing::instrument(skip(self), level = "debug")] + pub fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) { + let userroom_id = (user_id, room_id); + let userroom_id = serialize_key(userroom_id).expect("failed to serialize userroom_id"); + + let roomuser_id = (room_id, user_id); + let roomuser_id = serialize_key(roomuser_id).expect("failed to serialize roomuser_id"); + + self.db.userroomid_joined.insert(&userroom_id, []); + self.db.roomuserid_joined.insert(&roomuser_id, []); + + self.db.userroomid_invitestate.remove(&userroom_id); + self.db.roomuserid_invitecount.remove(&roomuser_id); + + self.db.userroomid_leftstate.remove(&userroom_id); + self.db.roomuserid_leftcount.remove(&roomuser_id); + + self.db.userroomid_knockedstate.remove(&userroom_id); + self.db.roomuserid_knockedcount.remove(&roomuser_id); + + self.db.roomid_inviteviaservers.remove(room_id); + } + + /// Direct DB function to directly mark a user as left. It is not + /// recommended to use this directly. You most likely should use + /// `update_membership` instead + #[tracing::instrument(skip(self), level = "debug")] + pub fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) { + let userroom_id = (user_id, room_id); + let userroom_id = serialize_key(userroom_id).expect("failed to serialize userroom_id"); + + let roomuser_id = (room_id, user_id); + let roomuser_id = serialize_key(roomuser_id).expect("failed to serialize roomuser_id"); + + // (timo) TODO + let leftstate = Vec::>::new(); + + self.db + .userroomid_leftstate + .raw_put(&userroom_id, Json(leftstate)); + self.db + .roomuserid_leftcount + .raw_aput::<8, _, _>(&roomuser_id, self.services.globals.next_count().unwrap()); + + self.db.userroomid_joined.remove(&userroom_id); + self.db.roomuserid_joined.remove(&roomuser_id); + + self.db.userroomid_invitestate.remove(&userroom_id); + self.db.roomuserid_invitecount.remove(&roomuser_id); + + self.db.userroomid_knockedstate.remove(&userroom_id); + self.db.roomuserid_knockedcount.remove(&roomuser_id); + + self.db.roomid_inviteviaservers.remove(room_id); + } + + /// Direct DB function to directly mark a user as knocked. It is not + /// recommended to use this directly. You most likely should use + /// `update_membership` instead + #[tracing::instrument(skip(self), level = "debug")] + pub fn mark_as_knocked( + &self, + user_id: &UserId, + room_id: &RoomId, + knocked_state: Option>>, + ) { + let userroom_id = (user_id, room_id); + let userroom_id = serialize_key(userroom_id).expect("failed to serialize userroom_id"); + + let roomuser_id = (room_id, user_id); + let roomuser_id = serialize_key(roomuser_id).expect("failed to serialize roomuser_id"); + + self.db + .userroomid_knockedstate + .raw_put(&userroom_id, Json(knocked_state.unwrap_or_default())); + self.db + .roomuserid_knockedcount + .raw_aput::<8, _, _>(&roomuser_id, self.services.globals.next_count().unwrap()); + + self.db.userroomid_joined.remove(&userroom_id); + self.db.roomuserid_joined.remove(&roomuser_id); + + self.db.userroomid_invitestate.remove(&userroom_id); + self.db.roomuserid_invitecount.remove(&roomuser_id); + + self.db.userroomid_leftstate.remove(&userroom_id); + self.db.roomuserid_leftcount.remove(&roomuser_id); + + self.db.roomid_inviteviaservers.remove(room_id); + } + + /// Makes a user forget a room. + #[tracing::instrument(skip(self), level = "debug")] + pub fn forget(&self, room_id: &RoomId, user_id: &UserId) { + let userroom_id = (user_id, room_id); + let roomuser_id = (room_id, user_id); + + self.db.userroomid_leftstate.del(userroom_id); + self.db.roomuserid_leftcount.del(roomuser_id); + } + + /// Returns an iterator of all servers participating in this room. + #[tracing::instrument(skip(self), level = "debug")] + pub fn room_servers<'a>( + &'a self, + room_id: &'a RoomId, + ) -> impl Stream + Send + 'a { + let prefix = (room_id, Interfix); + self.db + .roomserverids + .keys_prefix(&prefix) + .ignore_err() + .map(|(_, server): (Ignore, &ServerName)| server) + } + + #[tracing::instrument(skip(self), level = "trace")] + pub async fn server_in_room<'a>( + &'a self, + server: &'a ServerName, + room_id: &'a RoomId, + ) -> bool { + let key = (server, room_id); + self.db.serverroomids.qry(&key).await.is_ok() + } + + /// Returns an iterator of all rooms a server participates in (as far as we + /// know). + #[tracing::instrument(skip(self), level = "debug")] + pub fn server_rooms<'a>( + &'a self, + server: &'a ServerName, + ) -> impl Stream + Send + 'a { + let prefix = (server, Interfix); + self.db + .serverroomids + .keys_prefix(&prefix) + .ignore_err() + .map(|(_, room_id): (Ignore, &RoomId)| room_id) + } + + /// Returns true if server can see user by sharing at least one room. + #[tracing::instrument(skip(self), level = "trace")] + pub async fn server_sees_user(&self, server: &ServerName, user_id: &UserId) -> bool { + self.server_rooms(server) + .any(|room_id| self.is_joined(user_id, room_id)) + .await + } + + /// Returns true if user_a and user_b share at least one room. + #[tracing::instrument(skip(self), level = "trace")] + pub async fn user_sees_user(&self, user_a: &UserId, user_b: &UserId) -> bool { + let get_shared_rooms = self.get_shared_rooms(user_a, user_b); + + pin_mut!(get_shared_rooms); + get_shared_rooms.next().await.is_some() + } + + /// List the rooms common between two users + #[tracing::instrument(skip(self), level = "debug")] + pub fn get_shared_rooms<'a>( + &'a self, + user_a: &'a UserId, + user_b: &'a UserId, + ) -> impl Stream + Send + 'a { + use conduwuit::utils::set; + + let a = self.rooms_joined(user_a); + let b = self.rooms_joined(user_b); + set::intersection_sorted_stream2(a, b) + } + + /// Returns an iterator of all joined members of a room. + #[tracing::instrument(skip(self), level = "debug")] + pub fn room_members<'a>( + &'a self, + room_id: &'a RoomId, + ) -> impl Stream + Send + 'a { + let prefix = (room_id, Interfix); + self.db + .roomuserid_joined + .keys_prefix(&prefix) + .ignore_err() + .map(|(_, user_id): (Ignore, &UserId)| user_id) + } + + /// Returns the number of users which are currently in a room + #[tracing::instrument(skip(self), level = "trace")] + pub async fn room_joined_count(&self, room_id: &RoomId) -> Result { + self.db.roomid_joinedcount.get(room_id).await.deserialized() + } + + #[tracing::instrument(skip(self), level = "debug")] + /// Returns an iterator of all our local users in the room, even if they're + /// deactivated/guests + pub fn local_users_in_room<'a>( + &'a self, + room_id: &'a RoomId, + ) -> impl Stream + Send + 'a { + self.room_members(room_id) + .ready_filter(|user| self.services.globals.user_is_local(user)) + } + + /// Returns an iterator of all our local joined users in a room who are + /// active (not deactivated, not guest) + #[tracing::instrument(skip(self), level = "trace")] + pub fn active_local_users_in_room<'a>( + &'a self, + room_id: &'a RoomId, + ) -> impl Stream + Send + 'a { + self.local_users_in_room(room_id) + .filter(|user| self.services.users.is_active(user)) + } + + /// Returns the number of users which are currently invited to a room + #[tracing::instrument(skip(self), level = "trace")] + pub async fn room_invited_count(&self, room_id: &RoomId) -> Result { + self.db + .roomid_invitedcount + .get(room_id) + .await + .deserialized() + } + + /// Returns an iterator over all User IDs who ever joined a room. + #[tracing::instrument(skip(self), level = "debug")] + pub fn room_useroncejoined<'a>( + &'a self, + room_id: &'a RoomId, + ) -> impl Stream + Send + 'a { + let prefix = (room_id, Interfix); + self.db + .roomuseroncejoinedids + .keys_prefix(&prefix) + .ignore_err() + .map(|(_, user_id): (Ignore, &UserId)| user_id) + } + + /// Returns an iterator over all invited members of a room. + #[tracing::instrument(skip(self), level = "debug")] + pub fn room_members_invited<'a>( + &'a self, + room_id: &'a RoomId, + ) -> impl Stream + Send + 'a { + let prefix = (room_id, Interfix); + self.db + .roomuserid_invitecount + .keys_prefix(&prefix) + .ignore_err() + .map(|(_, user_id): (Ignore, &UserId)| user_id) + } + + /// Returns an iterator over all knocked members of a room. + #[tracing::instrument(skip(self), level = "debug")] + pub fn room_members_knocked<'a>( + &'a self, + room_id: &'a RoomId, + ) -> impl Stream + Send + 'a { + let prefix = (room_id, Interfix); + self.db + .roomuserid_knockedcount + .keys_prefix(&prefix) + .ignore_err() + .map(|(_, user_id): (Ignore, &UserId)| user_id) + } + + #[tracing::instrument(skip(self), level = "trace")] + pub async fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result { + let key = (room_id, user_id); + self.db + .roomuserid_invitecount + .qry(&key) + .await + .deserialized() + } + + #[tracing::instrument(skip(self), level = "trace")] + pub async fn get_knock_count(&self, room_id: &RoomId, user_id: &UserId) -> Result { + let key = (room_id, user_id); + self.db + .roomuserid_knockedcount + .qry(&key) + .await + .deserialized() + } + + #[tracing::instrument(skip(self), level = "trace")] + pub async fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result { + let key = (room_id, user_id); + self.db.roomuserid_leftcount.qry(&key).await.deserialized() + } + + /// Returns an iterator over all rooms this user joined. + #[tracing::instrument(skip(self), level = "debug")] + pub fn rooms_joined<'a>( + &'a self, + user_id: &'a UserId, + ) -> impl Stream + Send + 'a { + self.db + .userroomid_joined + .keys_raw_prefix(user_id) + .ignore_err() + .map(|(_, room_id): (Ignore, &RoomId)| room_id) + } + + /// Returns an iterator over all rooms a user was invited to. + #[tracing::instrument(skip(self), level = "debug")] + pub fn rooms_invited<'a>( + &'a self, + user_id: &'a UserId, + ) -> impl Stream + Send + 'a { + type KeyVal<'a> = (Key<'a>, Raw>); + type Key<'a> = (&'a UserId, &'a RoomId); + + let prefix = (user_id, Interfix); + self.db + .userroomid_invitestate + .stream_prefix(&prefix) + .ignore_err() + .map(|((_, room_id), state): KeyVal<'_>| (room_id.to_owned(), state)) + .map(|(room_id, state)| Ok((room_id, state.deserialize_as()?))) + .ignore_err() + } + + /// Returns an iterator over all rooms a user is currently knocking. + #[tracing::instrument(skip(self), level = "trace")] + pub fn rooms_knocked<'a>( + &'a self, + user_id: &'a UserId, + ) -> impl Stream + Send + 'a { + type KeyVal<'a> = (Key<'a>, Raw>); + type Key<'a> = (&'a UserId, &'a RoomId); + + let prefix = (user_id, Interfix); + self.db + .userroomid_knockedstate + .stream_prefix(&prefix) + .ignore_err() + .map(|((_, room_id), state): KeyVal<'_>| (room_id.to_owned(), state)) + .map(|(room_id, state)| Ok((room_id, state.deserialize_as()?))) + .ignore_err() + } + + #[tracing::instrument(skip(self), level = "trace")] + pub async fn invite_state( + &self, + user_id: &UserId, + room_id: &RoomId, + ) -> Result>> { + let key = (user_id, room_id); + self.db + .userroomid_invitestate + .qry(&key) + .await + .deserialized() + .and_then(|val: Raw>| { + val.deserialize_as().map_err(Into::into) + }) + } + + #[tracing::instrument(skip(self), level = "trace")] + pub async fn knock_state( + &self, + user_id: &UserId, + room_id: &RoomId, + ) -> Result>> { + let key = (user_id, room_id); + self.db + .userroomid_knockedstate + .qry(&key) + .await + .deserialized() + .and_then(|val: Raw>| { + val.deserialize_as().map_err(Into::into) + }) + } + + #[tracing::instrument(skip(self), level = "trace")] + pub async fn left_state( + &self, + user_id: &UserId, + room_id: &RoomId, + ) -> Result>> { + let key = (user_id, room_id); + self.db + .userroomid_leftstate + .qry(&key) + .await + .deserialized() + .and_then(|val: Raw>| { + val.deserialize_as().map_err(Into::into) + }) + } + + /// Returns an iterator over all rooms a user left. + #[tracing::instrument(skip(self), level = "debug")] + pub fn rooms_left<'a>( + &'a self, + user_id: &'a UserId, + ) -> impl Stream + Send + 'a { + type KeyVal<'a> = (Key<'a>, Raw>>); + type Key<'a> = (&'a UserId, &'a RoomId); + + let prefix = (user_id, Interfix); + self.db + .userroomid_leftstate + .stream_prefix(&prefix) + .ignore_err() + .map(|((_, room_id), state): KeyVal<'_>| (room_id.to_owned(), state)) + .map(|(room_id, state)| Ok((room_id, state.deserialize_as()?))) + .ignore_err() + } + + #[tracing::instrument(skip(self), level = "debug")] + pub async fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> bool { + let key = (user_id, room_id); + self.db.roomuseroncejoinedids.qry(&key).await.is_ok() + } + + #[tracing::instrument(skip(self), level = "trace")] + pub async fn is_joined<'a>(&'a self, user_id: &'a UserId, room_id: &'a RoomId) -> bool { + let key = (user_id, room_id); + self.db.userroomid_joined.qry(&key).await.is_ok() + } + + #[tracing::instrument(skip(self), level = "trace")] + pub async fn is_knocked<'a>(&'a self, user_id: &'a UserId, room_id: &'a RoomId) -> bool { + let key = (user_id, room_id); + self.db.userroomid_knockedstate.qry(&key).await.is_ok() + } + + #[tracing::instrument(skip(self), level = "trace")] + pub async fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> bool { + let key = (user_id, room_id); + self.db.userroomid_invitestate.qry(&key).await.is_ok() + } + + #[tracing::instrument(skip(self), level = "trace")] + pub async fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> bool { + let key = (user_id, room_id); + self.db.userroomid_leftstate.qry(&key).await.is_ok() + } + + #[tracing::instrument(skip(self), level = "trace")] + pub async fn user_membership( + &self, + user_id: &UserId, + room_id: &RoomId, + ) -> Option { + let states = join5( + self.is_joined(user_id, room_id), + self.is_left(user_id, room_id), + self.is_knocked(user_id, room_id), + self.is_invited(user_id, room_id), + self.once_joined(user_id, room_id), + ) + .await; + + match states { + | (true, ..) => Some(MembershipState::Join), + | (_, true, ..) => Some(MembershipState::Leave), + | (_, _, true, ..) => Some(MembershipState::Knock), + | (_, _, _, true, ..) => Some(MembershipState::Invite), + | (false, false, false, false, true) => Some(MembershipState::Ban), + | _ => None, + } + } + + #[tracing::instrument(skip(self), level = "debug")] + pub fn servers_invite_via<'a>( + &'a self, + room_id: &'a RoomId, + ) -> impl Stream + Send + 'a { + type KeyVal<'a> = (Ignore, Vec<&'a ServerName>); + + self.db + .roomid_inviteviaservers + .stream_raw_prefix(room_id) + .ignore_err() + .map(|(_, servers): KeyVal<'_>| *servers.last().expect("at least one server")) + } + + /// Gets up to five servers that are likely to be in the room in the + /// distant future. + /// + /// See + #[tracing::instrument(skip(self), level = "trace")] + pub async fn servers_route_via(&self, room_id: &RoomId) -> Result> { + let most_powerful_user_server = self + .services + .state_accessor + .room_state_get_content(room_id, &StateEventType::RoomPowerLevels, "") + .await + .map(|content: RoomPowerLevelsEventContent| { + content + .users + .iter() + .max_by_key(|(_, power)| *power) + .and_then(|x| (x.1 >= &int!(50)).then_some(x)) + .map(|(user, _power)| user.server_name().to_owned()) + }); + + let mut servers: Vec = self .room_members(room_id) - .ready_any(|user_id| appservice.users.is_match(user_id.as_str())) + .counts_by(|user| user.server_name().to_owned()) + .await + .into_iter() + .sorted_by_key(|(_, users)| *users) + .map(|(server, _)| server) + .rev() + .take(5) + .collect(); + + if let Ok(Some(server)) = most_powerful_user_server { + servers.insert(0, server); + servers.truncate(5); + } + + Ok(servers) + } + + pub fn get_appservice_in_room_cache_usage(&self) -> (usize, usize) { + let cache = self.appservice_in_room_cache.read().expect("locked"); + + (cache.len(), cache.capacity()) + } + + #[tracing::instrument(level = "debug", skip_all)] + pub fn clear_appservice_in_room_cache(&self) { + self.appservice_in_room_cache + .write() + .expect("locked") + .clear(); + } + + #[tracing::instrument(level = "debug", skip(self))] + pub async fn update_joined_count(&self, room_id: &RoomId) { + let mut joinedcount = 0_u64; + let mut invitedcount = 0_u64; + let mut knockedcount = 0_u64; + let mut joined_servers = HashSet::new(); + + self.room_members(room_id) + .ready_for_each(|joined| { + joined_servers.insert(joined.server_name().to_owned()); + joinedcount = joinedcount.saturating_add(1); + }) .await; - self.appservice_in_room_cache - .write() - .expect("locked") - .entry(room_id.into()) - .or_default() - .insert(appservice.registration.id.clone(), in_room); + invitedcount = invitedcount.saturating_add( + self.room_members_invited(room_id) + .count() + .await + .try_into() + .unwrap_or(0), + ); - in_room -} + knockedcount = knockedcount.saturating_add( + self.room_members_knocked(room_id) + .count() + .await + .try_into() + .unwrap_or(0), + ); -#[implement(Service)] -pub fn get_appservice_in_room_cache_usage(&self) -> (usize, usize) { - let cache = self.appservice_in_room_cache.read().expect("locked"); + self.db.roomid_joinedcount.raw_put(room_id, joinedcount); + self.db.roomid_invitedcount.raw_put(room_id, invitedcount); + self.db + .roomuserid_knockedcount + .raw_put(room_id, knockedcount); - (cache.len(), cache.capacity()) -} + self.room_servers(room_id) + .ready_for_each(|old_joined_server| { + if joined_servers.remove(old_joined_server) { + return; + } -#[implement(Service)] -#[tracing::instrument(level = "debug", skip_all)] -pub fn clear_appservice_in_room_cache(&self) { - self.appservice_in_room_cache - .write() - .expect("locked") - .clear(); -} + // Server not in room anymore + let roomserver_id = (room_id, old_joined_server); + let serverroom_id = (old_joined_server, room_id); -/// Returns an iterator of all servers participating in this room. -#[implement(Service)] -#[tracing::instrument(skip(self), level = "debug")] -pub fn room_servers<'a>( - &'a self, - room_id: &'a RoomId, -) -> impl Stream + Send + 'a { - let prefix = (room_id, Interfix); - self.db - .roomserverids - .keys_prefix(&prefix) - .ignore_err() - .map(|(_, server): (Ignore, &ServerName)| server) -} + self.db.roomserverids.del(roomserver_id); + self.db.serverroomids.del(serverroom_id); + }) + .await; -#[implement(Service)] -#[tracing::instrument(skip(self), level = "trace")] -pub async fn server_in_room<'a>(&'a self, server: &'a ServerName, room_id: &'a RoomId) -> bool { - let key = (server, room_id); - self.db.serverroomids.qry(&key).await.is_ok() -} + // Now only new servers are in joined_servers anymore + for server in &joined_servers { + let roomserver_id = (room_id, server); + let serverroom_id = (server, room_id); -/// Returns an iterator of all rooms a server participates in (as far as we -/// know). -#[implement(Service)] -#[tracing::instrument(skip(self), level = "debug")] -pub fn server_rooms<'a>( - &'a self, - server: &'a ServerName, -) -> impl Stream + Send + 'a { - let prefix = (server, Interfix); - self.db - .serverroomids - .keys_prefix(&prefix) - .ignore_err() - .map(|(_, room_id): (Ignore, &RoomId)| room_id) -} + self.db.roomserverids.put_raw(roomserver_id, []); + self.db.serverroomids.put_raw(serverroom_id, []); + } -/// Returns true if server can see user by sharing at least one room. -#[implement(Service)] -#[tracing::instrument(skip(self), level = "trace")] -pub async fn server_sees_user(&self, server: &ServerName, user_id: &UserId) -> bool { - self.server_rooms(server) - .any(|room_id| self.is_joined(user_id, room_id)) - .await -} + self.appservice_in_room_cache + .write() + .expect("locked") + .remove(room_id); + } -/// Returns true if user_a and user_b share at least one room. -#[implement(Service)] -#[tracing::instrument(skip(self), level = "trace")] -pub async fn user_sees_user(&self, user_a: &UserId, user_b: &UserId) -> bool { - let get_shared_rooms = self.get_shared_rooms(user_a, user_b); + #[tracing::instrument(level = "debug", skip(self))] + fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) { + let key = (user_id, room_id); + self.db.roomuseroncejoinedids.put_raw(key, []); + } - pin_mut!(get_shared_rooms); - get_shared_rooms.next().await.is_some() -} + #[tracing::instrument(level = "debug", skip(self, last_state, invite_via))] + pub async fn mark_as_invited( + &self, + user_id: &UserId, + room_id: &RoomId, + last_state: Option>>, + invite_via: Option>, + ) { + let roomuser_id = (room_id, user_id); + let roomuser_id = serialize_key(roomuser_id).expect("failed to serialize roomuser_id"); -/// List the rooms common between two users -#[implement(Service)] -#[tracing::instrument(skip(self), level = "debug")] -pub fn get_shared_rooms<'a>( - &'a self, - user_a: &'a UserId, - user_b: &'a UserId, -) -> impl Stream + Send + 'a { - use conduwuit::utils::set; + let userroom_id = (user_id, room_id); + let userroom_id = serialize_key(userroom_id).expect("failed to serialize userroom_id"); - let a = self.rooms_joined(user_a); - let b = self.rooms_joined(user_b); - set::intersection_sorted_stream2(a, b) -} + self.db + .userroomid_invitestate + .raw_put(&userroom_id, Json(last_state.unwrap_or_default())); + self.db + .roomuserid_invitecount + .raw_aput::<8, _, _>(&roomuser_id, self.services.globals.next_count().unwrap()); -/// Returns an iterator of all joined members of a room. -#[implement(Service)] -#[tracing::instrument(skip(self), level = "debug")] -pub fn room_members<'a>( - &'a self, - room_id: &'a RoomId, -) -> impl Stream + Send + 'a { - let prefix = (room_id, Interfix); - self.db - .roomuserid_joined - .keys_prefix(&prefix) - .ignore_err() - .map(|(_, user_id): (Ignore, &UserId)| user_id) -} + self.db.userroomid_joined.remove(&userroom_id); + self.db.roomuserid_joined.remove(&roomuser_id); -/// Returns the number of users which are currently in a room -#[implement(Service)] -#[tracing::instrument(skip(self), level = "trace")] -pub async fn room_joined_count(&self, room_id: &RoomId) -> Result { - self.db.roomid_joinedcount.get(room_id).await.deserialized() -} + self.db.userroomid_leftstate.remove(&userroom_id); + self.db.roomuserid_leftcount.remove(&roomuser_id); -#[implement(Service)] -#[tracing::instrument(skip(self), level = "debug")] -/// Returns an iterator of all our local users in the room, even if they're -/// deactivated/guests -pub fn local_users_in_room<'a>( - &'a self, - room_id: &'a RoomId, -) -> impl Stream + Send + 'a { - self.room_members(room_id) - .ready_filter(|user| self.services.globals.user_is_local(user)) -} + self.db.userroomid_knockedstate.remove(&userroom_id); + self.db.roomuserid_knockedcount.remove(&roomuser_id); -/// Returns an iterator of all our local joined users in a room who are -/// active (not deactivated, not guest) -#[implement(Service)] -#[tracing::instrument(skip(self), level = "trace")] -pub fn active_local_users_in_room<'a>( - &'a self, - room_id: &'a RoomId, -) -> impl Stream + Send + 'a { - self.local_users_in_room(room_id) - .filter(|user| self.services.users.is_active(user)) -} + if let Some(servers) = invite_via.filter(is_not_empty!()) { + self.add_servers_invite_via(room_id, servers).await; + } + } -/// Returns the number of users which are currently invited to a room -#[implement(Service)] -#[tracing::instrument(skip(self), level = "trace")] -pub async fn room_invited_count(&self, room_id: &RoomId) -> Result { - self.db - .roomid_invitedcount - .get(room_id) - .await - .deserialized() -} + #[tracing::instrument(level = "debug", skip(self, servers))] + pub async fn add_servers_invite_via(&self, room_id: &RoomId, servers: Vec) { + let mut servers: Vec<_> = self + .servers_invite_via(room_id) + .map(ToOwned::to_owned) + .chain(iter(servers.into_iter())) + .collect() + .await; -/// Returns an iterator over all User IDs who ever joined a room. -#[implement(Service)] -#[tracing::instrument(skip(self), level = "debug")] -pub fn room_useroncejoined<'a>( - &'a self, - room_id: &'a RoomId, -) -> impl Stream + Send + 'a { - let prefix = (room_id, Interfix); - self.db - .roomuseroncejoinedids - .keys_prefix(&prefix) - .ignore_err() - .map(|(_, user_id): (Ignore, &UserId)| user_id) -} + servers.sort_unstable(); + servers.dedup(); -/// Returns an iterator over all invited members of a room. -#[implement(Service)] -#[tracing::instrument(skip(self), level = "debug")] -pub fn room_members_invited<'a>( - &'a self, - room_id: &'a RoomId, -) -> impl Stream + Send + 'a { - let prefix = (room_id, Interfix); - self.db - .roomuserid_invitecount - .keys_prefix(&prefix) - .ignore_err() - .map(|(_, user_id): (Ignore, &UserId)| user_id) -} + let servers = servers + .iter() + .map(|server| server.as_bytes()) + .collect_vec() + .join(&[0xFF][..]); -/// Returns an iterator over all knocked members of a room. -#[implement(Service)] -#[tracing::instrument(skip(self), level = "debug")] -pub fn room_members_knocked<'a>( - &'a self, - room_id: &'a RoomId, -) -> impl Stream + Send + 'a { - let prefix = (room_id, Interfix); - self.db - .roomuserid_knockedcount - .keys_prefix(&prefix) - .ignore_err() - .map(|(_, user_id): (Ignore, &UserId)| user_id) -} - -#[implement(Service)] -#[tracing::instrument(skip(self), level = "trace")] -pub async fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result { - let key = (room_id, user_id); - self.db - .roomuserid_invitecount - .qry(&key) - .await - .deserialized() -} - -#[implement(Service)] -#[tracing::instrument(skip(self), level = "trace")] -pub async fn get_knock_count(&self, room_id: &RoomId, user_id: &UserId) -> Result { - let key = (room_id, user_id); - self.db - .roomuserid_knockedcount - .qry(&key) - .await - .deserialized() -} - -#[implement(Service)] -#[tracing::instrument(skip(self), level = "trace")] -pub async fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result { - let key = (room_id, user_id); - self.db.roomuserid_leftcount.qry(&key).await.deserialized() -} - -/// Returns an iterator over all rooms this user joined. -#[implement(Service)] -#[tracing::instrument(skip(self), level = "debug")] -pub fn rooms_joined<'a>( - &'a self, - user_id: &'a UserId, -) -> impl Stream + Send + 'a { - self.db - .userroomid_joined - .keys_raw_prefix(user_id) - .ignore_err() - .map(|(_, room_id): (Ignore, &RoomId)| room_id) -} - -/// Returns an iterator over all rooms a user was invited to. -#[implement(Service)] -#[tracing::instrument(skip(self), level = "debug")] -pub fn rooms_invited<'a>( - &'a self, - user_id: &'a UserId, -) -> impl Stream + Send + 'a { - type KeyVal<'a> = (Key<'a>, Raw>); - type Key<'a> = (&'a UserId, &'a RoomId); - - let prefix = (user_id, Interfix); - self.db - .userroomid_invitestate - .stream_prefix(&prefix) - .ignore_err() - .map(|((_, room_id), state): KeyVal<'_>| (room_id.to_owned(), state)) - .map(|(room_id, state)| Ok((room_id, state.deserialize_as()?))) - .ignore_err() -} - -/// Returns an iterator over all rooms a user is currently knocking. -#[implement(Service)] -#[tracing::instrument(skip(self), level = "trace")] -pub fn rooms_knocked<'a>( - &'a self, - user_id: &'a UserId, -) -> impl Stream + Send + 'a { - type KeyVal<'a> = (Key<'a>, Raw>); - type Key<'a> = (&'a UserId, &'a RoomId); - - let prefix = (user_id, Interfix); - self.db - .userroomid_knockedstate - .stream_prefix(&prefix) - .ignore_err() - .map(|((_, room_id), state): KeyVal<'_>| (room_id.to_owned(), state)) - .map(|(room_id, state)| Ok((room_id, state.deserialize_as()?))) - .ignore_err() -} - -#[implement(Service)] -#[tracing::instrument(skip(self), level = "trace")] -pub async fn invite_state( - &self, - user_id: &UserId, - room_id: &RoomId, -) -> Result>> { - let key = (user_id, room_id); - self.db - .userroomid_invitestate - .qry(&key) - .await - .deserialized() - .and_then(|val: Raw>| val.deserialize_as().map_err(Into::into)) -} - -#[implement(Service)] -#[tracing::instrument(skip(self), level = "trace")] -pub async fn knock_state( - &self, - user_id: &UserId, - room_id: &RoomId, -) -> Result>> { - let key = (user_id, room_id); - self.db - .userroomid_knockedstate - .qry(&key) - .await - .deserialized() - .and_then(|val: Raw>| val.deserialize_as().map_err(Into::into)) -} - -#[implement(Service)] -#[tracing::instrument(skip(self), level = "trace")] -pub async fn left_state( - &self, - user_id: &UserId, - room_id: &RoomId, -) -> Result>> { - let key = (user_id, room_id); - self.db - .userroomid_leftstate - .qry(&key) - .await - .deserialized() - .and_then(|val: Raw>| val.deserialize_as().map_err(Into::into)) -} - -/// Returns an iterator over all rooms a user left. -#[implement(Service)] -#[tracing::instrument(skip(self), level = "debug")] -pub fn rooms_left<'a>( - &'a self, - user_id: &'a UserId, -) -> impl Stream + Send + 'a { - type KeyVal<'a> = (Key<'a>, Raw>>); - type Key<'a> = (&'a UserId, &'a RoomId); - - let prefix = (user_id, Interfix); - self.db - .userroomid_leftstate - .stream_prefix(&prefix) - .ignore_err() - .map(|((_, room_id), state): KeyVal<'_>| (room_id.to_owned(), state)) - .map(|(room_id, state)| Ok((room_id, state.deserialize_as()?))) - .ignore_err() -} - -#[implement(Service)] -#[tracing::instrument(skip(self), level = "trace")] -pub async fn user_membership( - &self, - user_id: &UserId, - room_id: &RoomId, -) -> Option { - let states = join5( - self.is_joined(user_id, room_id), - self.is_left(user_id, room_id), - self.is_knocked(user_id, room_id), - self.is_invited(user_id, room_id), - self.once_joined(user_id, room_id), - ) - .await; - - match states { - | (true, ..) => Some(MembershipState::Join), - | (_, true, ..) => Some(MembershipState::Leave), - | (_, _, true, ..) => Some(MembershipState::Knock), - | (_, _, _, true, ..) => Some(MembershipState::Invite), - | (false, false, false, false, true) => Some(MembershipState::Ban), - | _ => None, + self.db + .roomid_inviteviaservers + .insert(room_id.as_bytes(), &servers); } } - -#[implement(Service)] -#[tracing::instrument(skip(self), level = "debug")] -pub async fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> bool { - let key = (user_id, room_id); - self.db.roomuseroncejoinedids.qry(&key).await.is_ok() -} - -#[implement(Service)] -#[tracing::instrument(skip(self), level = "trace")] -pub async fn is_joined<'a>(&'a self, user_id: &'a UserId, room_id: &'a RoomId) -> bool { - let key = (user_id, room_id); - self.db.userroomid_joined.qry(&key).await.is_ok() -} - -#[implement(Service)] -#[tracing::instrument(skip(self), level = "trace")] -pub async fn is_knocked<'a>(&'a self, user_id: &'a UserId, room_id: &'a RoomId) -> bool { - let key = (user_id, room_id); - self.db.userroomid_knockedstate.qry(&key).await.is_ok() -} - -#[implement(Service)] -#[tracing::instrument(skip(self), level = "trace")] -pub async fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> bool { - let key = (user_id, room_id); - self.db.userroomid_invitestate.qry(&key).await.is_ok() -} - -#[implement(Service)] -#[tracing::instrument(skip(self), level = "trace")] -pub async fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> bool { - let key = (user_id, room_id); - self.db.userroomid_leftstate.qry(&key).await.is_ok() -} diff --git a/src/service/rooms/state_cache/update.rs b/src/service/rooms/state_cache/update.rs deleted file mode 100644 index 02c6bec6..00000000 --- a/src/service/rooms/state_cache/update.rs +++ /dev/null @@ -1,369 +0,0 @@ -use std::collections::HashSet; - -use conduwuit::{Result, implement, is_not_empty, utils::ReadyExt, warn}; -use database::{Json, serialize_key}; -use futures::StreamExt; -use ruma::{ - OwnedServerName, RoomId, UserId, - events::{ - AnyStrippedStateEvent, AnySyncStateEvent, GlobalAccountDataEventType, - RoomAccountDataEventType, StateEventType, - direct::DirectEvent, - room::{ - create::RoomCreateEventContent, - member::{MembershipState, RoomMemberEventContent}, - }, - }, - serde::Raw, -}; - -/// Update current membership data. -#[implement(super::Service)] -#[tracing::instrument( - level = "debug", - skip_all, - fields( - %room_id, - %user_id, - %sender, - ?membership_event, - ), - )] -#[allow(clippy::too_many_arguments)] -pub async fn update_membership( - &self, - room_id: &RoomId, - user_id: &UserId, - membership_event: RoomMemberEventContent, - sender: &UserId, - last_state: Option>>, - invite_via: Option>, - update_joined_count: bool, -) -> Result { - let membership = membership_event.membership; - - // Keep track what remote users exist by adding them as "deactivated" users - // - // TODO: use futures to update remote profiles without blocking the membership - // update - #[allow(clippy::collapsible_if)] - if !self.services.globals.user_is_local(user_id) { - if !self.services.users.exists(user_id).await { - self.services.users.create(user_id, None)?; - } - } - - match &membership { - | MembershipState::Join => { - // Check if the user never joined this room - if !self.once_joined(user_id, room_id).await { - // Add the user ID to the join list then - self.mark_as_once_joined(user_id, room_id); - - // Check if the room has a predecessor - if let Ok(Some(predecessor)) = self - .services - .state_accessor - .room_state_get_content(room_id, &StateEventType::RoomCreate, "") - .await - .map(|content: RoomCreateEventContent| content.predecessor) - { - // Copy old tags to new room - if let Ok(tag_event) = self - .services - .account_data - .get_room(&predecessor.room_id, user_id, RoomAccountDataEventType::Tag) - .await - { - self.services - .account_data - .update( - Some(room_id), - user_id, - RoomAccountDataEventType::Tag, - &tag_event, - ) - .await - .ok(); - } - - // Copy direct chat flag - if let Ok(mut direct_event) = self - .services - .account_data - .get_global::(user_id, GlobalAccountDataEventType::Direct) - .await - { - let mut room_ids_updated = false; - for room_ids in direct_event.content.0.values_mut() { - if room_ids.iter().any(|r| r == &predecessor.room_id) { - room_ids.push(room_id.to_owned()); - room_ids_updated = true; - } - } - - if room_ids_updated { - self.services - .account_data - .update( - None, - user_id, - GlobalAccountDataEventType::Direct.to_string().into(), - &serde_json::to_value(&direct_event) - .expect("to json always works"), - ) - .await?; - } - } - } - } - - self.mark_as_joined(user_id, room_id); - }, - | MembershipState::Invite => { - // We want to know if the sender is ignored by the receiver - if self.services.users.user_is_ignored(sender, user_id).await { - return Ok(()); - } - - self.mark_as_invited(user_id, room_id, last_state, invite_via) - .await; - }, - | MembershipState::Leave | MembershipState::Ban => { - self.mark_as_left(user_id, room_id); - - if self.services.globals.user_is_local(user_id) - && (self.services.config.forget_forced_upon_leave - || self.services.metadata.is_banned(room_id).await - || self.services.metadata.is_disabled(room_id).await) - { - self.forget(room_id, user_id); - } - }, - | _ => {}, - } - - if update_joined_count { - self.update_joined_count(room_id).await; - } - - Ok(()) -} - -#[implement(super::Service)] -#[tracing::instrument(level = "debug", skip(self))] -pub async fn update_joined_count(&self, room_id: &RoomId) { - let mut joinedcount = 0_u64; - let mut invitedcount = 0_u64; - let mut knockedcount = 0_u64; - let mut joined_servers = HashSet::new(); - - self.room_members(room_id) - .ready_for_each(|joined| { - joined_servers.insert(joined.server_name().to_owned()); - joinedcount = joinedcount.saturating_add(1); - }) - .await; - - invitedcount = invitedcount.saturating_add( - self.room_members_invited(room_id) - .count() - .await - .try_into() - .unwrap_or(0), - ); - - knockedcount = knockedcount.saturating_add( - self.room_members_knocked(room_id) - .count() - .await - .try_into() - .unwrap_or(0), - ); - - self.db.roomid_joinedcount.raw_put(room_id, joinedcount); - self.db.roomid_invitedcount.raw_put(room_id, invitedcount); - self.db - .roomuserid_knockedcount - .raw_put(room_id, knockedcount); - - self.room_servers(room_id) - .ready_for_each(|old_joined_server| { - if joined_servers.remove(old_joined_server) { - return; - } - - // Server not in room anymore - let roomserver_id = (room_id, old_joined_server); - let serverroom_id = (old_joined_server, room_id); - - self.db.roomserverids.del(roomserver_id); - self.db.serverroomids.del(serverroom_id); - }) - .await; - - // Now only new servers are in joined_servers anymore - for server in &joined_servers { - let roomserver_id = (room_id, server); - let serverroom_id = (server, room_id); - - self.db.roomserverids.put_raw(roomserver_id, []); - self.db.serverroomids.put_raw(serverroom_id, []); - } - - self.appservice_in_room_cache - .write() - .expect("locked") - .remove(room_id); -} - -/// Direct DB function to directly mark a user as joined. It is not -/// recommended to use this directly. You most likely should use -/// `update_membership` instead -#[implement(super::Service)] -#[tracing::instrument(skip(self), level = "debug")] -pub fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) { - let userroom_id = (user_id, room_id); - let userroom_id = serialize_key(userroom_id).expect("failed to serialize userroom_id"); - - let roomuser_id = (room_id, user_id); - let roomuser_id = serialize_key(roomuser_id).expect("failed to serialize roomuser_id"); - - self.db.userroomid_joined.insert(&userroom_id, []); - self.db.roomuserid_joined.insert(&roomuser_id, []); - - self.db.userroomid_invitestate.remove(&userroom_id); - self.db.roomuserid_invitecount.remove(&roomuser_id); - - self.db.userroomid_leftstate.remove(&userroom_id); - self.db.roomuserid_leftcount.remove(&roomuser_id); - - self.db.userroomid_knockedstate.remove(&userroom_id); - self.db.roomuserid_knockedcount.remove(&roomuser_id); - - self.db.roomid_inviteviaservers.remove(room_id); -} - -/// Direct DB function to directly mark a user as left. It is not -/// recommended to use this directly. You most likely should use -/// `update_membership` instead -#[implement(super::Service)] -#[tracing::instrument(skip(self), level = "debug")] -pub fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) { - let userroom_id = (user_id, room_id); - let userroom_id = serialize_key(userroom_id).expect("failed to serialize userroom_id"); - - let roomuser_id = (room_id, user_id); - let roomuser_id = serialize_key(roomuser_id).expect("failed to serialize roomuser_id"); - - // (timo) TODO - let leftstate = Vec::>::new(); - - self.db - .userroomid_leftstate - .raw_put(&userroom_id, Json(leftstate)); - self.db - .roomuserid_leftcount - .raw_aput::<8, _, _>(&roomuser_id, self.services.globals.next_count().unwrap()); - - self.db.userroomid_joined.remove(&userroom_id); - self.db.roomuserid_joined.remove(&roomuser_id); - - self.db.userroomid_invitestate.remove(&userroom_id); - self.db.roomuserid_invitecount.remove(&roomuser_id); - - self.db.userroomid_knockedstate.remove(&userroom_id); - self.db.roomuserid_knockedcount.remove(&roomuser_id); - - self.db.roomid_inviteviaservers.remove(room_id); -} - -/// Direct DB function to directly mark a user as knocked. It is not -/// recommended to use this directly. You most likely should use -/// `update_membership` instead -#[implement(super::Service)] -#[tracing::instrument(skip(self), level = "debug")] -pub fn mark_as_knocked( - &self, - user_id: &UserId, - room_id: &RoomId, - knocked_state: Option>>, -) { - let userroom_id = (user_id, room_id); - let userroom_id = serialize_key(userroom_id).expect("failed to serialize userroom_id"); - - let roomuser_id = (room_id, user_id); - let roomuser_id = serialize_key(roomuser_id).expect("failed to serialize roomuser_id"); - - self.db - .userroomid_knockedstate - .raw_put(&userroom_id, Json(knocked_state.unwrap_or_default())); - self.db - .roomuserid_knockedcount - .raw_aput::<8, _, _>(&roomuser_id, self.services.globals.next_count().unwrap()); - - self.db.userroomid_joined.remove(&userroom_id); - self.db.roomuserid_joined.remove(&roomuser_id); - - self.db.userroomid_invitestate.remove(&userroom_id); - self.db.roomuserid_invitecount.remove(&roomuser_id); - - self.db.userroomid_leftstate.remove(&userroom_id); - self.db.roomuserid_leftcount.remove(&roomuser_id); - - self.db.roomid_inviteviaservers.remove(room_id); -} - -/// Makes a user forget a room. -#[implement(super::Service)] -#[tracing::instrument(skip(self), level = "debug")] -pub fn forget(&self, room_id: &RoomId, user_id: &UserId) { - let userroom_id = (user_id, room_id); - let roomuser_id = (room_id, user_id); - - self.db.userroomid_leftstate.del(userroom_id); - self.db.roomuserid_leftcount.del(roomuser_id); -} - -#[implement(super::Service)] -#[tracing::instrument(level = "debug", skip(self))] -fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) { - let key = (user_id, room_id); - self.db.roomuseroncejoinedids.put_raw(key, []); -} - -#[implement(super::Service)] -#[tracing::instrument(level = "debug", skip(self, last_state, invite_via))] -pub async fn mark_as_invited( - &self, - user_id: &UserId, - room_id: &RoomId, - last_state: Option>>, - invite_via: Option>, -) { - let roomuser_id = (room_id, user_id); - let roomuser_id = serialize_key(roomuser_id).expect("failed to serialize roomuser_id"); - - let userroom_id = (user_id, room_id); - let userroom_id = serialize_key(userroom_id).expect("failed to serialize userroom_id"); - - self.db - .userroomid_invitestate - .raw_put(&userroom_id, Json(last_state.unwrap_or_default())); - self.db - .roomuserid_invitecount - .raw_aput::<8, _, _>(&roomuser_id, self.services.globals.next_count().unwrap()); - - self.db.userroomid_joined.remove(&userroom_id); - self.db.roomuserid_joined.remove(&roomuser_id); - - self.db.userroomid_leftstate.remove(&userroom_id); - self.db.roomuserid_leftcount.remove(&roomuser_id); - - self.db.userroomid_knockedstate.remove(&userroom_id); - self.db.roomuserid_knockedcount.remove(&roomuser_id); - - if let Some(servers) = invite_via.filter(is_not_empty!()) { - self.add_servers_invite_via(room_id, servers).await; - } -} diff --git a/src/service/rooms/state_cache/via.rs b/src/service/rooms/state_cache/via.rs deleted file mode 100644 index a818cc04..00000000 --- a/src/service/rooms/state_cache/via.rs +++ /dev/null @@ -1,92 +0,0 @@ -use conduwuit::{ - Result, implement, - utils::{StreamTools, stream::TryIgnore}, - warn, -}; -use database::Ignore; -use futures::{Stream, StreamExt, stream::iter}; -use itertools::Itertools; -use ruma::{ - OwnedServerName, RoomId, ServerName, - events::{StateEventType, room::power_levels::RoomPowerLevelsEventContent}, - int, -}; - -#[implement(super::Service)] -#[tracing::instrument(level = "debug", skip(self, servers))] -pub async fn add_servers_invite_via(&self, room_id: &RoomId, servers: Vec) { - let mut servers: Vec<_> = self - .servers_invite_via(room_id) - .map(ToOwned::to_owned) - .chain(iter(servers.into_iter())) - .collect() - .await; - - servers.sort_unstable(); - servers.dedup(); - - let servers = servers - .iter() - .map(|server| server.as_bytes()) - .collect_vec() - .join(&[0xFF][..]); - - self.db - .roomid_inviteviaservers - .insert(room_id.as_bytes(), &servers); -} - -/// Gets up to five servers that are likely to be in the room in the -/// distant future. -/// -/// See -#[implement(super::Service)] -#[tracing::instrument(skip(self), level = "trace")] -pub async fn servers_route_via(&self, room_id: &RoomId) -> Result> { - let most_powerful_user_server = self - .services - .state_accessor - .room_state_get_content(room_id, &StateEventType::RoomPowerLevels, "") - .await - .map(|content: RoomPowerLevelsEventContent| { - content - .users - .iter() - .max_by_key(|(_, power)| *power) - .and_then(|x| (x.1 >= &int!(50)).then_some(x)) - .map(|(user, _power)| user.server_name().to_owned()) - }); - - let mut servers: Vec = self - .room_members(room_id) - .counts_by(|user| user.server_name().to_owned()) - .await - .into_iter() - .sorted_by_key(|(_, users)| *users) - .map(|(server, _)| server) - .rev() - .take(5) - .collect(); - - if let Ok(Some(server)) = most_powerful_user_server { - servers.insert(0, server); - servers.truncate(5); - } - - Ok(servers) -} - -#[implement(super::Service)] -#[tracing::instrument(skip(self), level = "debug")] -pub fn servers_invite_via<'a>( - &'a self, - room_id: &'a RoomId, -) -> impl Stream + Send + 'a { - type KeyVal<'a> = (Ignore, Vec<&'a ServerName>); - - self.db - .roomid_inviteviaservers - .stream_raw_prefix(room_id) - .ignore_err() - .map(|(_, servers): KeyVal<'_>| *servers.last().expect("at least one server")) -} diff --git a/src/service/rooms/state_compressor/mod.rs b/src/service/rooms/state_compressor/mod.rs index a33fb342..56a91d0e 100644 --- a/src/service/rooms/state_compressor/mod.rs +++ b/src/service/rooms/state_compressor/mod.rs @@ -9,7 +9,7 @@ use async_trait::async_trait; use conduwuit::{ Result, arrayvec::ArrayVec, - at, checked, err, expected, implement, utils, + at, checked, err, expected, utils, utils::{bytes, math::usize_from_f64, stream::IterStream}, }; use database::Map; @@ -115,30 +115,29 @@ impl crate::Service for Service { fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } -/// Returns a stack with info on shortstatehash, full state, added diff and -/// removed diff for the selected shortstatehash and each parent layer. -#[implement(Service)] -#[tracing::instrument(name = "load", level = "debug", skip(self))] -pub async fn load_shortstatehash_info( - &self, - shortstatehash: ShortStateHash, -) -> Result { - if let Some(r) = self.stateinfo_cache.lock()?.get_mut(&shortstatehash) { - return Ok(r.clone()); +impl Service { + /// Returns a stack with info on shortstatehash, full state, added diff and + /// removed diff for the selected shortstatehash and each parent layer. + #[tracing::instrument(name = "load", level = "debug", skip(self))] + pub async fn load_shortstatehash_info( + &self, + shortstatehash: ShortStateHash, + ) -> Result { + if let Some(r) = self.stateinfo_cache.lock()?.get_mut(&shortstatehash) { + return Ok(r.clone()); + } + + let stack = self.new_shortstatehash_info(shortstatehash).await?; + + self.cache_shortstatehash_info(shortstatehash, stack.clone()) + .await?; + + Ok(stack) } - let stack = self.new_shortstatehash_info(shortstatehash).await?; - - self.cache_shortstatehash_info(shortstatehash, stack.clone()) - .await?; - - Ok(stack) -} - -/// Returns a stack with info on shortstatehash, full state, added diff and -/// removed diff for the selected shortstatehash and each parent layer. -#[implement(Service)] -#[tracing::instrument( + /// Returns a stack with info on shortstatehash, full state, added diff and + /// removed diff for the selected shortstatehash and each parent layer. + #[tracing::instrument( name = "cache", level = "debug", skip_all, @@ -147,365 +146,362 @@ pub async fn load_shortstatehash_info( stack = stack.len(), ), )] -async fn cache_shortstatehash_info( - &self, - shortstatehash: ShortStateHash, - stack: ShortStateInfoVec, -) -> Result { - self.stateinfo_cache.lock()?.insert(shortstatehash, stack); + async fn cache_shortstatehash_info( + &self, + shortstatehash: ShortStateHash, + stack: ShortStateInfoVec, + ) -> Result { + self.stateinfo_cache.lock()?.insert(shortstatehash, stack); - Ok(()) -} + Ok(()) + } -#[implement(Service)] -async fn new_shortstatehash_info( - &self, - shortstatehash: ShortStateHash, -) -> Result { - let StateDiff { parent, added, removed } = self.get_statediff(shortstatehash).await?; + async fn new_shortstatehash_info( + &self, + shortstatehash: ShortStateHash, + ) -> Result { + let StateDiff { parent, added, removed } = self.get_statediff(shortstatehash).await?; - let Some(parent) = parent else { - return Ok(vec![ShortStateInfo { + let Some(parent) = parent else { + return Ok(vec![ShortStateInfo { + shortstatehash, + full_state: added.clone(), + added, + removed, + }]); + }; + + let mut stack = Box::pin(self.load_shortstatehash_info(parent)).await?; + let top = stack.last().expect("at least one frame"); + + let mut full_state = (*top.full_state).clone(); + full_state.extend(added.iter().copied()); + + let removed = (*removed).clone(); + for r in &removed { + full_state.remove(r); + } + + stack.push(ShortStateInfo { shortstatehash, - full_state: added.clone(), added, - removed, - }]); - }; + removed: Arc::new(removed), + full_state: Arc::new(full_state), + }); - let mut stack = Box::pin(self.load_shortstatehash_info(parent)).await?; - let top = stack.last().expect("at least one frame"); - - let mut full_state = (*top.full_state).clone(); - full_state.extend(added.iter().copied()); - - let removed = (*removed).clone(); - for r in &removed { - full_state.remove(r); + Ok(stack) } - stack.push(ShortStateInfo { - shortstatehash, - added, - removed: Arc::new(removed), - full_state: Arc::new(full_state), - }); + pub fn compress_state_events<'a, I>( + &'a self, + state: I, + ) -> impl Stream + Send + 'a + where + I: Iterator + Clone + Debug + Send + 'a, + { + let event_ids = state.clone().map(at!(1)); - Ok(stack) -} + let short_event_ids = self + .services + .short + .multi_get_or_create_shorteventid(event_ids); -#[implement(Service)] -pub fn compress_state_events<'a, I>( - &'a self, - state: I, -) -> impl Stream + Send + 'a -where - I: Iterator + Clone + Debug + Send + 'a, -{ - let event_ids = state.clone().map(at!(1)); + state + .stream() + .map(at!(0)) + .zip(short_event_ids) + .map(|(shortstatekey, shorteventid)| { + compress_state_event(*shortstatekey, shorteventid) + }) + } - let short_event_ids = self - .services - .short - .multi_get_or_create_shorteventid(event_ids); + pub async fn compress_state_event( + &self, + shortstatekey: ShortStateKey, + event_id: &EventId, + ) -> CompressedStateEvent { + let shorteventid = self + .services + .short + .get_or_create_shorteventid(event_id) + .await; - state - .stream() - .map(at!(0)) - .zip(short_event_ids) - .map(|(shortstatekey, shorteventid)| compress_state_event(*shortstatekey, shorteventid)) -} + compress_state_event(shortstatekey, shorteventid) + } -#[implement(Service)] -pub async fn compress_state_event( - &self, - shortstatekey: ShortStateKey, - event_id: &EventId, -) -> CompressedStateEvent { - let shorteventid = self - .services - .short - .get_or_create_shorteventid(event_id) - .await; + /// Creates a new shortstatehash that often is just a diff to an already + /// existing shortstatehash and therefore very efficient. + /// + /// There are multiple layers of diffs. The bottom layer 0 always contains + /// the full state. Layer 1 contains diffs to states of layer 0, layer 2 + /// diffs to layer 1 and so on. If layer n > 0 grows too big, it will be + /// combined with layer n-1 to create a new diff on layer n-1 that's + /// based on layer n-2. If that layer is also too big, it will recursively + /// fix above layers too. + /// + /// * `shortstatehash` - Shortstatehash of this state + /// * `statediffnew` - Added to base. Each vec is shortstatekey+shorteventid + /// * `statediffremoved` - Removed from base. Each vec is + /// shortstatekey+shorteventid + /// * `diff_to_sibling` - Approximately how much the diff grows each time + /// for this layer + /// * `parent_states` - A stack with info on shortstatehash, full state, + /// added diff and removed diff for each parent layer + pub fn save_state_from_diff( + &self, + shortstatehash: ShortStateHash, + statediffnew: Arc, + statediffremoved: Arc, + diff_to_sibling: usize, + mut parent_states: ParentStatesVec, + ) -> Result { + let statediffnew_len = statediffnew.len(); + let statediffremoved_len = statediffremoved.len(); + let diffsum = checked!(statediffnew_len + statediffremoved_len)?; - compress_state_event(shortstatekey, shorteventid) -} + if parent_states.len() > 3 { + // Number of layers + // To many layers, we have to go deeper + let parent = parent_states.pop().expect("parent must have a state"); -/// Creates a new shortstatehash that often is just a diff to an already -/// existing shortstatehash and therefore very efficient. -/// -/// There are multiple layers of diffs. The bottom layer 0 always contains -/// the full state. Layer 1 contains diffs to states of layer 0, layer 2 -/// diffs to layer 1 and so on. If layer n > 0 grows too big, it will be -/// combined with layer n-1 to create a new diff on layer n-1 that's -/// based on layer n-2. If that layer is also too big, it will recursively -/// fix above layers too. -/// -/// * `shortstatehash` - Shortstatehash of this state -/// * `statediffnew` - Added to base. Each vec is shortstatekey+shorteventid -/// * `statediffremoved` - Removed from base. Each vec is -/// shortstatekey+shorteventid -/// * `diff_to_sibling` - Approximately how much the diff grows each time for -/// this layer -/// * `parent_states` - A stack with info on shortstatehash, full state, added -/// diff and removed diff for each parent layer -#[implement(Service)] -pub fn save_state_from_diff( - &self, - shortstatehash: ShortStateHash, - statediffnew: Arc, - statediffremoved: Arc, - diff_to_sibling: usize, - mut parent_states: ParentStatesVec, -) -> Result { - let statediffnew_len = statediffnew.len(); - let statediffremoved_len = statediffremoved.len(); - let diffsum = checked!(statediffnew_len + statediffremoved_len)?; + let mut parent_new = (*parent.added).clone(); + let mut parent_removed = (*parent.removed).clone(); + + for removed in statediffremoved.iter() { + if !parent_new.remove(removed) { + // It was not added in the parent and we removed it + parent_removed.insert(*removed); + } + // Else it was added in the parent and we removed it again. We + // can forget this change + } + + for new in statediffnew.iter() { + if !parent_removed.remove(new) { + // It was not touched in the parent and we added it + parent_new.insert(*new); + } + // Else it was removed in the parent and we added it again. We + // can forget this change + } + + self.save_state_from_diff( + shortstatehash, + Arc::new(parent_new), + Arc::new(parent_removed), + diffsum, + parent_states, + )?; + + return Ok(()); + } + + if parent_states.is_empty() { + // There is no parent layer, create a new state + self.save_statediff(shortstatehash, &StateDiff { + parent: None, + added: statediffnew, + removed: statediffremoved, + }); + + return Ok(()); + } + + // Else we have two options. + // 1. We add the current diff on top of the parent layer. + // 2. We replace a layer above - if parent_states.len() > 3 { - // Number of layers - // To many layers, we have to go deeper let parent = parent_states.pop().expect("parent must have a state"); + let parent_added_len = parent.added.len(); + let parent_removed_len = parent.removed.len(); + let parent_diff = checked!(parent_added_len + parent_removed_len)?; - let mut parent_new = (*parent.added).clone(); - let mut parent_removed = (*parent.removed).clone(); + if checked!(diffsum * diffsum)? >= checked!(2 * diff_to_sibling * parent_diff)? { + // Diff too big, we replace above layer(s) + let mut parent_new = (*parent.added).clone(); + let mut parent_removed = (*parent.removed).clone(); - for removed in statediffremoved.iter() { - if !parent_new.remove(removed) { - // It was not added in the parent and we removed it - parent_removed.insert(*removed); + for removed in statediffremoved.iter() { + if !parent_new.remove(removed) { + // It was not added in the parent and we removed it + parent_removed.insert(*removed); + } + // Else it was added in the parent and we removed it again. We + // can forget this change } - // Else it was added in the parent and we removed it again. We - // can forget this change - } - for new in statediffnew.iter() { - if !parent_removed.remove(new) { - // It was not touched in the parent and we added it - parent_new.insert(*new); + for new in statediffnew.iter() { + if !parent_removed.remove(new) { + // It was not touched in the parent and we added it + parent_new.insert(*new); + } + // Else it was removed in the parent and we added it again. We + // can forget this change } - // Else it was removed in the parent and we added it again. We - // can forget this change - } - self.save_state_from_diff( - shortstatehash, - Arc::new(parent_new), - Arc::new(parent_removed), - diffsum, - parent_states, - )?; - - return Ok(()); - } - - if parent_states.is_empty() { - // There is no parent layer, create a new state - self.save_statediff(shortstatehash, &StateDiff { - parent: None, - added: statediffnew, - removed: statediffremoved, - }); - - return Ok(()); - } - - // Else we have two options. - // 1. We add the current diff on top of the parent layer. - // 2. We replace a layer above - - let parent = parent_states.pop().expect("parent must have a state"); - let parent_added_len = parent.added.len(); - let parent_removed_len = parent.removed.len(); - let parent_diff = checked!(parent_added_len + parent_removed_len)?; - - if checked!(diffsum * diffsum)? >= checked!(2 * diff_to_sibling * parent_diff)? { - // Diff too big, we replace above layer(s) - let mut parent_new = (*parent.added).clone(); - let mut parent_removed = (*parent.removed).clone(); - - for removed in statediffremoved.iter() { - if !parent_new.remove(removed) { - // It was not added in the parent and we removed it - parent_removed.insert(*removed); - } - // Else it was added in the parent and we removed it again. We - // can forget this change - } - - for new in statediffnew.iter() { - if !parent_removed.remove(new) { - // It was not touched in the parent and we added it - parent_new.insert(*new); - } - // Else it was removed in the parent and we added it again. We - // can forget this change - } - - self.save_state_from_diff( - shortstatehash, - Arc::new(parent_new), - Arc::new(parent_removed), - diffsum, - parent_states, - )?; - } else { - // Diff small enough, we add diff as layer on top of parent - self.save_statediff(shortstatehash, &StateDiff { - parent: Some(parent.shortstatehash), - added: statediffnew, - removed: statediffremoved, - }); - } - - Ok(()) -} - -/// Returns the new shortstatehash, and the state diff from the previous -/// room state -#[implement(Service)] -#[tracing::instrument(skip(self, new_state_ids_compressed), level = "debug")] -pub async fn save_state( - &self, - room_id: &RoomId, - new_state_ids_compressed: Arc, -) -> Result { - let previous_shortstatehash = self - .services - .state - .get_room_shortstatehash(room_id) - .await - .ok(); - - let state_hash = - utils::calculate_hash(new_state_ids_compressed.iter().map(|bytes| &bytes[..])); - - let (new_shortstatehash, already_existed) = self - .services - .short - .get_or_create_shortstatehash(&state_hash) - .await; - - if Some(new_shortstatehash) == previous_shortstatehash { - return Ok(HashSetCompressStateEvent { - shortstatehash: new_shortstatehash, - ..Default::default() - }); - } - - let states_parents = if let Some(p) = previous_shortstatehash { - self.load_shortstatehash_info(p).await.unwrap_or_default() - } else { - ShortStateInfoVec::new() - }; - - let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() { - let statediffnew: CompressedState = new_state_ids_compressed - .difference(&parent_stateinfo.full_state) - .copied() - .collect(); - - let statediffremoved: CompressedState = parent_stateinfo - .full_state - .difference(&new_state_ids_compressed) - .copied() - .collect(); - - (Arc::new(statediffnew), Arc::new(statediffremoved)) - } else { - (new_state_ids_compressed, Arc::new(CompressedState::new())) - }; - - if !already_existed { - self.save_state_from_diff( - new_shortstatehash, - statediffnew.clone(), - statediffremoved.clone(), - 2, // every state change is 2 event changes on average - states_parents, - )?; - } - - Ok(HashSetCompressStateEvent { - shortstatehash: new_shortstatehash, - added: statediffnew, - removed: statediffremoved, - }) -} - -#[implement(Service)] -#[tracing::instrument(skip(self), level = "debug", name = "get")] -async fn get_statediff(&self, shortstatehash: ShortStateHash) -> Result { - const BUFSIZE: usize = size_of::(); - const STRIDE: usize = size_of::(); - - let value = self - .db - .shortstatehash_statediff - .aqry::(&shortstatehash) - .await - .map_err(|e| { - err!(Database("Failed to find StateDiff from short {shortstatehash:?}: {e}")) - })?; - - let parent = utils::u64_from_bytes(&value[0..size_of::()]) - .ok() - .take_if(|parent| *parent != 0); - - debug_assert!(value.len() % STRIDE == 0, "value not aligned to stride"); - let _num_values = value.len() / STRIDE; - - let mut add_mode = true; - let mut added = CompressedState::new(); - let mut removed = CompressedState::new(); - - let mut i = STRIDE; - while let Some(v) = value.get(i..expected!(i + 2 * STRIDE)) { - if add_mode && v.starts_with(&0_u64.to_be_bytes()) { - add_mode = false; - i = expected!(i + STRIDE); - continue; - } - if add_mode { - added.insert(v.try_into()?); + self.save_state_from_diff( + shortstatehash, + Arc::new(parent_new), + Arc::new(parent_removed), + diffsum, + parent_states, + )?; } else { - removed.insert(v.try_into()?); + // Diff small enough, we add diff as layer on top of parent + self.save_statediff(shortstatehash, &StateDiff { + parent: Some(parent.shortstatehash), + added: statediffnew, + removed: statediffremoved, + }); } - i = expected!(i + 2 * STRIDE); + + Ok(()) } - Ok(StateDiff { - parent, - added: Arc::new(added), - removed: Arc::new(removed), - }) -} + /// Returns the new shortstatehash, and the state diff from the previous + /// room state + #[tracing::instrument(skip(self, new_state_ids_compressed), level = "debug")] + pub async fn save_state( + &self, + room_id: &RoomId, + new_state_ids_compressed: Arc, + ) -> Result { + let previous_shortstatehash = self + .services + .state + .get_room_shortstatehash(room_id) + .await + .ok(); -#[implement(Service)] -fn save_statediff(&self, shortstatehash: ShortStateHash, diff: &StateDiff) { - let mut value = Vec::::with_capacity( - 2_usize - .saturating_add(diff.added.len()) - .saturating_add(diff.removed.len()), - ); + let state_hash = + utils::calculate_hash(new_state_ids_compressed.iter().map(|bytes| &bytes[..])); - let parent = diff.parent.unwrap_or(0_u64); - value.extend_from_slice(&parent.to_be_bytes()); + let (new_shortstatehash, already_existed) = self + .services + .short + .get_or_create_shortstatehash(&state_hash) + .await; - for new in diff.added.iter() { - value.extend_from_slice(&new[..]); - } - - if !diff.removed.is_empty() { - value.extend_from_slice(&0_u64.to_be_bytes()); - for removed in diff.removed.iter() { - value.extend_from_slice(&removed[..]); + if Some(new_shortstatehash) == previous_shortstatehash { + return Ok(HashSetCompressStateEvent { + shortstatehash: new_shortstatehash, + ..Default::default() + }); } + + let states_parents = if let Some(p) = previous_shortstatehash { + self.load_shortstatehash_info(p).await.unwrap_or_default() + } else { + ShortStateInfoVec::new() + }; + + let (statediffnew, statediffremoved) = + if let Some(parent_stateinfo) = states_parents.last() { + let statediffnew: CompressedState = new_state_ids_compressed + .difference(&parent_stateinfo.full_state) + .copied() + .collect(); + + let statediffremoved: CompressedState = parent_stateinfo + .full_state + .difference(&new_state_ids_compressed) + .copied() + .collect(); + + (Arc::new(statediffnew), Arc::new(statediffremoved)) + } else { + (new_state_ids_compressed, Arc::new(CompressedState::new())) + }; + + if !already_existed { + self.save_state_from_diff( + new_shortstatehash, + statediffnew.clone(), + statediffremoved.clone(), + 2, // every state change is 2 event changes on average + states_parents, + )?; + } + + Ok(HashSetCompressStateEvent { + shortstatehash: new_shortstatehash, + added: statediffnew, + removed: statediffremoved, + }) } - self.db - .shortstatehash_statediff - .insert(&shortstatehash.to_be_bytes(), &value); + #[tracing::instrument(skip(self), level = "debug", name = "get")] + async fn get_statediff(&self, shortstatehash: ShortStateHash) -> Result { + const BUFSIZE: usize = size_of::(); + const STRIDE: usize = size_of::(); + + let value = self + .db + .shortstatehash_statediff + .aqry::(&shortstatehash) + .await + .map_err(|e| { + err!(Database("Failed to find StateDiff from short {shortstatehash:?}: {e}")) + })?; + + let parent = utils::u64_from_bytes(&value[0..size_of::()]) + .ok() + .take_if(|parent| *parent != 0); + + debug_assert!(value.len() % STRIDE == 0, "value not aligned to stride"); + let _num_values = value.len() / STRIDE; + + let mut add_mode = true; + let mut added = CompressedState::new(); + let mut removed = CompressedState::new(); + + let mut i = STRIDE; + while let Some(v) = value.get(i..expected!(i + 2 * STRIDE)) { + if add_mode && v.starts_with(&0_u64.to_be_bytes()) { + add_mode = false; + i = expected!(i + STRIDE); + continue; + } + if add_mode { + added.insert(v.try_into()?); + } else { + removed.insert(v.try_into()?); + } + i = expected!(i + 2 * STRIDE); + } + + Ok(StateDiff { + parent, + added: Arc::new(added), + removed: Arc::new(removed), + }) + } + + fn save_statediff(&self, shortstatehash: ShortStateHash, diff: &StateDiff) { + let mut value = Vec::::with_capacity( + 2_usize + .saturating_add(diff.added.len()) + .saturating_add(diff.removed.len()), + ); + + let parent = diff.parent.unwrap_or(0_u64); + value.extend_from_slice(&parent.to_be_bytes()); + + for new in diff.added.iter() { + value.extend_from_slice(&new[..]); + } + + if !diff.removed.is_empty() { + value.extend_from_slice(&0_u64.to_be_bytes()); + for removed in diff.removed.iter() { + value.extend_from_slice(&removed[..]); + } + } + + self.db + .shortstatehash_statediff + .insert(&shortstatehash.to_be_bytes(), &value); + } } #[inline] diff --git a/src/service/rooms/threads/mod.rs b/src/service/rooms/threads/mod.rs index 59319ba6..a24183e6 100644 --- a/src/service/rooms/threads/mod.rs +++ b/src/service/rooms/threads/mod.rs @@ -1,7 +1,7 @@ use std::{collections::BTreeMap, sync::Arc}; -use conduwuit_core::{ - Event, Result, err, +use conduwuit::{ + Result, err, matrix::pdu::{PduCount, PduEvent, PduId, RawPduId}, utils::{ ReadyExt, @@ -49,10 +49,7 @@ impl crate::Service for Service { } impl Service { - pub async fn add_to_thread(&self, root_event_id: &EventId, event: &E) -> Result - where - E: Event + Send + Sync, - { + pub async fn add_to_thread(&self, root_event_id: &EventId, pdu: &PduEvent) -> Result<()> { let root_id = self .services .timeline @@ -89,7 +86,7 @@ impl Service { }) { // Thread already existed relations.count = relations.count.saturating_add(uint!(1)); - relations.latest_event = event.to_format(); + relations.latest_event = pdu.to_message_like_event(); let content = serde_json::to_value(relations).expect("to_value always works"); @@ -102,7 +99,7 @@ impl Service { } else { // New thread let relations = BundledThread { - latest_event: event.to_format(), + latest_event: pdu.to_message_like_event(), count: uint!(1), current_user_participated: true, }; @@ -119,7 +116,7 @@ impl Service { self.services .timeline - .replace_pdu(&root_id, &root_pdu_json) + .replace_pdu(&root_id, &root_pdu_json, &root_pdu) .await?; } @@ -129,10 +126,10 @@ impl Service { users.extend_from_slice(&userids); }, | _ => { - users.push(root_pdu.sender().to_owned()); + users.push(root_pdu.sender); }, } - users.push(event.sender().to_owned()); + users.push(pdu.sender.clone()); self.update_participants(&root_id, &users) } @@ -161,11 +158,9 @@ impl Service { .ready_take_while(move |pdu_id| pdu_id.shortroomid() == shortroomid.to_be_bytes()) .wide_filter_map(move |pdu_id| async move { let mut pdu = self.services.timeline.get_pdu_from_id(&pdu_id).await.ok()?; - let pdu_id: PduId = pdu_id.into(); - if pdu.sender() != user_id { - pdu.as_mut_pdu().remove_transaction_id().ok(); - } + + pdu.set_unsigned(Some(user_id)); Some((pdu_id.shorteventid, pdu)) }); diff --git a/src/service/rooms/timeline/append.rs b/src/service/rooms/timeline/append.rs deleted file mode 100644 index 1d404e8a..00000000 --- a/src/service/rooms/timeline/append.rs +++ /dev/null @@ -1,448 +0,0 @@ -use std::{ - collections::{BTreeMap, HashSet}, - sync::Arc, -}; - -use conduwuit_core::{ - Result, err, error, implement, - matrix::{ - event::Event, - pdu::{PduCount, PduEvent, PduId, RawPduId}, - }, - utils::{self, ReadyExt}, -}; -use futures::StreamExt; -use ruma::{ - CanonicalJsonObject, CanonicalJsonValue, EventId, RoomVersionId, UserId, - events::{ - GlobalAccountDataEventType, StateEventType, TimelineEventType, - push_rules::PushRulesEvent, - room::{ - encrypted::Relation, - member::{MembershipState, RoomMemberEventContent}, - power_levels::RoomPowerLevelsEventContent, - redaction::RoomRedactionEventContent, - }, - }, - push::{Action, Ruleset, Tweak}, -}; - -use super::{ExtractBody, ExtractRelatesTo, ExtractRelatesToEventId, RoomMutexGuard}; -use crate::{appservice::NamespaceRegex, rooms::state_compressor::CompressedState}; - -/// Append the incoming event setting the state snapshot to the state from -/// the server that sent the event. -#[implement(super::Service)] -#[tracing::instrument(level = "debug", skip_all)] -pub async fn append_incoming_pdu<'a, Leaves>( - &'a self, - pdu: &'a PduEvent, - pdu_json: CanonicalJsonObject, - new_room_leaves: Leaves, - state_ids_compressed: Arc, - soft_fail: bool, - state_lock: &'a RoomMutexGuard, -) -> Result> -where - Leaves: Iterator + Send + 'a, -{ - // We append to state before appending the pdu, so we don't have a moment in - // time with the pdu without it's state. This is okay because append_pdu can't - // fail. - self.services - .state - .set_event_state(&pdu.event_id, &pdu.room_id, state_ids_compressed) - .await?; - - if soft_fail { - self.services - .pdu_metadata - .mark_as_referenced(&pdu.room_id, pdu.prev_events.iter().map(AsRef::as_ref)); - - self.services - .state - .set_forward_extremities(&pdu.room_id, new_room_leaves, state_lock) - .await; - - return Ok(None); - } - - let pdu_id = self - .append_pdu(pdu, pdu_json, new_room_leaves, state_lock) - .await?; - - Ok(Some(pdu_id)) -} - -/// Creates a new persisted data unit and adds it to a room. -/// -/// By this point the incoming event should be fully authenticated, no auth -/// happens in `append_pdu`. -/// -/// Returns pdu id -#[implement(super::Service)] -#[tracing::instrument(level = "debug", skip_all)] -pub async fn append_pdu<'a, Leaves>( - &'a self, - pdu: &'a PduEvent, - mut pdu_json: CanonicalJsonObject, - leaves: Leaves, - state_lock: &'a RoomMutexGuard, -) -> Result -where - Leaves: Iterator + Send + 'a, -{ - // Coalesce database writes for the remainder of this scope. - let _cork = self.db.db.cork_and_flush(); - - let shortroomid = self - .services - .short - .get_shortroomid(pdu.room_id()) - .await - .map_err(|_| err!(Database("Room does not exist")))?; - - // Make unsigned fields correct. This is not properly documented in the spec, - // but state events need to have previous content in the unsigned field, so - // clients can easily interpret things like membership changes - if let Some(state_key) = pdu.state_key() { - if let CanonicalJsonValue::Object(unsigned) = pdu_json - .entry("unsigned".to_owned()) - .or_insert_with(|| CanonicalJsonValue::Object(BTreeMap::default())) - { - if let Ok(shortstatehash) = self - .services - .state_accessor - .pdu_shortstatehash(pdu.event_id()) - .await - { - if let Ok(prev_state) = self - .services - .state_accessor - .state_get(shortstatehash, &pdu.kind().to_string().into(), state_key) - .await - { - unsigned.insert( - "prev_content".to_owned(), - CanonicalJsonValue::Object( - utils::to_canonical_object(prev_state.get_content_as_value()) - .map_err(|e| { - err!(Database(error!( - "Failed to convert prev_state to canonical JSON: {e}", - ))) - })?, - ), - ); - unsigned.insert( - String::from("prev_sender"), - CanonicalJsonValue::String(prev_state.sender().to_string()), - ); - unsigned.insert( - String::from("replaces_state"), - CanonicalJsonValue::String(prev_state.event_id().to_string()), - ); - } - } - } else { - error!("Invalid unsigned type in pdu."); - } - } - - // We must keep track of all events that have been referenced. - self.services - .pdu_metadata - .mark_as_referenced(pdu.room_id(), pdu.prev_events().map(AsRef::as_ref)); - - self.services - .state - .set_forward_extremities(pdu.room_id(), leaves, state_lock) - .await; - - let insert_lock = self.mutex_insert.lock(pdu.room_id()).await; - - let count1 = self.services.globals.next_count().unwrap(); - - // Mark as read first so the sending client doesn't get a notification even if - // appending fails - self.services - .read_receipt - .private_read_set(pdu.room_id(), pdu.sender(), count1); - - self.services - .user - .reset_notification_counts(pdu.sender(), pdu.room_id()); - - let count2 = PduCount::Normal(self.services.globals.next_count().unwrap()); - let pdu_id: RawPduId = PduId { shortroomid, shorteventid: count2 }.into(); - - // Insert pdu - self.db.append_pdu(&pdu_id, pdu, &pdu_json, count2).await; - - drop(insert_lock); - - // See if the event matches any known pushers via power level - let power_levels: RoomPowerLevelsEventContent = self - .services - .state_accessor - .room_state_get_content(pdu.room_id(), &StateEventType::RoomPowerLevels, "") - .await - .unwrap_or_default(); - - let mut push_target: HashSet<_> = self - .services - .state_cache - .active_local_users_in_room(pdu.room_id()) - .map(ToOwned::to_owned) - // Don't notify the sender of their own events, and dont send from ignored users - .ready_filter(|user| *user != pdu.sender()) - .filter_map(|recipient_user| async move { (!self.services.users.user_is_ignored(pdu.sender(), &recipient_user).await).then_some(recipient_user) }) - .collect() - .await; - - let mut notifies = Vec::with_capacity(push_target.len().saturating_add(1)); - let mut highlights = Vec::with_capacity(push_target.len().saturating_add(1)); - - if *pdu.kind() == TimelineEventType::RoomMember { - if let Some(state_key) = pdu.state_key() { - let target_user_id = UserId::parse(state_key)?; - - if self.services.users.is_active_local(target_user_id).await { - push_target.insert(target_user_id.to_owned()); - } - } - } - - let serialized = pdu.to_format(); - for user in &push_target { - let rules_for_user = self - .services - .account_data - .get_global(user, GlobalAccountDataEventType::PushRules) - .await - .map_or_else( - |_| Ruleset::server_default(user), - |ev: PushRulesEvent| ev.content.global, - ); - - let mut highlight = false; - let mut notify = false; - - for action in self - .services - .pusher - .get_actions(user, &rules_for_user, &power_levels, &serialized, pdu.room_id()) - .await - { - match action { - | Action::Notify => notify = true, - | Action::SetTweak(Tweak::Highlight(true)) => { - highlight = true; - }, - | _ => {}, - } - - // Break early if both conditions are true - if notify && highlight { - break; - } - } - - if notify { - notifies.push(user.clone()); - } - - if highlight { - highlights.push(user.clone()); - } - - self.services - .pusher - .get_pushkeys(user) - .ready_for_each(|push_key| { - self.services - .sending - .send_pdu_push(&pdu_id, user, push_key.to_owned()) - .expect("TODO: replace with future"); - }) - .await; - } - - self.db - .increment_notification_counts(pdu.room_id(), notifies, highlights); - - match *pdu.kind() { - | TimelineEventType::RoomRedaction => { - use RoomVersionId::*; - - let room_version_id = self.services.state.get_room_version(pdu.room_id()).await?; - match room_version_id { - | V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => { - if let Some(redact_id) = pdu.redacts() { - if self - .services - .state_accessor - .user_can_redact(redact_id, pdu.sender(), pdu.room_id(), false) - .await? - { - self.redact_pdu(redact_id, pdu, shortroomid).await?; - } - } - }, - | _ => { - let content: RoomRedactionEventContent = pdu.get_content()?; - if let Some(redact_id) = &content.redacts { - if self - .services - .state_accessor - .user_can_redact(redact_id, pdu.sender(), pdu.room_id(), false) - .await? - { - self.redact_pdu(redact_id, pdu, shortroomid).await?; - } - } - }, - } - }, - | TimelineEventType::SpaceChild => - if let Some(_state_key) = pdu.state_key() { - self.services - .spaces - .roomid_spacehierarchy_cache - .lock() - .await - .remove(pdu.room_id()); - }, - | TimelineEventType::RoomMember => { - if let Some(state_key) = pdu.state_key() { - // if the state_key fails - let target_user_id = - UserId::parse(state_key).expect("This state_key was previously validated"); - - let content: RoomMemberEventContent = pdu.get_content()?; - let stripped_state = match content.membership { - | MembershipState::Invite | MembershipState::Knock => - self.services.state.summary_stripped(pdu).await.into(), - | _ => None, - }; - - // Update our membership info, we do this here incase a user is invited or - // knocked and immediately leaves we need the DB to record the invite or - // knock event for auth - self.services - .state_cache - .update_membership( - pdu.room_id(), - target_user_id, - content, - pdu.sender(), - stripped_state, - None, - true, - ) - .await?; - } - }, - | TimelineEventType::RoomMessage => { - let content: ExtractBody = pdu.get_content()?; - if let Some(body) = content.body { - self.services.search.index_pdu(shortroomid, &pdu_id, &body); - - if self.services.admin.is_admin_command(pdu, &body).await { - self.services.admin.command_with_sender( - body, - Some((pdu.event_id()).into()), - pdu.sender.clone().into(), - )?; - } - } - }, - | _ => {}, - } - - if let Ok(content) = pdu.get_content::() { - if let Ok(related_pducount) = self.get_pdu_count(&content.relates_to.event_id).await { - self.services - .pdu_metadata - .add_relation(count2, related_pducount); - } - } - - if let Ok(content) = pdu.get_content::() { - match content.relates_to { - | Relation::Reply { in_reply_to } => { - // We need to do it again here, because replies don't have - // event_id as a top level field - if let Ok(related_pducount) = self.get_pdu_count(&in_reply_to.event_id).await { - self.services - .pdu_metadata - .add_relation(count2, related_pducount); - } - }, - | Relation::Thread(thread) => { - self.services - .threads - .add_to_thread(&thread.event_id, pdu) - .await?; - }, - | _ => {}, // TODO: Aggregate other types - } - } - - for appservice in self.services.appservice.read().await.values() { - if self - .services - .state_cache - .appservice_in_room(pdu.room_id(), appservice) - .await - { - self.services - .sending - .send_pdu_appservice(appservice.registration.id.clone(), pdu_id)?; - continue; - } - - // If the RoomMember event has a non-empty state_key, it is targeted at someone. - // If it is our appservice user, we send this PDU to it. - if *pdu.kind() == TimelineEventType::RoomMember { - if let Some(state_key_uid) = &pdu - .state_key - .as_ref() - .and_then(|state_key| UserId::parse(state_key.as_str()).ok()) - { - let appservice_uid = appservice.registration.sender_localpart.as_str(); - if state_key_uid == &appservice_uid { - self.services - .sending - .send_pdu_appservice(appservice.registration.id.clone(), pdu_id)?; - continue; - } - } - } - - let matching_users = |users: &NamespaceRegex| { - appservice.users.is_match(pdu.sender().as_str()) - || *pdu.kind() == TimelineEventType::RoomMember - && pdu - .state_key - .as_ref() - .is_some_and(|state_key| users.is_match(state_key)) - }; - let matching_aliases = |aliases: NamespaceRegex| { - self.services - .alias - .local_aliases_for_room(pdu.room_id()) - .ready_any(move |room_alias| aliases.is_match(room_alias.as_str())) - }; - - if matching_aliases(appservice.aliases.clone()).await - || appservice.rooms.is_match(pdu.room_id().as_str()) - || matching_users(&appservice.users) - { - self.services - .sending - .send_pdu_appservice(appservice.registration.id.clone(), pdu_id)?; - } - } - - Ok(pdu_id) -} diff --git a/src/service/rooms/timeline/backfill.rs b/src/service/rooms/timeline/backfill.rs deleted file mode 100644 index e976981e..00000000 --- a/src/service/rooms/timeline/backfill.rs +++ /dev/null @@ -1,191 +0,0 @@ -use std::iter::once; - -use conduwuit_core::{ - Result, debug, debug_warn, implement, info, - matrix::{ - event::Event, - pdu::{PduCount, PduId, RawPduId}, - }, - utils::{IterStream, ReadyExt}, - validated, warn, -}; -use futures::{FutureExt, StreamExt}; -use ruma::{ - RoomId, ServerName, - api::federation, - events::{ - StateEventType, TimelineEventType, room::power_levels::RoomPowerLevelsEventContent, - }, - uint, -}; -use serde_json::value::RawValue as RawJsonValue; - -use super::ExtractBody; - -#[implement(super::Service)] -#[tracing::instrument(name = "backfill", level = "debug", skip(self))] -pub async fn backfill_if_required(&self, room_id: &RoomId, from: PduCount) -> Result<()> { - if self - .services - .state_cache - .room_joined_count(room_id) - .await - .is_ok_and(|count| count <= 1) - && !self - .services - .state_accessor - .is_world_readable(room_id) - .await - { - // Room is empty (1 user or none), there is no one that can backfill - return Ok(()); - } - - let first_pdu = self - .first_item_in_room(room_id) - .await - .expect("Room is not empty"); - - if first_pdu.0 < from { - // No backfill required, there are still events between them - return Ok(()); - } - - let power_levels: RoomPowerLevelsEventContent = self - .services - .state_accessor - .room_state_get_content(room_id, &StateEventType::RoomPowerLevels, "") - .await - .unwrap_or_default(); - - let room_mods = power_levels.users.iter().filter_map(|(user_id, level)| { - if level > &power_levels.users_default && !self.services.globals.user_is_local(user_id) { - Some(user_id.server_name()) - } else { - None - } - }); - - let canonical_room_alias_server = once( - self.services - .state_accessor - .get_canonical_alias(room_id) - .await, - ) - .filter_map(Result::ok) - .map(|alias| alias.server_name().to_owned()) - .stream(); - - let mut servers = room_mods - .stream() - .map(ToOwned::to_owned) - .chain(canonical_room_alias_server) - .chain( - self.services - .server - .config - .trusted_servers - .iter() - .map(ToOwned::to_owned) - .stream(), - ) - .ready_filter(|server_name| !self.services.globals.server_is_ours(server_name)) - .filter_map(|server_name| async move { - self.services - .state_cache - .server_in_room(&server_name, room_id) - .await - .then_some(server_name) - }) - .boxed(); - - while let Some(ref backfill_server) = servers.next().await { - info!("Asking {backfill_server} for backfill"); - let response = self - .services - .sending - .send_federation_request( - backfill_server, - federation::backfill::get_backfill::v1::Request { - room_id: room_id.to_owned(), - v: vec![first_pdu.1.event_id().to_owned()], - limit: uint!(100), - }, - ) - .await; - match response { - | Ok(response) => { - for pdu in response.pdus { - if let Err(e) = self.backfill_pdu(backfill_server, pdu).boxed().await { - debug_warn!("Failed to add backfilled pdu in room {room_id}: {e}"); - } - } - return Ok(()); - }, - | Err(e) => { - warn!("{backfill_server} failed to provide backfill for room {room_id}: {e}"); - }, - } - } - - info!("No servers could backfill, but backfill was needed in room {room_id}"); - Ok(()) -} - -#[implement(super::Service)] -#[tracing::instrument(skip(self, pdu), level = "debug")] -pub async fn backfill_pdu(&self, origin: &ServerName, pdu: Box) -> Result<()> { - let (room_id, event_id, value) = self.services.event_handler.parse_incoming_pdu(&pdu).await?; - - // Lock so we cannot backfill the same pdu twice at the same time - let mutex_lock = self - .services - .event_handler - .mutex_federation - .lock(&room_id) - .await; - - // Skip the PDU if we already have it as a timeline event - if let Ok(pdu_id) = self.get_pdu_id(&event_id).await { - debug!("We already know {event_id} at {pdu_id:?}"); - return Ok(()); - } - - self.services - .event_handler - .handle_incoming_pdu(origin, &room_id, &event_id, value, false) - .boxed() - .await?; - - let value = self.get_pdu_json(&event_id).await?; - - let pdu = self.get_pdu(&event_id).await?; - - let shortroomid = self.services.short.get_shortroomid(&room_id).await?; - - let insert_lock = self.mutex_insert.lock(&room_id).await; - - let count: i64 = self.services.globals.next_count().unwrap().try_into()?; - - let pdu_id: RawPduId = PduId { - shortroomid, - shorteventid: PduCount::Backfilled(validated!(0 - count)), - } - .into(); - - // Insert pdu - self.db.prepend_backfill_pdu(&pdu_id, &event_id, &value); - - drop(insert_lock); - - if pdu.kind == TimelineEventType::RoomMessage { - let content: ExtractBody = pdu.get_content()?; - if let Some(body) = content.body { - self.services.search.index_pdu(shortroomid, &pdu_id, &body); - } - } - drop(mutex_lock); - - debug!("Prepended backfill pdu"); - Ok(()) -} diff --git a/src/service/rooms/timeline/build.rs b/src/service/rooms/timeline/build.rs deleted file mode 100644 index a522c531..00000000 --- a/src/service/rooms/timeline/build.rs +++ /dev/null @@ -1,226 +0,0 @@ -use std::{collections::HashSet, iter::once}; - -use conduwuit_core::{ - Err, Result, implement, - matrix::{event::Event, pdu::PduBuilder}, - utils::{IterStream, ReadyExt}, -}; -use futures::{FutureExt, StreamExt}; -use ruma::{ - OwnedEventId, OwnedServerName, RoomId, RoomVersionId, UserId, - events::{ - TimelineEventType, - room::{ - member::{MembershipState, RoomMemberEventContent}, - redaction::RoomRedactionEventContent, - }, - }, -}; - -use super::RoomMutexGuard; - -/// Creates a new persisted data unit and adds it to a room. This function -/// takes a roomid_mutex_state, meaning that only this function is able to -/// mutate the room state. -#[implement(super::Service)] -#[tracing::instrument(skip(self, state_lock), level = "debug")] -pub async fn build_and_append_pdu( - &self, - pdu_builder: PduBuilder, - sender: &UserId, - room_id: &RoomId, - state_lock: &RoomMutexGuard, -) -> Result { - let (pdu, pdu_json) = self - .create_hash_and_sign_event(pdu_builder, sender, room_id, state_lock) - .await?; - - if self.services.admin.is_admin_room(pdu.room_id()).await { - self.check_pdu_for_admin_room(&pdu, sender).boxed().await?; - } - - // If redaction event is not authorized, do not append it to the timeline - if *pdu.kind() == TimelineEventType::RoomRedaction { - use RoomVersionId::*; - match self.services.state.get_room_version(pdu.room_id()).await? { - | V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => { - if let Some(redact_id) = pdu.redacts() { - if !self - .services - .state_accessor - .user_can_redact(redact_id, pdu.sender(), pdu.room_id(), false) - .await? - { - return Err!(Request(Forbidden("User cannot redact this event."))); - } - } - }, - | _ => { - let content: RoomRedactionEventContent = pdu.get_content()?; - if let Some(redact_id) = &content.redacts { - if !self - .services - .state_accessor - .user_can_redact(redact_id, pdu.sender(), pdu.room_id(), false) - .await? - { - return Err!(Request(Forbidden("User cannot redact this event."))); - } - } - }, - } - } - - if *pdu.kind() == TimelineEventType::RoomMember { - let content: RoomMemberEventContent = pdu.get_content()?; - - if content.join_authorized_via_users_server.is_some() - && content.membership != MembershipState::Join - { - return Err!(Request(BadJson( - "join_authorised_via_users_server is only for member joins" - ))); - } - - if content - .join_authorized_via_users_server - .as_ref() - .is_some_and(|authorising_user| { - !self.services.globals.user_is_local(authorising_user) - }) { - return Err!(Request(InvalidParam( - "Authorising user does not belong to this homeserver" - ))); - } - } - - // We append to state before appending the pdu, so we don't have a moment in - // time with the pdu without it's state. This is okay because append_pdu can't - // fail. - let statehashid = self.services.state.append_to_state(&pdu).await?; - - let pdu_id = self - .append_pdu( - &pdu, - pdu_json, - // Since this PDU references all pdu_leaves we can update the leaves - // of the room - once(pdu.event_id()), - state_lock, - ) - .boxed() - .await?; - - // We set the room state after inserting the pdu, so that we never have a moment - // in time where events in the current room state do not exist - self.services - .state - .set_room_state(pdu.room_id(), statehashid, state_lock); - - let mut servers: HashSet = self - .services - .state_cache - .room_servers(pdu.room_id()) - .map(ToOwned::to_owned) - .collect() - .await; - - // In case we are kicking or banning a user, we need to inform their server of - // the change - if *pdu.kind() == TimelineEventType::RoomMember { - if let Some(state_key_uid) = &pdu - .state_key - .as_ref() - .and_then(|state_key| UserId::parse(state_key.as_str()).ok()) - { - servers.insert(state_key_uid.server_name().to_owned()); - } - } - - // Remove our server from the server list since it will be added to it by - // room_servers() and/or the if statement above - servers.remove(self.services.globals.server_name()); - - self.services - .sending - .send_pdu_servers(servers.iter().map(AsRef::as_ref).stream(), &pdu_id) - .await?; - - Ok(pdu.event_id().to_owned()) -} - -#[implement(super::Service)] -#[tracing::instrument(skip_all, level = "debug")] -async fn check_pdu_for_admin_room(&self, pdu: &Pdu, sender: &UserId) -> Result -where - Pdu: Event + Send + Sync, -{ - match pdu.kind() { - | TimelineEventType::RoomEncryption => { - return Err!(Request(Forbidden(error!("Encryption not supported in admins room.")))); - }, - | TimelineEventType::RoomMember => { - let target = pdu - .state_key() - .filter(|v| v.starts_with('@')) - .unwrap_or(sender.as_str()); - - let server_user = &self.services.globals.server_user.to_string(); - - let content: RoomMemberEventContent = pdu.get_content()?; - match content.membership { - | MembershipState::Leave => { - if target == server_user { - return Err!(Request(Forbidden(error!( - "Server user cannot leave the admins room." - )))); - } - - let count = self - .services - .state_cache - .room_members(pdu.room_id()) - .ready_filter(|user| self.services.globals.user_is_local(user)) - .ready_filter(|user| *user != target) - .boxed() - .count() - .await; - - if count < 2 { - return Err!(Request(Forbidden(error!( - "Last admin cannot leave the admins room." - )))); - } - }, - - | MembershipState::Ban if pdu.state_key().is_some() => { - if target == server_user { - return Err!(Request(Forbidden(error!( - "Server cannot be banned from admins room." - )))); - } - - let count = self - .services - .state_cache - .room_members(pdu.room_id()) - .ready_filter(|user| self.services.globals.user_is_local(user)) - .ready_filter(|user| *user != target) - .boxed() - .count() - .await; - - if count < 2 { - return Err!(Request(Forbidden(error!( - "Last admin cannot be banned from admins room." - )))); - } - }, - | _ => {}, - } - }, - | _ => {}, - } - - Ok(()) -} diff --git a/src/service/rooms/timeline/create.rs b/src/service/rooms/timeline/create.rs deleted file mode 100644 index 20ccaf56..00000000 --- a/src/service/rooms/timeline/create.rs +++ /dev/null @@ -1,214 +0,0 @@ -use std::cmp; - -use conduwuit_core::{ - Err, Error, Result, err, implement, - matrix::{ - event::{Event, gen_event_id}, - pdu::{EventHash, PduBuilder, PduEvent}, - state_res::{self, RoomVersion}, - }, - utils::{self, IterStream, ReadyExt, stream::TryIgnore}, -}; -use futures::{StreamExt, TryStreamExt, future, future::ready}; -use ruma::{ - CanonicalJsonObject, CanonicalJsonValue, OwnedEventId, RoomId, RoomVersionId, UserId, - canonical_json::to_canonical_value, - events::{StateEventType, TimelineEventType, room::create::RoomCreateEventContent}, - uint, -}; -use serde_json::value::to_raw_value; -use tracing::warn; - -use super::RoomMutexGuard; - -#[implement(super::Service)] -pub async fn create_hash_and_sign_event( - &self, - pdu_builder: PduBuilder, - sender: &UserId, - room_id: &RoomId, - _mutex_lock: &RoomMutexGuard, /* Take mutex guard to make sure users get the room - * state mutex */ -) -> Result<(PduEvent, CanonicalJsonObject)> { - let PduBuilder { - event_type, - content, - unsigned, - state_key, - redacts, - timestamp, - } = pdu_builder; - - let prev_events: Vec = self - .services - .state - .get_forward_extremities(room_id) - .take(20) - .map(Into::into) - .collect() - .await; - - // If there was no create event yet, assume we are creating a room - let room_version_id = self - .services - .state - .get_room_version(room_id) - .await - .or_else(|_| { - if event_type == TimelineEventType::RoomCreate { - let content: RoomCreateEventContent = serde_json::from_str(content.get())?; - Ok(content.room_version) - } else { - Err(Error::InconsistentRoomState( - "non-create event for room of unknown version", - room_id.to_owned(), - )) - } - })?; - - let room_version = RoomVersion::new(&room_version_id).expect("room version is supported"); - - let auth_events = self - .services - .state - .get_auth_events(room_id, &event_type, sender, state_key.as_deref(), &content) - .await?; - - // Our depth is the maximum depth of prev_events + 1 - let depth = prev_events - .iter() - .stream() - .map(Ok) - .and_then(|event_id| self.get_pdu(event_id)) - .and_then(|pdu| future::ok(pdu.depth)) - .ignore_err() - .ready_fold(uint!(0), cmp::max) - .await - .saturating_add(uint!(1)); - - let mut unsigned = unsigned.unwrap_or_default(); - - if let Some(state_key) = &state_key { - if let Ok(prev_pdu) = self - .services - .state_accessor - .room_state_get(room_id, &event_type.to_string().into(), state_key) - .await - { - unsigned.insert("prev_content".to_owned(), prev_pdu.get_content_as_value()); - unsigned.insert("prev_sender".to_owned(), serde_json::to_value(prev_pdu.sender())?); - unsigned - .insert("replaces_state".to_owned(), serde_json::to_value(prev_pdu.event_id())?); - } - } - - if event_type != TimelineEventType::RoomCreate && prev_events.is_empty() { - return Err!(Request(Unknown("Event incorrectly had zero prev_events."))); - } - if state_key.is_none() && depth.lt(&uint!(2)) { - // The first two events in a room are always m.room.create and m.room.member, - // so any other events with that same depth are illegal. - warn!( - "Had unsafe depth {depth} when creating non-state event in {room_id}. Cowardly \ - aborting" - ); - return Err!(Request(Unknown("Unsafe depth for non-state event."))); - } - - let mut pdu = PduEvent { - event_id: ruma::event_id!("$thiswillbefilledinlater").into(), - room_id: room_id.to_owned(), - sender: sender.to_owned(), - origin: None, - origin_server_ts: timestamp.map_or_else( - || { - utils::millis_since_unix_epoch() - .try_into() - .expect("u64 fits into UInt") - }, - |ts| ts.get(), - ), - kind: event_type, - content, - state_key, - prev_events, - depth, - auth_events: auth_events - .values() - .map(|pdu| pdu.event_id.clone()) - .collect(), - redacts, - unsigned: if unsigned.is_empty() { - None - } else { - Some(to_raw_value(&unsigned)?) - }, - hashes: EventHash { sha256: "aaa".to_owned() }, - signatures: None, - }; - - let auth_fetch = |k: &StateEventType, s: &str| { - let key = (k.clone(), s.into()); - ready(auth_events.get(&key).map(ToOwned::to_owned)) - }; - - let auth_check = state_res::auth_check( - &room_version, - &pdu, - None, // TODO: third_party_invite - auth_fetch, - ) - .await - .map_err(|e| err!(Request(Forbidden(warn!("Auth check failed: {e:?}")))))?; - - if !auth_check { - return Err!(Request(Forbidden("Event is not authorized."))); - } - - // Hash and sign - let mut pdu_json = utils::to_canonical_object(&pdu).map_err(|e| { - err!(Request(BadJson(warn!("Failed to convert PDU to canonical JSON: {e}")))) - })?; - - // room v3 and above removed the "event_id" field from remote PDU format - match room_version_id { - | RoomVersionId::V1 | RoomVersionId::V2 => {}, - | _ => { - pdu_json.remove("event_id"); - }, - } - - // Add origin because synapse likes that (and it's required in the spec) - pdu_json.insert( - "origin".to_owned(), - to_canonical_value(self.services.globals.server_name()) - .expect("server name is a valid CanonicalJsonValue"), - ); - - if let Err(e) = self - .services - .server_keys - .hash_and_sign_event(&mut pdu_json, &room_version_id) - { - return match e { - | Error::Signatures(ruma::signatures::Error::PduSize) => { - Err!(Request(TooLarge("Message/PDU is too long (exceeds 65535 bytes)"))) - }, - | _ => Err!(Request(Unknown(warn!("Signing event failed: {e}")))), - }; - } - - // Generate event id - pdu.event_id = gen_event_id(&pdu_json, &room_version_id)?; - - pdu_json.insert("event_id".into(), CanonicalJsonValue::String(pdu.event_id.clone().into())); - - // Generate short event id - let _shorteventid = self - .services - .short - .get_or_create_shorteventid(&pdu.event_id) - .await; - - Ok((pdu, pdu_json)) -} diff --git a/src/service/rooms/timeline/data.rs b/src/service/rooms/timeline/data.rs index fa10a5c0..5f7b8c81 100644 --- a/src/service/rooms/timeline/data.rs +++ b/src/service/rooms/timeline/data.rs @@ -1,14 +1,11 @@ -use std::{borrow::Borrow, sync::Arc}; +use std::sync::Arc; use conduwuit::{ - Err, PduCount, PduEvent, Result, at, err, - result::{LogErr, NotFound}, - utils, - utils::stream::TryReadyExt, + Err, PduCount, PduEvent, Result, at, err, result::NotFound, utils, utils::stream::TryReadyExt, }; use database::{Database, Deserialized, Json, KeyVal, Map}; use futures::{FutureExt, Stream, TryFutureExt, TryStreamExt, future::select_ok, pin_mut}; -use ruma::{CanonicalJsonObject, EventId, OwnedUserId, RoomId, UserId, api::Direction}; +use ruma::{CanonicalJsonObject, EventId, OwnedUserId, RoomId, api::Direction}; use super::{PduId, RawPduId}; use crate::{Dep, rooms, rooms::short::ShortRoomId}; @@ -46,12 +43,8 @@ impl Data { } #[inline] - pub(super) async fn last_timeline_count( - &self, - sender_user: Option<&UserId>, - room_id: &RoomId, - ) -> Result { - let pdus_rev = self.pdus_rev(sender_user, room_id, PduCount::max()); + pub(super) async fn last_timeline_count(&self, room_id: &RoomId) -> Result { + let pdus_rev = self.pdus_rev(room_id, PduCount::max()); pin_mut!(pdus_rev); let last_count = pdus_rev @@ -65,12 +58,8 @@ impl Data { } #[inline] - pub(super) async fn latest_pdu_in_room( - &self, - sender_user: Option<&UserId>, - room_id: &RoomId, - ) -> Result { - let pdus_rev = self.pdus_rev(sender_user, room_id, PduCount::max()); + pub(super) async fn latest_pdu_in_room(&self, room_id: &RoomId) -> Result { + let pdus_rev = self.pdus_rev(room_id, PduCount::max()); pin_mut!(pdus_rev); pdus_rev @@ -207,6 +196,7 @@ impl Data { &self, pdu_id: &RawPduId, pdu_json: &CanonicalJsonObject, + _pdu: &PduEvent, ) -> Result { if self.pduid_pdu.get(pdu_id).await.is_not_found() { return Err!(Request(NotFound("PDU does not exist."))); @@ -222,7 +212,6 @@ impl Data { /// order. pub(super) fn pdus_rev<'a>( &'a self, - user_id: Option<&'a UserId>, room_id: &'a RoomId, until: PduCount, ) -> impl Stream> + Send + 'a { @@ -232,14 +221,13 @@ impl Data { self.pduid_pdu .rev_raw_stream_from(¤t) .ready_try_take_while(move |(key, _)| Ok(key.starts_with(&prefix))) - .ready_and_then(move |item| Self::each_pdu(item, user_id)) + .ready_and_then(Self::from_json_slice) }) .try_flatten_stream() } pub(super) fn pdus<'a>( &'a self, - user_id: Option<&'a UserId>, room_id: &'a RoomId, from: PduCount, ) -> impl Stream> + Send + 'a { @@ -249,21 +237,15 @@ impl Data { self.pduid_pdu .raw_stream_from(¤t) .ready_try_take_while(move |(key, _)| Ok(key.starts_with(&prefix))) - .ready_and_then(move |item| Self::each_pdu(item, user_id)) + .ready_and_then(Self::from_json_slice) }) .try_flatten_stream() } - fn each_pdu((pdu_id, pdu): KeyVal<'_>, user_id: Option<&UserId>) -> Result { + fn from_json_slice((pdu_id, pdu): KeyVal<'_>) -> Result { let pdu_id: RawPduId = pdu_id.into(); - let mut pdu = serde_json::from_slice::(pdu)?; - - if Some(pdu.sender.borrow()) != user_id { - pdu.remove_transaction_id().log_err().ok(); - } - - pdu.add_age().log_err().ok(); + let pdu = serde_json::from_slice::(pdu)?; Ok((pdu_id.pdu_count(), pdu)) } diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index 2a4418d8..a3709533 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -165,7 +165,7 @@ impl Service { #[tracing::instrument(skip(self), level = "debug")] pub async fn first_item_in_room(&self, room_id: &RoomId) -> Result<(PduCount, PduEvent)> { - let pdus = self.pdus(None, room_id, None); + let pdus = self.pdus(room_id, None); pin_mut!(pdus); pdus.try_next() @@ -175,16 +175,12 @@ impl Service { #[tracing::instrument(skip(self), level = "debug")] pub async fn latest_pdu_in_room(&self, room_id: &RoomId) -> Result { - self.db.latest_pdu_in_room(None, room_id).await + self.db.latest_pdu_in_room(room_id).await } #[tracing::instrument(skip(self), level = "debug")] - pub async fn last_timeline_count( - &self, - sender_user: Option<&UserId>, - room_id: &RoomId, - ) -> Result { - self.db.last_timeline_count(sender_user, room_id).await + pub async fn last_timeline_count(&self, room_id: &RoomId) -> Result { + self.db.last_timeline_count(room_id).await } /// Returns the `count` of this pdu's id. @@ -547,6 +543,10 @@ impl Service { | _ => {}, } + // CONCERN: If we receive events with a relation out-of-order, we never write + // their relation / thread. We need some kind of way to trigger when we receive + // this event, and potentially a way to rebuild the table entirely. + if let Ok(content) = pdu.get_content::() { if let Ok(related_pducount) = self.get_pdu_count(&content.relates_to.event_id).await { self.services @@ -1025,34 +1025,30 @@ impl Service { #[inline] pub fn all_pdus<'a>( &'a self, - user_id: &'a UserId, room_id: &'a RoomId, ) -> impl Stream + Send + 'a { - self.pdus(Some(user_id), room_id, None).ignore_err() + self.pdus(room_id, None).ignore_err() } /// Reverse iteration starting at from. #[tracing::instrument(skip(self), level = "debug")] pub fn pdus_rev<'a>( &'a self, - user_id: Option<&'a UserId>, room_id: &'a RoomId, until: Option, ) -> impl Stream> + Send + 'a { self.db - .pdus_rev(user_id, room_id, until.unwrap_or_else(PduCount::max)) + .pdus_rev(room_id, until.unwrap_or_else(PduCount::max)) } /// Forward iteration starting at from. #[tracing::instrument(skip(self), level = "debug")] pub fn pdus<'a>( &'a self, - user_id: Option<&'a UserId>, room_id: &'a RoomId, from: Option, ) -> impl Stream> + Send + 'a { - self.db - .pdus(user_id, room_id, from.unwrap_or_else(PduCount::min)) + self.db.pdus(room_id, from.unwrap_or_else(PduCount::min)) } /// Replace a PDU with the redacted form. diff --git a/src/service/rooms/timeline/redact.rs b/src/service/rooms/timeline/redact.rs deleted file mode 100644 index d51a8462..00000000 --- a/src/service/rooms/timeline/redact.rs +++ /dev/null @@ -1,51 +0,0 @@ -use conduwuit_core::{ - Result, err, implement, - matrix::event::Event, - utils::{self}, -}; -use ruma::EventId; - -use super::ExtractBody; -use crate::rooms::short::ShortRoomId; - -/// Replace a PDU with the redacted form. -#[implement(super::Service)] -#[tracing::instrument(name = "redact", level = "debug", skip(self))] -pub async fn redact_pdu( - &self, - event_id: &EventId, - reason: &Pdu, - shortroomid: ShortRoomId, -) -> Result { - // TODO: Don't reserialize, keep original json - let Ok(pdu_id) = self.get_pdu_id(event_id).await else { - // If event does not exist, just noop - return Ok(()); - }; - - let mut pdu = self - .get_pdu_from_id(&pdu_id) - .await - .map(Event::into_pdu) - .map_err(|e| { - err!(Database(error!(?pdu_id, ?event_id, ?e, "PDU ID points to invalid PDU."))) - })?; - - if let Ok(content) = pdu.get_content::() { - if let Some(body) = content.body { - self.services - .search - .deindex_pdu(shortroomid, &pdu_id, &body); - } - } - - let room_version_id = self.services.state.get_room_version(pdu.room_id()).await?; - - pdu.redact(&room_version_id, reason.to_value())?; - - let obj = utils::to_canonical_object(&pdu).map_err(|e| { - err!(Database(error!(?event_id, ?e, "Failed to convert PDU to canonical JSON"))) - })?; - - self.replace_pdu(&pdu_id, &obj).await -} diff --git a/src/service/sending/sender.rs b/src/service/sending/sender.rs index a708f746..cd84f7e7 100644 --- a/src/service/sending/sender.rs +++ b/src/service/sending/sender.rs @@ -9,8 +9,8 @@ use std::{ }; use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; -use conduwuit_core::{ - Error, Event, Result, debug, err, error, +use conduwuit::{ + Error, Result, debug, err, error, result::LogErr, trace, utils::{ @@ -697,7 +697,7 @@ impl Service { match event { | SendingEvent::Pdu(pdu_id) => { if let Ok(pdu) = self.services.timeline.get_pdu_from_id(pdu_id).await { - pdu_jsons.push(pdu.to_format()); + pdu_jsons.push(pdu.into_room_event()); } }, | SendingEvent::Edu(edu) => @@ -781,7 +781,7 @@ impl Service { for pdu in pdus { // Redacted events are not notification targets (we don't send push for them) - if pdu.contains_unsigned_property("redacted_because", serde_json::Value::is_string) { + if pdu.is_redacted() { continue; } @@ -798,7 +798,7 @@ impl Service { let unread: UInt = self .services .user - .notification_count(&user_id, pdu.room_id()) + .notification_count(&user_id, &pdu.room_id) .await .try_into() .expect("notification count can't go that high"); diff --git a/src/service/server_keys/verify.rs b/src/service/server_keys/verify.rs index 9cc3655a..84433628 100644 --- a/src/service/server_keys/verify.rs +++ b/src/service/server_keys/verify.rs @@ -1,4 +1,4 @@ -use conduwuit::{Err, Result, implement, matrix::event::gen_event_id_canonical_json}; +use conduwuit::{Err, Result, implement, pdu::gen_event_id_canonical_json}; use ruma::{ CanonicalJsonObject, CanonicalJsonValue, OwnedEventId, RoomVersionId, signatures::Verified, };