diff --git a/src/api/client/context.rs b/src/api/client/context.rs index 8c9ba64c..ee3a458c 100644 --- a/src/api/client/context.rs +++ b/src/api/client/context.rs @@ -93,7 +93,14 @@ pub(crate) async fn get_context_route( .ignore_err() .then(async |mut pdu| { pdu.1.set_unsigned(Some(sender_user)); - // TODO: bundled aggregations + if let Err(e) = services + .rooms + .pdu_metadata + .add_bundled_aggregations_to_pdu(sender_user, &mut pdu.1) + .await + { + debug_warn!("Failed to add bundled aggregations: {e}"); + } pdu }) .ready_filter_map(|item| event_filter(item, filter)) @@ -109,7 +116,14 @@ pub(crate) async fn get_context_route( .ignore_err() .then(async |mut pdu| { pdu.1.set_unsigned(Some(sender_user)); - // TODO: bundled aggregations + if let Err(e) = services + .rooms + .pdu_metadata + .add_bundled_aggregations_to_pdu(sender_user, &mut pdu.1) + .await + { + debug_warn!("Failed to add bundled aggregations: {e}"); + } pdu }) .ready_filter_map(|item| event_filter(item, filter)) diff --git a/src/api/client/message.rs b/src/api/client/message.rs index cd6b0a60..6087478c 100644 --- a/src/api/client/message.rs +++ b/src/api/client/message.rs @@ -2,7 +2,7 @@ use core::panic; use axum::extract::State; use conduwuit::{ - Err, Result, at, + Err, Result, at, debug_warn, matrix::{ Event, pdu::{PduCount, PduEvent}, @@ -134,7 +134,14 @@ pub(crate) async fn get_message_events_route( .take(limit) .then(async |mut pdu| { pdu.1.set_unsigned(Some(sender_user)); - // TODO: bundled aggregations + if let Err(e) = services + .rooms + .pdu_metadata + .add_bundled_aggregations_to_pdu(sender_user, &mut pdu.1) + .await + { + debug_warn!("Failed to add bundled aggregations: {e}"); + } pdu }) .collect() diff --git a/src/api/client/relations.rs b/src/api/client/relations.rs index 5e323756..377f0c71 100644 --- a/src/api/client/relations.rs +++ b/src/api/client/relations.rs @@ -1,6 +1,6 @@ use axum::extract::State; use conduwuit::{ - Result, at, + Result, at, debug_warn, matrix::pdu::PduCount, utils::{IterStream, ReadyExt, result::FlatOk, stream::WidebandExt}, }; @@ -149,6 +149,17 @@ async fn paginate_relations_with_filter( .ready_take_while(|(count, _)| Some(*count) != to) .wide_filter_map(|item| visibility_filter(services, sender_user, item)) .take(limit) + .then(async |mut pdu| { + if let Err(e) = services + .rooms + .pdu_metadata + .add_bundled_aggregations_to_pdu(sender_user, &mut pdu.1) + .await + { + debug_warn!("Failed to add bundled aggregations to relation: {e}"); + } + pdu + }) .collect() .await; diff --git a/src/api/client/room/event.rs b/src/api/client/room/event.rs index ec673d75..e7b0bb37 100644 --- a/src/api/client/room/event.rs +++ b/src/api/client/room/event.rs @@ -1,5 +1,5 @@ use axum::extract::State; -use conduwuit::{Err, Event, Result, err}; +use conduwuit::{Err, Event, Result, debug_warn, err}; use futures::{FutureExt, TryFutureExt, future::try_join}; use ruma::api::client::room::get_room_event; @@ -38,6 +38,15 @@ pub(crate) async fn get_room_event_route( "Fetched PDU must match requested" ); + if let Err(e) = services + .rooms + .pdu_metadata + .add_bundled_aggregations_to_pdu(body.sender_user(), &mut event) + .await + { + debug_warn!("Failed to add bundled aggregations to event: {e}"); + } + event.set_unsigned(body.sender_user.as_deref()); Ok(get_room_event::v3::Response { event: event.into_room_event() }) diff --git a/src/api/client/room/initial_sync.rs b/src/api/client/room/initial_sync.rs index 765d2a39..2f965245 100644 --- a/src/api/client/room/initial_sync.rs +++ b/src/api/client/room/initial_sync.rs @@ -1,6 +1,6 @@ use axum::extract::State; use conduwuit::{ - Err, PduEvent, Result, at, + Err, PduEvent, Result, at, debug_warn, utils::{BoolExt, stream::TryTools}, }; use futures::TryStreamExt; @@ -35,7 +35,16 @@ pub(crate) async fn room_initial_sync_route( .try_take(limit) .and_then(async |mut pdu| { pdu.1.set_unsigned(body.sender_user.as_deref()); - // TODO: bundled aggregations + if let Some(sender_user) = body.sender_user.as_deref() { + if let Err(e) = services + .rooms + .pdu_metadata + .add_bundled_aggregations_to_pdu(sender_user, &mut pdu.1) + .await + { + debug_warn!("Failed to add bundled aggregations: {e}"); + } + } Ok(pdu) }) .try_collect() diff --git a/src/api/client/search.rs b/src/api/client/search.rs index d4dcde57..af5fccec 100644 --- a/src/api/client/search.rs +++ b/src/api/client/search.rs @@ -2,7 +2,7 @@ use std::collections::BTreeMap; use axum::extract::State; use conduwuit::{ - Err, Result, at, is_true, + Err, Result, at, debug_warn, is_true, matrix::pdu::PduEvent, result::FlatOk, utils::{IterStream, stream::ReadyExt}, @@ -144,6 +144,17 @@ async fn category_room_events( .map(at!(2)) .flatten() .stream() + .then(|mut pdu| async { + if let Err(e) = services + .rooms + .pdu_metadata + .add_bundled_aggregations_to_pdu(sender_user, &mut pdu) + .await + { + debug_warn!("Failed to add bundled aggregations to search result: {e}"); + } + pdu + }) .map(PduEvent::into_room_event) .map(|result| SearchResult { rank: None, diff --git a/src/api/client/sync/mod.rs b/src/api/client/sync/mod.rs index db581bd6..1ea62883 100644 --- a/src/api/client/sync/mod.rs +++ b/src/api/client/sync/mod.rs @@ -3,7 +3,7 @@ mod v4; mod v5; use conduwuit::{ - Error, PduCount, Result, + Error, PduCount, Result, debug_warn, matrix::pdu::PduEvent, utils::stream::{BroadbandExt, ReadyExt, TryIgnore}, }; @@ -42,13 +42,23 @@ async fn load_timeline( .timeline .pdus_rev(room_id, None) .ignore_err() + .ready_skip_while(|&(pducount, _)| pducount > next_batch.unwrap_or_else(PduCount::max)) + .ready_take_while(|&(pducount, _)| pducount > roomsincecount) .map(move |mut pdu| { pdu.1.set_unsigned(Some(sender_user)); - // TODO: bundled aggregations pdu }) - .ready_skip_while(|&(pducount, _)| pducount > next_batch.unwrap_or_else(PduCount::max)) - .ready_take_while(|&(pducount, _)| pducount > roomsincecount); + .then(async move |mut pdu| { + if let Err(e) = services + .rooms + .pdu_metadata + .add_bundled_aggregations_to_pdu(sender_user, &mut pdu.1) + .await + { + debug_warn!("Failed to add bundled aggregations: {e}"); + } + pdu + }); // Take the last events for the timeline pin_mut!(non_timeline_pdus); diff --git a/src/api/client/threads.rs b/src/api/client/threads.rs index 404da4f8..09fb75d6 100644 --- a/src/api/client/threads.rs +++ b/src/api/client/threads.rs @@ -1,6 +1,6 @@ use axum::extract::State; use conduwuit::{ - Result, at, + Result, at, debug_warn, matrix::pdu::{PduCount, PduEvent}, }; use futures::StreamExt; @@ -28,7 +28,6 @@ pub(crate) async fn get_threads_route( .transpose()? .unwrap_or_else(PduCount::max); - // TODO: bundled aggregation // TODO: user_can_see_event and set_unsigned should be at the same level / // function, so unsigned is only set for seen events. let threads: Vec<(PduCount, PduEvent)> = services @@ -45,6 +44,17 @@ pub(crate) async fn get_threads_route( .await .then_some((count, pdu)) }) + .then(|(count, mut pdu)| async move { + if let Err(e) = services + .rooms + .pdu_metadata + .add_bundled_aggregations_to_pdu(body.sender_user(), &mut pdu) + .await + { + debug_warn!("Failed to add bundled aggregations to thread: {e}"); + } + (count, pdu) + }) .collect() .await; diff --git a/src/service/rooms/pdu_metadata/bundled_aggregations.rs b/src/service/rooms/pdu_metadata/bundled_aggregations.rs new file mode 100644 index 00000000..4ef8efc1 --- /dev/null +++ b/src/service/rooms/pdu_metadata/bundled_aggregations.rs @@ -0,0 +1,394 @@ +use conduwuit::{Event, PduEvent, Result, err}; +use ruma::{ + EventId, RoomId, UserId, + api::Direction, + events::relation::{BundledMessageLikeRelations, BundledReference, ReferenceChunk}, +}; + +use super::PdusIterItem; + +const MAX_BUNDLED_RELATIONS: usize = 50; + +impl super::Service { + /// Gets bundled aggregations for an event according to the Matrix + /// specification. + /// - m.replace relations are bundled to include the most recent replacement + /// event. + /// - m.reference relations are bundled to include a chunk of event IDs. + #[tracing::instrument(skip(self), level = "debug")] + pub async fn get_bundled_aggregations( + &self, + user_id: &UserId, + room_id: &RoomId, + event_id: &EventId, + ) -> Result>>> { + let relations = self + .get_relations( + user_id, + room_id, + event_id, + conduwuit::PduCount::max(), + MAX_BUNDLED_RELATIONS, + 0, + Direction::Backward, + ) + .await; + // The relations database code still handles the basic unsigned data + // We don't want to recursively fetch relations + + // TODO: Event visibility check + // TODO: ignored users? + + if relations.is_empty() { + return Ok(None); + } + + let mut replace_events = Vec::with_capacity(relations.len().min(10)); // Most events have few replacements + let mut reference_events = Vec::with_capacity(relations.len()); + + for relation in &relations { + let pdu = &relation.1; + + let content = pdu.get_content_as_value(); + if let Some(relates_to) = content.get("m.relates_to") { + // We don't check that the event relates back, because we assume the database is + // good. + if let Some(rel_type) = relates_to.get("rel_type") { + match rel_type.as_str() { + | Some("m.replace") => { + replace_events.push(relation); + }, + | Some("m.reference") => { + reference_events.push(relation); + }, + | _ => { + // Ignore other relation types for now + // Threads are in the database but not handled here + // Other types are not specified AFAICT. + }, + } + } + } + } + + // If no relations to bundle, return None + if replace_events.is_empty() && reference_events.is_empty() { + return Ok(None); + } + + let mut bundled = BundledMessageLikeRelations::new(); + + // Handle m.replace relations - find the most recent one + if !replace_events.is_empty() { + let most_recent_replacement = Self::find_most_recent_replacement(&replace_events)?; + + // Convert the replacement event to the bundled format + if let Some(replacement_pdu) = most_recent_replacement { + // According to the Matrix spec, we should include the full event as raw JSON + let replacement_json = serde_json::to_string(replacement_pdu) + .map_err(|e| err!(Database("Failed to serialize replacement event: {e}")))?; + let raw_value = serde_json::value::RawValue::from_string(replacement_json) + .map_err(|e| err!(Database("Failed to create RawValue: {e}")))?; + bundled.replace = Some(Box::new(raw_value)); + } + } + + // Handle m.reference relations - collect event IDs + if !reference_events.is_empty() { + let reference_chunk = Self::build_reference_chunk(&reference_events)?; + if !reference_chunk.is_empty() { + bundled.reference = Some(Box::new(ReferenceChunk::new(reference_chunk))); + } + } + + // TODO: Handle other relation types (m.annotation, etc.) when specified + + Ok(Some(bundled)) + } + + /// Build reference chunk for m.reference bundled aggregations + fn build_reference_chunk( + reference_events: &[&PdusIterItem], + ) -> Result> { + let mut chunk = Vec::with_capacity(reference_events.len()); + + for relation in reference_events { + let pdu = &relation.1; + + let reference_entry = BundledReference::new(pdu.event_id().to_owned()); + chunk.push(reference_entry); + } + + // Don't sort, order is unspecified + + Ok(chunk) + } + + /// Find the most recent replacement event based on origin_server_ts and + /// lexicographic event_id ordering + fn find_most_recent_replacement<'a>( + replacement_events: &'a [&'a PdusIterItem], + ) -> Result> { + if replacement_events.is_empty() { + return Ok(None); + } + + let mut most_recent: Option<&PduEvent> = None; + + // Jank, is there a better way to do this? + for relation in replacement_events { + let pdu = &relation.1; + + match most_recent { + | None => { + most_recent = Some(pdu); + }, + | Some(current_most_recent) => { + // Compare by origin_server_ts first + match pdu + .origin_server_ts() + .cmp(¤t_most_recent.origin_server_ts()) + { + | std::cmp::Ordering::Greater => { + most_recent = Some(pdu); + }, + | std::cmp::Ordering::Equal => { + // If timestamps are equal, use lexicographic ordering of event_id + if pdu.event_id() > current_most_recent.event_id() { + most_recent = Some(pdu); + } + }, + | std::cmp::Ordering::Less => { + // Keep current most recent + }, + } + }, + } + } + + Ok(most_recent) + } + + /// Adds bundled aggregations to a PDU's unsigned field + #[tracing::instrument(skip(self, pdu), level = "debug")] + pub async fn add_bundled_aggregations_to_pdu( + &self, + user_id: &UserId, + pdu: &mut PduEvent, + ) -> Result<()> { + if pdu.is_redacted() { + return Ok(()); + } + + let bundled_aggregations = self + .get_bundled_aggregations(user_id, pdu.room_id(), pdu.event_id()) + .await?; + + if let Some(aggregations) = bundled_aggregations { + let aggregations_json = serde_json::to_value(aggregations) + .map_err(|e| err!(Database("Failed to serialize bundled aggregations: {e}")))?; + + Self::add_bundled_aggregations_to_unsigned(pdu, aggregations_json)?; + } + + Ok(()) + } + + /// Helper method to add bundled aggregations to a PDU's unsigned + /// field + fn add_bundled_aggregations_to_unsigned( + pdu: &mut PduEvent, + aggregations_json: serde_json::Value, + ) -> Result<()> { + use serde_json::{ + Map, Value as JsonValue, + value::{RawValue as RawJsonValue, to_raw_value}, + }; + + let mut unsigned: Map = pdu + .unsigned + .as_deref() + .map(RawJsonValue::get) + .map_or_else(|| Ok(Map::new()), serde_json::from_str) + .map_err(|e| err!(Database("Invalid unsigned in pdu event: {e}")))?; + + let relations = unsigned + .entry("m.relations") + .or_insert_with(|| JsonValue::Object(Map::new())) + .as_object_mut() + .ok_or_else(|| err!(Database("m.relations is not an object")))?; + + if let JsonValue::Object(aggregations_map) = aggregations_json { + for (rel_type, aggregation) in aggregations_map { + relations.insert(rel_type, aggregation); + } + } + + pdu.unsigned = Some(to_raw_value(&unsigned)?); + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use conduwuit_core::pdu::{EventHash, PduEvent}; + use ruma::{UInt, events::TimelineEventType, owned_event_id, owned_room_id, owned_user_id}; + use serde_json::{Value as JsonValue, json, value::to_raw_value}; + + fn create_test_pdu(unsigned_content: Option) -> PduEvent { + PduEvent { + event_id: owned_event_id!("$test:example.com"), + room_id: owned_room_id!("!test:example.com"), + sender: owned_user_id!("@test:example.com"), + origin_server_ts: UInt::try_from(1_234_567_890_u64).unwrap(), + kind: TimelineEventType::RoomMessage, + content: to_raw_value(&json!({"msgtype": "m.text", "body": "test"})).unwrap(), + state_key: None, + prev_events: vec![], + depth: UInt::from(1_u32), + auth_events: vec![], + redacts: None, + unsigned: unsigned_content.map(|content| to_raw_value(&content).unwrap()), + hashes: EventHash { sha256: "test_hash".to_owned() }, + signatures: None, + origin: None, + } + } + + fn create_bundled_aggregations() -> JsonValue { + json!({ + "m.replace": { + "event_id": "$replace:example.com", + "origin_server_ts": 1_234_567_890, + "sender": "@replacer:example.com" + }, + "m.reference": { + "count": 5, + "chunk": [ + "$ref1:example.com", + "$ref2:example.com" + ] + } + }) + } + + #[test] + fn test_add_bundled_aggregations_to_unsigned_no_existing_unsigned() { + let mut pdu = create_test_pdu(None); + let aggregations = create_bundled_aggregations(); + + let result = super::super::Service::add_bundled_aggregations_to_unsigned( + &mut pdu, + aggregations.clone(), + ); + assert!(result.is_ok(), "Should succeed when no unsigned field exists"); + + assert!(pdu.unsigned.is_some(), "Unsigned field should be created"); + + let unsigned_str = pdu.unsigned.as_ref().unwrap().get(); + let unsigned: JsonValue = serde_json::from_str(unsigned_str).unwrap(); + + assert!(unsigned.get("m.relations").is_some(), "m.relations should exist"); + assert_eq!( + unsigned["m.relations"], aggregations, + "Relations should match the aggregations" + ); + } + + #[test] + fn test_add_bundled_aggregations_to_unsigned_overwrite_same_relation_type() { + let existing_unsigned = json!({ + "m.relations": { + "m.replace": { + "event_id": "$old_replace:example.com", + "origin_server_ts": 1_111_111_111, + "sender": "@old_replacer:example.com" + } + } + }); + + let mut pdu = create_test_pdu(Some(existing_unsigned)); + let new_aggregations = create_bundled_aggregations(); + + let result = super::super::Service::add_bundled_aggregations_to_unsigned( + &mut pdu, + new_aggregations.clone(), + ); + assert!(result.is_ok(), "Should succeed when overwriting same relation type"); + + let unsigned_str = pdu.unsigned.as_ref().unwrap().get(); + let unsigned: JsonValue = serde_json::from_str(unsigned_str).unwrap(); + + let relations = &unsigned["m.relations"]; + + assert_eq!( + relations["m.replace"], new_aggregations["m.replace"], + "m.replace should be updated" + ); + assert_eq!( + relations["m.replace"]["event_id"], "$replace:example.com", + "Should have new event_id" + ); + + assert!(relations.get("m.reference").is_some(), "New m.reference should be added"); + } + + #[test] + fn test_add_bundled_aggregations_to_unsigned_preserve_other_unsigned_fields() { + // Test case: Other unsigned fields should be preserved + let existing_unsigned = json!({ + "age": 98765, + "prev_content": {"msgtype": "m.text", "body": "old message"}, + "redacted_because": {"event_id": "$redaction:example.com"}, + "m.relations": { + "m.annotation": {"count": 1} + } + }); + + let mut pdu = create_test_pdu(Some(existing_unsigned)); + let new_aggregations = json!({ + "m.replace": {"event_id": "$new:example.com"} + }); + + let result = super::super::Service::add_bundled_aggregations_to_unsigned( + &mut pdu, + new_aggregations, + ); + assert!(result.is_ok(), "Should succeed while preserving other fields"); + + let unsigned_str = pdu.unsigned.as_ref().unwrap().get(); + let unsigned: JsonValue = serde_json::from_str(unsigned_str).unwrap(); + + // Verify all existing fields are preserved + assert_eq!(unsigned["age"], 98765, "age should be preserved"); + assert!(unsigned.get("prev_content").is_some(), "prev_content should be preserved"); + assert!( + unsigned.get("redacted_because").is_some(), + "redacted_because should be preserved" + ); + + // Verify relations were merged correctly + let relations = &unsigned["m.relations"]; + assert!( + relations.get("m.annotation").is_some(), + "Existing m.annotation should be preserved" + ); + assert!(relations.get("m.replace").is_some(), "New m.replace should be added"); + } + + #[test] + fn test_add_bundled_aggregations_to_unsigned_invalid_existing_unsigned() { + // Test case: Invalid JSON in existing unsigned should result in error + let mut pdu = create_test_pdu(None); + // Manually set invalid unsigned data + pdu.unsigned = Some(to_raw_value(&"invalid json").unwrap()); + + let aggregations = create_bundled_aggregations(); + let result = + super::super::Service::add_bundled_aggregations_to_unsigned(&mut pdu, aggregations); + + assert!(result.is_err(), "fails when existing unsigned is invalid"); + // Should we ignore the error and overwrite anyway? + } +} diff --git a/src/service/rooms/pdu_metadata/mod.rs b/src/service/rooms/pdu_metadata/mod.rs index 18221c2d..2dff54d8 100644 --- a/src/service/rooms/pdu_metadata/mod.rs +++ b/src/service/rooms/pdu_metadata/mod.rs @@ -1,3 +1,4 @@ +mod bundled_aggregations; mod data; use std::sync::Arc;