diff --git a/src/service/rooms/state_compressor/mod.rs b/src/service/rooms/state_compressor/mod.rs index 56a91d0e..a33fb342 100644 --- a/src/service/rooms/state_compressor/mod.rs +++ b/src/service/rooms/state_compressor/mod.rs @@ -9,7 +9,7 @@ use async_trait::async_trait; use conduwuit::{ Result, arrayvec::ArrayVec, - at, checked, err, expected, utils, + at, checked, err, expected, implement, utils, utils::{bytes, math::usize_from_f64, stream::IterStream}, }; use database::Map; @@ -115,29 +115,30 @@ impl crate::Service for Service { fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } -impl Service { - /// Returns a stack with info on shortstatehash, full state, added diff and - /// removed diff for the selected shortstatehash and each parent layer. - #[tracing::instrument(name = "load", level = "debug", skip(self))] - pub async fn load_shortstatehash_info( - &self, - shortstatehash: ShortStateHash, - ) -> Result { - if let Some(r) = self.stateinfo_cache.lock()?.get_mut(&shortstatehash) { - return Ok(r.clone()); - } - - let stack = self.new_shortstatehash_info(shortstatehash).await?; - - self.cache_shortstatehash_info(shortstatehash, stack.clone()) - .await?; - - Ok(stack) +/// Returns a stack with info on shortstatehash, full state, added diff and +/// removed diff for the selected shortstatehash and each parent layer. +#[implement(Service)] +#[tracing::instrument(name = "load", level = "debug", skip(self))] +pub async fn load_shortstatehash_info( + &self, + shortstatehash: ShortStateHash, +) -> Result { + if let Some(r) = self.stateinfo_cache.lock()?.get_mut(&shortstatehash) { + return Ok(r.clone()); } - /// Returns a stack with info on shortstatehash, full state, added diff and - /// removed diff for the selected shortstatehash and each parent layer. - #[tracing::instrument( + let stack = self.new_shortstatehash_info(shortstatehash).await?; + + self.cache_shortstatehash_info(shortstatehash, stack.clone()) + .await?; + + Ok(stack) +} + +/// Returns a stack with info on shortstatehash, full state, added diff and +/// removed diff for the selected shortstatehash and each parent layer. +#[implement(Service)] +#[tracing::instrument( name = "cache", level = "debug", skip_all, @@ -146,362 +147,365 @@ impl Service { stack = stack.len(), ), )] - async fn cache_shortstatehash_info( - &self, - shortstatehash: ShortStateHash, - stack: ShortStateInfoVec, - ) -> Result { - self.stateinfo_cache.lock()?.insert(shortstatehash, stack); +async fn cache_shortstatehash_info( + &self, + shortstatehash: ShortStateHash, + stack: ShortStateInfoVec, +) -> Result { + self.stateinfo_cache.lock()?.insert(shortstatehash, stack); - Ok(()) - } + Ok(()) +} - async fn new_shortstatehash_info( - &self, - shortstatehash: ShortStateHash, - ) -> Result { - let StateDiff { parent, added, removed } = self.get_statediff(shortstatehash).await?; +#[implement(Service)] +async fn new_shortstatehash_info( + &self, + shortstatehash: ShortStateHash, +) -> Result { + let StateDiff { parent, added, removed } = self.get_statediff(shortstatehash).await?; - let Some(parent) = parent else { - return Ok(vec![ShortStateInfo { - shortstatehash, - full_state: added.clone(), - added, - removed, - }]); - }; - - let mut stack = Box::pin(self.load_shortstatehash_info(parent)).await?; - let top = stack.last().expect("at least one frame"); - - let mut full_state = (*top.full_state).clone(); - full_state.extend(added.iter().copied()); - - let removed = (*removed).clone(); - for r in &removed { - full_state.remove(r); - } - - stack.push(ShortStateInfo { + let Some(parent) = parent else { + return Ok(vec![ShortStateInfo { shortstatehash, + full_state: added.clone(), added, - removed: Arc::new(removed), - full_state: Arc::new(full_state), - }); + removed, + }]); + }; - Ok(stack) + let mut stack = Box::pin(self.load_shortstatehash_info(parent)).await?; + let top = stack.last().expect("at least one frame"); + + let mut full_state = (*top.full_state).clone(); + full_state.extend(added.iter().copied()); + + let removed = (*removed).clone(); + for r in &removed { + full_state.remove(r); } - pub fn compress_state_events<'a, I>( - &'a self, - state: I, - ) -> impl Stream + Send + 'a - where - I: Iterator + Clone + Debug + Send + 'a, - { - let event_ids = state.clone().map(at!(1)); + stack.push(ShortStateInfo { + shortstatehash, + added, + removed: Arc::new(removed), + full_state: Arc::new(full_state), + }); - let short_event_ids = self - .services - .short - .multi_get_or_create_shorteventid(event_ids); + Ok(stack) +} - state - .stream() - .map(at!(0)) - .zip(short_event_ids) - .map(|(shortstatekey, shorteventid)| { - compress_state_event(*shortstatekey, shorteventid) - }) - } +#[implement(Service)] +pub fn compress_state_events<'a, I>( + &'a self, + state: I, +) -> impl Stream + Send + 'a +where + I: Iterator + Clone + Debug + Send + 'a, +{ + let event_ids = state.clone().map(at!(1)); - pub async fn compress_state_event( - &self, - shortstatekey: ShortStateKey, - event_id: &EventId, - ) -> CompressedStateEvent { - let shorteventid = self - .services - .short - .get_or_create_shorteventid(event_id) - .await; + let short_event_ids = self + .services + .short + .multi_get_or_create_shorteventid(event_ids); - compress_state_event(shortstatekey, shorteventid) - } + state + .stream() + .map(at!(0)) + .zip(short_event_ids) + .map(|(shortstatekey, shorteventid)| compress_state_event(*shortstatekey, shorteventid)) +} - /// Creates a new shortstatehash that often is just a diff to an already - /// existing shortstatehash and therefore very efficient. - /// - /// There are multiple layers of diffs. The bottom layer 0 always contains - /// the full state. Layer 1 contains diffs to states of layer 0, layer 2 - /// diffs to layer 1 and so on. If layer n > 0 grows too big, it will be - /// combined with layer n-1 to create a new diff on layer n-1 that's - /// based on layer n-2. If that layer is also too big, it will recursively - /// fix above layers too. - /// - /// * `shortstatehash` - Shortstatehash of this state - /// * `statediffnew` - Added to base. Each vec is shortstatekey+shorteventid - /// * `statediffremoved` - Removed from base. Each vec is - /// shortstatekey+shorteventid - /// * `diff_to_sibling` - Approximately how much the diff grows each time - /// for this layer - /// * `parent_states` - A stack with info on shortstatehash, full state, - /// added diff and removed diff for each parent layer - pub fn save_state_from_diff( - &self, - shortstatehash: ShortStateHash, - statediffnew: Arc, - statediffremoved: Arc, - diff_to_sibling: usize, - mut parent_states: ParentStatesVec, - ) -> Result { - let statediffnew_len = statediffnew.len(); - let statediffremoved_len = statediffremoved.len(); - let diffsum = checked!(statediffnew_len + statediffremoved_len)?; +#[implement(Service)] +pub async fn compress_state_event( + &self, + shortstatekey: ShortStateKey, + event_id: &EventId, +) -> CompressedStateEvent { + let shorteventid = self + .services + .short + .get_or_create_shorteventid(event_id) + .await; - if parent_states.len() > 3 { - // Number of layers - // To many layers, we have to go deeper - let parent = parent_states.pop().expect("parent must have a state"); + compress_state_event(shortstatekey, shorteventid) +} - let mut parent_new = (*parent.added).clone(); - let mut parent_removed = (*parent.removed).clone(); - - for removed in statediffremoved.iter() { - if !parent_new.remove(removed) { - // It was not added in the parent and we removed it - parent_removed.insert(*removed); - } - // Else it was added in the parent and we removed it again. We - // can forget this change - } - - for new in statediffnew.iter() { - if !parent_removed.remove(new) { - // It was not touched in the parent and we added it - parent_new.insert(*new); - } - // Else it was removed in the parent and we added it again. We - // can forget this change - } - - self.save_state_from_diff( - shortstatehash, - Arc::new(parent_new), - Arc::new(parent_removed), - diffsum, - parent_states, - )?; - - return Ok(()); - } - - if parent_states.is_empty() { - // There is no parent layer, create a new state - self.save_statediff(shortstatehash, &StateDiff { - parent: None, - added: statediffnew, - removed: statediffremoved, - }); - - return Ok(()); - } - - // Else we have two options. - // 1. We add the current diff on top of the parent layer. - // 2. We replace a layer above +/// Creates a new shortstatehash that often is just a diff to an already +/// existing shortstatehash and therefore very efficient. +/// +/// There are multiple layers of diffs. The bottom layer 0 always contains +/// the full state. Layer 1 contains diffs to states of layer 0, layer 2 +/// diffs to layer 1 and so on. If layer n > 0 grows too big, it will be +/// combined with layer n-1 to create a new diff on layer n-1 that's +/// based on layer n-2. If that layer is also too big, it will recursively +/// fix above layers too. +/// +/// * `shortstatehash` - Shortstatehash of this state +/// * `statediffnew` - Added to base. Each vec is shortstatekey+shorteventid +/// * `statediffremoved` - Removed from base. Each vec is +/// shortstatekey+shorteventid +/// * `diff_to_sibling` - Approximately how much the diff grows each time for +/// this layer +/// * `parent_states` - A stack with info on shortstatehash, full state, added +/// diff and removed diff for each parent layer +#[implement(Service)] +pub fn save_state_from_diff( + &self, + shortstatehash: ShortStateHash, + statediffnew: Arc, + statediffremoved: Arc, + diff_to_sibling: usize, + mut parent_states: ParentStatesVec, +) -> Result { + let statediffnew_len = statediffnew.len(); + let statediffremoved_len = statediffremoved.len(); + let diffsum = checked!(statediffnew_len + statediffremoved_len)?; + if parent_states.len() > 3 { + // Number of layers + // To many layers, we have to go deeper let parent = parent_states.pop().expect("parent must have a state"); - let parent_added_len = parent.added.len(); - let parent_removed_len = parent.removed.len(); - let parent_diff = checked!(parent_added_len + parent_removed_len)?; - if checked!(diffsum * diffsum)? >= checked!(2 * diff_to_sibling * parent_diff)? { - // Diff too big, we replace above layer(s) - let mut parent_new = (*parent.added).clone(); - let mut parent_removed = (*parent.removed).clone(); + let mut parent_new = (*parent.added).clone(); + let mut parent_removed = (*parent.removed).clone(); - for removed in statediffremoved.iter() { - if !parent_new.remove(removed) { - // It was not added in the parent and we removed it - parent_removed.insert(*removed); - } - // Else it was added in the parent and we removed it again. We - // can forget this change + for removed in statediffremoved.iter() { + if !parent_new.remove(removed) { + // It was not added in the parent and we removed it + parent_removed.insert(*removed); } - - for new in statediffnew.iter() { - if !parent_removed.remove(new) { - // It was not touched in the parent and we added it - parent_new.insert(*new); - } - // Else it was removed in the parent and we added it again. We - // can forget this change - } - - self.save_state_from_diff( - shortstatehash, - Arc::new(parent_new), - Arc::new(parent_removed), - diffsum, - parent_states, - )?; - } else { - // Diff small enough, we add diff as layer on top of parent - self.save_statediff(shortstatehash, &StateDiff { - parent: Some(parent.shortstatehash), - added: statediffnew, - removed: statediffremoved, - }); + // Else it was added in the parent and we removed it again. We + // can forget this change } - Ok(()) + for new in statediffnew.iter() { + if !parent_removed.remove(new) { + // It was not touched in the parent and we added it + parent_new.insert(*new); + } + // Else it was removed in the parent and we added it again. We + // can forget this change + } + + self.save_state_from_diff( + shortstatehash, + Arc::new(parent_new), + Arc::new(parent_removed), + diffsum, + parent_states, + )?; + + return Ok(()); } - /// Returns the new shortstatehash, and the state diff from the previous - /// room state - #[tracing::instrument(skip(self, new_state_ids_compressed), level = "debug")] - pub async fn save_state( - &self, - room_id: &RoomId, - new_state_ids_compressed: Arc, - ) -> Result { - let previous_shortstatehash = self - .services - .state - .get_room_shortstatehash(room_id) - .await - .ok(); - - let state_hash = - utils::calculate_hash(new_state_ids_compressed.iter().map(|bytes| &bytes[..])); - - let (new_shortstatehash, already_existed) = self - .services - .short - .get_or_create_shortstatehash(&state_hash) - .await; - - if Some(new_shortstatehash) == previous_shortstatehash { - return Ok(HashSetCompressStateEvent { - shortstatehash: new_shortstatehash, - ..Default::default() - }); - } - - let states_parents = if let Some(p) = previous_shortstatehash { - self.load_shortstatehash_info(p).await.unwrap_or_default() - } else { - ShortStateInfoVec::new() - }; - - let (statediffnew, statediffremoved) = - if let Some(parent_stateinfo) = states_parents.last() { - let statediffnew: CompressedState = new_state_ids_compressed - .difference(&parent_stateinfo.full_state) - .copied() - .collect(); - - let statediffremoved: CompressedState = parent_stateinfo - .full_state - .difference(&new_state_ids_compressed) - .copied() - .collect(); - - (Arc::new(statediffnew), Arc::new(statediffremoved)) - } else { - (new_state_ids_compressed, Arc::new(CompressedState::new())) - }; - - if !already_existed { - self.save_state_from_diff( - new_shortstatehash, - statediffnew.clone(), - statediffremoved.clone(), - 2, // every state change is 2 event changes on average - states_parents, - )?; - } - - Ok(HashSetCompressStateEvent { - shortstatehash: new_shortstatehash, + if parent_states.is_empty() { + // There is no parent layer, create a new state + self.save_statediff(shortstatehash, &StateDiff { + parent: None, added: statediffnew, removed: statediffremoved, - }) + }); + + return Ok(()); } - #[tracing::instrument(skip(self), level = "debug", name = "get")] - async fn get_statediff(&self, shortstatehash: ShortStateHash) -> Result { - const BUFSIZE: usize = size_of::(); - const STRIDE: usize = size_of::(); + // Else we have two options. + // 1. We add the current diff on top of the parent layer. + // 2. We replace a layer above - let value = self - .db - .shortstatehash_statediff - .aqry::(&shortstatehash) - .await - .map_err(|e| { - err!(Database("Failed to find StateDiff from short {shortstatehash:?}: {e}")) - })?; + let parent = parent_states.pop().expect("parent must have a state"); + let parent_added_len = parent.added.len(); + let parent_removed_len = parent.removed.len(); + let parent_diff = checked!(parent_added_len + parent_removed_len)?; - let parent = utils::u64_from_bytes(&value[0..size_of::()]) - .ok() - .take_if(|parent| *parent != 0); + if checked!(diffsum * diffsum)? >= checked!(2 * diff_to_sibling * parent_diff)? { + // Diff too big, we replace above layer(s) + let mut parent_new = (*parent.added).clone(); + let mut parent_removed = (*parent.removed).clone(); - debug_assert!(value.len() % STRIDE == 0, "value not aligned to stride"); - let _num_values = value.len() / STRIDE; - - let mut add_mode = true; - let mut added = CompressedState::new(); - let mut removed = CompressedState::new(); - - let mut i = STRIDE; - while let Some(v) = value.get(i..expected!(i + 2 * STRIDE)) { - if add_mode && v.starts_with(&0_u64.to_be_bytes()) { - add_mode = false; - i = expected!(i + STRIDE); - continue; + for removed in statediffremoved.iter() { + if !parent_new.remove(removed) { + // It was not added in the parent and we removed it + parent_removed.insert(*removed); } - if add_mode { - added.insert(v.try_into()?); - } else { - removed.insert(v.try_into()?); - } - i = expected!(i + 2 * STRIDE); + // Else it was added in the parent and we removed it again. We + // can forget this change } - Ok(StateDiff { - parent, - added: Arc::new(added), - removed: Arc::new(removed), - }) + for new in statediffnew.iter() { + if !parent_removed.remove(new) { + // It was not touched in the parent and we added it + parent_new.insert(*new); + } + // Else it was removed in the parent and we added it again. We + // can forget this change + } + + self.save_state_from_diff( + shortstatehash, + Arc::new(parent_new), + Arc::new(parent_removed), + diffsum, + parent_states, + )?; + } else { + // Diff small enough, we add diff as layer on top of parent + self.save_statediff(shortstatehash, &StateDiff { + parent: Some(parent.shortstatehash), + added: statediffnew, + removed: statediffremoved, + }); } - fn save_statediff(&self, shortstatehash: ShortStateHash, diff: &StateDiff) { - let mut value = Vec::::with_capacity( - 2_usize - .saturating_add(diff.added.len()) - .saturating_add(diff.removed.len()), - ); + Ok(()) +} - let parent = diff.parent.unwrap_or(0_u64); - value.extend_from_slice(&parent.to_be_bytes()); +/// Returns the new shortstatehash, and the state diff from the previous +/// room state +#[implement(Service)] +#[tracing::instrument(skip(self, new_state_ids_compressed), level = "debug")] +pub async fn save_state( + &self, + room_id: &RoomId, + new_state_ids_compressed: Arc, +) -> Result { + let previous_shortstatehash = self + .services + .state + .get_room_shortstatehash(room_id) + .await + .ok(); - for new in diff.added.iter() { - value.extend_from_slice(&new[..]); - } + let state_hash = + utils::calculate_hash(new_state_ids_compressed.iter().map(|bytes| &bytes[..])); - if !diff.removed.is_empty() { - value.extend_from_slice(&0_u64.to_be_bytes()); - for removed in diff.removed.iter() { - value.extend_from_slice(&removed[..]); - } - } + let (new_shortstatehash, already_existed) = self + .services + .short + .get_or_create_shortstatehash(&state_hash) + .await; - self.db - .shortstatehash_statediff - .insert(&shortstatehash.to_be_bytes(), &value); + if Some(new_shortstatehash) == previous_shortstatehash { + return Ok(HashSetCompressStateEvent { + shortstatehash: new_shortstatehash, + ..Default::default() + }); } + + let states_parents = if let Some(p) = previous_shortstatehash { + self.load_shortstatehash_info(p).await.unwrap_or_default() + } else { + ShortStateInfoVec::new() + }; + + let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() { + let statediffnew: CompressedState = new_state_ids_compressed + .difference(&parent_stateinfo.full_state) + .copied() + .collect(); + + let statediffremoved: CompressedState = parent_stateinfo + .full_state + .difference(&new_state_ids_compressed) + .copied() + .collect(); + + (Arc::new(statediffnew), Arc::new(statediffremoved)) + } else { + (new_state_ids_compressed, Arc::new(CompressedState::new())) + }; + + if !already_existed { + self.save_state_from_diff( + new_shortstatehash, + statediffnew.clone(), + statediffremoved.clone(), + 2, // every state change is 2 event changes on average + states_parents, + )?; + } + + Ok(HashSetCompressStateEvent { + shortstatehash: new_shortstatehash, + added: statediffnew, + removed: statediffremoved, + }) +} + +#[implement(Service)] +#[tracing::instrument(skip(self), level = "debug", name = "get")] +async fn get_statediff(&self, shortstatehash: ShortStateHash) -> Result { + const BUFSIZE: usize = size_of::(); + const STRIDE: usize = size_of::(); + + let value = self + .db + .shortstatehash_statediff + .aqry::(&shortstatehash) + .await + .map_err(|e| { + err!(Database("Failed to find StateDiff from short {shortstatehash:?}: {e}")) + })?; + + let parent = utils::u64_from_bytes(&value[0..size_of::()]) + .ok() + .take_if(|parent| *parent != 0); + + debug_assert!(value.len() % STRIDE == 0, "value not aligned to stride"); + let _num_values = value.len() / STRIDE; + + let mut add_mode = true; + let mut added = CompressedState::new(); + let mut removed = CompressedState::new(); + + let mut i = STRIDE; + while let Some(v) = value.get(i..expected!(i + 2 * STRIDE)) { + if add_mode && v.starts_with(&0_u64.to_be_bytes()) { + add_mode = false; + i = expected!(i + STRIDE); + continue; + } + if add_mode { + added.insert(v.try_into()?); + } else { + removed.insert(v.try_into()?); + } + i = expected!(i + 2 * STRIDE); + } + + Ok(StateDiff { + parent, + added: Arc::new(added), + removed: Arc::new(removed), + }) +} + +#[implement(Service)] +fn save_statediff(&self, shortstatehash: ShortStateHash, diff: &StateDiff) { + let mut value = Vec::::with_capacity( + 2_usize + .saturating_add(diff.added.len()) + .saturating_add(diff.removed.len()), + ); + + let parent = diff.parent.unwrap_or(0_u64); + value.extend_from_slice(&parent.to_be_bytes()); + + for new in diff.added.iter() { + value.extend_from_slice(&new[..]); + } + + if !diff.removed.is_empty() { + value.extend_from_slice(&0_u64.to_be_bytes()); + for removed in diff.removed.iter() { + value.extend_from_slice(&removed[..]); + } + } + + self.db + .shortstatehash_statediff + .insert(&shortstatehash.to_be_bytes(), &value); } #[inline]