diff --git a/src/api/client/message.rs b/src/api/client/message.rs index f8818ebb..4d489c2f 100644 --- a/src/api/client/message.rs +++ b/src/api/client/message.rs @@ -8,7 +8,7 @@ use conduwuit::{ ref_at, utils::{ IterStream, ReadyExt, - result::{FlatOk, LogErr}, + result::LogErr, stream::{BroadbandExt, TryIgnore, WidebandExt}, }, }; @@ -35,6 +35,7 @@ use ruma::{ }; use tracing::warn; +use super::utils::{count_to_token, parse_pagination_token as parse_token}; use crate::Ruma; /// list of safe and common non-state events to ignore if the user is ignored @@ -84,14 +85,14 @@ pub(crate) async fn get_message_events_route( let from: PduCount = body .from .as_deref() - .map(str::parse) + .map(parse_token) .transpose()? .unwrap_or_else(|| match body.dir { | Direction::Forward => PduCount::min(), | Direction::Backward => PduCount::max(), }); - let to: Option = body.to.as_deref().map(str::parse).flat_ok(); + let to: Option = body.to.as_deref().map(parse_token).transpose()?; let limit: usize = body .limit @@ -180,8 +181,8 @@ pub(crate) async fn get_message_events_route( .collect(); Ok(get_message_events::v3::Response { - start: from.to_string(), - end: next_token.as_ref().map(ToString::to_string), + start: count_to_token(from), + end: next_token.map(count_to_token), chunk, state, }) diff --git a/src/api/client/mod.rs b/src/api/client/mod.rs index be54e65f..e4be20b7 100644 --- a/src/api/client/mod.rs +++ b/src/api/client/mod.rs @@ -36,6 +36,7 @@ pub(super) mod typing; pub(super) mod unstable; pub(super) mod unversioned; pub(super) mod user_directory; +pub(super) mod utils; pub(super) mod voip; pub(super) mod well_known; diff --git a/src/api/client/relations.rs b/src/api/client/relations.rs index 1aa34ada..f6d8fe9e 100644 --- a/src/api/client/relations.rs +++ b/src/api/client/relations.rs @@ -18,6 +18,7 @@ use ruma::{ events::{TimelineEventType, relation::RelationType}, }; +use super::utils::{count_to_token, parse_pagination_token as parse_token}; use crate::Ruma; /// # `GET /_matrix/client/r0/rooms/{roomId}/relations/{eventId}/{relType}/{eventType}` @@ -110,14 +111,14 @@ async fn paginate_relations_with_filter( dir: Direction, ) -> Result { let start: PduCount = from - .map(str::parse) + .map(parse_token) .transpose()? .unwrap_or_else(|| match dir { | Direction::Forward => PduCount::min(), | Direction::Backward => PduCount::max(), }); - let to: Option = to.map(str::parse).flat_ok(); + let to: Option = to.map(parse_token).transpose()?; // Use limit or else 30, with maximum 100 let limit: usize = limit @@ -129,6 +130,11 @@ async fn paginate_relations_with_filter( // Spec (v1.10) recommends depth of at least 3 let depth: u8 = if recurse { 3 } else { 1 }; + // Check if this is a thread request + let is_thread = filter_rel_type + .as_ref() + .is_some_and(|rel| *rel == RelationType::Thread); + let events: Vec<_> = services .rooms .pdu_metadata @@ -152,23 +158,58 @@ async fn paginate_relations_with_filter( .collect() .await; - let next_batch = match dir { - | Direction::Forward => events.last(), - | Direction::Backward => events.first(), + // For threads, check if we should include the root event + let mut root_event = None; + if is_thread && dir == Direction::Backward { + // Check if we've reached the beginning of the thread + // (fewer events than requested means we've exhausted the thread) + if events.len() < limit { + // Try to get the thread root event + if let Ok(root_pdu) = services.rooms.timeline.get_pdu(target).await { + // Check visibility + if services + .rooms + .state_accessor + .user_can_see_event(sender_user, room_id, target) + .await + { + // Store the root event to add to the response + root_event = Some(root_pdu); + } + } + } } - .map(at!(0)) - .as_ref() - .map(ToString::to_string); + + // Determine if there are more events to fetch + let has_more = if root_event.is_some() { + false // We've included the root, no more events + } else { + // Check if we got a full page of results (might be more) + events.len() >= limit + }; + + let next_batch = if has_more { + match dir { + | Direction::Forward => events.last(), + | Direction::Backward => events.first(), + } + .map(|(count, _)| count_to_token(*count)) + } else { + None + }; + + // Build the response chunk with thread root if needed + let chunk: Vec<_> = root_event + .into_iter() + .map(Event::into_format) + .chain(events.into_iter().map(at!(1)).map(Event::into_format)) + .collect(); Ok(get_relating_events::v1::Response { next_batch, prev_batch: from.map(Into::into), recursion_depth: recurse.then_some(depth.into()), - chunk: events - .into_iter() - .map(at!(1)) - .map(Event::into_format) - .collect(), + chunk, }) } diff --git a/src/api/client/session.rs b/src/api/client/session.rs index 992073c6..fe07e41d 100644 --- a/src/api/client/session.rs +++ b/src/api/client/session.rs @@ -198,8 +198,8 @@ pub(crate) async fn login_route( .clone() .unwrap_or_else(|| utils::random_string(DEVICE_ID_LENGTH).into()); - // Generate a new token for the device - let token = utils::random_string(TOKEN_LENGTH); + // Generate a new token for the device (ensuring no collisions) + let token = services.users.generate_unique_token().await; // Determine if device_id was provided and exists in the db for this user let device_exists = if body.device_id.is_some() { diff --git a/src/api/client/utils.rs b/src/api/client/utils.rs new file mode 100644 index 00000000..cc941b95 --- /dev/null +++ b/src/api/client/utils.rs @@ -0,0 +1,28 @@ +use conduwuit::{ + Result, err, + matrix::pdu::{PduCount, ShortEventId}, +}; + +/// Parse a pagination token, trying ShortEventId first, then falling back to +/// PduCount +pub(crate) fn parse_pagination_token(token: &str) -> Result { + // Try parsing as ShortEventId first + if let Ok(shorteventid) = token.parse::() { + // ShortEventId maps directly to a PduCount in our database + Ok(PduCount::Normal(shorteventid)) + } else if let Ok(count) = token.parse::() { + // Fallback to PduCount for backwards compatibility + Ok(PduCount::Normal(count)) + } else if let Ok(count) = token.parse::() { + // Also handle negative counts for backfilled events + Ok(PduCount::from_signed(count)) + } else { + Err(err!(Request(InvalidParam("Invalid pagination token")))) + } +} + +/// Convert a PduCount to a token string (using the underlying ShortEventId) +pub(crate) fn count_to_token(count: PduCount) -> String { + // The PduCount's unsigned value IS the ShortEventId + count.into_unsigned().to_string() +} diff --git a/src/api/router/auth.rs b/src/api/router/auth.rs index 01254c32..44afc3ef 100644 --- a/src/api/router/auth.rs +++ b/src/api/router/auth.rs @@ -5,6 +5,14 @@ use axum_extra::{ typed_header::TypedHeaderRejectionReason, }; use conduwuit::{Err, Error, Result, debug_error, err, warn}; +use futures::{ + TryFutureExt, + future::{ + Either::{Left, Right}, + select_ok, + }, + pin_mut, +}; use ruma::{ CanonicalJsonObject, CanonicalJsonValue, OwnedDeviceId, OwnedServerName, OwnedUserId, UserId, api::{ @@ -54,17 +62,7 @@ pub(super) async fn auth( | None => request.query.access_token.as_deref(), }; - let token = if let Some(token) = token { - match services.appservice.find_from_token(token).await { - | Some(reg_info) => Token::Appservice(Box::new(reg_info)), - | _ => match services.users.find_from_token(token).await { - | Ok((user_id, device_id)) => Token::User((user_id, device_id)), - | _ => Token::Invalid, - }, - } - } else { - Token::None - }; + let token = find_token(services, token).await?; if metadata.authentication == AuthScheme::None { match metadata { @@ -342,3 +340,25 @@ async fn parse_x_matrix(request: &mut Request) -> Result { Ok(x_matrix) } + +async fn find_token(services: &Services, token: Option<&str>) -> Result { + let Some(token) = token else { + return Ok(Token::None); + }; + + let user_token = services.users.find_from_token(token).map_ok(Token::User); + + let appservice_token = services + .appservice + .find_from_token(token) + .map_ok(Box::new) + .map_ok(Token::Appservice); + + pin_mut!(user_token, appservice_token); + // Returns Ok if either token type succeeds, Err only if both fail + match select_ok([Left(user_token), Right(appservice_token)]).await { + | Err(e) if !e.is_not_found() => Err(e), + | Ok((token, _)) => Ok(token), + | _ => Ok(Token::Invalid), + } +} diff --git a/src/service/appservice/mod.rs b/src/service/appservice/mod.rs index 7be8a471..ebd798f6 100644 --- a/src/service/appservice/mod.rs +++ b/src/service/appservice/mod.rs @@ -4,14 +4,14 @@ mod registration_info; use std::{collections::BTreeMap, iter::IntoIterator, sync::Arc}; use async_trait::async_trait; -use conduwuit::{Result, err, utils::stream::IterStream}; +use conduwuit::{Err, Result, err, utils::stream::IterStream}; use database::Map; use futures::{Future, FutureExt, Stream, TryStreamExt}; use ruma::{RoomAliasId, RoomId, UserId, api::appservice::Registration}; use tokio::sync::{RwLock, RwLockReadGuard}; pub use self::{namespace_regex::NamespaceRegex, registration_info::RegistrationInfo}; -use crate::{Dep, sending}; +use crate::{Dep, globals, sending, users}; pub struct Service { registration_info: RwLock, @@ -20,7 +20,9 @@ pub struct Service { } struct Services { + globals: Dep, sending: Dep, + users: Dep, } struct Data { @@ -35,7 +37,9 @@ impl crate::Service for Service { Ok(Arc::new(Self { registration_info: RwLock::new(BTreeMap::new()), services: Services { + globals: args.depend::("globals"), sending: args.depend::("sending"), + users: args.depend::("users"), }, db: Data { id_appserviceregistrations: args.db["id_appserviceregistrations"].clone(), @@ -44,23 +48,89 @@ impl crate::Service for Service { } async fn worker(self: Arc) -> Result { - // Inserting registrations into cache - self.iter_db_ids() - .try_for_each(async |appservice| { - self.registration_info - .write() - .await - .insert(appservice.0, appservice.1.try_into()?); + // First, collect all appservices to check for token conflicts + let appservices: Vec<(String, Registration)> = self.iter_db_ids().try_collect().await?; - Ok(()) - }) - .await + // Check for appservice-to-appservice token conflicts + for i in 0..appservices.len() { + for j in i.saturating_add(1)..appservices.len() { + if appservices[i].1.as_token == appservices[j].1.as_token { + return Err!(Database(error!( + "Token collision detected: Appservices '{}' and '{}' have the same token", + appservices[i].0, appservices[j].0 + ))); + } + } + } + + // Process each appservice + for (id, registration) in appservices { + // During startup, resolve any token collisions in favour of appservices + // by logging out conflicting user devices + if let Ok((user_id, device_id)) = self + .services + .users + .find_from_token(®istration.as_token) + .await + { + conduwuit::warn!( + "Token collision detected during startup: Appservice '{}' token was also \ + used by user '{}' device '{}'. Logging out the user device to resolve \ + conflict.", + id, + user_id.localpart(), + device_id + ); + + self.services + .users + .remove_device(&user_id, &device_id) + .await; + } + + self.start_appservice(id, registration).await?; + } + + Ok(()) } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } impl Service { + /// Starts an appservice, ensuring its sender_localpart user exists and is + /// active. Creates the user if it doesn't exist, or reactivates it if it + /// was deactivated. Then registers the appservice in memory for request + /// handling. + async fn start_appservice(&self, id: String, registration: Registration) -> Result { + let appservice_user_id = UserId::parse_with_server_name( + registration.sender_localpart.as_str(), + self.services.globals.server_name(), + )?; + + if !self.services.users.exists(&appservice_user_id).await { + self.services.users.create(&appservice_user_id, None)?; + } else if self + .services + .users + .is_deactivated(&appservice_user_id) + .await + .unwrap_or(false) + { + // Reactivate the appservice user if it was accidentally deactivated + self.services + .users + .set_password(&appservice_user_id, None)?; + } + + self.registration_info + .write() + .await + .insert(id, registration.try_into()?); + + Ok(()) + } + /// Registers an appservice and returns the ID to the caller pub async fn register_appservice( &self, @@ -68,15 +138,40 @@ impl Service { appservice_config_body: &str, ) -> Result { //TODO: Check for collisions between exclusive appservice namespaces - self.registration_info - .write() + + // Check for token collision with other appservices (allow re-registration of + // same appservice) + if let Ok(existing) = self.find_from_token(®istration.as_token).await { + if existing.registration.id != registration.id { + return Err(err!(Request(InvalidParam( + "Cannot register appservice: Token is already used by appservice '{}'. \ + Please generate a different token.", + existing.registration.id + )))); + } + } + + // Prevent token collision with existing user tokens + if self + .services + .users + .find_from_token(®istration.as_token) .await - .insert(registration.id.clone(), registration.clone().try_into()?); + .is_ok() + { + return Err(err!(Request(InvalidParam( + "Cannot register appservice: The provided token is already in use by a user \ + device. Please generate a different token for the appservice." + )))); + } self.db .id_appserviceregistrations .insert(®istration.id, appservice_config_body); + self.start_appservice(registration.id.clone(), registration.clone()) + .await?; + Ok(()) } @@ -113,12 +208,14 @@ impl Service { .map(|info| info.registration) } - pub async fn find_from_token(&self, token: &str) -> Option { + /// Returns Result to match users::find_from_token for select_ok usage + pub async fn find_from_token(&self, token: &str) -> Result { self.read() .await .values() .find(|info| info.registration.as_token == token) .cloned() + .ok_or_else(|| err!(Request(NotFound("Appservice token not found")))) } /// Checks if a given user id matches any exclusive appservice regex diff --git a/src/service/rooms/pdu_metadata/data.rs b/src/service/rooms/pdu_metadata/data.rs index c1376cb0..a746b4cc 100644 --- a/src/service/rooms/pdu_metadata/data.rs +++ b/src/service/rooms/pdu_metadata/data.rs @@ -61,9 +61,12 @@ impl Data { from: PduCount, dir: Direction, ) -> impl Stream + Send + '_ { + // Query from exact position then filter excludes it (saturating_inc could skip + // events at min/max boundaries) + let from_unsigned = from.into_unsigned(); let mut current = ArrayVec::::new(); current.extend(target.to_be_bytes()); - current.extend(from.saturating_inc(dir).into_unsigned().to_be_bytes()); + current.extend(from_unsigned.to_be_bytes()); let current = current.as_slice(); match dir { | Direction::Forward => self.tofrom_relation.raw_keys_from(current).boxed(), @@ -73,6 +76,17 @@ impl Data { .ready_take_while(move |key| key.starts_with(&target.to_be_bytes())) .map(|to_from| u64_from_u8(&to_from[8..16])) .map(PduCount::from_unsigned) + .ready_filter(move |count| { + if from == PduCount::min() || from == PduCount::max() { + true + } else { + let count_unsigned = count.into_unsigned(); + match dir { + | Direction::Forward => count_unsigned > from_unsigned, + | Direction::Backward => count_unsigned < from_unsigned, + } + } + }) .wide_filter_map(move |shorteventid| async move { let pdu_id: RawPduId = PduId { shortroomid, shorteventid }.into(); diff --git a/src/service/users/mod.rs b/src/service/users/mod.rs index d2dfccd9..eb54660e 100644 --- a/src/service/users/mod.rs +++ b/src/service/users/mod.rs @@ -19,7 +19,7 @@ use ruma::{ use serde::{Deserialize, Serialize}; use serde_json::json; -use crate::{Dep, account_data, admin, globals, rooms}; +use crate::{Dep, account_data, admin, appservice, globals, rooms}; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct UserSuspension { @@ -40,6 +40,7 @@ struct Services { server: Arc, account_data: Dep, admin: Dep, + appservice: Dep, globals: Dep, state_accessor: Dep, state_cache: Dep, @@ -76,6 +77,7 @@ impl crate::Service for Service { server: args.server.clone(), account_data: args.depend::("account_data"), admin: args.depend::("admin"), + appservice: args.depend::("appservice"), globals: args.depend::("globals"), state_accessor: args .depend::("rooms::state_accessor"), @@ -391,6 +393,31 @@ impl Service { self.db.userdeviceid_token.qry(&key).await.deserialized() } + /// Generate a unique access token that doesn't collide with existing tokens + pub async fn generate_unique_token(&self) -> String { + loop { + let token = utils::random_string(32); + + // Check for collision with appservice tokens + if self + .services + .appservice + .find_from_token(&token) + .await + .is_ok() + { + continue; + } + + // Check for collision with user tokens + if self.db.token_userdeviceid.get(&token).await.is_ok() { + continue; + } + + return token; + } + } + /// Replaces the access token of one device. pub async fn set_token( &self, @@ -407,6 +434,19 @@ impl Service { ))); } + // Check for token collision with appservices + if self + .services + .appservice + .find_from_token(token) + .await + .is_ok() + { + return Err!(Request(InvalidParam( + "Token conflicts with an existing appservice token" + ))); + } + // Remove old token if let Ok(old_token) = self.db.userdeviceid_token.qry(&key).await { self.db.token_userdeviceid.remove(&old_token);