chore: Merge branch 'main' into nex/serving-policy

# Conflicts:
#	src/service/rooms/event_handler/upgrade_outlier_pdu.rs
#	src/service/rooms/timeline/mod.rs
This commit is contained in:
nexy7574 2025-07-19 20:26:08 +01:00
commit 19a8eaca39
No known key found for this signature in database
GPG key ID: 0FA334385D0B689F
214 changed files with 12377 additions and 8549 deletions

View file

@ -3,7 +3,7 @@ mod remote;
use std::sync::Arc;
use conduwuit::{
Err, Result, Server, err,
Err, Event, Result, Server, err,
utils::{ReadyExt, stream::TryIgnore},
};
use database::{Deserialized, Ignore, Interfix, Map};
@ -241,7 +241,7 @@ impl Service {
.room_state_get(&room_id, &StateEventType::RoomCreate, "")
.await
{
return Ok(event.sender == user_id);
return Ok(event.sender() == user_id);
}
Err!(Database("Room has no m.room.create event"))

View file

@ -4,11 +4,13 @@ use std::{
};
use conduwuit::{
PduEvent, debug, debug_error, debug_warn, implement, pdu, trace,
utils::continue_exponential_backoff_secs, warn,
Event, PduEvent, debug, debug_error, debug_warn, implement,
matrix::event::gen_event_id_canonical_json, trace, utils::continue_exponential_backoff_secs,
warn,
};
use ruma::{
CanonicalJsonValue, OwnedEventId, RoomId, ServerName, api::federation::event::get_event,
CanonicalJsonValue, EventId, OwnedEventId, RoomId, ServerName,
api::federation::event::get_event,
};
use super::get_room_version_id;
@ -23,13 +25,17 @@ use super::get_room_version_id;
/// c. Ask origin server over federation
/// d. TODO: Ask other servers over federation?
#[implement(super::Service)]
pub(super) async fn fetch_and_handle_outliers<'a>(
pub(super) async fn fetch_and_handle_outliers<'a, Pdu, Events>(
&self,
origin: &'a ServerName,
events: &'a [OwnedEventId],
create_event: &'a PduEvent,
events: Events,
create_event: &'a Pdu,
room_id: &'a RoomId,
) -> Vec<(PduEvent, Option<BTreeMap<String, CanonicalJsonValue>>)> {
) -> Vec<(PduEvent, Option<BTreeMap<String, CanonicalJsonValue>>)>
where
Pdu: Event + Send + Sync,
Events: Iterator<Item = &'a EventId> + Clone + Send,
{
let back_off = |id| match self
.services
.globals
@ -46,22 +52,23 @@ pub(super) async fn fetch_and_handle_outliers<'a>(
},
};
let mut events_with_auth_events = Vec::with_capacity(events.len());
let mut events_with_auth_events = Vec::with_capacity(events.clone().count());
for id in events {
// a. Look in the main timeline (pduid_pdu tree)
// b. Look at outlier pdu tree
// (get_pdu_json checks both)
if let Ok(local_pdu) = self.services.timeline.get_pdu(id).await {
trace!("Found {id} in db");
events_with_auth_events.push((id, Some(local_pdu), vec![]));
events_with_auth_events.push((id.to_owned(), Some(local_pdu), vec![]));
continue;
}
// c. Ask origin server over federation
// We also handle its auth chain here so we don't get a stack overflow in
// handle_outlier_pdu.
let mut todo_auth_events: VecDeque<_> = [id.clone()].into();
let mut todo_auth_events: VecDeque<_> = [id.to_owned()].into();
let mut events_in_reverse_order = Vec::with_capacity(todo_auth_events.len());
let mut events_all = HashSet::with_capacity(todo_auth_events.len());
while let Some(next_id) = todo_auth_events.pop_front() {
if let Some((time, tries)) = self
@ -117,7 +124,7 @@ pub(super) async fn fetch_and_handle_outliers<'a>(
};
let Ok((calculated_event_id, value)) =
pdu::gen_event_id_canonical_json(&res.pdu, &room_version_id)
gen_event_id_canonical_json(&res.pdu, &room_version_id)
else {
back_off((*next_id).to_owned());
continue;
@ -160,7 +167,8 @@ pub(super) async fn fetch_and_handle_outliers<'a>(
},
}
}
events_with_auth_events.push((id, None, events_in_reverse_order));
events_with_auth_events.push((id.to_owned(), None, events_in_reverse_order));
}
let mut pdus = Vec::with_capacity(events_with_auth_events.len());
@ -217,5 +225,6 @@ pub(super) async fn fetch_and_handle_outliers<'a>(
}
}
}
pdus
}

View file

@ -1,13 +1,16 @@
use std::collections::{BTreeMap, HashMap, HashSet, VecDeque};
use std::{
collections::{BTreeMap, HashMap, HashSet, VecDeque},
iter::once,
};
use conduwuit::{
PduEvent, Result, debug_warn, err, implement,
Event, PduEvent, Result, debug_warn, err, implement,
state_res::{self},
};
use futures::{FutureExt, future};
use ruma::{
CanonicalJsonValue, MilliSecondsSinceUnixEpoch, OwnedEventId, RoomId, ServerName, UInt, int,
uint,
CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, RoomId, ServerName,
int, uint,
};
use super::check_room_id;
@ -19,20 +22,26 @@ use super::check_room_id;
fields(%origin),
)]
#[allow(clippy::type_complexity)]
pub(super) async fn fetch_prev(
pub(super) async fn fetch_prev<'a, Pdu, Events>(
&self,
origin: &ServerName,
create_event: &PduEvent,
create_event: &Pdu,
room_id: &RoomId,
first_ts_in_room: UInt,
initial_set: Vec<OwnedEventId>,
first_ts_in_room: MilliSecondsSinceUnixEpoch,
initial_set: Events,
) -> Result<(
Vec<OwnedEventId>,
HashMap<OwnedEventId, (PduEvent, BTreeMap<String, CanonicalJsonValue>)>,
)> {
let mut graph: HashMap<OwnedEventId, _> = HashMap::with_capacity(initial_set.len());
)>
where
Pdu: Event + Send + Sync,
Events: Iterator<Item = &'a EventId> + Clone + Send,
{
let num_ids = initial_set.clone().count();
let mut eventid_info = HashMap::new();
let mut todo_outlier_stack: VecDeque<OwnedEventId> = initial_set.into();
let mut graph: HashMap<OwnedEventId, _> = HashMap::with_capacity(num_ids);
let mut todo_outlier_stack: VecDeque<OwnedEventId> =
initial_set.map(ToOwned::to_owned).collect();
let mut amount = 0;
@ -40,7 +49,12 @@ pub(super) async fn fetch_prev(
self.services.server.check_running()?;
match self
.fetch_and_handle_outliers(origin, &[prev_event_id.clone()], create_event, room_id)
.fetch_and_handle_outliers(
origin,
once(prev_event_id.as_ref()),
create_event,
room_id,
)
.boxed()
.await
.pop()
@ -65,17 +79,17 @@ pub(super) async fn fetch_prev(
}
if let Some(json) = json_opt {
if pdu.origin_server_ts > first_ts_in_room {
if pdu.origin_server_ts() > first_ts_in_room {
amount = amount.saturating_add(1);
for prev_prev in &pdu.prev_events {
for prev_prev in pdu.prev_events() {
if !graph.contains_key(prev_prev) {
todo_outlier_stack.push_back(prev_prev.clone());
todo_outlier_stack.push_back(prev_prev.to_owned());
}
}
graph.insert(
prev_event_id.clone(),
pdu.prev_events.iter().cloned().collect(),
pdu.prev_events().map(ToOwned::to_owned).collect(),
);
} else {
// Time based check failed
@ -98,8 +112,7 @@ pub(super) async fn fetch_prev(
let event_fetch = |event_id| {
let origin_server_ts = eventid_info
.get(&event_id)
.cloned()
.map_or_else(|| uint!(0), |info| info.0.origin_server_ts);
.map_or_else(|| uint!(0), |info| info.0.origin_server_ts().get());
// This return value is the key used for sorting events,
// events are then sorted by power level, time,

View file

@ -1,6 +1,6 @@
use std::collections::{HashMap, hash_map};
use conduwuit::{Err, Error, PduEvent, Result, debug, debug_warn, implement};
use conduwuit::{Err, Event, Result, debug, debug_warn, err, implement};
use futures::FutureExt;
use ruma::{
EventId, OwnedEventId, RoomId, ServerName, api::federation::event::get_room_state_ids,
@ -18,13 +18,16 @@ use crate::rooms::short::ShortStateKey;
skip_all,
fields(%origin),
)]
pub(super) async fn fetch_state(
pub(super) async fn fetch_state<Pdu>(
&self,
origin: &ServerName,
create_event: &PduEvent,
create_event: &Pdu,
room_id: &RoomId,
event_id: &EventId,
) -> Result<Option<HashMap<u64, OwnedEventId>>> {
) -> Result<Option<HashMap<u64, OwnedEventId>>>
where
Pdu: Event + Send + Sync,
{
let res = self
.services
.sending
@ -36,27 +39,27 @@ pub(super) async fn fetch_state(
.inspect_err(|e| debug_warn!("Fetching state for event failed: {e}"))?;
debug!("Fetching state events");
let state_ids = res.pdu_ids.iter().map(AsRef::as_ref);
let state_vec = self
.fetch_and_handle_outliers(origin, &res.pdu_ids, create_event, room_id)
.fetch_and_handle_outliers(origin, state_ids, create_event, room_id)
.boxed()
.await;
let mut state: HashMap<ShortStateKey, OwnedEventId> = HashMap::with_capacity(state_vec.len());
for (pdu, _) in state_vec {
let state_key = pdu
.state_key
.clone()
.ok_or_else(|| Error::bad_database("Found non-state pdu in state events."))?;
.state_key()
.ok_or_else(|| err!(Database("Found non-state pdu in state events.")))?;
let shortstatekey = self
.services
.short
.get_or_create_shortstatekey(&pdu.kind.to_string().into(), &state_key)
.get_or_create_shortstatekey(&pdu.kind().to_string().into(), state_key)
.await;
match state.entry(shortstatekey) {
| hash_map::Entry::Vacant(v) => {
v.insert(pdu.event_id.clone());
v.insert(pdu.event_id().to_owned());
},
| hash_map::Entry::Occupied(_) => {
return Err!(Database(
@ -73,7 +76,7 @@ pub(super) async fn fetch_state(
.get_shortstatekey(&StateEventType::RoomCreate, "")
.await?;
if state.get(&create_shortstatekey) != Some(&create_event.event_id) {
if state.get(&create_shortstatekey).map(AsRef::as_ref) != Some(create_event.event_id()) {
return Err!(Database("Incoming event refers to wrong create event."));
}

View file

@ -4,7 +4,7 @@ use std::{
};
use conduwuit::{
Err, Result, debug, debug::INFO_SPAN_LEVEL, defer, err, implement, utils::stream::IterStream,
Err, Event, Result, debug::INFO_SPAN_LEVEL, defer, err, implement, utils::stream::IterStream,
warn,
};
use futures::{
@ -12,6 +12,7 @@ use futures::{
future::{OptionFuture, try_join5},
};
use ruma::{CanonicalJsonValue, EventId, RoomId, ServerName, UserId, events::StateEventType};
use tracing::debug;
use crate::rooms::timeline::RawPduId;
@ -121,22 +122,16 @@ pub async fn handle_incoming_pdu<'a>(
.timeline
.first_pdu_in_room(room_id)
.await?
.origin_server_ts;
.origin_server_ts();
if incoming_pdu.origin_server_ts < first_ts_in_room {
if incoming_pdu.origin_server_ts() < first_ts_in_room {
return Ok(None);
}
// 9. Fetch any missing prev events doing all checks listed here starting at 1.
// These are timeline events
let (sorted_prev_events, mut eventid_info) = self
.fetch_prev(
origin,
create_event,
room_id,
first_ts_in_room,
incoming_pdu.prev_events.clone(),
)
.fetch_prev(origin, create_event, room_id, first_ts_in_room, incoming_pdu.prev_events())
.await?;
debug!(

View file

@ -1,27 +1,29 @@
use std::collections::{BTreeMap, HashMap, hash_map};
use conduwuit::{
Err, Error, PduEvent, Result, debug, debug_info, err, implement, state_res, trace, warn,
Err, Event, PduEvent, Result, debug, debug_info, err, implement, state_res, trace, warn,
};
use futures::future::ready;
use ruma::{
CanonicalJsonObject, CanonicalJsonValue, EventId, RoomId, ServerName,
api::client::error::ErrorKind, events::StateEventType,
CanonicalJsonObject, CanonicalJsonValue, EventId, RoomId, ServerName, events::StateEventType,
};
use super::{check_room_id, get_room_version_id, to_room_version};
#[implement(super::Service)]
#[allow(clippy::too_many_arguments)]
pub(super) async fn handle_outlier_pdu<'a>(
pub(super) async fn handle_outlier_pdu<'a, Pdu>(
&self,
origin: &'a ServerName,
create_event: &'a PduEvent,
create_event: &'a Pdu,
event_id: &'a EventId,
room_id: &'a RoomId,
mut value: CanonicalJsonObject,
auth_events_known: bool,
) -> Result<(PduEvent, BTreeMap<String, CanonicalJsonValue>)> {
) -> Result<(PduEvent, BTreeMap<String, CanonicalJsonValue>)>
where
Pdu: Event + Send + Sync,
{
// 1. Remove unsigned field
value.remove("unsigned");
@ -30,7 +32,7 @@ pub(super) async fn handle_outlier_pdu<'a>(
// 2. Check signatures, otherwise drop
// 3. check content hash, redact if doesn't match
let room_version_id = get_room_version_id(create_event)?;
let mut val = match self
let mut incoming_pdu = match self
.services
.server_keys
.verify_event(&value, Some(&room_version_id))
@ -62,13 +64,15 @@ pub(super) async fn handle_outlier_pdu<'a>(
// Now that we have checked the signature and hashes we can add the eventID and
// convert to our PduEvent type
val.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.as_str().to_owned()));
let incoming_pdu = serde_json::from_value::<PduEvent>(
serde_json::to_value(&val).expect("CanonicalJsonObj is a valid JsonValue"),
incoming_pdu
.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.as_str().to_owned()));
let pdu_event = serde_json::from_value::<PduEvent>(
serde_json::to_value(&incoming_pdu).expect("CanonicalJsonObj is a valid JsonValue"),
)
.map_err(|e| err!(Request(BadJson(debug_warn!("Event is not a valid PDU: {e}")))))?;
check_room_id(room_id, &incoming_pdu)?;
check_room_id(room_id, &pdu_event)?;
if !auth_events_known {
// 4. fetch any missing auth events doing all checks listed here starting at 1.
@ -79,7 +83,7 @@ pub(super) async fn handle_outlier_pdu<'a>(
debug!("Fetching auth events");
Box::pin(self.fetch_and_handle_outliers(
origin,
&incoming_pdu.auth_events,
pdu_event.auth_events(),
create_event,
room_id,
))
@ -90,8 +94,8 @@ pub(super) async fn handle_outlier_pdu<'a>(
// auth events
debug!("Checking based on auth events");
// Build map of auth events
let mut auth_events = HashMap::with_capacity(incoming_pdu.auth_events.len());
for id in &incoming_pdu.auth_events {
let mut auth_events = HashMap::with_capacity(pdu_event.auth_events().count());
for id in pdu_event.auth_events() {
let Ok(auth_event) = self.services.timeline.get_pdu(id).await else {
warn!("Could not find auth event {id}");
continue;
@ -110,10 +114,9 @@ pub(super) async fn handle_outlier_pdu<'a>(
v.insert(auth_event);
},
| hash_map::Entry::Occupied(_) => {
return Err(Error::BadRequest(
ErrorKind::InvalidParam,
return Err!(Request(InvalidParam(
"Auth event's type and state_key combination exists multiple times.",
));
)));
},
}
}
@ -128,12 +131,12 @@ pub(super) async fn handle_outlier_pdu<'a>(
let state_fetch = |ty: &StateEventType, sk: &str| {
let key = (ty.to_owned(), sk.into());
ready(auth_events.get(&key))
ready(auth_events.get(&key).map(ToOwned::to_owned))
};
let auth_check = state_res::event_auth::auth_check(
&to_room_version(&room_version_id),
&incoming_pdu,
&pdu_event,
None, // TODO: third party invite
state_fetch,
)
@ -149,9 +152,9 @@ pub(super) async fn handle_outlier_pdu<'a>(
// 7. Persist the event as an outlier.
self.services
.outlier
.add_pdu_outlier(&incoming_pdu.event_id, &val);
.add_pdu_outlier(pdu_event.event_id(), &incoming_pdu);
trace!("Added pdu as outlier.");
Ok((incoming_pdu, val))
Ok((pdu_event, incoming_pdu))
}

View file

@ -1,10 +1,11 @@
use std::{collections::BTreeMap, time::Instant};
use conduwuit::{
Err, PduEvent, Result, debug, debug::INFO_SPAN_LEVEL, defer, implement,
Err, Event, PduEvent, Result, debug::INFO_SPAN_LEVEL, defer, implement,
utils::continue_exponential_backoff_secs,
};
use ruma::{CanonicalJsonValue, EventId, RoomId, ServerName, UInt};
use ruma::{CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, RoomId, ServerName};
use tracing::debug;
#[implement(super::Service)]
#[allow(clippy::type_complexity)]
@ -15,16 +16,19 @@ use ruma::{CanonicalJsonValue, EventId, RoomId, ServerName, UInt};
skip_all,
fields(%prev_id),
)]
pub(super) async fn handle_prev_pdu<'a>(
pub(super) async fn handle_prev_pdu<'a, Pdu>(
&self,
origin: &'a ServerName,
event_id: &'a EventId,
room_id: &'a RoomId,
eventid_info: Option<(PduEvent, BTreeMap<String, CanonicalJsonValue>)>,
create_event: &'a PduEvent,
first_ts_in_room: UInt,
create_event: &'a Pdu,
first_ts_in_room: MilliSecondsSinceUnixEpoch,
prev_id: &'a EventId,
) -> Result {
) -> Result
where
Pdu: Event + Send + Sync,
{
// Check for disabled again because it might have changed
if self.services.metadata.is_disabled(room_id).await {
return Err!(Request(Forbidden(debug_warn!(
@ -59,7 +63,7 @@ pub(super) async fn handle_prev_pdu<'a>(
};
// Skip old events
if pdu.origin_server_ts < first_ts_in_room {
if pdu.origin_server_ts() < first_ts_in_room {
return Ok(());
}

View file

@ -19,7 +19,7 @@ use std::{
};
use async_trait::async_trait;
use conduwuit::{Err, PduEvent, Result, RoomVersion, Server, utils::MutexMap};
use conduwuit::{Err, Event, PduEvent, Result, RoomVersion, Server, utils::MutexMap};
use ruma::{
OwnedEventId, OwnedRoomId, RoomId, RoomVersionId,
events::room::create::RoomCreateEventContent,
@ -105,11 +105,11 @@ impl Service {
}
}
fn check_room_id(room_id: &RoomId, pdu: &PduEvent) -> Result {
if pdu.room_id != room_id {
fn check_room_id<Pdu: Event>(room_id: &RoomId, pdu: &Pdu) -> Result {
if pdu.room_id() != room_id {
return Err!(Request(InvalidParam(error!(
pdu_event_id = ?pdu.event_id,
pdu_room_id = ?pdu.room_id,
pdu_event_id = ?pdu.event_id(),
pdu_room_id = ?pdu.room_id(),
?room_id,
"Found event from room in room",
))));
@ -118,7 +118,7 @@ fn check_room_id(room_id: &RoomId, pdu: &PduEvent) -> Result {
Ok(())
}
fn get_room_version_id(create_event: &PduEvent) -> Result<RoomVersionId> {
fn get_room_version_id<Pdu: Event>(create_event: &Pdu) -> Result<RoomVersionId> {
let content: RoomCreateEventContent = create_event.get_content()?;
let room_version = content.room_version;

View file

@ -1,4 +1,6 @@
use conduwuit::{Result, err, implement, pdu::gen_event_id_canonical_json, result::FlatOk};
use conduwuit::{
Result, err, implement, matrix::event::gen_event_id_canonical_json, result::FlatOk,
};
use ruma::{CanonicalJsonObject, CanonicalJsonValue, OwnedEventId, OwnedRoomId};
use serde_json::value::RawValue as RawJsonValue;

View file

@ -6,7 +6,7 @@ use std::{
use conduwuit::{
Result, debug, err, implement,
matrix::{PduEvent, StateMap},
matrix::{Event, StateMap},
trace,
utils::stream::{BroadbandExt, IterStream, ReadyExt, TryBroadbandExt, TryWidebandExt},
};
@ -19,11 +19,18 @@ use crate::rooms::short::ShortStateHash;
#[implement(super::Service)]
// request and build the state from a known point and resolve if > 1 prev_event
#[tracing::instrument(name = "state", level = "debug", skip_all)]
pub(super) async fn state_at_incoming_degree_one(
pub(super) async fn state_at_incoming_degree_one<Pdu>(
&self,
incoming_pdu: &PduEvent,
) -> Result<Option<HashMap<u64, OwnedEventId>>> {
let prev_event = &incoming_pdu.prev_events[0];
incoming_pdu: &Pdu,
) -> Result<Option<HashMap<u64, OwnedEventId>>>
where
Pdu: Event + Send + Sync,
{
let prev_event = incoming_pdu
.prev_events()
.next()
.expect("at least one prev_event");
let Ok(prev_event_sstatehash) = self
.services
.state_accessor
@ -55,7 +62,7 @@ pub(super) async fn state_at_incoming_degree_one(
.get_or_create_shortstatekey(&prev_pdu.kind.to_string().into(), state_key)
.await;
state.insert(shortstatekey, prev_event.clone());
state.insert(shortstatekey, prev_event.to_owned());
// Now it's the state after the pdu
}
@ -66,16 +73,18 @@ pub(super) async fn state_at_incoming_degree_one(
#[implement(super::Service)]
#[tracing::instrument(name = "state", level = "debug", skip_all)]
pub(super) async fn state_at_incoming_resolved(
pub(super) async fn state_at_incoming_resolved<Pdu>(
&self,
incoming_pdu: &PduEvent,
incoming_pdu: &Pdu,
room_id: &RoomId,
room_version_id: &RoomVersionId,
) -> Result<Option<HashMap<u64, OwnedEventId>>> {
) -> Result<Option<HashMap<u64, OwnedEventId>>>
where
Pdu: Event + Send + Sync,
{
trace!("Calculating extremity statehashes...");
let Ok(extremity_sstatehashes) = incoming_pdu
.prev_events
.iter()
.prev_events()
.try_stream()
.broad_and_then(|prev_eventid| {
self.services
@ -133,12 +142,15 @@ pub(super) async fn state_at_incoming_resolved(
}
#[implement(super::Service)]
async fn state_at_incoming_fork(
async fn state_at_incoming_fork<Pdu>(
&self,
room_id: &RoomId,
sstatehash: ShortStateHash,
prev_event: PduEvent,
) -> Result<(StateMap<OwnedEventId>, HashSet<OwnedEventId>)> {
prev_event: Pdu,
) -> Result<(StateMap<OwnedEventId>, HashSet<OwnedEventId>)>
where
Pdu: Event,
{
let mut leaf_state: HashMap<_, _> = self
.services
.state_accessor
@ -146,15 +158,15 @@ async fn state_at_incoming_fork(
.collect()
.await;
if let Some(state_key) = &prev_event.state_key {
if let Some(state_key) = prev_event.state_key() {
let shortstatekey = self
.services
.short
.get_or_create_shortstatekey(&prev_event.kind.to_string().into(), state_key)
.get_or_create_shortstatekey(&prev_event.kind().to_string().into(), state_key)
.await;
let event_id = &prev_event.event_id;
leaf_state.insert(shortstatekey, event_id.clone());
let event_id = prev_event.event_id();
leaf_state.insert(shortstatekey, event_id.to_owned());
// Now it's the state after the pdu
}

View file

@ -1,8 +1,8 @@
use std::{borrow::Borrow, collections::BTreeMap, iter::once, sync::Arc, time::Instant};
use conduwuit::{
Err, Event, Result, debug, debug_info, err, implement, info,
matrix::{EventTypeExt, PduEvent, StateKey, state_res},
Err, Result, debug, debug_info, err, implement, info, is_equal_to,
matrix::{Event, EventTypeExt, PduEvent, StateKey, state_res},
trace,
utils::stream::{BroadbandExt, ReadyExt},
warn,
@ -17,19 +17,22 @@ use crate::rooms::{
};
#[implement(super::Service)]
pub(super) async fn upgrade_outlier_to_timeline_pdu(
pub(super) async fn upgrade_outlier_to_timeline_pdu<Pdu>(
&self,
incoming_pdu: PduEvent,
val: BTreeMap<String, CanonicalJsonValue>,
create_event: &PduEvent,
create_event: &Pdu,
origin: &ServerName,
room_id: &RoomId,
) -> Result<Option<RawPduId>> {
) -> Result<Option<RawPduId>>
where
Pdu: Event + Send + Sync,
{
// Skip the PDU if we already have it as a timeline event
if let Ok(pduid) = self
.services
.timeline
.get_pdu_id(&incoming_pdu.event_id)
.get_pdu_id(incoming_pdu.event_id())
.await
{
return Ok(Some(pduid));
@ -38,7 +41,7 @@ pub(super) async fn upgrade_outlier_to_timeline_pdu(
if self
.services
.pdu_metadata
.is_event_soft_failed(&incoming_pdu.event_id)
.is_event_soft_failed(incoming_pdu.event_id())
.await
{
return Err!(Request(InvalidParam("Event has been soft failed")));
@ -52,8 +55,8 @@ pub(super) async fn upgrade_outlier_to_timeline_pdu(
// backwards extremities doing all the checks in this list starting at 1.
// These are not timeline events.
debug!("Resolving state at event {}", incoming_pdu.event_id);
let mut state_at_incoming_event = if incoming_pdu.prev_events.len() == 1 {
debug!("Resolving state at event");
let mut state_at_incoming_event = if incoming_pdu.prev_events().count() == 1 {
self.state_at_incoming_degree_one(&incoming_pdu).await?
} else {
self.state_at_incoming_resolved(&incoming_pdu, room_id, &room_version_id)
@ -62,12 +65,13 @@ pub(super) async fn upgrade_outlier_to_timeline_pdu(
if state_at_incoming_event.is_none() {
state_at_incoming_event = self
.fetch_state(origin, create_event, room_id, &incoming_pdu.event_id)
.fetch_state(origin, create_event, room_id, incoming_pdu.event_id())
.await?;
}
let state_at_incoming_event =
state_at_incoming_event.expect("we always set this to some above");
let room_version = to_room_version(&room_version_id);
debug!("Performing auth check to upgrade {}", incoming_pdu.event_id);
@ -100,16 +104,16 @@ pub(super) async fn upgrade_outlier_to_timeline_pdu(
.state
.get_auth_events(
room_id,
&incoming_pdu.kind,
&incoming_pdu.sender,
incoming_pdu.state_key.as_deref(),
&incoming_pdu.content,
incoming_pdu.kind(),
incoming_pdu.sender(),
incoming_pdu.state_key(),
incoming_pdu.content(),
)
.await?;
let state_fetch = |k: &StateEventType, s: &str| {
let key = k.with_state_key(s);
ready(auth_events.get(&key).cloned())
ready(auth_events.get(&key).map(ToOwned::to_owned))
};
debug!("running auth check on {} with claimed state auth", incoming_pdu.event_id);
@ -131,7 +135,7 @@ pub(super) async fn upgrade_outlier_to_timeline_pdu(
!self
.services
.state_accessor
.user_can_redact(&redact_id, &incoming_pdu.sender, &incoming_pdu.room_id, true)
.user_can_redact(&redact_id, incoming_pdu.sender(), incoming_pdu.room_id(), true)
.await?,
};
@ -151,7 +155,7 @@ pub(super) async fn upgrade_outlier_to_timeline_pdu(
.map(ToOwned::to_owned)
.ready_filter(|event_id| {
// Remove any that are referenced by this incoming event's prev_events
!incoming_pdu.prev_events.contains(event_id)
!incoming_pdu.prev_events().any(is_equal_to!(event_id))
})
.broad_filter_map(|event_id| async move {
// Only keep those extremities were not referenced yet
@ -168,7 +172,7 @@ pub(super) async fn upgrade_outlier_to_timeline_pdu(
debug!(
"Retained {} extremities checked against {} prev_events",
extremities.len(),
incoming_pdu.prev_events.len()
incoming_pdu.prev_events().count()
);
let state_ids_compressed: Arc<CompressedState> = self
@ -183,20 +187,20 @@ pub(super) async fn upgrade_outlier_to_timeline_pdu(
.map(Arc::new)
.await;
if incoming_pdu.state_key.is_some() {
if incoming_pdu.state_key().is_some() {
debug!("Event is a state-event. Deriving new room state");
// We also add state after incoming event to the fork states
let mut state_after = state_at_incoming_event.clone();
if let Some(state_key) = &incoming_pdu.state_key {
if let Some(state_key) = incoming_pdu.state_key() {
let shortstatekey = self
.services
.short
.get_or_create_shortstatekey(&incoming_pdu.kind.to_string().into(), state_key)
.get_or_create_shortstatekey(&incoming_pdu.kind().to_string().into(), state_key)
.await;
let event_id = &incoming_pdu.event_id;
state_after.insert(shortstatekey, event_id.clone());
let event_id = incoming_pdu.event_id();
state_after.insert(shortstatekey, event_id.to_owned());
}
let new_room_state = self
@ -254,9 +258,9 @@ pub(super) async fn upgrade_outlier_to_timeline_pdu(
// Soft fail, we keep the event as an outlier but don't add it to the timeline
self.services
.pdu_metadata
.mark_event_soft_failed(&incoming_pdu.event_id);
.mark_event_soft_failed(incoming_pdu.event_id());
warn!("Event was soft failed: {incoming_pdu:?}");
warn!("Event was soft failed: {:?}", incoming_pdu.event_id());
return Err!(Request(InvalidParam("Event has been soft failed")));
}
@ -267,7 +271,7 @@ pub(super) async fn upgrade_outlier_to_timeline_pdu(
let extremities = extremities
.iter()
.map(Borrow::borrow)
.chain(once(incoming_pdu.event_id.borrow()));
.chain(once(incoming_pdu.event_id()));
let pdu_id = self
.services

View file

@ -27,7 +27,7 @@ pub trait Options: Send + Sync {
#[derive(Clone, Debug)]
pub struct Context<'a> {
pub user_id: &'a UserId,
pub device_id: &'a DeviceId,
pub device_id: Option<&'a DeviceId>,
pub room_id: &'a RoomId,
pub token: Option<u64>,
pub options: Option<&'a LazyLoadOptions>,
@ -40,7 +40,7 @@ pub enum Status {
}
pub type Witness = HashSet<OwnedUserId>;
type Key<'a> = (&'a UserId, &'a DeviceId, &'a RoomId, &'a UserId);
type Key<'a> = (&'a UserId, Option<&'a DeviceId>, &'a RoomId, &'a UserId);
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {

View file

@ -1,7 +1,7 @@
use std::sync::Arc;
use conduwuit::{Result, implement, matrix::pdu::PduEvent};
use conduwuit_database::{Deserialized, Json, Map};
use conduwuit::{Result, implement, matrix::PduEvent};
use database::{Deserialized, Json, Map};
use ruma::{CanonicalJsonObject, EventId};
pub struct Service {

View file

@ -1,8 +1,8 @@
use std::{mem::size_of, sync::Arc};
use conduwuit::{
PduCount, PduEvent,
arrayvec::ArrayVec,
matrix::{Event, PduCount},
result::LogErr,
utils::{
ReadyExt,
@ -33,8 +33,6 @@ struct Services {
timeline: Dep<rooms::timeline::Service>,
}
pub(super) type PdusIterItem = (PduCount, PduEvent);
impl Data {
pub(super) fn new(args: &crate::Args<'_>) -> Self {
let db = &args.db;
@ -62,7 +60,7 @@ impl Data {
target: ShortEventId,
from: PduCount,
dir: Direction,
) -> impl Stream<Item = PdusIterItem> + Send + '_ {
) -> impl Stream<Item = (PduCount, impl Event)> + Send + '_ {
let mut current = ArrayVec::<u8, 16>::new();
current.extend(target.to_be_bytes());
current.extend(from.saturating_inc(dir).into_unsigned().to_be_bytes());
@ -80,8 +78,8 @@ impl Data {
let mut pdu = self.services.timeline.get_pdu_from_id(&pdu_id).await.ok()?;
if pdu.sender != user_id {
pdu.remove_transaction_id().log_err().ok();
if pdu.sender() != user_id {
pdu.as_mut_pdu().remove_transaction_id().log_err().ok();
}
Some((shorteventid, pdu))

View file

@ -1,11 +1,14 @@
mod data;
use std::sync::Arc;
use conduwuit::{PduCount, Result};
use conduwuit::{
Result,
matrix::{Event, PduCount},
};
use futures::{StreamExt, future::try_join};
use ruma::{EventId, RoomId, UserId, api::Direction};
use self::data::{Data, PdusIterItem};
use self::data::Data;
use crate::{Dep, rooms};
pub struct Service {
@ -44,16 +47,16 @@ impl Service {
}
#[allow(clippy::too_many_arguments)]
pub async fn get_relations(
&self,
user_id: &UserId,
room_id: &RoomId,
target: &EventId,
pub async fn get_relations<'a>(
&'a self,
user_id: &'a UserId,
room_id: &'a RoomId,
target: &'a EventId,
from: PduCount,
limit: usize,
max_depth: u8,
dir: Direction,
) -> Vec<PdusIterItem> {
) -> Vec<(PduCount, impl Event)> {
let room_id = self.services.short.get_shortroomid(room_id);
let target = self.services.timeline.get_pdu_count(target);

View file

@ -4,7 +4,10 @@ use std::{collections::BTreeMap, sync::Arc};
use conduwuit::{
Result, debug, err,
matrix::pdu::{PduCount, PduId, RawPduId},
matrix::{
Event,
pdu::{PduCount, PduId, RawPduId},
},
warn,
};
use futures::{Stream, TryFutureExt, try_join};
@ -74,14 +77,13 @@ impl Service {
let shortroomid = self.services.short.get_shortroomid(room_id).map_err(|e| {
err!(Database(warn!("Short room ID does not exist in database for {room_id}: {e}")))
});
let (pdu_count, shortroomid) = try_join!(pdu_count, shortroomid)?;
let (pdu_count, shortroomid) = try_join!(pdu_count, shortroomid)?;
let shorteventid = PduCount::Normal(pdu_count);
let pdu_id: RawPduId = PduId { shortroomid, shorteventid }.into();
let pdu = self.services.timeline.get_pdu_from_id(&pdu_id).await?;
let event_id: OwnedEventId = pdu.event_id;
let event_id: OwnedEventId = pdu.event_id().to_owned();
let user_id: OwnedUserId = user_id.to_owned();
let content: BTreeMap<OwnedEventId, Receipts> = BTreeMap::from_iter([(
event_id,

View file

@ -1,9 +1,10 @@
use std::sync::Arc;
use conduwuit::{
PduCount, PduEvent, Result,
PduCount, Result,
arrayvec::ArrayVec,
implement,
matrix::event::{Event, Matches},
utils::{
ArrayVecExt, IterStream, ReadyExt, set,
stream::{TryIgnore, WidebandExt},
@ -103,9 +104,10 @@ pub fn deindex_pdu(&self, shortroomid: ShortRoomId, pdu_id: &RawPduId, message_b
pub async fn search_pdus<'a>(
&'a self,
query: &'a RoomQuery<'a>,
) -> Result<(usize, impl Stream<Item = PduEvent> + Send + 'a)> {
) -> Result<(usize, impl Stream<Item = impl Event + use<>> + Send + '_)> {
let pdu_ids: Vec<_> = self.search_pdu_ids(query).await?.collect().await;
let filter = &query.criteria.filter;
let count = pdu_ids.len();
let pdus = pdu_ids
.into_iter()
@ -118,11 +120,11 @@ pub async fn search_pdus<'a>(
.ok()
})
.ready_filter(|pdu| !pdu.is_redacted())
.ready_filter(|pdu| pdu.matches(&query.criteria.filter))
.ready_filter(move |pdu| filter.matches(pdu))
.wide_filter_map(move |pdu| async move {
self.services
.state_accessor
.user_can_see_event(query.user_id?, &pdu.room_id, &pdu.event_id)
.user_can_see_event(query.user_id?, pdu.room_id(), pdu.event_id())
.await
.then_some(pdu)
})

View file

@ -5,8 +5,8 @@ mod tests;
use std::{fmt::Write, sync::Arc};
use async_trait::async_trait;
use conduwuit::{
Err, Error, PduEvent, Result, implement,
use conduwuit_core::{
Err, Error, Event, PduEvent, Result, implement,
utils::{
IterStream,
future::{BoolExt, TryExtExt},
@ -142,7 +142,7 @@ pub async fn get_summary_and_children_local(
let children_pdus: Vec<_> = self
.get_space_child_events(current_room)
.map(PduEvent::into_stripped_spacechild_state_event)
.map(Event::into_format)
.collect()
.await;
@ -511,7 +511,7 @@ async fn cache_insert(
room_id: room_id.clone(),
children_state: self
.get_space_child_events(&room_id)
.map(PduEvent::into_stripped_spacechild_state_event)
.map(Event::into_format)
.collect()
.await,
encryption,

View file

@ -1,8 +1,8 @@
use std::{collections::HashMap, fmt::Write, iter::once, sync::Arc};
use async_trait::async_trait;
use conduwuit::{
PduEvent, Result, err,
use conduwuit_core::{
Event, PduEvent, Result, err,
result::FlatOk,
state_res::{self, StateMap},
utils::{
@ -11,7 +11,7 @@ use conduwuit::{
},
warn,
};
use database::{Deserialized, Ignore, Interfix, Map};
use conduwuit_database::{Deserialized, Ignore, Interfix, Map};
use futures::{
FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt, future::join_all, pin_mut,
};
@ -319,30 +319,34 @@ impl Service {
}
#[tracing::instrument(skip_all, level = "debug")]
pub async fn summary_stripped(&self, event: &PduEvent) -> Vec<Raw<AnyStrippedStateEvent>> {
pub async fn summary_stripped<'a, E>(&self, event: &'a E) -> Vec<Raw<AnyStrippedStateEvent>>
where
E: Event + Send + Sync,
&'a E: Event + Send,
{
let cells = [
(&StateEventType::RoomCreate, ""),
(&StateEventType::RoomJoinRules, ""),
(&StateEventType::RoomCanonicalAlias, ""),
(&StateEventType::RoomName, ""),
(&StateEventType::RoomAvatar, ""),
(&StateEventType::RoomMember, event.sender.as_str()), // Add recommended events
(&StateEventType::RoomMember, event.sender().as_str()), // Add recommended events
(&StateEventType::RoomEncryption, ""),
(&StateEventType::RoomTopic, ""),
];
let fetches = cells.iter().map(|(event_type, state_key)| {
let fetches = cells.into_iter().map(|(event_type, state_key)| {
self.services
.state_accessor
.room_state_get(&event.room_id, event_type, state_key)
.room_state_get(event.room_id(), event_type, state_key)
});
join_all(fetches)
.await
.into_iter()
.filter_map(Result::ok)
.map(PduEvent::into_stripped_state_event)
.chain(once(event.to_stripped_state_event()))
.map(Event::into_format)
.chain(once(event.to_format()))
.collect()
}
@ -352,8 +356,8 @@ impl Service {
&self,
room_id: &RoomId,
shortstatehash: u64,
_mutex_lock: &RoomMutexGuard, /* Take mutex guard to make sure users get the room
* state mutex */
// Take mutex guard to make sure users get the room state mutex
_mutex_lock: &RoomMutexGuard,
) {
const BUFSIZE: usize = size_of::<u64>();

View file

@ -2,7 +2,7 @@ use std::borrow::Borrow;
use conduwuit::{
Result, err, implement,
matrix::{PduEvent, StateKey},
matrix::{Event, StateKey},
};
use futures::{Stream, StreamExt, TryFutureExt};
use ruma::{EventId, RoomId, events::StateEventType};
@ -30,7 +30,7 @@ where
pub fn room_state_full<'a>(
&'a self,
room_id: &'a RoomId,
) -> impl Stream<Item = Result<((StateEventType, StateKey), PduEvent)>> + Send + 'a {
) -> impl Stream<Item = Result<((StateEventType, StateKey), impl Event)>> + Send + 'a {
self.services
.state
.get_room_shortstatehash(room_id)
@ -45,7 +45,7 @@ pub fn room_state_full<'a>(
pub fn room_state_full_pdus<'a>(
&'a self,
room_id: &'a RoomId,
) -> impl Stream<Item = Result<PduEvent>> + Send + 'a {
) -> impl Stream<Item = Result<impl Event>> + Send + 'a {
self.services
.state
.get_room_shortstatehash(room_id)
@ -84,10 +84,29 @@ pub async fn room_state_get(
room_id: &RoomId,
event_type: &StateEventType,
state_key: &str,
) -> Result<PduEvent> {
) -> Result<impl Event> {
self.services
.state
.get_room_shortstatehash(room_id)
.and_then(|shortstatehash| self.state_get(shortstatehash, event_type, state_key))
.await
}
/// Returns all state keys for the given `room_id` and `event_type`.
#[implement(super::Service)]
#[tracing::instrument(skip(self), level = "debug")]
pub async fn room_state_keys(
&self,
room_id: &RoomId,
event_type: &StateEventType,
) -> Result<Vec<String>> {
let shortstatehash = self.services.state.get_room_shortstatehash(room_id).await?;
let state_keys: Vec<String> = self
.state_keys(shortstatehash, event_type)
.map(|state_key| state_key.to_string())
.collect()
.await;
Ok(state_keys)
}

View file

@ -2,14 +2,14 @@ use std::{borrow::Borrow, ops::Deref, sync::Arc};
use conduwuit::{
Result, at, err, implement,
matrix::{PduEvent, StateKey},
matrix::{Event, StateKey},
pair_of,
utils::{
result::FlatOk,
stream::{BroadbandExt, IterStream, ReadyExt, TryIgnore},
},
};
use conduwuit_database::Deserialized;
use database::Deserialized;
use futures::{FutureExt, Stream, StreamExt, TryFutureExt, future::try_join, pin_mut};
use ruma::{
EventId, OwnedEventId, UserId,
@ -125,11 +125,9 @@ pub async fn state_get(
shortstatehash: ShortStateHash,
event_type: &StateEventType,
state_key: &str,
) -> Result<PduEvent> {
) -> Result<impl Event> {
self.state_get_id(shortstatehash, event_type, state_key)
.and_then(|event_id: OwnedEventId| async move {
self.services.timeline.get_pdu(&event_id).await
})
.and_then(async |event_id: OwnedEventId| self.services.timeline.get_pdu(&event_id).await)
.await
}
@ -316,18 +314,16 @@ pub fn state_added(
pub fn state_full(
&self,
shortstatehash: ShortStateHash,
) -> impl Stream<Item = ((StateEventType, StateKey), PduEvent)> + Send + '_ {
) -> impl Stream<Item = ((StateEventType, StateKey), impl Event)> + Send + '_ {
self.state_full_pdus(shortstatehash)
.ready_filter_map(|pdu| {
Some(((pdu.kind.to_string().into(), pdu.state_key.clone()?), pdu))
})
.ready_filter_map(|pdu| Some(((pdu.kind().clone().into(), pdu.state_key()?.into()), pdu)))
}
#[implement(super::Service)]
pub fn state_full_pdus(
&self,
shortstatehash: ShortStateHash,
) -> impl Stream<Item = PduEvent> + Send + '_ {
) -> impl Stream<Item = impl Event> + Send + '_ {
let short_ids = self
.state_full_shortids(shortstatehash)
.ignore_err()

View file

@ -1,4 +1,4 @@
use conduwuit::{Err, Result, implement, pdu::PduBuilder};
use conduwuit::{Err, Result, implement, matrix::Event, pdu::PduBuilder};
use ruma::{
EventId, RoomId, UserId,
events::{
@ -29,14 +29,14 @@ pub async fn user_can_redact(
if redacting_event
.as_ref()
.is_ok_and(|pdu| pdu.kind == TimelineEventType::RoomCreate)
.is_ok_and(|pdu| *pdu.kind() == TimelineEventType::RoomCreate)
{
return Err!(Request(Forbidden("Redacting m.room.create is not safe, forbidding.")));
}
if redacting_event
.as_ref()
.is_ok_and(|pdu| pdu.kind == TimelineEventType::RoomServerAcl)
.is_ok_and(|pdu| *pdu.kind() == TimelineEventType::RoomServerAcl)
{
return Err!(Request(Forbidden(
"Redacting m.room.server_acl will result in the room being inaccessible for \
@ -59,9 +59,9 @@ pub async fn user_can_redact(
&& match redacting_event {
| Ok(redacting_event) =>
if federation {
redacting_event.sender.server_name() == sender.server_name()
redacting_event.sender().server_name() == sender.server_name()
} else {
redacting_event.sender == sender
redacting_event.sender() == sender
},
| _ => false,
})
@ -72,10 +72,10 @@ pub async fn user_can_redact(
.room_state_get(room_id, &StateEventType::RoomCreate, "")
.await
{
| Ok(room_create) => Ok(room_create.sender == sender
| Ok(room_create) => Ok(room_create.sender() == sender
|| redacting_event
.as_ref()
.is_ok_and(|redacting_event| redacting_event.sender == sender)),
.is_ok_and(|redacting_event| redacting_event.sender() == sender)),
| _ => Err!(Database(
"No m.room.power_levels or m.room.create events in database for room"
)),

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,369 @@
use std::collections::HashSet;
use conduwuit::{Result, implement, is_not_empty, utils::ReadyExt, warn};
use database::{Json, serialize_key};
use futures::StreamExt;
use ruma::{
OwnedServerName, RoomId, UserId,
events::{
AnyStrippedStateEvent, AnySyncStateEvent, GlobalAccountDataEventType,
RoomAccountDataEventType, StateEventType,
direct::DirectEvent,
room::{
create::RoomCreateEventContent,
member::{MembershipState, RoomMemberEventContent},
},
},
serde::Raw,
};
/// Update current membership data.
#[implement(super::Service)]
#[tracing::instrument(
level = "debug",
skip_all,
fields(
%room_id,
%user_id,
%sender,
?membership_event,
),
)]
#[allow(clippy::too_many_arguments)]
pub async fn update_membership(
&self,
room_id: &RoomId,
user_id: &UserId,
membership_event: RoomMemberEventContent,
sender: &UserId,
last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
invite_via: Option<Vec<OwnedServerName>>,
update_joined_count: bool,
) -> Result {
let membership = membership_event.membership;
// Keep track what remote users exist by adding them as "deactivated" users
//
// TODO: use futures to update remote profiles without blocking the membership
// update
#[allow(clippy::collapsible_if)]
if !self.services.globals.user_is_local(user_id) {
if !self.services.users.exists(user_id).await {
self.services.users.create(user_id, None)?;
}
}
match &membership {
| MembershipState::Join => {
// Check if the user never joined this room
if !self.once_joined(user_id, room_id).await {
// Add the user ID to the join list then
self.mark_as_once_joined(user_id, room_id);
// Check if the room has a predecessor
if let Ok(Some(predecessor)) = self
.services
.state_accessor
.room_state_get_content(room_id, &StateEventType::RoomCreate, "")
.await
.map(|content: RoomCreateEventContent| content.predecessor)
{
// Copy old tags to new room
if let Ok(tag_event) = self
.services
.account_data
.get_room(&predecessor.room_id, user_id, RoomAccountDataEventType::Tag)
.await
{
self.services
.account_data
.update(
Some(room_id),
user_id,
RoomAccountDataEventType::Tag,
&tag_event,
)
.await
.ok();
}
// Copy direct chat flag
if let Ok(mut direct_event) = self
.services
.account_data
.get_global::<DirectEvent>(user_id, GlobalAccountDataEventType::Direct)
.await
{
let mut room_ids_updated = false;
for room_ids in direct_event.content.0.values_mut() {
if room_ids.iter().any(|r| r == &predecessor.room_id) {
room_ids.push(room_id.to_owned());
room_ids_updated = true;
}
}
if room_ids_updated {
self.services
.account_data
.update(
None,
user_id,
GlobalAccountDataEventType::Direct.to_string().into(),
&serde_json::to_value(&direct_event)
.expect("to json always works"),
)
.await?;
}
}
}
}
self.mark_as_joined(user_id, room_id);
},
| MembershipState::Invite => {
// We want to know if the sender is ignored by the receiver
if self.services.users.user_is_ignored(sender, user_id).await {
return Ok(());
}
self.mark_as_invited(user_id, room_id, last_state, invite_via)
.await;
},
| MembershipState::Leave | MembershipState::Ban => {
self.mark_as_left(user_id, room_id);
if self.services.globals.user_is_local(user_id)
&& (self.services.config.forget_forced_upon_leave
|| self.services.metadata.is_banned(room_id).await
|| self.services.metadata.is_disabled(room_id).await)
{
self.forget(room_id, user_id);
}
},
| _ => {},
}
if update_joined_count {
self.update_joined_count(room_id).await;
}
Ok(())
}
#[implement(super::Service)]
#[tracing::instrument(level = "debug", skip(self))]
pub async fn update_joined_count(&self, room_id: &RoomId) {
let mut joinedcount = 0_u64;
let mut invitedcount = 0_u64;
let mut knockedcount = 0_u64;
let mut joined_servers = HashSet::new();
self.room_members(room_id)
.ready_for_each(|joined| {
joined_servers.insert(joined.server_name().to_owned());
joinedcount = joinedcount.saturating_add(1);
})
.await;
invitedcount = invitedcount.saturating_add(
self.room_members_invited(room_id)
.count()
.await
.try_into()
.unwrap_or(0),
);
knockedcount = knockedcount.saturating_add(
self.room_members_knocked(room_id)
.count()
.await
.try_into()
.unwrap_or(0),
);
self.db.roomid_joinedcount.raw_put(room_id, joinedcount);
self.db.roomid_invitedcount.raw_put(room_id, invitedcount);
self.db
.roomuserid_knockedcount
.raw_put(room_id, knockedcount);
self.room_servers(room_id)
.ready_for_each(|old_joined_server| {
if joined_servers.remove(old_joined_server) {
return;
}
// Server not in room anymore
let roomserver_id = (room_id, old_joined_server);
let serverroom_id = (old_joined_server, room_id);
self.db.roomserverids.del(roomserver_id);
self.db.serverroomids.del(serverroom_id);
})
.await;
// Now only new servers are in joined_servers anymore
for server in &joined_servers {
let roomserver_id = (room_id, server);
let serverroom_id = (server, room_id);
self.db.roomserverids.put_raw(roomserver_id, []);
self.db.serverroomids.put_raw(serverroom_id, []);
}
self.appservice_in_room_cache
.write()
.expect("locked")
.remove(room_id);
}
/// Direct DB function to directly mark a user as joined. It is not
/// recommended to use this directly. You most likely should use
/// `update_membership` instead
#[implement(super::Service)]
#[tracing::instrument(skip(self), level = "debug")]
pub fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) {
let userroom_id = (user_id, room_id);
let userroom_id = serialize_key(userroom_id).expect("failed to serialize userroom_id");
let roomuser_id = (room_id, user_id);
let roomuser_id = serialize_key(roomuser_id).expect("failed to serialize roomuser_id");
self.db.userroomid_joined.insert(&userroom_id, []);
self.db.roomuserid_joined.insert(&roomuser_id, []);
self.db.userroomid_invitestate.remove(&userroom_id);
self.db.roomuserid_invitecount.remove(&roomuser_id);
self.db.userroomid_leftstate.remove(&userroom_id);
self.db.roomuserid_leftcount.remove(&roomuser_id);
self.db.userroomid_knockedstate.remove(&userroom_id);
self.db.roomuserid_knockedcount.remove(&roomuser_id);
self.db.roomid_inviteviaservers.remove(room_id);
}
/// Direct DB function to directly mark a user as left. It is not
/// recommended to use this directly. You most likely should use
/// `update_membership` instead
#[implement(super::Service)]
#[tracing::instrument(skip(self), level = "debug")]
pub fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) {
let userroom_id = (user_id, room_id);
let userroom_id = serialize_key(userroom_id).expect("failed to serialize userroom_id");
let roomuser_id = (room_id, user_id);
let roomuser_id = serialize_key(roomuser_id).expect("failed to serialize roomuser_id");
// (timo) TODO
let leftstate = Vec::<Raw<AnySyncStateEvent>>::new();
self.db
.userroomid_leftstate
.raw_put(&userroom_id, Json(leftstate));
self.db
.roomuserid_leftcount
.raw_aput::<8, _, _>(&roomuser_id, self.services.globals.next_count().unwrap());
self.db.userroomid_joined.remove(&userroom_id);
self.db.roomuserid_joined.remove(&roomuser_id);
self.db.userroomid_invitestate.remove(&userroom_id);
self.db.roomuserid_invitecount.remove(&roomuser_id);
self.db.userroomid_knockedstate.remove(&userroom_id);
self.db.roomuserid_knockedcount.remove(&roomuser_id);
self.db.roomid_inviteviaservers.remove(room_id);
}
/// Direct DB function to directly mark a user as knocked. It is not
/// recommended to use this directly. You most likely should use
/// `update_membership` instead
#[implement(super::Service)]
#[tracing::instrument(skip(self), level = "debug")]
pub fn mark_as_knocked(
&self,
user_id: &UserId,
room_id: &RoomId,
knocked_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
) {
let userroom_id = (user_id, room_id);
let userroom_id = serialize_key(userroom_id).expect("failed to serialize userroom_id");
let roomuser_id = (room_id, user_id);
let roomuser_id = serialize_key(roomuser_id).expect("failed to serialize roomuser_id");
self.db
.userroomid_knockedstate
.raw_put(&userroom_id, Json(knocked_state.unwrap_or_default()));
self.db
.roomuserid_knockedcount
.raw_aput::<8, _, _>(&roomuser_id, self.services.globals.next_count().unwrap());
self.db.userroomid_joined.remove(&userroom_id);
self.db.roomuserid_joined.remove(&roomuser_id);
self.db.userroomid_invitestate.remove(&userroom_id);
self.db.roomuserid_invitecount.remove(&roomuser_id);
self.db.userroomid_leftstate.remove(&userroom_id);
self.db.roomuserid_leftcount.remove(&roomuser_id);
self.db.roomid_inviteviaservers.remove(room_id);
}
/// Makes a user forget a room.
#[implement(super::Service)]
#[tracing::instrument(skip(self), level = "debug")]
pub fn forget(&self, room_id: &RoomId, user_id: &UserId) {
let userroom_id = (user_id, room_id);
let roomuser_id = (room_id, user_id);
self.db.userroomid_leftstate.del(userroom_id);
self.db.roomuserid_leftcount.del(roomuser_id);
}
#[implement(super::Service)]
#[tracing::instrument(level = "debug", skip(self))]
fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) {
let key = (user_id, room_id);
self.db.roomuseroncejoinedids.put_raw(key, []);
}
#[implement(super::Service)]
#[tracing::instrument(level = "debug", skip(self, last_state, invite_via))]
pub async fn mark_as_invited(
&self,
user_id: &UserId,
room_id: &RoomId,
last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
invite_via: Option<Vec<OwnedServerName>>,
) {
let roomuser_id = (room_id, user_id);
let roomuser_id = serialize_key(roomuser_id).expect("failed to serialize roomuser_id");
let userroom_id = (user_id, room_id);
let userroom_id = serialize_key(userroom_id).expect("failed to serialize userroom_id");
self.db
.userroomid_invitestate
.raw_put(&userroom_id, Json(last_state.unwrap_or_default()));
self.db
.roomuserid_invitecount
.raw_aput::<8, _, _>(&roomuser_id, self.services.globals.next_count().unwrap());
self.db.userroomid_joined.remove(&userroom_id);
self.db.roomuserid_joined.remove(&roomuser_id);
self.db.userroomid_leftstate.remove(&userroom_id);
self.db.roomuserid_leftcount.remove(&roomuser_id);
self.db.userroomid_knockedstate.remove(&userroom_id);
self.db.roomuserid_knockedcount.remove(&roomuser_id);
if let Some(servers) = invite_via.filter(is_not_empty!()) {
self.add_servers_invite_via(room_id, servers).await;
}
}

View file

@ -0,0 +1,92 @@
use conduwuit::{
Result, implement,
utils::{StreamTools, stream::TryIgnore},
warn,
};
use database::Ignore;
use futures::{Stream, StreamExt, stream::iter};
use itertools::Itertools;
use ruma::{
OwnedServerName, RoomId, ServerName,
events::{StateEventType, room::power_levels::RoomPowerLevelsEventContent},
int,
};
#[implement(super::Service)]
#[tracing::instrument(level = "debug", skip(self, servers))]
pub async fn add_servers_invite_via(&self, room_id: &RoomId, servers: Vec<OwnedServerName>) {
let mut servers: Vec<_> = self
.servers_invite_via(room_id)
.map(ToOwned::to_owned)
.chain(iter(servers.into_iter()))
.collect()
.await;
servers.sort_unstable();
servers.dedup();
let servers = servers
.iter()
.map(|server| server.as_bytes())
.collect_vec()
.join(&[0xFF][..]);
self.db
.roomid_inviteviaservers
.insert(room_id.as_bytes(), &servers);
}
/// Gets up to five servers that are likely to be in the room in the
/// distant future.
///
/// See <https://spec.matrix.org/latest/appendices/#routing>
#[implement(super::Service)]
#[tracing::instrument(skip(self), level = "trace")]
pub async fn servers_route_via(&self, room_id: &RoomId) -> Result<Vec<OwnedServerName>> {
let most_powerful_user_server = self
.services
.state_accessor
.room_state_get_content(room_id, &StateEventType::RoomPowerLevels, "")
.await
.map(|content: RoomPowerLevelsEventContent| {
content
.users
.iter()
.max_by_key(|(_, power)| *power)
.and_then(|x| (x.1 >= &int!(50)).then_some(x))
.map(|(user, _power)| user.server_name().to_owned())
});
let mut servers: Vec<OwnedServerName> = self
.room_members(room_id)
.counts_by(|user| user.server_name().to_owned())
.await
.into_iter()
.sorted_by_key(|(_, users)| *users)
.map(|(server, _)| server)
.rev()
.take(5)
.collect();
if let Ok(Some(server)) = most_powerful_user_server {
servers.insert(0, server);
servers.truncate(5);
}
Ok(servers)
}
#[implement(super::Service)]
#[tracing::instrument(skip(self), level = "debug")]
pub fn servers_invite_via<'a>(
&'a self,
room_id: &'a RoomId,
) -> impl Stream<Item = &ServerName> + Send + 'a {
type KeyVal<'a> = (Ignore, Vec<&'a ServerName>);
self.db
.roomid_inviteviaservers
.stream_raw_prefix(room_id)
.ignore_err()
.map(|(_, servers): KeyVal<'_>| *servers.last().expect("at least one server"))
}

View file

@ -9,7 +9,7 @@ use async_trait::async_trait;
use conduwuit::{
Result,
arrayvec::ArrayVec,
at, checked, err, expected, utils,
at, checked, err, expected, implement, utils,
utils::{bytes, math::usize_from_f64, stream::IterStream},
};
use database::Map;
@ -115,29 +115,30 @@ impl crate::Service for Service {
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
}
impl Service {
/// Returns a stack with info on shortstatehash, full state, added diff and
/// removed diff for the selected shortstatehash and each parent layer.
#[tracing::instrument(name = "load", level = "debug", skip(self))]
pub async fn load_shortstatehash_info(
&self,
shortstatehash: ShortStateHash,
) -> Result<ShortStateInfoVec> {
if let Some(r) = self.stateinfo_cache.lock()?.get_mut(&shortstatehash) {
return Ok(r.clone());
}
let stack = self.new_shortstatehash_info(shortstatehash).await?;
self.cache_shortstatehash_info(shortstatehash, stack.clone())
.await?;
Ok(stack)
/// Returns a stack with info on shortstatehash, full state, added diff and
/// removed diff for the selected shortstatehash and each parent layer.
#[implement(Service)]
#[tracing::instrument(name = "load", level = "debug", skip(self))]
pub async fn load_shortstatehash_info(
&self,
shortstatehash: ShortStateHash,
) -> Result<ShortStateInfoVec> {
if let Some(r) = self.stateinfo_cache.lock()?.get_mut(&shortstatehash) {
return Ok(r.clone());
}
/// Returns a stack with info on shortstatehash, full state, added diff and
/// removed diff for the selected shortstatehash and each parent layer.
#[tracing::instrument(
let stack = self.new_shortstatehash_info(shortstatehash).await?;
self.cache_shortstatehash_info(shortstatehash, stack.clone())
.await?;
Ok(stack)
}
/// Returns a stack with info on shortstatehash, full state, added diff and
/// removed diff for the selected shortstatehash and each parent layer.
#[implement(Service)]
#[tracing::instrument(
name = "cache",
level = "debug",
skip_all,
@ -146,362 +147,365 @@ impl Service {
stack = stack.len(),
),
)]
async fn cache_shortstatehash_info(
&self,
shortstatehash: ShortStateHash,
stack: ShortStateInfoVec,
) -> Result {
self.stateinfo_cache.lock()?.insert(shortstatehash, stack);
async fn cache_shortstatehash_info(
&self,
shortstatehash: ShortStateHash,
stack: ShortStateInfoVec,
) -> Result {
self.stateinfo_cache.lock()?.insert(shortstatehash, stack);
Ok(())
}
Ok(())
}
async fn new_shortstatehash_info(
&self,
shortstatehash: ShortStateHash,
) -> Result<ShortStateInfoVec> {
let StateDiff { parent, added, removed } = self.get_statediff(shortstatehash).await?;
#[implement(Service)]
async fn new_shortstatehash_info(
&self,
shortstatehash: ShortStateHash,
) -> Result<ShortStateInfoVec> {
let StateDiff { parent, added, removed } = self.get_statediff(shortstatehash).await?;
let Some(parent) = parent else {
return Ok(vec![ShortStateInfo {
shortstatehash,
full_state: added.clone(),
added,
removed,
}]);
};
let mut stack = Box::pin(self.load_shortstatehash_info(parent)).await?;
let top = stack.last().expect("at least one frame");
let mut full_state = (*top.full_state).clone();
full_state.extend(added.iter().copied());
let removed = (*removed).clone();
for r in &removed {
full_state.remove(r);
}
stack.push(ShortStateInfo {
let Some(parent) = parent else {
return Ok(vec![ShortStateInfo {
shortstatehash,
full_state: added.clone(),
added,
removed: Arc::new(removed),
full_state: Arc::new(full_state),
});
removed,
}]);
};
Ok(stack)
let mut stack = Box::pin(self.load_shortstatehash_info(parent)).await?;
let top = stack.last().expect("at least one frame");
let mut full_state = (*top.full_state).clone();
full_state.extend(added.iter().copied());
let removed = (*removed).clone();
for r in &removed {
full_state.remove(r);
}
pub fn compress_state_events<'a, I>(
&'a self,
state: I,
) -> impl Stream<Item = CompressedStateEvent> + Send + 'a
where
I: Iterator<Item = (&'a ShortStateKey, &'a EventId)> + Clone + Debug + Send + 'a,
{
let event_ids = state.clone().map(at!(1));
stack.push(ShortStateInfo {
shortstatehash,
added,
removed: Arc::new(removed),
full_state: Arc::new(full_state),
});
let short_event_ids = self
.services
.short
.multi_get_or_create_shorteventid(event_ids);
Ok(stack)
}
state
.stream()
.map(at!(0))
.zip(short_event_ids)
.map(|(shortstatekey, shorteventid)| {
compress_state_event(*shortstatekey, shorteventid)
})
}
#[implement(Service)]
pub fn compress_state_events<'a, I>(
&'a self,
state: I,
) -> impl Stream<Item = CompressedStateEvent> + Send + 'a
where
I: Iterator<Item = (&'a ShortStateKey, &'a EventId)> + Clone + Debug + Send + 'a,
{
let event_ids = state.clone().map(at!(1));
pub async fn compress_state_event(
&self,
shortstatekey: ShortStateKey,
event_id: &EventId,
) -> CompressedStateEvent {
let shorteventid = self
.services
.short
.get_or_create_shorteventid(event_id)
.await;
let short_event_ids = self
.services
.short
.multi_get_or_create_shorteventid(event_ids);
compress_state_event(shortstatekey, shorteventid)
}
state
.stream()
.map(at!(0))
.zip(short_event_ids)
.map(|(shortstatekey, shorteventid)| compress_state_event(*shortstatekey, shorteventid))
}
/// Creates a new shortstatehash that often is just a diff to an already
/// existing shortstatehash and therefore very efficient.
///
/// There are multiple layers of diffs. The bottom layer 0 always contains
/// the full state. Layer 1 contains diffs to states of layer 0, layer 2
/// diffs to layer 1 and so on. If layer n > 0 grows too big, it will be
/// combined with layer n-1 to create a new diff on layer n-1 that's
/// based on layer n-2. If that layer is also too big, it will recursively
/// fix above layers too.
///
/// * `shortstatehash` - Shortstatehash of this state
/// * `statediffnew` - Added to base. Each vec is shortstatekey+shorteventid
/// * `statediffremoved` - Removed from base. Each vec is
/// shortstatekey+shorteventid
/// * `diff_to_sibling` - Approximately how much the diff grows each time
/// for this layer
/// * `parent_states` - A stack with info on shortstatehash, full state,
/// added diff and removed diff for each parent layer
pub fn save_state_from_diff(
&self,
shortstatehash: ShortStateHash,
statediffnew: Arc<CompressedState>,
statediffremoved: Arc<CompressedState>,
diff_to_sibling: usize,
mut parent_states: ParentStatesVec,
) -> Result {
let statediffnew_len = statediffnew.len();
let statediffremoved_len = statediffremoved.len();
let diffsum = checked!(statediffnew_len + statediffremoved_len)?;
#[implement(Service)]
pub async fn compress_state_event(
&self,
shortstatekey: ShortStateKey,
event_id: &EventId,
) -> CompressedStateEvent {
let shorteventid = self
.services
.short
.get_or_create_shorteventid(event_id)
.await;
if parent_states.len() > 3 {
// Number of layers
// To many layers, we have to go deeper
let parent = parent_states.pop().expect("parent must have a state");
compress_state_event(shortstatekey, shorteventid)
}
let mut parent_new = (*parent.added).clone();
let mut parent_removed = (*parent.removed).clone();
for removed in statediffremoved.iter() {
if !parent_new.remove(removed) {
// It was not added in the parent and we removed it
parent_removed.insert(*removed);
}
// Else it was added in the parent and we removed it again. We
// can forget this change
}
for new in statediffnew.iter() {
if !parent_removed.remove(new) {
// It was not touched in the parent and we added it
parent_new.insert(*new);
}
// Else it was removed in the parent and we added it again. We
// can forget this change
}
self.save_state_from_diff(
shortstatehash,
Arc::new(parent_new),
Arc::new(parent_removed),
diffsum,
parent_states,
)?;
return Ok(());
}
if parent_states.is_empty() {
// There is no parent layer, create a new state
self.save_statediff(shortstatehash, &StateDiff {
parent: None,
added: statediffnew,
removed: statediffremoved,
});
return Ok(());
}
// Else we have two options.
// 1. We add the current diff on top of the parent layer.
// 2. We replace a layer above
/// Creates a new shortstatehash that often is just a diff to an already
/// existing shortstatehash and therefore very efficient.
///
/// There are multiple layers of diffs. The bottom layer 0 always contains
/// the full state. Layer 1 contains diffs to states of layer 0, layer 2
/// diffs to layer 1 and so on. If layer n > 0 grows too big, it will be
/// combined with layer n-1 to create a new diff on layer n-1 that's
/// based on layer n-2. If that layer is also too big, it will recursively
/// fix above layers too.
///
/// * `shortstatehash` - Shortstatehash of this state
/// * `statediffnew` - Added to base. Each vec is shortstatekey+shorteventid
/// * `statediffremoved` - Removed from base. Each vec is
/// shortstatekey+shorteventid
/// * `diff_to_sibling` - Approximately how much the diff grows each time for
/// this layer
/// * `parent_states` - A stack with info on shortstatehash, full state, added
/// diff and removed diff for each parent layer
#[implement(Service)]
pub fn save_state_from_diff(
&self,
shortstatehash: ShortStateHash,
statediffnew: Arc<CompressedState>,
statediffremoved: Arc<CompressedState>,
diff_to_sibling: usize,
mut parent_states: ParentStatesVec,
) -> Result {
let statediffnew_len = statediffnew.len();
let statediffremoved_len = statediffremoved.len();
let diffsum = checked!(statediffnew_len + statediffremoved_len)?;
if parent_states.len() > 3 {
// Number of layers
// To many layers, we have to go deeper
let parent = parent_states.pop().expect("parent must have a state");
let parent_added_len = parent.added.len();
let parent_removed_len = parent.removed.len();
let parent_diff = checked!(parent_added_len + parent_removed_len)?;
if checked!(diffsum * diffsum)? >= checked!(2 * diff_to_sibling * parent_diff)? {
// Diff too big, we replace above layer(s)
let mut parent_new = (*parent.added).clone();
let mut parent_removed = (*parent.removed).clone();
let mut parent_new = (*parent.added).clone();
let mut parent_removed = (*parent.removed).clone();
for removed in statediffremoved.iter() {
if !parent_new.remove(removed) {
// It was not added in the parent and we removed it
parent_removed.insert(*removed);
}
// Else it was added in the parent and we removed it again. We
// can forget this change
for removed in statediffremoved.iter() {
if !parent_new.remove(removed) {
// It was not added in the parent and we removed it
parent_removed.insert(*removed);
}
for new in statediffnew.iter() {
if !parent_removed.remove(new) {
// It was not touched in the parent and we added it
parent_new.insert(*new);
}
// Else it was removed in the parent and we added it again. We
// can forget this change
}
self.save_state_from_diff(
shortstatehash,
Arc::new(parent_new),
Arc::new(parent_removed),
diffsum,
parent_states,
)?;
} else {
// Diff small enough, we add diff as layer on top of parent
self.save_statediff(shortstatehash, &StateDiff {
parent: Some(parent.shortstatehash),
added: statediffnew,
removed: statediffremoved,
});
// Else it was added in the parent and we removed it again. We
// can forget this change
}
Ok(())
for new in statediffnew.iter() {
if !parent_removed.remove(new) {
// It was not touched in the parent and we added it
parent_new.insert(*new);
}
// Else it was removed in the parent and we added it again. We
// can forget this change
}
self.save_state_from_diff(
shortstatehash,
Arc::new(parent_new),
Arc::new(parent_removed),
diffsum,
parent_states,
)?;
return Ok(());
}
/// Returns the new shortstatehash, and the state diff from the previous
/// room state
#[tracing::instrument(skip(self, new_state_ids_compressed), level = "debug")]
pub async fn save_state(
&self,
room_id: &RoomId,
new_state_ids_compressed: Arc<CompressedState>,
) -> Result<HashSetCompressStateEvent> {
let previous_shortstatehash = self
.services
.state
.get_room_shortstatehash(room_id)
.await
.ok();
let state_hash =
utils::calculate_hash(new_state_ids_compressed.iter().map(|bytes| &bytes[..]));
let (new_shortstatehash, already_existed) = self
.services
.short
.get_or_create_shortstatehash(&state_hash)
.await;
if Some(new_shortstatehash) == previous_shortstatehash {
return Ok(HashSetCompressStateEvent {
shortstatehash: new_shortstatehash,
..Default::default()
});
}
let states_parents = if let Some(p) = previous_shortstatehash {
self.load_shortstatehash_info(p).await.unwrap_or_default()
} else {
ShortStateInfoVec::new()
};
let (statediffnew, statediffremoved) =
if let Some(parent_stateinfo) = states_parents.last() {
let statediffnew: CompressedState = new_state_ids_compressed
.difference(&parent_stateinfo.full_state)
.copied()
.collect();
let statediffremoved: CompressedState = parent_stateinfo
.full_state
.difference(&new_state_ids_compressed)
.copied()
.collect();
(Arc::new(statediffnew), Arc::new(statediffremoved))
} else {
(new_state_ids_compressed, Arc::new(CompressedState::new()))
};
if !already_existed {
self.save_state_from_diff(
new_shortstatehash,
statediffnew.clone(),
statediffremoved.clone(),
2, // every state change is 2 event changes on average
states_parents,
)?;
}
Ok(HashSetCompressStateEvent {
shortstatehash: new_shortstatehash,
if parent_states.is_empty() {
// There is no parent layer, create a new state
self.save_statediff(shortstatehash, &StateDiff {
parent: None,
added: statediffnew,
removed: statediffremoved,
})
});
return Ok(());
}
#[tracing::instrument(skip(self), level = "debug", name = "get")]
async fn get_statediff(&self, shortstatehash: ShortStateHash) -> Result<StateDiff> {
const BUFSIZE: usize = size_of::<ShortStateHash>();
const STRIDE: usize = size_of::<ShortStateHash>();
// Else we have two options.
// 1. We add the current diff on top of the parent layer.
// 2. We replace a layer above
let value = self
.db
.shortstatehash_statediff
.aqry::<BUFSIZE, _>(&shortstatehash)
.await
.map_err(|e| {
err!(Database("Failed to find StateDiff from short {shortstatehash:?}: {e}"))
})?;
let parent = parent_states.pop().expect("parent must have a state");
let parent_added_len = parent.added.len();
let parent_removed_len = parent.removed.len();
let parent_diff = checked!(parent_added_len + parent_removed_len)?;
let parent = utils::u64_from_bytes(&value[0..size_of::<u64>()])
.ok()
.take_if(|parent| *parent != 0);
if checked!(diffsum * diffsum)? >= checked!(2 * diff_to_sibling * parent_diff)? {
// Diff too big, we replace above layer(s)
let mut parent_new = (*parent.added).clone();
let mut parent_removed = (*parent.removed).clone();
debug_assert!(value.len() % STRIDE == 0, "value not aligned to stride");
let _num_values = value.len() / STRIDE;
let mut add_mode = true;
let mut added = CompressedState::new();
let mut removed = CompressedState::new();
let mut i = STRIDE;
while let Some(v) = value.get(i..expected!(i + 2 * STRIDE)) {
if add_mode && v.starts_with(&0_u64.to_be_bytes()) {
add_mode = false;
i = expected!(i + STRIDE);
continue;
for removed in statediffremoved.iter() {
if !parent_new.remove(removed) {
// It was not added in the parent and we removed it
parent_removed.insert(*removed);
}
if add_mode {
added.insert(v.try_into()?);
} else {
removed.insert(v.try_into()?);
}
i = expected!(i + 2 * STRIDE);
// Else it was added in the parent and we removed it again. We
// can forget this change
}
Ok(StateDiff {
parent,
added: Arc::new(added),
removed: Arc::new(removed),
})
for new in statediffnew.iter() {
if !parent_removed.remove(new) {
// It was not touched in the parent and we added it
parent_new.insert(*new);
}
// Else it was removed in the parent and we added it again. We
// can forget this change
}
self.save_state_from_diff(
shortstatehash,
Arc::new(parent_new),
Arc::new(parent_removed),
diffsum,
parent_states,
)?;
} else {
// Diff small enough, we add diff as layer on top of parent
self.save_statediff(shortstatehash, &StateDiff {
parent: Some(parent.shortstatehash),
added: statediffnew,
removed: statediffremoved,
});
}
fn save_statediff(&self, shortstatehash: ShortStateHash, diff: &StateDiff) {
let mut value = Vec::<u8>::with_capacity(
2_usize
.saturating_add(diff.added.len())
.saturating_add(diff.removed.len()),
);
Ok(())
}
let parent = diff.parent.unwrap_or(0_u64);
value.extend_from_slice(&parent.to_be_bytes());
/// Returns the new shortstatehash, and the state diff from the previous
/// room state
#[implement(Service)]
#[tracing::instrument(skip(self, new_state_ids_compressed), level = "debug")]
pub async fn save_state(
&self,
room_id: &RoomId,
new_state_ids_compressed: Arc<CompressedState>,
) -> Result<HashSetCompressStateEvent> {
let previous_shortstatehash = self
.services
.state
.get_room_shortstatehash(room_id)
.await
.ok();
for new in diff.added.iter() {
value.extend_from_slice(&new[..]);
}
let state_hash =
utils::calculate_hash(new_state_ids_compressed.iter().map(|bytes| &bytes[..]));
if !diff.removed.is_empty() {
value.extend_from_slice(&0_u64.to_be_bytes());
for removed in diff.removed.iter() {
value.extend_from_slice(&removed[..]);
}
}
let (new_shortstatehash, already_existed) = self
.services
.short
.get_or_create_shortstatehash(&state_hash)
.await;
self.db
.shortstatehash_statediff
.insert(&shortstatehash.to_be_bytes(), &value);
if Some(new_shortstatehash) == previous_shortstatehash {
return Ok(HashSetCompressStateEvent {
shortstatehash: new_shortstatehash,
..Default::default()
});
}
let states_parents = if let Some(p) = previous_shortstatehash {
self.load_shortstatehash_info(p).await.unwrap_or_default()
} else {
ShortStateInfoVec::new()
};
let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() {
let statediffnew: CompressedState = new_state_ids_compressed
.difference(&parent_stateinfo.full_state)
.copied()
.collect();
let statediffremoved: CompressedState = parent_stateinfo
.full_state
.difference(&new_state_ids_compressed)
.copied()
.collect();
(Arc::new(statediffnew), Arc::new(statediffremoved))
} else {
(new_state_ids_compressed, Arc::new(CompressedState::new()))
};
if !already_existed {
self.save_state_from_diff(
new_shortstatehash,
statediffnew.clone(),
statediffremoved.clone(),
2, // every state change is 2 event changes on average
states_parents,
)?;
}
Ok(HashSetCompressStateEvent {
shortstatehash: new_shortstatehash,
added: statediffnew,
removed: statediffremoved,
})
}
#[implement(Service)]
#[tracing::instrument(skip(self), level = "debug", name = "get")]
async fn get_statediff(&self, shortstatehash: ShortStateHash) -> Result<StateDiff> {
const BUFSIZE: usize = size_of::<ShortStateHash>();
const STRIDE: usize = size_of::<ShortStateHash>();
let value = self
.db
.shortstatehash_statediff
.aqry::<BUFSIZE, _>(&shortstatehash)
.await
.map_err(|e| {
err!(Database("Failed to find StateDiff from short {shortstatehash:?}: {e}"))
})?;
let parent = utils::u64_from_bytes(&value[0..size_of::<u64>()])
.ok()
.take_if(|parent| *parent != 0);
debug_assert!(value.len() % STRIDE == 0, "value not aligned to stride");
let _num_values = value.len() / STRIDE;
let mut add_mode = true;
let mut added = CompressedState::new();
let mut removed = CompressedState::new();
let mut i = STRIDE;
while let Some(v) = value.get(i..expected!(i + 2 * STRIDE)) {
if add_mode && v.starts_with(&0_u64.to_be_bytes()) {
add_mode = false;
i = expected!(i + STRIDE);
continue;
}
if add_mode {
added.insert(v.try_into()?);
} else {
removed.insert(v.try_into()?);
}
i = expected!(i + 2 * STRIDE);
}
Ok(StateDiff {
parent,
added: Arc::new(added),
removed: Arc::new(removed),
})
}
#[implement(Service)]
fn save_statediff(&self, shortstatehash: ShortStateHash, diff: &StateDiff) {
let mut value = Vec::<u8>::with_capacity(
2_usize
.saturating_add(diff.added.len())
.saturating_add(diff.removed.len()),
);
let parent = diff.parent.unwrap_or(0_u64);
value.extend_from_slice(&parent.to_be_bytes());
for new in diff.added.iter() {
value.extend_from_slice(&new[..]);
}
if !diff.removed.is_empty() {
value.extend_from_slice(&0_u64.to_be_bytes());
for removed in diff.removed.iter() {
value.extend_from_slice(&removed[..]);
}
}
self.db
.shortstatehash_statediff
.insert(&shortstatehash.to_be_bytes(), &value);
}
#[inline]

View file

@ -1,7 +1,7 @@
use std::{collections::BTreeMap, sync::Arc};
use conduwuit::{
Result, err,
use conduwuit_core::{
Event, Result, err,
matrix::pdu::{PduCount, PduEvent, PduId, RawPduId},
utils::{
ReadyExt,
@ -49,7 +49,10 @@ impl crate::Service for Service {
}
impl Service {
pub async fn add_to_thread(&self, root_event_id: &EventId, pdu: &PduEvent) -> Result<()> {
pub async fn add_to_thread<E>(&self, root_event_id: &EventId, event: &E) -> Result
where
E: Event + Send + Sync,
{
let root_id = self
.services
.timeline
@ -86,7 +89,7 @@ impl Service {
}) {
// Thread already existed
relations.count = relations.count.saturating_add(uint!(1));
relations.latest_event = pdu.to_message_like_event();
relations.latest_event = event.to_format();
let content = serde_json::to_value(relations).expect("to_value always works");
@ -99,7 +102,7 @@ impl Service {
} else {
// New thread
let relations = BundledThread {
latest_event: pdu.to_message_like_event(),
latest_event: event.to_format(),
count: uint!(1),
current_user_participated: true,
};
@ -116,7 +119,7 @@ impl Service {
self.services
.timeline
.replace_pdu(&root_id, &root_pdu_json, &root_pdu)
.replace_pdu(&root_id, &root_pdu_json)
.await?;
}
@ -126,10 +129,10 @@ impl Service {
users.extend_from_slice(&userids);
},
| _ => {
users.push(root_pdu.sender);
users.push(root_pdu.sender().to_owned());
},
}
users.push(pdu.sender.clone());
users.push(event.sender().to_owned());
self.update_participants(&root_id, &users)
}
@ -158,10 +161,10 @@ impl Service {
.ready_take_while(move |pdu_id| pdu_id.shortroomid() == shortroomid.to_be_bytes())
.wide_filter_map(move |pdu_id| async move {
let mut pdu = self.services.timeline.get_pdu_from_id(&pdu_id).await.ok()?;
let pdu_id: PduId = pdu_id.into();
if pdu.sender != user_id {
pdu.remove_transaction_id().ok();
let pdu_id: PduId = pdu_id.into();
if pdu.sender() != user_id {
pdu.as_mut_pdu().remove_transaction_id().ok();
}
Some((pdu_id.shorteventid, pdu))

View file

@ -0,0 +1,448 @@
use std::{
collections::{BTreeMap, HashSet},
sync::Arc,
};
use conduwuit_core::{
Result, err, error, implement,
matrix::{
event::Event,
pdu::{PduCount, PduEvent, PduId, RawPduId},
},
utils::{self, ReadyExt},
};
use futures::StreamExt;
use ruma::{
CanonicalJsonObject, CanonicalJsonValue, EventId, RoomVersionId, UserId,
events::{
GlobalAccountDataEventType, StateEventType, TimelineEventType,
push_rules::PushRulesEvent,
room::{
encrypted::Relation,
member::{MembershipState, RoomMemberEventContent},
power_levels::RoomPowerLevelsEventContent,
redaction::RoomRedactionEventContent,
},
},
push::{Action, Ruleset, Tweak},
};
use super::{ExtractBody, ExtractRelatesTo, ExtractRelatesToEventId, RoomMutexGuard};
use crate::{appservice::NamespaceRegex, rooms::state_compressor::CompressedState};
/// Append the incoming event setting the state snapshot to the state from
/// the server that sent the event.
#[implement(super::Service)]
#[tracing::instrument(level = "debug", skip_all)]
pub async fn append_incoming_pdu<'a, Leaves>(
&'a self,
pdu: &'a PduEvent,
pdu_json: CanonicalJsonObject,
new_room_leaves: Leaves,
state_ids_compressed: Arc<CompressedState>,
soft_fail: bool,
state_lock: &'a RoomMutexGuard,
) -> Result<Option<RawPduId>>
where
Leaves: Iterator<Item = &'a EventId> + Send + 'a,
{
// We append to state before appending the pdu, so we don't have a moment in
// time with the pdu without it's state. This is okay because append_pdu can't
// fail.
self.services
.state
.set_event_state(&pdu.event_id, &pdu.room_id, state_ids_compressed)
.await?;
if soft_fail {
self.services
.pdu_metadata
.mark_as_referenced(&pdu.room_id, pdu.prev_events.iter().map(AsRef::as_ref));
self.services
.state
.set_forward_extremities(&pdu.room_id, new_room_leaves, state_lock)
.await;
return Ok(None);
}
let pdu_id = self
.append_pdu(pdu, pdu_json, new_room_leaves, state_lock)
.await?;
Ok(Some(pdu_id))
}
/// Creates a new persisted data unit and adds it to a room.
///
/// By this point the incoming event should be fully authenticated, no auth
/// happens in `append_pdu`.
///
/// Returns pdu id
#[implement(super::Service)]
#[tracing::instrument(level = "debug", skip_all)]
pub async fn append_pdu<'a, Leaves>(
&'a self,
pdu: &'a PduEvent,
mut pdu_json: CanonicalJsonObject,
leaves: Leaves,
state_lock: &'a RoomMutexGuard,
) -> Result<RawPduId>
where
Leaves: Iterator<Item = &'a EventId> + Send + 'a,
{
// Coalesce database writes for the remainder of this scope.
let _cork = self.db.db.cork_and_flush();
let shortroomid = self
.services
.short
.get_shortroomid(pdu.room_id())
.await
.map_err(|_| err!(Database("Room does not exist")))?;
// Make unsigned fields correct. This is not properly documented in the spec,
// but state events need to have previous content in the unsigned field, so
// clients can easily interpret things like membership changes
if let Some(state_key) = pdu.state_key() {
if let CanonicalJsonValue::Object(unsigned) = pdu_json
.entry("unsigned".to_owned())
.or_insert_with(|| CanonicalJsonValue::Object(BTreeMap::default()))
{
if let Ok(shortstatehash) = self
.services
.state_accessor
.pdu_shortstatehash(pdu.event_id())
.await
{
if let Ok(prev_state) = self
.services
.state_accessor
.state_get(shortstatehash, &pdu.kind().to_string().into(), state_key)
.await
{
unsigned.insert(
"prev_content".to_owned(),
CanonicalJsonValue::Object(
utils::to_canonical_object(prev_state.get_content_as_value())
.map_err(|e| {
err!(Database(error!(
"Failed to convert prev_state to canonical JSON: {e}",
)))
})?,
),
);
unsigned.insert(
String::from("prev_sender"),
CanonicalJsonValue::String(prev_state.sender().to_string()),
);
unsigned.insert(
String::from("replaces_state"),
CanonicalJsonValue::String(prev_state.event_id().to_string()),
);
}
}
} else {
error!("Invalid unsigned type in pdu.");
}
}
// We must keep track of all events that have been referenced.
self.services
.pdu_metadata
.mark_as_referenced(pdu.room_id(), pdu.prev_events().map(AsRef::as_ref));
self.services
.state
.set_forward_extremities(pdu.room_id(), leaves, state_lock)
.await;
let insert_lock = self.mutex_insert.lock(pdu.room_id()).await;
let count1 = self.services.globals.next_count().unwrap();
// Mark as read first so the sending client doesn't get a notification even if
// appending fails
self.services
.read_receipt
.private_read_set(pdu.room_id(), pdu.sender(), count1);
self.services
.user
.reset_notification_counts(pdu.sender(), pdu.room_id());
let count2 = PduCount::Normal(self.services.globals.next_count().unwrap());
let pdu_id: RawPduId = PduId { shortroomid, shorteventid: count2 }.into();
// Insert pdu
self.db.append_pdu(&pdu_id, pdu, &pdu_json, count2).await;
drop(insert_lock);
// See if the event matches any known pushers via power level
let power_levels: RoomPowerLevelsEventContent = self
.services
.state_accessor
.room_state_get_content(pdu.room_id(), &StateEventType::RoomPowerLevels, "")
.await
.unwrap_or_default();
let mut push_target: HashSet<_> = self
.services
.state_cache
.active_local_users_in_room(pdu.room_id())
.map(ToOwned::to_owned)
// Don't notify the sender of their own events, and dont send from ignored users
.ready_filter(|user| *user != pdu.sender())
.filter_map(|recipient_user| async move { (!self.services.users.user_is_ignored(pdu.sender(), &recipient_user).await).then_some(recipient_user) })
.collect()
.await;
let mut notifies = Vec::with_capacity(push_target.len().saturating_add(1));
let mut highlights = Vec::with_capacity(push_target.len().saturating_add(1));
if *pdu.kind() == TimelineEventType::RoomMember {
if let Some(state_key) = pdu.state_key() {
let target_user_id = UserId::parse(state_key)?;
if self.services.users.is_active_local(target_user_id).await {
push_target.insert(target_user_id.to_owned());
}
}
}
let serialized = pdu.to_format();
for user in &push_target {
let rules_for_user = self
.services
.account_data
.get_global(user, GlobalAccountDataEventType::PushRules)
.await
.map_or_else(
|_| Ruleset::server_default(user),
|ev: PushRulesEvent| ev.content.global,
);
let mut highlight = false;
let mut notify = false;
for action in self
.services
.pusher
.get_actions(user, &rules_for_user, &power_levels, &serialized, pdu.room_id())
.await
{
match action {
| Action::Notify => notify = true,
| Action::SetTweak(Tweak::Highlight(true)) => {
highlight = true;
},
| _ => {},
}
// Break early if both conditions are true
if notify && highlight {
break;
}
}
if notify {
notifies.push(user.clone());
}
if highlight {
highlights.push(user.clone());
}
self.services
.pusher
.get_pushkeys(user)
.ready_for_each(|push_key| {
self.services
.sending
.send_pdu_push(&pdu_id, user, push_key.to_owned())
.expect("TODO: replace with future");
})
.await;
}
self.db
.increment_notification_counts(pdu.room_id(), notifies, highlights);
match *pdu.kind() {
| TimelineEventType::RoomRedaction => {
use RoomVersionId::*;
let room_version_id = self.services.state.get_room_version(pdu.room_id()).await?;
match room_version_id {
| V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => {
if let Some(redact_id) = pdu.redacts() {
if self
.services
.state_accessor
.user_can_redact(redact_id, pdu.sender(), pdu.room_id(), false)
.await?
{
self.redact_pdu(redact_id, pdu, shortroomid).await?;
}
}
},
| _ => {
let content: RoomRedactionEventContent = pdu.get_content()?;
if let Some(redact_id) = &content.redacts {
if self
.services
.state_accessor
.user_can_redact(redact_id, pdu.sender(), pdu.room_id(), false)
.await?
{
self.redact_pdu(redact_id, pdu, shortroomid).await?;
}
}
},
}
},
| TimelineEventType::SpaceChild =>
if let Some(_state_key) = pdu.state_key() {
self.services
.spaces
.roomid_spacehierarchy_cache
.lock()
.await
.remove(pdu.room_id());
},
| TimelineEventType::RoomMember => {
if let Some(state_key) = pdu.state_key() {
// if the state_key fails
let target_user_id =
UserId::parse(state_key).expect("This state_key was previously validated");
let content: RoomMemberEventContent = pdu.get_content()?;
let stripped_state = match content.membership {
| MembershipState::Invite | MembershipState::Knock =>
self.services.state.summary_stripped(pdu).await.into(),
| _ => None,
};
// Update our membership info, we do this here incase a user is invited or
// knocked and immediately leaves we need the DB to record the invite or
// knock event for auth
self.services
.state_cache
.update_membership(
pdu.room_id(),
target_user_id,
content,
pdu.sender(),
stripped_state,
None,
true,
)
.await?;
}
},
| TimelineEventType::RoomMessage => {
let content: ExtractBody = pdu.get_content()?;
if let Some(body) = content.body {
self.services.search.index_pdu(shortroomid, &pdu_id, &body);
if self.services.admin.is_admin_command(pdu, &body).await {
self.services.admin.command_with_sender(
body,
Some((pdu.event_id()).into()),
pdu.sender.clone().into(),
)?;
}
}
},
| _ => {},
}
if let Ok(content) = pdu.get_content::<ExtractRelatesToEventId>() {
if let Ok(related_pducount) = self.get_pdu_count(&content.relates_to.event_id).await {
self.services
.pdu_metadata
.add_relation(count2, related_pducount);
}
}
if let Ok(content) = pdu.get_content::<ExtractRelatesTo>() {
match content.relates_to {
| Relation::Reply { in_reply_to } => {
// We need to do it again here, because replies don't have
// event_id as a top level field
if let Ok(related_pducount) = self.get_pdu_count(&in_reply_to.event_id).await {
self.services
.pdu_metadata
.add_relation(count2, related_pducount);
}
},
| Relation::Thread(thread) => {
self.services
.threads
.add_to_thread(&thread.event_id, pdu)
.await?;
},
| _ => {}, // TODO: Aggregate other types
}
}
for appservice in self.services.appservice.read().await.values() {
if self
.services
.state_cache
.appservice_in_room(pdu.room_id(), appservice)
.await
{
self.services
.sending
.send_pdu_appservice(appservice.registration.id.clone(), pdu_id)?;
continue;
}
// If the RoomMember event has a non-empty state_key, it is targeted at someone.
// If it is our appservice user, we send this PDU to it.
if *pdu.kind() == TimelineEventType::RoomMember {
if let Some(state_key_uid) = &pdu
.state_key
.as_ref()
.and_then(|state_key| UserId::parse(state_key.as_str()).ok())
{
let appservice_uid = appservice.registration.sender_localpart.as_str();
if state_key_uid == &appservice_uid {
self.services
.sending
.send_pdu_appservice(appservice.registration.id.clone(), pdu_id)?;
continue;
}
}
}
let matching_users = |users: &NamespaceRegex| {
appservice.users.is_match(pdu.sender().as_str())
|| *pdu.kind() == TimelineEventType::RoomMember
&& pdu
.state_key
.as_ref()
.is_some_and(|state_key| users.is_match(state_key))
};
let matching_aliases = |aliases: NamespaceRegex| {
self.services
.alias
.local_aliases_for_room(pdu.room_id())
.ready_any(move |room_alias| aliases.is_match(room_alias.as_str()))
};
if matching_aliases(appservice.aliases.clone()).await
|| appservice.rooms.is_match(pdu.room_id().as_str())
|| matching_users(&appservice.users)
{
self.services
.sending
.send_pdu_appservice(appservice.registration.id.clone(), pdu_id)?;
}
}
Ok(pdu_id)
}

View file

@ -0,0 +1,191 @@
use std::iter::once;
use conduwuit_core::{
Result, debug, debug_warn, implement, info,
matrix::{
event::Event,
pdu::{PduCount, PduId, RawPduId},
},
utils::{IterStream, ReadyExt},
validated, warn,
};
use futures::{FutureExt, StreamExt};
use ruma::{
RoomId, ServerName,
api::federation,
events::{
StateEventType, TimelineEventType, room::power_levels::RoomPowerLevelsEventContent,
},
uint,
};
use serde_json::value::RawValue as RawJsonValue;
use super::ExtractBody;
#[implement(super::Service)]
#[tracing::instrument(name = "backfill", level = "debug", skip(self))]
pub async fn backfill_if_required(&self, room_id: &RoomId, from: PduCount) -> Result<()> {
if self
.services
.state_cache
.room_joined_count(room_id)
.await
.is_ok_and(|count| count <= 1)
&& !self
.services
.state_accessor
.is_world_readable(room_id)
.await
{
// Room is empty (1 user or none), there is no one that can backfill
return Ok(());
}
let first_pdu = self
.first_item_in_room(room_id)
.await
.expect("Room is not empty");
if first_pdu.0 < from {
// No backfill required, there are still events between them
return Ok(());
}
let power_levels: RoomPowerLevelsEventContent = self
.services
.state_accessor
.room_state_get_content(room_id, &StateEventType::RoomPowerLevels, "")
.await
.unwrap_or_default();
let room_mods = power_levels.users.iter().filter_map(|(user_id, level)| {
if level > &power_levels.users_default && !self.services.globals.user_is_local(user_id) {
Some(user_id.server_name())
} else {
None
}
});
let canonical_room_alias_server = once(
self.services
.state_accessor
.get_canonical_alias(room_id)
.await,
)
.filter_map(Result::ok)
.map(|alias| alias.server_name().to_owned())
.stream();
let mut servers = room_mods
.stream()
.map(ToOwned::to_owned)
.chain(canonical_room_alias_server)
.chain(
self.services
.server
.config
.trusted_servers
.iter()
.map(ToOwned::to_owned)
.stream(),
)
.ready_filter(|server_name| !self.services.globals.server_is_ours(server_name))
.filter_map(|server_name| async move {
self.services
.state_cache
.server_in_room(&server_name, room_id)
.await
.then_some(server_name)
})
.boxed();
while let Some(ref backfill_server) = servers.next().await {
info!("Asking {backfill_server} for backfill");
let response = self
.services
.sending
.send_federation_request(
backfill_server,
federation::backfill::get_backfill::v1::Request {
room_id: room_id.to_owned(),
v: vec![first_pdu.1.event_id().to_owned()],
limit: uint!(100),
},
)
.await;
match response {
| Ok(response) => {
for pdu in response.pdus {
if let Err(e) = self.backfill_pdu(backfill_server, pdu).boxed().await {
debug_warn!("Failed to add backfilled pdu in room {room_id}: {e}");
}
}
return Ok(());
},
| Err(e) => {
warn!("{backfill_server} failed to provide backfill for room {room_id}: {e}");
},
}
}
info!("No servers could backfill, but backfill was needed in room {room_id}");
Ok(())
}
#[implement(super::Service)]
#[tracing::instrument(skip(self, pdu), level = "debug")]
pub async fn backfill_pdu(&self, origin: &ServerName, pdu: Box<RawJsonValue>) -> Result<()> {
let (room_id, event_id, value) = self.services.event_handler.parse_incoming_pdu(&pdu).await?;
// Lock so we cannot backfill the same pdu twice at the same time
let mutex_lock = self
.services
.event_handler
.mutex_federation
.lock(&room_id)
.await;
// Skip the PDU if we already have it as a timeline event
if let Ok(pdu_id) = self.get_pdu_id(&event_id).await {
debug!("We already know {event_id} at {pdu_id:?}");
return Ok(());
}
self.services
.event_handler
.handle_incoming_pdu(origin, &room_id, &event_id, value, false)
.boxed()
.await?;
let value = self.get_pdu_json(&event_id).await?;
let pdu = self.get_pdu(&event_id).await?;
let shortroomid = self.services.short.get_shortroomid(&room_id).await?;
let insert_lock = self.mutex_insert.lock(&room_id).await;
let count: i64 = self.services.globals.next_count().unwrap().try_into()?;
let pdu_id: RawPduId = PduId {
shortroomid,
shorteventid: PduCount::Backfilled(validated!(0 - count)),
}
.into();
// Insert pdu
self.db.prepend_backfill_pdu(&pdu_id, &event_id, &value);
drop(insert_lock);
if pdu.kind == TimelineEventType::RoomMessage {
let content: ExtractBody = pdu.get_content()?;
if let Some(body) = content.body {
self.services.search.index_pdu(shortroomid, &pdu_id, &body);
}
}
drop(mutex_lock);
debug!("Prepended backfill pdu");
Ok(())
}

View file

@ -0,0 +1,226 @@
use std::{collections::HashSet, iter::once};
use conduwuit_core::{
Err, Result, implement,
matrix::{event::Event, pdu::PduBuilder},
utils::{IterStream, ReadyExt},
};
use futures::{FutureExt, StreamExt};
use ruma::{
OwnedEventId, OwnedServerName, RoomId, RoomVersionId, UserId,
events::{
TimelineEventType,
room::{
member::{MembershipState, RoomMemberEventContent},
redaction::RoomRedactionEventContent,
},
},
};
use super::RoomMutexGuard;
/// Creates a new persisted data unit and adds it to a room. This function
/// takes a roomid_mutex_state, meaning that only this function is able to
/// mutate the room state.
#[implement(super::Service)]
#[tracing::instrument(skip(self, state_lock), level = "debug")]
pub async fn build_and_append_pdu(
&self,
pdu_builder: PduBuilder,
sender: &UserId,
room_id: &RoomId,
state_lock: &RoomMutexGuard,
) -> Result<OwnedEventId> {
let (pdu, pdu_json) = self
.create_hash_and_sign_event(pdu_builder, sender, room_id, state_lock)
.await?;
if self.services.admin.is_admin_room(pdu.room_id()).await {
self.check_pdu_for_admin_room(&pdu, sender).boxed().await?;
}
// If redaction event is not authorized, do not append it to the timeline
if *pdu.kind() == TimelineEventType::RoomRedaction {
use RoomVersionId::*;
match self.services.state.get_room_version(pdu.room_id()).await? {
| V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => {
if let Some(redact_id) = pdu.redacts() {
if !self
.services
.state_accessor
.user_can_redact(redact_id, pdu.sender(), pdu.room_id(), false)
.await?
{
return Err!(Request(Forbidden("User cannot redact this event.")));
}
}
},
| _ => {
let content: RoomRedactionEventContent = pdu.get_content()?;
if let Some(redact_id) = &content.redacts {
if !self
.services
.state_accessor
.user_can_redact(redact_id, pdu.sender(), pdu.room_id(), false)
.await?
{
return Err!(Request(Forbidden("User cannot redact this event.")));
}
}
},
}
}
if *pdu.kind() == TimelineEventType::RoomMember {
let content: RoomMemberEventContent = pdu.get_content()?;
if content.join_authorized_via_users_server.is_some()
&& content.membership != MembershipState::Join
{
return Err!(Request(BadJson(
"join_authorised_via_users_server is only for member joins"
)));
}
if content
.join_authorized_via_users_server
.as_ref()
.is_some_and(|authorising_user| {
!self.services.globals.user_is_local(authorising_user)
}) {
return Err!(Request(InvalidParam(
"Authorising user does not belong to this homeserver"
)));
}
}
// We append to state before appending the pdu, so we don't have a moment in
// time with the pdu without it's state. This is okay because append_pdu can't
// fail.
let statehashid = self.services.state.append_to_state(&pdu).await?;
let pdu_id = self
.append_pdu(
&pdu,
pdu_json,
// Since this PDU references all pdu_leaves we can update the leaves
// of the room
once(pdu.event_id()),
state_lock,
)
.boxed()
.await?;
// We set the room state after inserting the pdu, so that we never have a moment
// in time where events in the current room state do not exist
self.services
.state
.set_room_state(pdu.room_id(), statehashid, state_lock);
let mut servers: HashSet<OwnedServerName> = self
.services
.state_cache
.room_servers(pdu.room_id())
.map(ToOwned::to_owned)
.collect()
.await;
// In case we are kicking or banning a user, we need to inform their server of
// the change
if *pdu.kind() == TimelineEventType::RoomMember {
if let Some(state_key_uid) = &pdu
.state_key
.as_ref()
.and_then(|state_key| UserId::parse(state_key.as_str()).ok())
{
servers.insert(state_key_uid.server_name().to_owned());
}
}
// Remove our server from the server list since it will be added to it by
// room_servers() and/or the if statement above
servers.remove(self.services.globals.server_name());
self.services
.sending
.send_pdu_servers(servers.iter().map(AsRef::as_ref).stream(), &pdu_id)
.await?;
Ok(pdu.event_id().to_owned())
}
#[implement(super::Service)]
#[tracing::instrument(skip_all, level = "debug")]
async fn check_pdu_for_admin_room<Pdu>(&self, pdu: &Pdu, sender: &UserId) -> Result
where
Pdu: Event + Send + Sync,
{
match pdu.kind() {
| TimelineEventType::RoomEncryption => {
return Err!(Request(Forbidden(error!("Encryption not supported in admins room."))));
},
| TimelineEventType::RoomMember => {
let target = pdu
.state_key()
.filter(|v| v.starts_with('@'))
.unwrap_or(sender.as_str());
let server_user = &self.services.globals.server_user.to_string();
let content: RoomMemberEventContent = pdu.get_content()?;
match content.membership {
| MembershipState::Leave => {
if target == server_user {
return Err!(Request(Forbidden(error!(
"Server user cannot leave the admins room."
))));
}
let count = self
.services
.state_cache
.room_members(pdu.room_id())
.ready_filter(|user| self.services.globals.user_is_local(user))
.ready_filter(|user| *user != target)
.boxed()
.count()
.await;
if count < 2 {
return Err!(Request(Forbidden(error!(
"Last admin cannot leave the admins room."
))));
}
},
| MembershipState::Ban if pdu.state_key().is_some() => {
if target == server_user {
return Err!(Request(Forbidden(error!(
"Server cannot be banned from admins room."
))));
}
let count = self
.services
.state_cache
.room_members(pdu.room_id())
.ready_filter(|user| self.services.globals.user_is_local(user))
.ready_filter(|user| *user != target)
.boxed()
.count()
.await;
if count < 2 {
return Err!(Request(Forbidden(error!(
"Last admin cannot be banned from admins room."
))));
}
},
| _ => {},
}
},
| _ => {},
}
Ok(())
}

View file

@ -0,0 +1,214 @@
use std::cmp;
use conduwuit_core::{
Err, Error, Result, err, implement,
matrix::{
event::{Event, gen_event_id},
pdu::{EventHash, PduBuilder, PduEvent},
state_res::{self, RoomVersion},
},
utils::{self, IterStream, ReadyExt, stream::TryIgnore},
};
use futures::{StreamExt, TryStreamExt, future, future::ready};
use ruma::{
CanonicalJsonObject, CanonicalJsonValue, OwnedEventId, RoomId, RoomVersionId, UserId,
canonical_json::to_canonical_value,
events::{StateEventType, TimelineEventType, room::create::RoomCreateEventContent},
uint,
};
use serde_json::value::to_raw_value;
use tracing::warn;
use super::RoomMutexGuard;
#[implement(super::Service)]
pub async fn create_hash_and_sign_event(
&self,
pdu_builder: PduBuilder,
sender: &UserId,
room_id: &RoomId,
_mutex_lock: &RoomMutexGuard, /* Take mutex guard to make sure users get the room
* state mutex */
) -> Result<(PduEvent, CanonicalJsonObject)> {
let PduBuilder {
event_type,
content,
unsigned,
state_key,
redacts,
timestamp,
} = pdu_builder;
let prev_events: Vec<OwnedEventId> = self
.services
.state
.get_forward_extremities(room_id)
.take(20)
.map(Into::into)
.collect()
.await;
// If there was no create event yet, assume we are creating a room
let room_version_id = self
.services
.state
.get_room_version(room_id)
.await
.or_else(|_| {
if event_type == TimelineEventType::RoomCreate {
let content: RoomCreateEventContent = serde_json::from_str(content.get())?;
Ok(content.room_version)
} else {
Err(Error::InconsistentRoomState(
"non-create event for room of unknown version",
room_id.to_owned(),
))
}
})?;
let room_version = RoomVersion::new(&room_version_id).expect("room version is supported");
let auth_events = self
.services
.state
.get_auth_events(room_id, &event_type, sender, state_key.as_deref(), &content)
.await?;
// Our depth is the maximum depth of prev_events + 1
let depth = prev_events
.iter()
.stream()
.map(Ok)
.and_then(|event_id| self.get_pdu(event_id))
.and_then(|pdu| future::ok(pdu.depth))
.ignore_err()
.ready_fold(uint!(0), cmp::max)
.await
.saturating_add(uint!(1));
let mut unsigned = unsigned.unwrap_or_default();
if let Some(state_key) = &state_key {
if let Ok(prev_pdu) = self
.services
.state_accessor
.room_state_get(room_id, &event_type.to_string().into(), state_key)
.await
{
unsigned.insert("prev_content".to_owned(), prev_pdu.get_content_as_value());
unsigned.insert("prev_sender".to_owned(), serde_json::to_value(prev_pdu.sender())?);
unsigned
.insert("replaces_state".to_owned(), serde_json::to_value(prev_pdu.event_id())?);
}
}
if event_type != TimelineEventType::RoomCreate && prev_events.is_empty() {
return Err!(Request(Unknown("Event incorrectly had zero prev_events.")));
}
if state_key.is_none() && depth.lt(&uint!(2)) {
// The first two events in a room are always m.room.create and m.room.member,
// so any other events with that same depth are illegal.
warn!(
"Had unsafe depth {depth} when creating non-state event in {room_id}. Cowardly \
aborting"
);
return Err!(Request(Unknown("Unsafe depth for non-state event.")));
}
let mut pdu = PduEvent {
event_id: ruma::event_id!("$thiswillbefilledinlater").into(),
room_id: room_id.to_owned(),
sender: sender.to_owned(),
origin: None,
origin_server_ts: timestamp.map_or_else(
|| {
utils::millis_since_unix_epoch()
.try_into()
.expect("u64 fits into UInt")
},
|ts| ts.get(),
),
kind: event_type,
content,
state_key,
prev_events,
depth,
auth_events: auth_events
.values()
.map(|pdu| pdu.event_id.clone())
.collect(),
redacts,
unsigned: if unsigned.is_empty() {
None
} else {
Some(to_raw_value(&unsigned)?)
},
hashes: EventHash { sha256: "aaa".to_owned() },
signatures: None,
};
let auth_fetch = |k: &StateEventType, s: &str| {
let key = (k.clone(), s.into());
ready(auth_events.get(&key).map(ToOwned::to_owned))
};
let auth_check = state_res::auth_check(
&room_version,
&pdu,
None, // TODO: third_party_invite
auth_fetch,
)
.await
.map_err(|e| err!(Request(Forbidden(warn!("Auth check failed: {e:?}")))))?;
if !auth_check {
return Err!(Request(Forbidden("Event is not authorized.")));
}
// Hash and sign
let mut pdu_json = utils::to_canonical_object(&pdu).map_err(|e| {
err!(Request(BadJson(warn!("Failed to convert PDU to canonical JSON: {e}"))))
})?;
// room v3 and above removed the "event_id" field from remote PDU format
match room_version_id {
| RoomVersionId::V1 | RoomVersionId::V2 => {},
| _ => {
pdu_json.remove("event_id");
},
}
// Add origin because synapse likes that (and it's required in the spec)
pdu_json.insert(
"origin".to_owned(),
to_canonical_value(self.services.globals.server_name())
.expect("server name is a valid CanonicalJsonValue"),
);
if let Err(e) = self
.services
.server_keys
.hash_and_sign_event(&mut pdu_json, &room_version_id)
{
return match e {
| Error::Signatures(ruma::signatures::Error::PduSize) => {
Err!(Request(TooLarge("Message/PDU is too long (exceeds 65535 bytes)")))
},
| _ => Err!(Request(Unknown(warn!("Signing event failed: {e}")))),
};
}
// Generate event id
pdu.event_id = gen_event_id(&pdu_json, &room_version_id)?;
pdu_json.insert("event_id".into(), CanonicalJsonValue::String(pdu.event_id.clone().into()));
// Generate short event id
let _shorteventid = self
.services
.short
.get_or_create_shorteventid(&pdu.event_id)
.await;
Ok((pdu, pdu_json))
}

View file

@ -207,7 +207,6 @@ impl Data {
&self,
pdu_id: &RawPduId,
pdu_json: &CanonicalJsonObject,
_pdu: &PduEvent,
) -> Result {
if self.pduid_pdu.get(pdu_id).await.is_not_found() {
return Err!(Request(NotFound("PDU does not exist.")));

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,51 @@
use conduwuit_core::{
Result, err, implement,
matrix::event::Event,
utils::{self},
};
use ruma::EventId;
use super::ExtractBody;
use crate::rooms::short::ShortRoomId;
/// Replace a PDU with the redacted form.
#[implement(super::Service)]
#[tracing::instrument(name = "redact", level = "debug", skip(self))]
pub async fn redact_pdu<Pdu: Event + Send + Sync>(
&self,
event_id: &EventId,
reason: &Pdu,
shortroomid: ShortRoomId,
) -> Result {
// TODO: Don't reserialize, keep original json
let Ok(pdu_id) = self.get_pdu_id(event_id).await else {
// If event does not exist, just noop
return Ok(());
};
let mut pdu = self
.get_pdu_from_id(&pdu_id)
.await
.map(Event::into_pdu)
.map_err(|e| {
err!(Database(error!(?pdu_id, ?event_id, ?e, "PDU ID points to invalid PDU.")))
})?;
if let Ok(content) = pdu.get_content::<ExtractBody>() {
if let Some(body) = content.body {
self.services
.search
.deindex_pdu(shortroomid, &pdu_id, &body);
}
}
let room_version_id = self.services.state.get_room_version(pdu.room_id()).await?;
pdu.redact(&room_version_id, reason.to_value())?;
let obj = utils::to_canonical_object(&pdu).map_err(|e| {
err!(Database(error!(?event_id, ?e, "Failed to convert PDU to canonical JSON")))
})?;
self.replace_pdu(&pdu_id, &obj).await
}