From e820551f6290bcaef46a544fe1eb6a11476df41c Mon Sep 17 00:00:00 2001 From: Tom Foster Date: Sun, 10 Aug 2025 14:37:26 +0100 Subject: [PATCH 1/4] fix(appservice): create sender_localpart user during appservice startup Fixes #813: Application services were unable to work because their sender_localpart user was never created in the database, preventing authentication. This fix ensures the appservice user account is created when: - The server starts up and loads existing appservices from the database - A new appservice is registered via the admin command Additionally, if an appservice user has been accidentally deactivated, it will be automatically reactivated when the appservice starts. The solution centralises all appservice startup logic into a single `start_appservice` helper method, eliminating code duplication between the registration and startup paths. --- src/service/appservice/mod.rs | 89 +++++++++++++++++++++++++++++++---- 1 file changed, 79 insertions(+), 10 deletions(-) diff --git a/src/service/appservice/mod.rs b/src/service/appservice/mod.rs index 7be8a471..de6fcfac 100644 --- a/src/service/appservice/mod.rs +++ b/src/service/appservice/mod.rs @@ -11,7 +11,7 @@ use ruma::{RoomAliasId, RoomId, UserId, api::appservice::Registration}; use tokio::sync::{RwLock, RwLockReadGuard}; pub use self::{namespace_regex::NamespaceRegex, registration_info::RegistrationInfo}; -use crate::{Dep, sending}; +use crate::{Dep, globals, sending, users}; pub struct Service { registration_info: RwLock, @@ -20,7 +20,9 @@ pub struct Service { } struct Services { + globals: Dep, sending: Dep, + users: Dep, } struct Data { @@ -35,7 +37,9 @@ impl crate::Service for Service { Ok(Arc::new(Self { registration_info: RwLock::new(BTreeMap::new()), services: Services { + globals: args.depend::("globals"), sending: args.depend::("sending"), + users: args.depend::("users"), }, db: Data { id_appserviceregistrations: args.db["id_appserviceregistrations"].clone(), @@ -44,15 +48,34 @@ impl crate::Service for Service { } async fn worker(self: Arc) -> Result { - // Inserting registrations into cache self.iter_db_ids() .try_for_each(async |appservice| { - self.registration_info - .write() - .await - .insert(appservice.0, appservice.1.try_into()?); + let (id, registration) = appservice; - Ok(()) + // During startup, resolve any token collisions in favour of appservices + // by logging out conflicting user devices + if let Ok((user_id, device_id)) = self + .services + .users + .find_from_token(®istration.as_token) + .await + { + conduwuit::warn!( + "Token collision detected during startup: Appservice '{}' token was \ + also used by user '{}' device '{}'. Logging out the user device to \ + resolve conflict.", + id, + user_id.localpart(), + device_id + ); + + self.services + .users + .remove_device(&user_id, &device_id) + .await; + } + + self.start_appservice(id, registration).await }) .await } @@ -61,6 +84,39 @@ impl crate::Service for Service { } impl Service { + /// Starts an appservice, ensuring its sender_localpart user exists and is + /// active. Creates the user if it doesn't exist, or reactivates it if it + /// was deactivated. Then registers the appservice in memory for request + /// handling. + async fn start_appservice(&self, id: String, registration: Registration) -> Result { + let appservice_user_id = UserId::parse_with_server_name( + registration.sender_localpart.as_str(), + self.services.globals.server_name(), + )?; + + if !self.services.users.exists(&appservice_user_id).await { + self.services.users.create(&appservice_user_id, None)?; + } else if self + .services + .users + .is_deactivated(&appservice_user_id) + .await + .unwrap_or(false) + { + // Reactivate the appservice user if it was accidentally deactivated + self.services + .users + .set_password(&appservice_user_id, None)?; + } + + self.registration_info + .write() + .await + .insert(id, registration.try_into()?); + + Ok(()) + } + /// Registers an appservice and returns the ID to the caller pub async fn register_appservice( &self, @@ -68,15 +124,28 @@ impl Service { appservice_config_body: &str, ) -> Result { //TODO: Check for collisions between exclusive appservice namespaces - self.registration_info - .write() + + // Prevent token collision with existing user tokens + if self + .services + .users + .find_from_token(®istration.as_token) .await - .insert(registration.id.clone(), registration.clone().try_into()?); + .is_ok() + { + return err!(Request(InvalidParam( + "Cannot register appservice: The provided token is already in use by a user \ + device. Please generate a different token for the appservice." + ))); + } self.db .id_appserviceregistrations .insert(®istration.id, appservice_config_body); + self.start_appservice(registration.id.clone(), registration.clone()) + .await?; + Ok(()) } From d1ebcfaf0bdcc039df3e484acf6eb2428d81ba26 Mon Sep 17 00:00:00 2001 From: Tom Foster Date: Sun, 10 Aug 2025 14:38:54 +0100 Subject: [PATCH 2/4] fix(auth): prevent token collisions and optimise lookups Ensures access tokens are unique across both user and appservice tables to prevent authentication ambiguity and potential security issues. Changes: - On startup, automatically logout any user devices using tokens that conflict with appservice tokens (resolves in favour of appservices) and log a warning with affected user/device details - When creating new user tokens, check for conflicts with appservice tokens and generate a new token if a collision would occur - When registering new appservices, reject registration if the token is already in use by a user device - Use futures::select_ok to race token lookups concurrently for better performance (adapted from tuwunel commit 066097a8) This fix-forward approach resolves existing token collisions on startup whilst preventing new ones from being created, without breaking existing valid authentications. The find_token optimisation is adapted from tuwunel (matrix-construct/tuwunel) commit 066097a8: "Optimize user and appservice token queries" by Jason Volk. --- src/api/router/auth.rs | 41 +++++++++++++++++++++++++---------- src/service/appservice/mod.rs | 7 +++--- src/service/users/mod.rs | 28 +++++++++++++++++++++--- 3 files changed, 59 insertions(+), 17 deletions(-) diff --git a/src/api/router/auth.rs b/src/api/router/auth.rs index 01254c32..5088e699 100644 --- a/src/api/router/auth.rs +++ b/src/api/router/auth.rs @@ -5,6 +5,14 @@ use axum_extra::{ typed_header::TypedHeaderRejectionReason, }; use conduwuit::{Err, Error, Result, debug_error, err, warn}; +use futures::{ + TryFutureExt, + future::{ + Either::{Left, Right}, + select_ok, + }, + pin_mut, +}; use ruma::{ CanonicalJsonObject, CanonicalJsonValue, OwnedDeviceId, OwnedServerName, OwnedUserId, UserId, api::{ @@ -54,17 +62,7 @@ pub(super) async fn auth( | None => request.query.access_token.as_deref(), }; - let token = if let Some(token) = token { - match services.appservice.find_from_token(token).await { - | Some(reg_info) => Token::Appservice(Box::new(reg_info)), - | _ => match services.users.find_from_token(token).await { - | Ok((user_id, device_id)) => Token::User((user_id, device_id)), - | _ => Token::Invalid, - }, - } - } else { - Token::None - }; + let token = find_token(services, token).await?; if metadata.authentication == AuthScheme::None { match metadata { @@ -342,3 +340,24 @@ async fn parse_x_matrix(request: &mut Request) -> Result { Ok(x_matrix) } + +async fn find_token(services: &Services, token: Option<&str>) -> Result { + let Some(token) = token else { + return Ok(Token::None); + }; + + let user_token = services.users.find_from_token(token).map_ok(Token::User); + + let appservice_token = services + .appservice + .find_from_token(token) + .map_ok(Box::new) + .map_ok(Token::Appservice); + + pin_mut!(user_token, appservice_token); + match select_ok([Left(user_token), Right(appservice_token)]).await { + | Err(e) if !e.is_not_found() => Err(e), + | Ok((token, _)) => Ok(token), + | _ => Ok(Token::Invalid), + } +} diff --git a/src/service/appservice/mod.rs b/src/service/appservice/mod.rs index de6fcfac..ad9c4a3f 100644 --- a/src/service/appservice/mod.rs +++ b/src/service/appservice/mod.rs @@ -133,10 +133,10 @@ impl Service { .await .is_ok() { - return err!(Request(InvalidParam( + return Err(err!(Request(InvalidParam( "Cannot register appservice: The provided token is already in use by a user \ device. Please generate a different token for the appservice." - ))); + )))); } self.db @@ -182,12 +182,13 @@ impl Service { .map(|info| info.registration) } - pub async fn find_from_token(&self, token: &str) -> Option { + pub async fn find_from_token(&self, token: &str) -> Result { self.read() .await .values() .find(|info| info.registration.as_token == token) .cloned() + .ok_or_else(|| err!(Request(NotFound("Appservice token not found")))) } /// Checks if a given user id matches any exclusive appservice regex diff --git a/src/service/users/mod.rs b/src/service/users/mod.rs index d2dfccd9..0aacc0e1 100644 --- a/src/service/users/mod.rs +++ b/src/service/users/mod.rs @@ -19,7 +19,7 @@ use ruma::{ use serde::{Deserialize, Serialize}; use serde_json::json; -use crate::{Dep, account_data, admin, globals, rooms}; +use crate::{Dep, account_data, admin, appservice, globals, rooms}; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct UserSuspension { @@ -40,6 +40,7 @@ struct Services { server: Arc, account_data: Dep, admin: Dep, + appservice: Dep, globals: Dep, state_accessor: Dep, state_cache: Dep, @@ -76,6 +77,7 @@ impl crate::Service for Service { server: args.server.clone(), account_data: args.depend::("account_data"), admin: args.depend::("admin"), + appservice: args.depend::("appservice"), globals: args.depend::("globals"), state_accessor: args .depend::("rooms::state_accessor"), @@ -407,6 +409,26 @@ impl Service { ))); } + // Prevent token collisions with appservice tokens + let final_token = if self + .services + .appservice + .find_from_token(token) + .await + .is_ok() + { + let new_token = utils::random_string(32); + conduwuit::debug_warn!( + "Token collision prevented: Generated new token for user '{}' device '{}' \ + (original token conflicted with an appservice)", + user_id.localpart(), + device_id + ); + new_token + } else { + token.to_owned() + }; + // Remove old token if let Ok(old_token) = self.db.userdeviceid_token.qry(&key).await { self.db.token_userdeviceid.remove(&old_token); @@ -414,8 +436,8 @@ impl Service { } // Assign token to user device combination - self.db.userdeviceid_token.put_raw(key, token); - self.db.token_userdeviceid.raw_put(token, key); + self.db.userdeviceid_token.put_raw(key, &final_token); + self.db.token_userdeviceid.raw_put(&final_token, key); Ok(()) } From 9286838d23aa5b60bb0a030d403879406732bbc9 Mon Sep 17 00:00:00 2001 From: Tom Foster Date: Sun, 10 Aug 2025 18:22:20 +0100 Subject: [PATCH 3/4] fix(relations): improve thread pagination and include root event Replace unreliable PduCount pagination tokens with ShortEventId throughout the relations and messages endpoints. ShortEventId provides stable, unique identifiers that persist across server restarts and database operations. Key improvements: - Add token parsing helpers that try ShortEventId first, fall back to PduCount for backwards compatibility - Include thread root event when paginating backwards to thread start - Fix off-by-one error in get_relations that was returning the starting event in results - Only return next_batch/prev_batch tokens when more events are available, preventing clients from making unnecessary requests at thread boundaries - Ensure consistent token format between /relations, /messages, and /sync endpoints for interoperability This fixes duplicate events when scrolling at thread boundaries and ensures the thread root message is visible when viewing a thread, matching expected client behaviour. --- src/api/client/message.rs | 60 +++++++++--- src/api/client/relations.rs | 129 +++++++++++++++++++++---- src/service/rooms/pdu_metadata/data.rs | 14 ++- 3 files changed, 168 insertions(+), 35 deletions(-) diff --git a/src/api/client/message.rs b/src/api/client/message.rs index f8818ebb..95a135e1 100644 --- a/src/api/client/message.rs +++ b/src/api/client/message.rs @@ -1,14 +1,14 @@ use axum::extract::State; use conduwuit::{ - Err, Result, at, + Err, Result, at, err, matrix::{ event::{Event, Matches}, - pdu::PduCount, + pdu::{PduCount, ShortEventId}, }, ref_at, utils::{ IterStream, ReadyExt, - result::{FlatOk, LogErr}, + result::LogErr, stream::{BroadbandExt, TryIgnore, WidebandExt}, }, }; @@ -61,6 +61,39 @@ const IGNORED_MESSAGE_TYPES: &[TimelineEventType] = &[ const LIMIT_MAX: usize = 100; const LIMIT_DEFAULT: usize = 10; +/// Parse a pagination token, trying ShortEventId first, then falling back to +/// PduCount +async fn parse_pagination_token( + _services: &Services, + _room_id: &RoomId, + token: Option<&str>, + default: PduCount, +) -> Result { + let Some(token) = token else { + return Ok(default); + }; + + // Try parsing as ShortEventId first + if let Ok(shorteventid) = token.parse::() { + // ShortEventId maps directly to a PduCount in our database + Ok(PduCount::Normal(shorteventid)) + } else if let Ok(count) = token.parse::() { + // Fallback to PduCount for backwards compatibility + Ok(PduCount::Normal(count)) + } else if let Ok(count) = token.parse::() { + // Also handle negative counts for backfilled events + Ok(PduCount::from_signed(count)) + } else { + Err(err!(Request(InvalidParam("Invalid pagination token")))) + } +} + +/// Convert a PduCount to a token string (using the underlying ShortEventId) +fn count_to_token(count: PduCount) -> String { + // The PduCount's unsigned value IS the ShortEventId + count.into_unsigned().to_string() +} + /// # `GET /_matrix/client/r0/rooms/{roomId}/messages` /// /// Allows paginating through room history. @@ -81,17 +114,18 @@ pub(crate) async fn get_message_events_route( return Err!(Request(Forbidden("Room does not exist to this server"))); } - let from: PduCount = body - .from - .as_deref() - .map(str::parse) - .transpose()? - .unwrap_or_else(|| match body.dir { + let from: PduCount = + parse_pagination_token(&services, room_id, body.from.as_deref(), match body.dir { | Direction::Forward => PduCount::min(), | Direction::Backward => PduCount::max(), - }); + }) + .await?; - let to: Option = body.to.as_deref().map(str::parse).flat_ok(); + let to: Option = if let Some(to_str) = body.to.as_deref() { + Some(parse_pagination_token(&services, room_id, Some(to_str), PduCount::min()).await?) + } else { + None + }; let limit: usize = body .limit @@ -180,8 +214,8 @@ pub(crate) async fn get_message_events_route( .collect(); Ok(get_message_events::v3::Response { - start: from.to_string(), - end: next_token.as_ref().map(ToString::to_string), + start: count_to_token(from), + end: next_token.map(count_to_token), chunk, state, }) diff --git a/src/api/client/relations.rs b/src/api/client/relations.rs index 1aa34ada..48bcde20 100644 --- a/src/api/client/relations.rs +++ b/src/api/client/relations.rs @@ -1,7 +1,11 @@ use axum::extract::State; use conduwuit::{ - Result, at, - matrix::{Event, event::RelationTypeEqual, pdu::PduCount}, + Result, at, err, + matrix::{ + Event, + event::RelationTypeEqual, + pdu::{PduCount, ShortEventId}, + }, utils::{IterStream, ReadyExt, result::FlatOk, stream::WidebandExt}, }; use conduwuit_service::Services; @@ -20,6 +24,40 @@ use ruma::{ use crate::Ruma; +/// Parse a pagination token, trying ShortEventId first, then falling back to +/// PduCount +async fn parse_pagination_token( + _services: &Services, + _room_id: &RoomId, + token: Option<&str>, + default: PduCount, +) -> Result { + let Some(token) = token else { + return Ok(default); + }; + + // Try parsing as ShortEventId first + if let Ok(shorteventid) = token.parse::() { + // ShortEventId maps directly to a PduCount in our database + // The shorteventid IS the count value, just need to wrap it + Ok(PduCount::Normal(shorteventid)) + } else if let Ok(count) = token.parse::() { + // Fallback to PduCount for backwards compatibility + Ok(PduCount::Normal(count)) + } else if let Ok(count) = token.parse::() { + // Also handle negative counts for backfilled events + Ok(PduCount::from_signed(count)) + } else { + Err(err!(Request(InvalidParam("Invalid pagination token")))) + } +} + +/// Convert a PduCount to a token string (using the underlying ShortEventId) +fn count_to_token(count: PduCount) -> String { + // The PduCount's unsigned value IS the ShortEventId + count.into_unsigned().to_string() +} + /// # `GET /_matrix/client/r0/rooms/{roomId}/relations/{eventId}/{relType}/{eventType}` pub(crate) async fn get_relating_events_with_rel_type_and_event_type_route( State(services): State, @@ -109,15 +147,17 @@ async fn paginate_relations_with_filter( recurse: bool, dir: Direction, ) -> Result { - let start: PduCount = from - .map(str::parse) - .transpose()? - .unwrap_or_else(|| match dir { - | Direction::Forward => PduCount::min(), - | Direction::Backward => PduCount::max(), - }); + let start: PduCount = parse_pagination_token(services, room_id, from, match dir { + | Direction::Forward => PduCount::min(), + | Direction::Backward => PduCount::max(), + }) + .await?; - let to: Option = to.map(str::parse).flat_ok(); + let to: Option = if let Some(to_str) = to { + Some(parse_pagination_token(services, room_id, Some(to_str), PduCount::min()).await?) + } else { + None + }; // Use limit or else 30, with maximum 100 let limit: usize = limit @@ -129,6 +169,11 @@ async fn paginate_relations_with_filter( // Spec (v1.10) recommends depth of at least 3 let depth: u8 = if recurse { 3 } else { 1 }; + // Check if this is a thread request + let is_thread = filter_rel_type + .as_ref() + .is_some_and(|rel| *rel == RelationType::Thread); + let events: Vec<_> = services .rooms .pdu_metadata @@ -152,23 +197,65 @@ async fn paginate_relations_with_filter( .collect() .await; - let next_batch = match dir { - | Direction::Forward => events.last(), - | Direction::Backward => events.first(), + // For threads, check if we should include the root event + let mut root_event = None; + if is_thread && dir == Direction::Backward { + // Check if we've reached the beginning of the thread + // (fewer events than requested means we've exhausted the thread) + if events.len() < limit { + // Try to get the thread root event + if let Ok(root_pdu) = services.rooms.timeline.get_pdu(target).await { + // Check visibility + if services + .rooms + .state_accessor + .user_can_see_event(sender_user, room_id, target) + .await + { + // Store the root event to add to the response + root_event = Some(root_pdu); + } + } + } } - .map(at!(0)) - .as_ref() - .map(ToString::to_string); + + // Determine if there are more events to fetch + let has_more = if root_event.is_some() { + false // We've included the root, no more events + } else { + // Check if we got a full page of results (might be more) + events.len() >= limit + }; + + let next_batch = if has_more { + match dir { + | Direction::Forward => events.last(), + | Direction::Backward => events.first(), + } + .map(|(count, _)| count_to_token(*count)) + } else { + None + }; + + // Build the response chunk with thread root if needed + let chunk: Vec<_> = if let Some(root) = root_event { + // Add root event at the beginning for backward pagination + std::iter::once(root.into_format()) + .chain(events.into_iter().map(at!(1)).map(Event::into_format)) + .collect() + } else { + events + .into_iter() + .map(at!(1)) + .map(Event::into_format) + .collect() + }; Ok(get_relating_events::v1::Response { next_batch, prev_batch: from.map(Into::into), recursion_depth: recurse.then_some(depth.into()), - chunk: events - .into_iter() - .map(at!(1)) - .map(Event::into_format) - .collect(), + chunk, }) } diff --git a/src/service/rooms/pdu_metadata/data.rs b/src/service/rooms/pdu_metadata/data.rs index c1376cb0..f9cc80a0 100644 --- a/src/service/rooms/pdu_metadata/data.rs +++ b/src/service/rooms/pdu_metadata/data.rs @@ -61,9 +61,10 @@ impl Data { from: PduCount, dir: Direction, ) -> impl Stream + Send + '_ { + let from_unsigned = from.into_unsigned(); let mut current = ArrayVec::::new(); current.extend(target.to_be_bytes()); - current.extend(from.saturating_inc(dir).into_unsigned().to_be_bytes()); + current.extend(from_unsigned.to_be_bytes()); let current = current.as_slice(); match dir { | Direction::Forward => self.tofrom_relation.raw_keys_from(current).boxed(), @@ -73,6 +74,17 @@ impl Data { .ready_take_while(move |key| key.starts_with(&target.to_be_bytes())) .map(|to_from| u64_from_u8(&to_from[8..16])) .map(PduCount::from_unsigned) + .ready_filter(move |count| { + if from == PduCount::min() || from == PduCount::max() { + true + } else { + let count_unsigned = count.into_unsigned(); + match dir { + | Direction::Forward => count_unsigned > from_unsigned, + | Direction::Backward => count_unsigned < from_unsigned, + } + } + }) .wide_filter_map(move |shorteventid| async move { let pdu_id: RawPduId = PduId { shortroomid, shorteventid }.into(); From 583cb924f1510960702f626c7397cb9e87e1722a Mon Sep 17 00:00:00 2001 From: Tom Foster Date: Mon, 11 Aug 2025 06:24:29 +0100 Subject: [PATCH 4/4] refactor: address code review feedback for auth and pagination improvements - Extract duplicated thread/message pagination functions to shared utils module - Refactor pagination token parsing to use Option combinators instead of defaults - Split access token generation from assignment for clearer error handling - Add appservice token collision detection at startup and registration - Allow appservice re-registration with same token (for config updates) - Simplify thread relation chunk building using iterator chaining - Fix saturating_inc edge case in relation queries with explicit filtering - Add concise comments explaining non-obvious behaviour choices --- src/api/client/message.rs | 55 ++++------------- src/api/client/mod.rs | 1 + src/api/client/relations.rs | 78 +++++------------------- src/api/client/session.rs | 4 +- src/api/client/utils.rs | 28 +++++++++ src/api/router/auth.rs | 1 + src/service/appservice/mod.rs | 83 +++++++++++++++++--------- src/service/rooms/pdu_metadata/data.rs | 2 + src/service/users/mod.rs | 48 ++++++++++----- 9 files changed, 149 insertions(+), 151 deletions(-) create mode 100644 src/api/client/utils.rs diff --git a/src/api/client/message.rs b/src/api/client/message.rs index 95a135e1..4d489c2f 100644 --- a/src/api/client/message.rs +++ b/src/api/client/message.rs @@ -1,9 +1,9 @@ use axum::extract::State; use conduwuit::{ - Err, Result, at, err, + Err, Result, at, matrix::{ event::{Event, Matches}, - pdu::{PduCount, ShortEventId}, + pdu::PduCount, }, ref_at, utils::{ @@ -35,6 +35,7 @@ use ruma::{ }; use tracing::warn; +use super::utils::{count_to_token, parse_pagination_token as parse_token}; use crate::Ruma; /// list of safe and common non-state events to ignore if the user is ignored @@ -61,39 +62,6 @@ const IGNORED_MESSAGE_TYPES: &[TimelineEventType] = &[ const LIMIT_MAX: usize = 100; const LIMIT_DEFAULT: usize = 10; -/// Parse a pagination token, trying ShortEventId first, then falling back to -/// PduCount -async fn parse_pagination_token( - _services: &Services, - _room_id: &RoomId, - token: Option<&str>, - default: PduCount, -) -> Result { - let Some(token) = token else { - return Ok(default); - }; - - // Try parsing as ShortEventId first - if let Ok(shorteventid) = token.parse::() { - // ShortEventId maps directly to a PduCount in our database - Ok(PduCount::Normal(shorteventid)) - } else if let Ok(count) = token.parse::() { - // Fallback to PduCount for backwards compatibility - Ok(PduCount::Normal(count)) - } else if let Ok(count) = token.parse::() { - // Also handle negative counts for backfilled events - Ok(PduCount::from_signed(count)) - } else { - Err(err!(Request(InvalidParam("Invalid pagination token")))) - } -} - -/// Convert a PduCount to a token string (using the underlying ShortEventId) -fn count_to_token(count: PduCount) -> String { - // The PduCount's unsigned value IS the ShortEventId - count.into_unsigned().to_string() -} - /// # `GET /_matrix/client/r0/rooms/{roomId}/messages` /// /// Allows paginating through room history. @@ -114,18 +82,17 @@ pub(crate) async fn get_message_events_route( return Err!(Request(Forbidden("Room does not exist to this server"))); } - let from: PduCount = - parse_pagination_token(&services, room_id, body.from.as_deref(), match body.dir { + let from: PduCount = body + .from + .as_deref() + .map(parse_token) + .transpose()? + .unwrap_or_else(|| match body.dir { | Direction::Forward => PduCount::min(), | Direction::Backward => PduCount::max(), - }) - .await?; + }); - let to: Option = if let Some(to_str) = body.to.as_deref() { - Some(parse_pagination_token(&services, room_id, Some(to_str), PduCount::min()).await?) - } else { - None - }; + let to: Option = body.to.as_deref().map(parse_token).transpose()?; let limit: usize = body .limit diff --git a/src/api/client/mod.rs b/src/api/client/mod.rs index be54e65f..e4be20b7 100644 --- a/src/api/client/mod.rs +++ b/src/api/client/mod.rs @@ -36,6 +36,7 @@ pub(super) mod typing; pub(super) mod unstable; pub(super) mod unversioned; pub(super) mod user_directory; +pub(super) mod utils; pub(super) mod voip; pub(super) mod well_known; diff --git a/src/api/client/relations.rs b/src/api/client/relations.rs index 48bcde20..f6d8fe9e 100644 --- a/src/api/client/relations.rs +++ b/src/api/client/relations.rs @@ -1,11 +1,7 @@ use axum::extract::State; use conduwuit::{ - Result, at, err, - matrix::{ - Event, - event::RelationTypeEqual, - pdu::{PduCount, ShortEventId}, - }, + Result, at, + matrix::{Event, event::RelationTypeEqual, pdu::PduCount}, utils::{IterStream, ReadyExt, result::FlatOk, stream::WidebandExt}, }; use conduwuit_service::Services; @@ -22,42 +18,9 @@ use ruma::{ events::{TimelineEventType, relation::RelationType}, }; +use super::utils::{count_to_token, parse_pagination_token as parse_token}; use crate::Ruma; -/// Parse a pagination token, trying ShortEventId first, then falling back to -/// PduCount -async fn parse_pagination_token( - _services: &Services, - _room_id: &RoomId, - token: Option<&str>, - default: PduCount, -) -> Result { - let Some(token) = token else { - return Ok(default); - }; - - // Try parsing as ShortEventId first - if let Ok(shorteventid) = token.parse::() { - // ShortEventId maps directly to a PduCount in our database - // The shorteventid IS the count value, just need to wrap it - Ok(PduCount::Normal(shorteventid)) - } else if let Ok(count) = token.parse::() { - // Fallback to PduCount for backwards compatibility - Ok(PduCount::Normal(count)) - } else if let Ok(count) = token.parse::() { - // Also handle negative counts for backfilled events - Ok(PduCount::from_signed(count)) - } else { - Err(err!(Request(InvalidParam("Invalid pagination token")))) - } -} - -/// Convert a PduCount to a token string (using the underlying ShortEventId) -fn count_to_token(count: PduCount) -> String { - // The PduCount's unsigned value IS the ShortEventId - count.into_unsigned().to_string() -} - /// # `GET /_matrix/client/r0/rooms/{roomId}/relations/{eventId}/{relType}/{eventType}` pub(crate) async fn get_relating_events_with_rel_type_and_event_type_route( State(services): State, @@ -147,17 +110,15 @@ async fn paginate_relations_with_filter( recurse: bool, dir: Direction, ) -> Result { - let start: PduCount = parse_pagination_token(services, room_id, from, match dir { - | Direction::Forward => PduCount::min(), - | Direction::Backward => PduCount::max(), - }) - .await?; + let start: PduCount = from + .map(parse_token) + .transpose()? + .unwrap_or_else(|| match dir { + | Direction::Forward => PduCount::min(), + | Direction::Backward => PduCount::max(), + }); - let to: Option = if let Some(to_str) = to { - Some(parse_pagination_token(services, room_id, Some(to_str), PduCount::min()).await?) - } else { - None - }; + let to: Option = to.map(parse_token).transpose()?; // Use limit or else 30, with maximum 100 let limit: usize = limit @@ -238,18 +199,11 @@ async fn paginate_relations_with_filter( }; // Build the response chunk with thread root if needed - let chunk: Vec<_> = if let Some(root) = root_event { - // Add root event at the beginning for backward pagination - std::iter::once(root.into_format()) - .chain(events.into_iter().map(at!(1)).map(Event::into_format)) - .collect() - } else { - events - .into_iter() - .map(at!(1)) - .map(Event::into_format) - .collect() - }; + let chunk: Vec<_> = root_event + .into_iter() + .map(Event::into_format) + .chain(events.into_iter().map(at!(1)).map(Event::into_format)) + .collect(); Ok(get_relating_events::v1::Response { next_batch, diff --git a/src/api/client/session.rs b/src/api/client/session.rs index 992073c6..fe07e41d 100644 --- a/src/api/client/session.rs +++ b/src/api/client/session.rs @@ -198,8 +198,8 @@ pub(crate) async fn login_route( .clone() .unwrap_or_else(|| utils::random_string(DEVICE_ID_LENGTH).into()); - // Generate a new token for the device - let token = utils::random_string(TOKEN_LENGTH); + // Generate a new token for the device (ensuring no collisions) + let token = services.users.generate_unique_token().await; // Determine if device_id was provided and exists in the db for this user let device_exists = if body.device_id.is_some() { diff --git a/src/api/client/utils.rs b/src/api/client/utils.rs new file mode 100644 index 00000000..cc941b95 --- /dev/null +++ b/src/api/client/utils.rs @@ -0,0 +1,28 @@ +use conduwuit::{ + Result, err, + matrix::pdu::{PduCount, ShortEventId}, +}; + +/// Parse a pagination token, trying ShortEventId first, then falling back to +/// PduCount +pub(crate) fn parse_pagination_token(token: &str) -> Result { + // Try parsing as ShortEventId first + if let Ok(shorteventid) = token.parse::() { + // ShortEventId maps directly to a PduCount in our database + Ok(PduCount::Normal(shorteventid)) + } else if let Ok(count) = token.parse::() { + // Fallback to PduCount for backwards compatibility + Ok(PduCount::Normal(count)) + } else if let Ok(count) = token.parse::() { + // Also handle negative counts for backfilled events + Ok(PduCount::from_signed(count)) + } else { + Err(err!(Request(InvalidParam("Invalid pagination token")))) + } +} + +/// Convert a PduCount to a token string (using the underlying ShortEventId) +pub(crate) fn count_to_token(count: PduCount) -> String { + // The PduCount's unsigned value IS the ShortEventId + count.into_unsigned().to_string() +} diff --git a/src/api/router/auth.rs b/src/api/router/auth.rs index 5088e699..44afc3ef 100644 --- a/src/api/router/auth.rs +++ b/src/api/router/auth.rs @@ -355,6 +355,7 @@ async fn find_token(services: &Services, token: Option<&str>) -> Result { .map_ok(Token::Appservice); pin_mut!(user_token, appservice_token); + // Returns Ok if either token type succeeds, Err only if both fail match select_ok([Left(user_token), Right(appservice_token)]).await { | Err(e) if !e.is_not_found() => Err(e), | Ok((token, _)) => Ok(token), diff --git a/src/service/appservice/mod.rs b/src/service/appservice/mod.rs index ad9c4a3f..ebd798f6 100644 --- a/src/service/appservice/mod.rs +++ b/src/service/appservice/mod.rs @@ -4,7 +4,7 @@ mod registration_info; use std::{collections::BTreeMap, iter::IntoIterator, sync::Arc}; use async_trait::async_trait; -use conduwuit::{Result, err, utils::stream::IterStream}; +use conduwuit::{Err, Result, err, utils::stream::IterStream}; use database::Map; use futures::{Future, FutureExt, Stream, TryStreamExt}; use ruma::{RoomAliasId, RoomId, UserId, api::appservice::Registration}; @@ -48,36 +48,50 @@ impl crate::Service for Service { } async fn worker(self: Arc) -> Result { - self.iter_db_ids() - .try_for_each(async |appservice| { - let (id, registration) = appservice; + // First, collect all appservices to check for token conflicts + let appservices: Vec<(String, Registration)> = self.iter_db_ids().try_collect().await?; - // During startup, resolve any token collisions in favour of appservices - // by logging out conflicting user devices - if let Ok((user_id, device_id)) = self - .services - .users - .find_from_token(®istration.as_token) - .await - { - conduwuit::warn!( - "Token collision detected during startup: Appservice '{}' token was \ - also used by user '{}' device '{}'. Logging out the user device to \ - resolve conflict.", - id, - user_id.localpart(), - device_id - ); - - self.services - .users - .remove_device(&user_id, &device_id) - .await; + // Check for appservice-to-appservice token conflicts + for i in 0..appservices.len() { + for j in i.saturating_add(1)..appservices.len() { + if appservices[i].1.as_token == appservices[j].1.as_token { + return Err!(Database(error!( + "Token collision detected: Appservices '{}' and '{}' have the same token", + appservices[i].0, appservices[j].0 + ))); } + } + } - self.start_appservice(id, registration).await - }) - .await + // Process each appservice + for (id, registration) in appservices { + // During startup, resolve any token collisions in favour of appservices + // by logging out conflicting user devices + if let Ok((user_id, device_id)) = self + .services + .users + .find_from_token(®istration.as_token) + .await + { + conduwuit::warn!( + "Token collision detected during startup: Appservice '{}' token was also \ + used by user '{}' device '{}'. Logging out the user device to resolve \ + conflict.", + id, + user_id.localpart(), + device_id + ); + + self.services + .users + .remove_device(&user_id, &device_id) + .await; + } + + self.start_appservice(id, registration).await?; + } + + Ok(()) } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } @@ -125,6 +139,18 @@ impl Service { ) -> Result { //TODO: Check for collisions between exclusive appservice namespaces + // Check for token collision with other appservices (allow re-registration of + // same appservice) + if let Ok(existing) = self.find_from_token(®istration.as_token).await { + if existing.registration.id != registration.id { + return Err(err!(Request(InvalidParam( + "Cannot register appservice: Token is already used by appservice '{}'. \ + Please generate a different token.", + existing.registration.id + )))); + } + } + // Prevent token collision with existing user tokens if self .services @@ -182,6 +208,7 @@ impl Service { .map(|info| info.registration) } + /// Returns Result to match users::find_from_token for select_ok usage pub async fn find_from_token(&self, token: &str) -> Result { self.read() .await diff --git a/src/service/rooms/pdu_metadata/data.rs b/src/service/rooms/pdu_metadata/data.rs index f9cc80a0..a746b4cc 100644 --- a/src/service/rooms/pdu_metadata/data.rs +++ b/src/service/rooms/pdu_metadata/data.rs @@ -61,6 +61,8 @@ impl Data { from: PduCount, dir: Direction, ) -> impl Stream + Send + '_ { + // Query from exact position then filter excludes it (saturating_inc could skip + // events at min/max boundaries) let from_unsigned = from.into_unsigned(); let mut current = ArrayVec::::new(); current.extend(target.to_be_bytes()); diff --git a/src/service/users/mod.rs b/src/service/users/mod.rs index 0aacc0e1..eb54660e 100644 --- a/src/service/users/mod.rs +++ b/src/service/users/mod.rs @@ -393,6 +393,31 @@ impl Service { self.db.userdeviceid_token.qry(&key).await.deserialized() } + /// Generate a unique access token that doesn't collide with existing tokens + pub async fn generate_unique_token(&self) -> String { + loop { + let token = utils::random_string(32); + + // Check for collision with appservice tokens + if self + .services + .appservice + .find_from_token(&token) + .await + .is_ok() + { + continue; + } + + // Check for collision with user tokens + if self.db.token_userdeviceid.get(&token).await.is_ok() { + continue; + } + + return token; + } + } + /// Replaces the access token of one device. pub async fn set_token( &self, @@ -409,25 +434,18 @@ impl Service { ))); } - // Prevent token collisions with appservice tokens - let final_token = if self + // Check for token collision with appservices + if self .services .appservice .find_from_token(token) .await .is_ok() { - let new_token = utils::random_string(32); - conduwuit::debug_warn!( - "Token collision prevented: Generated new token for user '{}' device '{}' \ - (original token conflicted with an appservice)", - user_id.localpart(), - device_id - ); - new_token - } else { - token.to_owned() - }; + return Err!(Request(InvalidParam( + "Token conflicts with an existing appservice token" + ))); + } // Remove old token if let Ok(old_token) = self.db.userdeviceid_token.qry(&key).await { @@ -436,8 +454,8 @@ impl Service { } // Assign token to user device combination - self.db.userdeviceid_token.put_raw(key, &final_token); - self.db.token_userdeviceid.raw_put(&final_token, key); + self.db.userdeviceid_token.put_raw(key, token); + self.db.token_userdeviceid.raw_put(token, key); Ok(()) }