refactor to iterator inputs for auth_chain/short batch functions

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2024-11-22 12:25:46 +00:00
commit 3789d60b6a
10 changed files with 76 additions and 71 deletions

View file

@ -1,7 +1,7 @@
use std::{mem::size_of_val, sync::Arc};
use std::{fmt::Debug, mem::size_of_val, sync::Arc};
pub use conduit::pdu::{ShortEventId, ShortId, ShortRoomId};
use conduit::{err, implement, utils, Result};
use conduit::{err, implement, utils, utils::stream::ReadyExt, Result};
use database::{Deserialized, Map};
use futures::{Stream, StreamExt};
use ruma::{events::StateEventType, EventId, RoomId};
@ -51,52 +51,46 @@ impl crate::Service for Service {
#[implement(Service)]
pub async fn get_or_create_shorteventid(&self, event_id: &EventId) -> ShortEventId {
const BUFSIZE: usize = size_of::<ShortEventId>();
if let Ok(shorteventid) = self.get_shorteventid(event_id).await {
return shorteventid;
}
let shorteventid = self.services.globals.next_count().unwrap();
debug_assert!(size_of_val(&shorteventid) == BUFSIZE, "buffer requirement changed");
self.db
.eventid_shorteventid
.raw_aput::<BUFSIZE, _, _>(event_id, shorteventid);
self.db
.shorteventid_eventid
.aput_raw::<BUFSIZE, _, _>(shorteventid, event_id);
shorteventid
self.create_shorteventid(event_id)
}
#[implement(Service)]
pub fn multi_get_or_create_shorteventid<'a>(
&'a self, event_ids: &'a [&EventId],
) -> impl Stream<Item = ShortEventId> + Send + 'a {
pub fn multi_get_or_create_shorteventid<'a, I>(&'a self, event_ids: I) -> impl Stream<Item = ShortEventId> + Send + '_
where
I: Iterator<Item = &'a EventId> + Clone + Debug + ExactSizeIterator + Send + 'a,
<I as Iterator>::Item: AsRef<[u8]> + Send + Sync + 'a,
{
self.db
.eventid_shorteventid
.get_batch(event_ids.iter())
.enumerate()
.map(|(i, result)| match result {
Ok(ref short) => utils::u64_from_u8(short),
Err(_) => {
const BUFSIZE: usize = size_of::<ShortEventId>();
let short = self.services.globals.next_count().unwrap();
debug_assert!(size_of_val(&short) == BUFSIZE, "buffer requirement changed");
self.db
.eventid_shorteventid
.raw_aput::<BUFSIZE, _, _>(event_ids[i], short);
self.db
.shorteventid_eventid
.aput_raw::<BUFSIZE, _, _>(short, event_ids[i]);
short
},
.get_batch(event_ids.clone())
.ready_scan(event_ids, |event_ids, result| {
event_ids.next().map(|event_id| (event_id, result))
})
.map(|(event_id, result)| match result {
Ok(ref short) => utils::u64_from_u8(short),
Err(_) => self.create_shorteventid(event_id),
})
}
#[implement(Service)]
fn create_shorteventid(&self, event_id: &EventId) -> ShortEventId {
const BUFSIZE: usize = size_of::<ShortEventId>();
let short = self.services.globals.next_count().unwrap();
debug_assert!(size_of_val(&short) == BUFSIZE, "buffer requirement changed");
self.db
.eventid_shorteventid
.raw_aput::<BUFSIZE, _, _>(event_id, short);
self.db
.shorteventid_eventid
.aput_raw::<BUFSIZE, _, _>(short, event_id);
short
}
#[implement(Service)]
@ -154,13 +148,13 @@ pub async fn get_eventid_from_short(&self, shorteventid: ShortEventId) -> Result
}
#[implement(Service)]
pub async fn multi_get_eventid_from_short(&self, shorteventid: &[ShortEventId]) -> Vec<Result<Arc<EventId>>> {
pub async fn multi_get_eventid_from_short<I>(&self, shorteventid: I) -> Vec<Result<Arc<EventId>>>
where
I: Iterator<Item = ShortEventId> + Send,
{
const BUFSIZE: usize = size_of::<ShortEventId>();
let keys: Vec<[u8; BUFSIZE]> = shorteventid
.iter()
.map(|short| short.to_be_bytes())
.collect();
let keys: Vec<[u8; BUFSIZE]> = shorteventid.map(u64::to_be_bytes).collect();
self.db
.shorteventid_eventid