mod data; use std::{ collections::{HashMap, HashSet}, fmt::Write, sync::Arc, }; use conduit::{ utils::{calculate_hash, MutexMap, MutexMapGuard}, warn, Error, PduEvent, Result, }; use data::Data; use ruma::{ api::client::error::ErrorKind, events::{ room::{create::RoomCreateEventContent, member::RoomMemberEventContent}, AnyStrippedStateEvent, StateEventType, TimelineEventType, }, serde::Raw, state_res::{self, StateMap}, EventId, OwnedEventId, OwnedRoomId, RoomId, RoomVersionId, UserId, }; use super::state_compressor::CompressedStateEvent; use crate::{globals, rooms, Dep}; pub struct Service { services: Services, db: Data, pub mutex: RoomMutexMap, } struct Services { globals: Dep, short: Dep, spaces: Dep, state_cache: Dep, state_accessor: Dep, state_compressor: Dep, timeline: Dep, } type RoomMutexMap = MutexMap; pub type RoomMutexGuard = MutexMapGuard; impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { services: Services { globals: args.depend::("globals"), short: args.depend::("rooms::short"), spaces: args.depend::("rooms::spaces"), state_cache: args.depend::("rooms::state_cache"), state_accessor: args.depend::("rooms::state_accessor"), state_compressor: args.depend::("rooms::state_compressor"), timeline: args.depend::("rooms::timeline"), }, db: Data::new(args.db), mutex: RoomMutexMap::new(), })) } fn memory_usage(&self, out: &mut dyn Write) -> Result<()> { let mutex = self.mutex.len(); writeln!(out, "state_mutex: {mutex}")?; Ok(()) } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } impl Service { /// Set the room to the given statehash and update caches. pub async fn force_state( &self, room_id: &RoomId, shortstatehash: u64, statediffnew: Arc>, _statediffremoved: Arc>, state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex ) -> Result<()> { for event_id in statediffnew.iter().filter_map(|new| { self.services .state_compressor .parse_compressed_state_event(new) .ok() .map(|(_, id)| id) }) { let Some(pdu) = self.services.timeline.get_pdu_json(&event_id)? else { continue; }; let pdu: PduEvent = match serde_json::from_str( &serde_json::to_string(&pdu).expect("CanonicalJsonObj can be serialized to JSON"), ) { Ok(pdu) => pdu, Err(_) => continue, }; match pdu.kind { TimelineEventType::RoomMember => { let Ok(membership_event) = serde_json::from_str::(pdu.content.get()) else { continue; }; let Some(state_key) = pdu.state_key else { continue; }; let Ok(user_id) = UserId::parse(state_key) else { continue; }; self.services.state_cache.update_membership( room_id, &user_id, membership_event, &pdu.sender, None, None, false, )?; }, TimelineEventType::SpaceChild => { self.services .spaces .roomid_spacehierarchy_cache .lock() .await .remove(&pdu.room_id); }, _ => continue, } } self.services.state_cache.update_joined_count(room_id)?; self.db .set_room_state(room_id, shortstatehash, state_lock)?; Ok(()) } /// Generates a new StateHash and associates it with the incoming event. /// /// This adds all current state events (not including the incoming event) /// to `stateid_pduid` and adds the incoming event to `eventid_statehash`. #[tracing::instrument(skip(self, state_ids_compressed), level = "debug")] pub fn set_event_state( &self, event_id: &EventId, room_id: &RoomId, state_ids_compressed: Arc>, ) -> Result { let shorteventid = self.services.short.get_or_create_shorteventid(event_id)?; let previous_shortstatehash = self.db.get_room_shortstatehash(room_id)?; let state_hash = calculate_hash( &state_ids_compressed .iter() .map(|s| &s[..]) .collect::>(), ); let (shortstatehash, already_existed) = self .services .short .get_or_create_shortstatehash(&state_hash)?; if !already_existed { let states_parents = previous_shortstatehash.map_or_else( || Ok(Vec::new()), |p| self.services.state_compressor.load_shortstatehash_info(p), )?; let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() { let statediffnew: HashSet<_> = state_ids_compressed .difference(&parent_stateinfo.1) .copied() .collect(); let statediffremoved: HashSet<_> = parent_stateinfo .1 .difference(&state_ids_compressed) .copied() .collect(); (Arc::new(statediffnew), Arc::new(statediffremoved)) } else { (state_ids_compressed, Arc::new(HashSet::new())) }; self.services.state_compressor.save_state_from_diff( shortstatehash, statediffnew, statediffremoved, 1_000_000, // high number because no state will be based on this one states_parents, )?; } self.db.set_event_state(shorteventid, shortstatehash)?; Ok(shortstatehash) } /// Generates a new StateHash and associates it with the incoming event. /// /// This adds all current state events (not including the incoming event) /// to `stateid_pduid` and adds the incoming event to `eventid_statehash`. #[tracing::instrument(skip(self, new_pdu), level = "debug")] pub fn append_to_state(&self, new_pdu: &PduEvent) -> Result { let shorteventid = self .services .short .get_or_create_shorteventid(&new_pdu.event_id)?; let previous_shortstatehash = self.get_room_shortstatehash(&new_pdu.room_id)?; if let Some(p) = previous_shortstatehash { self.db.set_event_state(shorteventid, p)?; } if let Some(state_key) = &new_pdu.state_key { let states_parents = previous_shortstatehash.map_or_else( || Ok(Vec::new()), #[inline] |p| self.services.state_compressor.load_shortstatehash_info(p), )?; let shortstatekey = self .services .short .get_or_create_shortstatekey(&new_pdu.kind.to_string().into(), state_key)?; let new = self .services .state_compressor .compress_state_event(shortstatekey, &new_pdu.event_id)?; let replaces = states_parents .last() .map(|info| { info.1 .iter() .find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes())) }) .unwrap_or_default(); if Some(&new) == replaces { return Ok(previous_shortstatehash.expect("must exist")); } // TODO: statehash with deterministic inputs let shortstatehash = self.services.globals.next_count()?; let mut statediffnew = HashSet::new(); statediffnew.insert(new); let mut statediffremoved = HashSet::new(); if let Some(replaces) = replaces { statediffremoved.insert(*replaces); } self.services.state_compressor.save_state_from_diff( shortstatehash, Arc::new(statediffnew), Arc::new(statediffremoved), 2, states_parents, )?; Ok(shortstatehash) } else { Ok(previous_shortstatehash.expect("first event in room must be a state event")) } } #[tracing::instrument(skip(self, invite_event), level = "debug")] pub fn calculate_invite_state(&self, invite_event: &PduEvent) -> Result>> { let mut state = Vec::new(); // Add recommended events if let Some(e) = self.services .state_accessor .room_state_get(&invite_event.room_id, &StateEventType::RoomCreate, "")? { state.push(e.to_stripped_state_event()); } if let Some(e) = self.services .state_accessor .room_state_get(&invite_event.room_id, &StateEventType::RoomJoinRules, "")? { state.push(e.to_stripped_state_event()); } if let Some(e) = self.services.state_accessor.room_state_get( &invite_event.room_id, &StateEventType::RoomCanonicalAlias, "", )? { state.push(e.to_stripped_state_event()); } if let Some(e) = self.services .state_accessor .room_state_get(&invite_event.room_id, &StateEventType::RoomAvatar, "")? { state.push(e.to_stripped_state_event()); } if let Some(e) = self.services .state_accessor .room_state_get(&invite_event.room_id, &StateEventType::RoomName, "")? { state.push(e.to_stripped_state_event()); } if let Some(e) = self.services.state_accessor.room_state_get( &invite_event.room_id, &StateEventType::RoomMember, invite_event.sender.as_str(), )? { state.push(e.to_stripped_state_event()); } state.push(invite_event.to_stripped_state_event()); Ok(state) } /// Set the state hash to a new version, but does not update state_cache. #[tracing::instrument(skip(self, mutex_lock), level = "debug")] pub fn set_room_state( &self, room_id: &RoomId, shortstatehash: u64, mutex_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex ) -> Result<()> { self.db.set_room_state(room_id, shortstatehash, mutex_lock) } /// Returns the room's version. #[tracing::instrument(skip(self), level = "debug")] pub fn get_room_version(&self, room_id: &RoomId) -> Result { let create_event = self .services .state_accessor .room_state_get(room_id, &StateEventType::RoomCreate, "")?; let create_event_content: RoomCreateEventContent = create_event .as_ref() .map(|create_event| { serde_json::from_str(create_event.content.get()).map_err(|e| { warn!("Invalid create event: {}", e); Error::bad_database("Invalid create event in db.") }) }) .transpose()? .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "No create event found"))?; Ok(create_event_content.room_version) } #[inline] pub fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result> { self.db.get_room_shortstatehash(room_id) } pub fn get_forward_extremities(&self, room_id: &RoomId) -> Result>> { self.db.get_forward_extremities(room_id) } pub fn set_forward_extremities( &self, room_id: &RoomId, event_ids: Vec, state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex ) -> Result<()> { self.db .set_forward_extremities(room_id, event_ids, state_lock) } /// This fetches auth events from the current state. #[tracing::instrument(skip(self), level = "debug")] pub fn get_auth_events( &self, room_id: &RoomId, kind: &TimelineEventType, sender: &UserId, state_key: Option<&str>, content: &serde_json::value::RawValue, ) -> Result>> { let Some(shortstatehash) = self.get_room_shortstatehash(room_id)? else { return Ok(HashMap::new()); }; let auth_events = state_res::auth_types_for_event(kind, sender, state_key, content).expect("content is a valid JSON object"); let mut sauthevents = auth_events .into_iter() .filter_map(|(event_type, state_key)| { self.services .short .get_shortstatekey(&event_type.to_string().into(), &state_key) .ok() .flatten() .map(|s| (s, (event_type, state_key))) }) .collect::>(); let full_state = self .services .state_compressor .load_shortstatehash_info(shortstatehash)? .pop() .expect("there is always one layer") .1; Ok(full_state .iter() .filter_map(|compressed| { self.services .state_compressor .parse_compressed_state_event(compressed) .ok() }) .filter_map(|(shortstatekey, event_id)| sauthevents.remove(&shortstatekey).map(|k| (k, event_id))) .filter_map(|(k, event_id)| { self.services .timeline .get_pdu(&event_id) .ok() .flatten() .map(|pdu| (k, pdu)) }) .collect()) } }