parallelize state-res pre-gathering

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2025-01-29 23:07:12 +00:00
commit 3c8376d897
2 changed files with 123 additions and 113 deletions

View file

@ -1,18 +1,20 @@
use std::{
borrow::Borrow,
collections::{HashMap, HashSet},
iter::Iterator,
sync::Arc,
};
use conduwuit::{
debug, err, implement,
result::LogErr,
utils::stream::{BroadbandExt, IterStream},
debug, err, implement, trace,
utils::stream::{BroadbandExt, IterStream, ReadyExt, TryBroadbandExt, TryWidebandExt},
PduEvent, Result,
};
use futures::{FutureExt, StreamExt, TryStreamExt};
use futures::{future::try_join, FutureExt, StreamExt, TryFutureExt, TryStreamExt};
use ruma::{state_res::StateMap, OwnedEventId, RoomId, RoomVersionId};
use crate::rooms::short::ShortStateHash;
// TODO: if we know the prev_events of the incoming event we can avoid the
#[implement(super::Service)]
// request and build the state from a known point and resolve if > 1 prev_event
@ -70,86 +72,44 @@ pub(super) async fn state_at_incoming_resolved(
room_id: &RoomId,
room_version_id: &RoomVersionId,
) -> Result<Option<HashMap<u64, OwnedEventId>>> {
debug!("Calculating state at event using state res");
let mut extremity_sstatehashes = HashMap::with_capacity(incoming_pdu.prev_events.len());
let mut okay = true;
for prev_eventid in &incoming_pdu.prev_events {
let Ok(prev_event) = self.services.timeline.get_pdu(prev_eventid).await else {
okay = false;
break;
};
let Ok(sstatehash) = self
.services
.state_accessor
.pdu_shortstatehash(prev_eventid)
.await
else {
okay = false;
break;
};
extremity_sstatehashes.insert(sstatehash, prev_event);
}
if !okay {
trace!("Calculating extremity statehashes...");
let Ok(extremity_sstatehashes) = incoming_pdu
.prev_events
.iter()
.try_stream()
.broad_and_then(|prev_eventid| {
self.services
.timeline
.get_pdu(prev_eventid)
.map_ok(move |prev_event| (prev_eventid, prev_event))
})
.broad_and_then(|(prev_eventid, prev_event)| {
self.services
.state_accessor
.pdu_shortstatehash(prev_eventid)
.map_ok(move |sstatehash| (sstatehash, prev_event))
})
.try_collect::<HashMap<_, _>>()
.await
else {
return Ok(None);
}
};
let mut fork_states = Vec::with_capacity(extremity_sstatehashes.len());
let mut auth_chain_sets = Vec::with_capacity(extremity_sstatehashes.len());
for (sstatehash, prev_event) in extremity_sstatehashes {
let mut leaf_state: HashMap<_, _> = self
.services
.state_accessor
.state_full_ids(sstatehash)
.collect()
.await;
if let Some(state_key) = &prev_event.state_key {
let shortstatekey = self
.services
.short
.get_or_create_shortstatekey(&prev_event.kind.to_string().into(), state_key)
.await;
let event_id = &prev_event.event_id;
leaf_state.insert(shortstatekey, event_id.clone());
// Now it's the state after the pdu
}
let mut state = StateMap::with_capacity(leaf_state.len());
let mut starting_events = Vec::with_capacity(leaf_state.len());
for (k, id) in &leaf_state {
if let Ok((ty, st_key)) = self
.services
.short
.get_statekey_from_short(*k)
.await
.log_err()
{
// FIXME: Undo .to_string().into() when StateMap
// is updated to use StateEventType
state.insert((ty.to_string().into(), st_key), id.clone());
}
starting_events.push(id.borrow());
}
let auth_chain: HashSet<OwnedEventId> = self
.services
.auth_chain
.event_ids_iter(room_id, starting_events.into_iter())
trace!("Calculating fork states...");
let (fork_states, auth_chain_sets): (Vec<StateMap<_>>, Vec<HashSet<_>>) =
extremity_sstatehashes
.into_iter()
.try_stream()
.wide_and_then(|(sstatehash, prev_event)| {
self.state_at_incoming_fork(room_id, sstatehash, prev_event)
})
.try_collect()
.map_ok(Vec::into_iter)
.map_ok(Iterator::unzip)
.await?;
auth_chain_sets.push(auth_chain);
fork_states.push(state);
}
let Ok(new_state) = self
.state_resolution(room_version_id, &fork_states, &auth_chain_sets)
.state_resolution(room_version_id, fork_states.iter(), &auth_chain_sets)
.boxed()
.await
else {
@ -157,16 +117,65 @@ pub(super) async fn state_at_incoming_resolved(
};
new_state
.iter()
.into_iter()
.stream()
.broad_then(|((event_type, state_key), event_id)| {
.broad_then(|((event_type, state_key), event_id)| async move {
self.services
.short
.get_or_create_shortstatekey(event_type, state_key)
.map(move |shortstatekey| (shortstatekey, event_id.clone()))
.get_or_create_shortstatekey(&event_type, &state_key)
.map(move |shortstatekey| (shortstatekey, event_id))
.await
})
.collect()
.map(Some)
.map(Ok)
.await
}
#[implement(super::Service)]
async fn state_at_incoming_fork(
&self,
room_id: &RoomId,
sstatehash: ShortStateHash,
prev_event: PduEvent,
) -> Result<(StateMap<OwnedEventId>, HashSet<OwnedEventId>)> {
let mut leaf_state: HashMap<_, _> = self
.services
.state_accessor
.state_full_ids(sstatehash)
.collect()
.await;
if let Some(state_key) = &prev_event.state_key {
let shortstatekey = self
.services
.short
.get_or_create_shortstatekey(&prev_event.kind.to_string().into(), state_key)
.await;
let event_id = &prev_event.event_id;
leaf_state.insert(shortstatekey, event_id.clone());
// Now it's the state after the pdu
}
let auth_chain = self
.services
.auth_chain
.event_ids_iter(room_id, leaf_state.values().map(Borrow::borrow))
.try_collect();
let fork_state = leaf_state
.iter()
.stream()
.broad_then(|(k, id)| {
self.services
.short
.get_statekey_from_short(*k)
.map_ok(|(ty, sk)| ((ty, sk), id.clone()))
})
.ready_filter_map(Result::ok)
.collect()
.map(Ok);
try_join(fork_state, auth_chain).await
}