apply new rustfmt.toml changes, fix some clippy lints

Signed-off-by: strawberry <strawberry@puppygock.gay>
This commit is contained in:
strawberry 2024-12-15 00:05:47 -05:00
commit 77e0b76408
No known key found for this signature in database
296 changed files with 7147 additions and 4300 deletions

View file

@ -9,8 +9,8 @@ use database::{Deserialized, Handle, Interfix, Json, Map};
use futures::{Stream, StreamExt, TryFutureExt};
use ruma::{
events::{
AnyGlobalAccountDataEvent, AnyRawAccountDataEvent, AnyRoomAccountDataEvent, GlobalAccountDataEventType,
RoomAccountDataEventType,
AnyGlobalAccountDataEvent, AnyRawAccountDataEvent, AnyRoomAccountDataEvent,
GlobalAccountDataEventType, RoomAccountDataEventType,
},
serde::Raw,
RoomId, UserId,
@ -54,7 +54,11 @@ impl crate::Service for Service {
#[allow(clippy::needless_pass_by_value)]
#[implement(Service)]
pub async fn update(
&self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType, data: &serde_json::Value,
&self,
room_id: Option<&RoomId>,
user_id: &UserId,
event_type: RoomAccountDataEventType,
data: &serde_json::Value,
) -> Result<()> {
if data.get("type").is_none() || data.get("content").is_none() {
return Err!(Request(InvalidParam("Account data doesn't have all required fields.")));
@ -91,7 +95,12 @@ where
/// Searches the global account data for a specific kind.
#[implement(Service)]
pub async fn get_room<T>(&self, room_id: &RoomId, user_id: &UserId, kind: RoomAccountDataEventType) -> Result<T>
pub async fn get_room<T>(
&self,
room_id: &RoomId,
user_id: &UserId,
kind: RoomAccountDataEventType,
) -> Result<T>
where
T: for<'de> Deserialize<'de>,
{
@ -101,7 +110,12 @@ where
}
#[implement(Service)]
pub async fn get_raw(&self, room_id: Option<&RoomId>, user_id: &UserId, kind: &str) -> Result<Handle<'_>> {
pub async fn get_raw(
&self,
room_id: Option<&RoomId>,
user_id: &UserId,
kind: &str,
) -> Result<Handle<'_>> {
let key = (room_id, user_id, kind.to_owned());
self.db
.roomusertype_roomuserdataid
@ -113,7 +127,10 @@ pub async fn get_raw(&self, room_id: Option<&RoomId>, user_id: &UserId, kind: &s
/// Returns all changes to the account data that happened after `since`.
#[implement(Service)]
pub fn changes_since<'a>(
&'a self, room_id: Option<&'a RoomId>, user_id: &'a UserId, since: u64,
&'a self,
room_id: Option<&'a RoomId>,
user_id: &'a UserId,
since: u64,
) -> impl Stream<Item = AnyRawAccountDataEvent> + Send + 'a {
let prefix = (room_id, user_id, Interfix);
let prefix = database::serialize_key(prefix).expect("failed to serialize prefix");
@ -128,8 +145,10 @@ pub fn changes_since<'a>(
.ready_take_while(move |(k, _)| k.starts_with(&prefix))
.map(move |(_, v)| {
match room_id {
Some(_) => serde_json::from_slice::<Raw<AnyRoomAccountDataEvent>>(v).map(AnyRawAccountDataEvent::Room),
None => serde_json::from_slice::<Raw<AnyGlobalAccountDataEvent>>(v).map(AnyRawAccountDataEvent::Global),
| Some(_) => serde_json::from_slice::<Raw<AnyRoomAccountDataEvent>>(v)
.map(AnyRawAccountDataEvent::Room),
| None => serde_json::from_slice::<Raw<AnyGlobalAccountDataEvent>>(v)
.map(AnyRawAccountDataEvent::Global),
}
.map_err(|e| err!(Database("Database contains invalid account data: {e}")))
.log_err()

View file

@ -94,15 +94,16 @@ impl Console {
debug!("session starting");
while self.server.running() {
match self.readline().await {
Ok(event) => match event {
ReadlineEvent::Line(string) => self.clone().handle(string).await,
ReadlineEvent::Interrupted => continue,
ReadlineEvent::Eof => break,
ReadlineEvent::Quit => self.server.shutdown().unwrap_or_else(error::default_log),
| Ok(event) => match event {
| ReadlineEvent::Line(string) => self.clone().handle(string).await,
| ReadlineEvent::Interrupted => continue,
| ReadlineEvent::Eof => break,
| ReadlineEvent::Quit =>
self.server.shutdown().unwrap_or_else(error::default_log),
},
Err(error) => match error {
ReadlineError::Closed => break,
ReadlineError::IO(error) => {
| Err(error) => match error {
| ReadlineError::Closed => break,
| ReadlineError::IO(error) => {
error!("console I/O: {error:?}");
break;
},
@ -158,9 +159,9 @@ impl Console {
async fn process(self: Arc<Self>, line: String) {
match self.admin.command_in_place(line, None).await {
Ok(Some(ref content)) => self.output(content),
Err(ref content) => self.output_err(content),
_ => unreachable!(),
| Ok(Some(ref content)) => self.output(content),
| Err(ref content) => self.output_err(content),
| _ => unreachable!(),
}
}

View file

@ -42,8 +42,9 @@ pub async fn create_admin_room(services: &Services) -> Result<()> {
let create_content = {
use RoomVersionId::*;
match room_version {
V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => RoomCreateEventContent::new_v1(server_user.clone()),
_ => RoomCreateEventContent::new_v11(),
| V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 =>
RoomCreateEventContent::new_v1(server_user.clone()),
| _ => RoomCreateEventContent::new_v11(),
}
};
@ -52,15 +53,12 @@ pub async fn create_admin_room(services: &Services) -> Result<()> {
.rooms
.timeline
.build_and_append_pdu(
PduBuilder::state(
String::new(),
&RoomCreateEventContent {
federate: true,
predecessor: None,
room_version: room_version.clone(),
..create_content
},
),
PduBuilder::state(String::new(), &RoomCreateEventContent {
federate: true,
predecessor: None,
room_version: room_version.clone(),
..create_content
}),
server_user,
&room_id,
&state_lock,
@ -72,7 +70,10 @@ pub async fn create_admin_room(services: &Services) -> Result<()> {
.rooms
.timeline
.build_and_append_pdu(
PduBuilder::state(server_user.to_string(), &RoomMemberEventContent::new(MembershipState::Join)),
PduBuilder::state(
server_user.to_string(),
&RoomMemberEventContent::new(MembershipState::Join),
),
server_user,
&room_id,
&state_lock,
@ -86,13 +87,10 @@ pub async fn create_admin_room(services: &Services) -> Result<()> {
.rooms
.timeline
.build_and_append_pdu(
PduBuilder::state(
String::new(),
&RoomPowerLevelsEventContent {
users,
..Default::default()
},
),
PduBuilder::state(String::new(), &RoomPowerLevelsEventContent {
users,
..Default::default()
}),
server_user,
&room_id,
&state_lock,
@ -131,7 +129,10 @@ pub async fn create_admin_room(services: &Services) -> Result<()> {
.rooms
.timeline
.build_and_append_pdu(
PduBuilder::state(String::new(), &RoomGuestAccessEventContent::new(GuestAccess::Forbidden)),
PduBuilder::state(
String::new(),
&RoomGuestAccessEventContent::new(GuestAccess::Forbidden),
),
server_user,
&room_id,
&state_lock,
@ -155,12 +156,9 @@ pub async fn create_admin_room(services: &Services) -> Result<()> {
.rooms
.timeline
.build_and_append_pdu(
PduBuilder::state(
String::new(),
&RoomTopicEventContent {
topic: format!("Manage {}", services.globals.server_name()),
},
),
PduBuilder::state(String::new(), &RoomTopicEventContent {
topic: format!("Manage {}", services.globals.server_name()),
}),
server_user,
&room_id,
&state_lock,
@ -174,13 +172,10 @@ pub async fn create_admin_room(services: &Services) -> Result<()> {
.rooms
.timeline
.build_and_append_pdu(
PduBuilder::state(
String::new(),
&RoomCanonicalAliasEventContent {
alias: Some(alias.clone()),
alt_aliases: Vec::new(),
},
),
PduBuilder::state(String::new(), &RoomCanonicalAliasEventContent {
alias: Some(alias.clone()),
alt_aliases: Vec::new(),
}),
server_user,
&room_id,
&state_lock,
@ -197,12 +192,7 @@ pub async fn create_admin_room(services: &Services) -> Result<()> {
.rooms
.timeline
.build_and_append_pdu(
PduBuilder::state(
String::new(),
&RoomPreviewUrlsEventContent {
disabled: true,
},
),
PduBuilder::state(String::new(), &RoomPreviewUrlsEventContent { disabled: true }),
server_user,
&room_id,
&state_lock,

View file

@ -34,7 +34,10 @@ pub async fn make_user_admin(&self, user_id: &UserId) -> Result<()> {
self.services
.timeline
.build_and_append_pdu(
PduBuilder::state(user_id.to_string(), &RoomMemberEventContent::new(MembershipState::Invite)),
PduBuilder::state(
user_id.to_string(),
&RoomMemberEventContent::new(MembershipState::Invite),
),
server_user,
&room_id,
&state_lock,
@ -43,7 +46,10 @@ pub async fn make_user_admin(&self, user_id: &UserId) -> Result<()> {
self.services
.timeline
.build_and_append_pdu(
PduBuilder::state(user_id.to_string(), &RoomMemberEventContent::new(MembershipState::Join)),
PduBuilder::state(
user_id.to_string(),
&RoomMemberEventContent::new(MembershipState::Join),
),
user_id,
&room_id,
&state_lock,
@ -51,18 +57,18 @@ pub async fn make_user_admin(&self, user_id: &UserId) -> Result<()> {
.await?;
// Set power level
let users = BTreeMap::from_iter([(server_user.clone(), 100.into()), (user_id.to_owned(), 100.into())]);
let users = BTreeMap::from_iter([
(server_user.clone(), 100.into()),
(user_id.to_owned(), 100.into()),
]);
self.services
.timeline
.build_and_append_pdu(
PduBuilder::state(
String::new(),
&RoomPowerLevelsEventContent {
users,
..Default::default()
},
),
PduBuilder::state(String::new(), &RoomPowerLevelsEventContent {
users,
..Default::default()
}),
server_user,
&room_id,
&state_lock,
@ -103,9 +109,7 @@ async fn set_room_tag(&self, room_id: &RoomId, user_id: &UserId, tag: &str) -> R
.get_room(room_id, user_id, RoomAccountDataEventType::Tag)
.await
.unwrap_or_else(|_| TagEvent {
content: TagEventContent {
tags: BTreeMap::new(),
},
content: TagEventContent { tags: BTreeMap::new() },
});
event

View file

@ -10,7 +10,9 @@ use std::{
};
use async_trait::async_trait;
use conduwuit::{debug, err, error, error::default_log, pdu::PduBuilder, Error, PduEvent, Result, Server};
use conduwuit::{
debug, err, error, error::default_log, pdu::PduBuilder, Error, PduEvent, Result, Server,
};
pub use create::create_admin_room;
use futures::{FutureExt, TryFutureExt};
use loole::{Receiver, Sender};
@ -158,21 +160,19 @@ impl Service {
/// the queue is full.
pub fn command(&self, command: String, reply_id: Option<OwnedEventId>) -> Result<()> {
self.sender
.send(CommandInput {
command,
reply_id,
})
.send(CommandInput { command, reply_id })
.map_err(|e| err!("Failed to enqueue admin command: {e:?}"))
}
/// Dispatches a comamnd to the processor on the current task and waits for
/// completion.
pub async fn command_in_place(&self, command: String, reply_id: Option<OwnedEventId>) -> ProcessorResult {
self.process_command(CommandInput {
command,
reply_id,
})
.await
pub async fn command_in_place(
&self,
command: String,
reply_id: Option<OwnedEventId>,
) -> ProcessorResult {
self.process_command(CommandInput { command, reply_id })
.await
}
/// Invokes the tab-completer to complete the command. When unavailable,
@ -191,8 +191,8 @@ impl Service {
async fn handle_command(&self, command: CommandInput) {
match self.process_command(command).await {
Ok(None) => debug!("Command successful with no response"),
Ok(Some(output)) | Err(output) => self
| Ok(None) => debug!("Command successful with no response"),
| Ok(Some(output)) | Err(output) => self
.handle_response(output)
.await
.unwrap_or_else(default_log),
@ -250,10 +250,7 @@ impl Service {
}
async fn handle_response(&self, content: RoomMessageEventContent) -> Result<()> {
let Some(Relation::Reply {
in_reply_to,
}) = content.relates_to.as_ref()
else {
let Some(Relation::Reply { in_reply_to }) = content.relates_to.as_ref() else {
return Ok(());
};
@ -277,7 +274,10 @@ impl Service {
}
async fn respond_to_room(
&self, content: RoomMessageEventContent, room_id: &RoomId, user_id: &UserId,
&self,
content: RoomMessageEventContent,
room_id: &RoomId,
user_id: &UserId,
) -> Result<()> {
assert!(self.user_is_admin(user_id).await, "sender is not admin");
@ -298,12 +298,16 @@ impl Service {
}
async fn handle_response_error(
&self, e: Error, room_id: &RoomId, user_id: &UserId, state_lock: &RoomMutexGuard,
&self,
e: Error,
room_id: &RoomId,
user_id: &UserId,
state_lock: &RoomMutexGuard,
) -> Result<()> {
error!("Failed to build and append admin room response PDU: \"{e}\"");
let content = RoomMessageEventContent::text_plain(format!(
"Failed to build and append admin room PDU: \"{e}\"\n\nThe original admin command may have finished \
successfully, but we could not return the output."
"Failed to build and append admin room PDU: \"{e}\"\n\nThe original admin command \
may have finished successfully, but we could not return the output."
));
self.services
@ -321,7 +325,8 @@ impl Service {
// Admin command with public echo (in admin room)
let server_user = &self.services.globals.server_user;
let is_public_prefix = body.starts_with("!admin") || body.starts_with(server_user.as_str());
let is_public_prefix =
body.starts_with("!admin") || body.starts_with(server_user.as_str());
// Expected backward branch
if !is_public_escape && !is_public_prefix {

View file

@ -65,9 +65,9 @@ async fn startup_execute_command(&self, i: usize, command: String) -> Result<()>
debug!("Startup command #{i}: executing {command:?}");
match self.command_in_place(command, None).await {
Ok(Some(output)) => Self::startup_command_output(i, &output),
Err(output) => Self::startup_command_error(i, &output),
Ok(None) => {
| Ok(Some(output)) => Self::startup_command_output(i, &output),
| Err(output) => Self::startup_command_error(i, &output),
| Ok(None) => {
info!("Startup command #{i} completed (no output).");
Ok(())
},

View file

@ -61,7 +61,11 @@ impl crate::Service for Service {
impl Service {
/// Registers an appservice and returns the ID to the caller
pub async fn register_appservice(&self, registration: &Registration, appservice_config_body: &str) -> Result {
pub async fn register_appservice(
&self,
registration: &Registration,
appservice_config_body: &str,
) -> Result {
//TODO: Check for collisions between exclusive appservice namespaces
self.registration_info
.write()
@ -152,7 +156,10 @@ impl Service {
.any(|info| info.rooms.is_exclusive_match(room_id.as_str()))
}
pub fn read(&self) -> impl Future<Output = tokio::sync::RwLockReadGuard<'_, BTreeMap<String, RegistrationInfo>>> {
pub fn read(
&self,
) -> impl Future<Output = tokio::sync::RwLockReadGuard<'_, BTreeMap<String, RegistrationInfo>>>
{
self.registration_info.read()
}

View file

@ -15,13 +15,15 @@ pub struct RegistrationInfo {
impl RegistrationInfo {
#[must_use]
pub fn is_user_match(&self, user_id: &UserId) -> bool {
self.users.is_match(user_id.as_str()) || self.registration.sender_localpart == user_id.localpart()
self.users.is_match(user_id.as_str())
|| self.registration.sender_localpart == user_id.localpart()
}
#[inline]
#[must_use]
pub fn is_exclusive_user_match(&self, user_id: &UserId) -> bool {
self.users.is_exclusive_match(user_id.as_str()) || self.registration.sender_localpart == user_id.localpart()
self.users.is_exclusive_match(user_id.as_str())
|| self.registration.sender_localpart == user_id.localpart()
}
}

View file

@ -42,7 +42,9 @@ impl crate::Service for Service {
.build()?,
url_preview: base(config)
.and_then(|builder| builder_interface(builder, url_preview_bind_iface.as_deref()))?
.and_then(|builder| {
builder_interface(builder, url_preview_bind_iface.as_deref())
})?
.local_address(url_preview_bind_addr)
.dns_resolver(resolver.resolver.clone())
.redirect(redirect::Policy::limited(3))
@ -178,7 +180,10 @@ fn base(config: &Config) -> Result<reqwest::ClientBuilder> {
}
#[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
fn builder_interface(builder: reqwest::ClientBuilder, config: Option<&str>) -> Result<reqwest::ClientBuilder> {
fn builder_interface(
builder: reqwest::ClientBuilder,
config: Option<&str>,
) -> Result<reqwest::ClientBuilder> {
if let Some(iface) = config {
Ok(builder.interface(iface))
} else {
@ -187,7 +192,10 @@ fn builder_interface(builder: reqwest::ClientBuilder, config: Option<&str>) -> R
}
#[cfg(not(any(target_os = "android", target_os = "fuchsia", target_os = "linux")))]
fn builder_interface(builder: reqwest::ClientBuilder, config: Option<&str>) -> Result<reqwest::ClientBuilder> {
fn builder_interface(
builder: reqwest::ClientBuilder,
config: Option<&str>,
) -> Result<reqwest::ClientBuilder> {
use conduwuit::Err;
if let Some(iface) = config {

View file

@ -3,7 +3,9 @@ use std::sync::Arc;
use async_trait::async_trait;
use conduwuit::{error, warn, Result};
use ruma::{
events::{push_rules::PushRulesEventContent, GlobalAccountDataEvent, GlobalAccountDataEventType},
events::{
push_rules::PushRulesEventContent, GlobalAccountDataEvent, GlobalAccountDataEventType,
},
push::Ruleset,
};
@ -31,16 +33,14 @@ impl crate::Service for Service {
}))
}
async fn worker(self: Arc<Self>) -> Result<()> {
async fn worker(self: Arc<Self>) -> Result {
if self.services.globals.is_read_only() {
return Ok(());
}
self.set_emergency_access()
.await
.inspect_err(|e| error!("Could not set the configured emergency password for the server user: {e}"))?;
Ok(())
self.set_emergency_access().await.inspect_err(|e| {
error!("Could not set the configured emergency password for the server user: {e}");
})
}
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
@ -49,7 +49,7 @@ impl crate::Service for Service {
impl Service {
/// Sets the emergency password and push rules for the server user account
/// in case emergency password is set
async fn set_emergency_access(&self) -> Result<bool> {
async fn set_emergency_access(&self) -> Result {
let server_user = &self.services.globals.server_user;
self.services
@ -57,8 +57,8 @@ impl Service {
.set_password(server_user, self.services.globals.emergency_password().as_deref())?;
let (ruleset, pwd_set) = match self.services.globals.emergency_password() {
Some(_) => (Ruleset::server_default(server_user), true),
None => (Ruleset::new(), false),
| Some(_) => (Ruleset::server_default(server_user), true),
| None => (Ruleset::new(), false),
};
self.services
@ -68,9 +68,7 @@ impl Service {
server_user,
GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(&GlobalAccountDataEvent {
content: PushRulesEventContent {
global: ruleset,
},
content: PushRulesEventContent { global: ruleset },
})
.expect("to json value always works"),
)
@ -78,14 +76,14 @@ impl Service {
if pwd_set {
warn!(
"The server account emergency password is set! Please unset it as soon as you finish admin account \
recovery! You will be logged out of the server service account when you finish."
"The server account emergency password is set! Please unset it as soon as you \
finish admin account recovery! You will be logged out of the server service \
account when you finish."
);
Ok(())
} else {
// logs out any users still in the server service account and removes sessions
self.services.users.deactivate_account(server_user).await?;
self.services.users.deactivate_account(server_user).await
}
Ok(pwd_set)
}
}

View file

@ -16,7 +16,9 @@ impl Data {
let db = &args.db;
Self {
global: db["global"].clone(),
counter: RwLock::new(Self::stored_count(&db["global"]).expect("initialized global counter")),
counter: RwLock::new(
Self::stored_count(&db["global"]).expect("initialized global counter"),
),
db: args.db.clone(),
}
}

View file

@ -10,7 +10,9 @@ use std::{
use conduwuit::{error, Config, Result};
use data::Data;
use regex::RegexSet;
use ruma::{OwnedEventId, OwnedRoomAliasId, OwnedServerName, OwnedUserId, RoomAliasId, ServerName, UserId};
use ruma::{
OwnedEventId, OwnedRoomAliasId, OwnedServerName, OwnedUserId, RoomAliasId, ServerName, UserId,
};
use tokio::sync::Mutex;
use crate::service;
@ -40,31 +42,31 @@ impl crate::Service for Service {
.as_ref()
.map(|secret| jsonwebtoken::DecodingKey::from_secret(secret.as_bytes()));
let turn_secret = config
.turn_secret_file
.as_ref()
.map_or(config.turn_secret.clone(), |path| {
std::fs::read_to_string(path).unwrap_or_else(|e| {
error!("Failed to read the TURN secret file: {e}");
config.turn_secret.clone()
})
});
let registration_token =
let turn_secret =
config
.registration_token_file
.turn_secret_file
.as_ref()
.map_or(config.registration_token.clone(), |path| {
let Ok(token) = std::fs::read_to_string(path).inspect_err(|e| {
error!("Failed to read the registration token file: {e}");
}) else {
return config.registration_token.clone();
};
.map_or(config.turn_secret.clone(), |path| {
std::fs::read_to_string(path).unwrap_or_else(|e| {
error!("Failed to read the TURN secret file: {e}");
Some(token)
config.turn_secret.clone()
})
});
let registration_token = config.registration_token_file.as_ref().map_or(
config.registration_token.clone(),
|path| {
let Ok(token) = std::fs::read_to_string(path).inspect_err(|e| {
error!("Failed to read the registration token file: {e}");
}) else {
return config.registration_token.clone();
};
Some(token)
},
);
let mut s = Self {
db,
config: config.clone(),
@ -73,8 +75,11 @@ impl crate::Service for Service {
stateres_mutex: Arc::new(Mutex::new(())),
admin_alias: RoomAliasId::parse(format!("#admins:{}", &config.server_name))
.expect("#admins:server_name is valid alias name"),
server_user: UserId::parse_with_server_name(String::from("conduit"), &config.server_name)
.expect("@conduit:server_name is valid"),
server_user: UserId::parse_with_server_name(
String::from("conduit"),
&config.server_name,
)
.expect("@conduit:server_name is valid"),
turn_secret,
registration_token,
};
@ -125,7 +130,9 @@ impl Service {
pub fn allow_guest_registration(&self) -> bool { self.config.allow_guest_registration }
pub fn allow_guests_auto_join_rooms(&self) -> bool { self.config.allow_guests_auto_join_rooms }
pub fn allow_guests_auto_join_rooms(&self) -> bool {
self.config.allow_guests_auto_join_rooms
}
pub fn log_guest_registrations(&self) -> bool { self.config.log_guest_registrations }
@ -137,17 +144,23 @@ impl Service {
self.config.allow_public_room_directory_over_federation
}
pub fn allow_device_name_federation(&self) -> bool { self.config.allow_device_name_federation }
pub fn allow_device_name_federation(&self) -> bool {
self.config.allow_device_name_federation
}
pub fn allow_room_creation(&self) -> bool { self.config.allow_room_creation }
pub fn new_user_displayname_suffix(&self) -> &String { &self.config.new_user_displayname_suffix }
pub fn new_user_displayname_suffix(&self) -> &String {
&self.config.new_user_displayname_suffix
}
pub fn allow_check_for_updates(&self) -> bool { self.config.allow_check_for_updates }
pub fn trusted_servers(&self) -> &[OwnedServerName] { &self.config.trusted_servers }
pub fn jwt_decoding_key(&self) -> Option<&jsonwebtoken::DecodingKey> { self.jwt_decoding_key.as_ref() }
pub fn jwt_decoding_key(&self) -> Option<&jsonwebtoken::DecodingKey> {
self.jwt_decoding_key.as_ref()
}
pub fn turn_password(&self) -> &String { &self.config.turn_password }
@ -173,11 +186,15 @@ impl Service {
&self.config.url_preview_domain_explicit_denylist
}
pub fn url_preview_url_contains_allowlist(&self) -> &Vec<String> { &self.config.url_preview_url_contains_allowlist }
pub fn url_preview_url_contains_allowlist(&self) -> &Vec<String> {
&self.config.url_preview_url_contains_allowlist
}
pub fn url_preview_max_spider_size(&self) -> usize { self.config.url_preview_max_spider_size }
pub fn url_preview_check_root_domain(&self) -> bool { self.config.url_preview_check_root_domain }
pub fn url_preview_check_root_domain(&self) -> bool {
self.config.url_preview_check_root_domain
}
pub fn forbidden_alias_names(&self) -> &RegexSet { &self.config.forbidden_alias_names }
@ -189,18 +206,26 @@ impl Service {
pub fn allow_outgoing_presence(&self) -> bool { self.config.allow_outgoing_presence }
pub fn allow_incoming_read_receipts(&self) -> bool { self.config.allow_incoming_read_receipts }
pub fn allow_incoming_read_receipts(&self) -> bool {
self.config.allow_incoming_read_receipts
}
pub fn allow_outgoing_read_receipts(&self) -> bool { self.config.allow_outgoing_read_receipts }
pub fn allow_outgoing_read_receipts(&self) -> bool {
self.config.allow_outgoing_read_receipts
}
pub fn block_non_admin_invites(&self) -> bool { self.config.block_non_admin_invites }
/// checks if `user_id` is local to us via server_name comparison
#[inline]
pub fn user_is_local(&self, user_id: &UserId) -> bool { self.server_is_ours(user_id.server_name()) }
pub fn user_is_local(&self, user_id: &UserId) -> bool {
self.server_is_ours(user_id.server_name())
}
#[inline]
pub fn server_is_ours(&self, server_name: &ServerName) -> bool { server_name == self.config.server_name }
pub fn server_is_ours(&self, server_name: &ServerName) -> bool {
server_name == self.config.server_name
}
#[inline]
pub fn is_read_only(&self) -> bool { self.db.db.is_read_only() }

View file

@ -48,7 +48,11 @@ impl crate::Service for Service {
}
#[implement(Service)]
pub fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw<BackupAlgorithm>) -> Result<String> {
pub fn create_backup(
&self,
user_id: &UserId,
backup_metadata: &Raw<BackupAlgorithm>,
) -> Result<String> {
let version = self.services.globals.next_count()?.to_string();
let count = self.services.globals.next_count()?;
@ -71,13 +75,18 @@ pub async fn delete_backup(&self, user_id: &UserId, version: &str) {
.backupkeyid_backup
.keys_prefix_raw(&key)
.ignore_err()
.ready_for_each(|outdated_key| self.db.backupkeyid_backup.remove(outdated_key))
.ready_for_each(|outdated_key| {
self.db.backupkeyid_backup.remove(outdated_key);
})
.await;
}
#[implement(Service)]
pub async fn update_backup<'a>(
&self, user_id: &UserId, version: &'a str, backup_metadata: &Raw<BackupAlgorithm>,
&self,
user_id: &UserId,
version: &'a str,
backup_metadata: &Raw<BackupAlgorithm>,
) -> Result<&'a str> {
let key = (user_id, version);
if self.db.backupid_algorithm.qry(&key).await.is_err() {
@ -110,7 +119,10 @@ pub async fn get_latest_backup_version(&self, user_id: &UserId) -> Result<String
}
#[implement(Service)]
pub async fn get_latest_backup(&self, user_id: &UserId) -> Result<(String, Raw<BackupAlgorithm>)> {
pub async fn get_latest_backup(
&self,
user_id: &UserId,
) -> Result<(String, Raw<BackupAlgorithm>)> {
type Key<'a> = (&'a UserId, &'a str);
type KeyVal<'a> = (Key<'a>, Raw<BackupAlgorithm>);
@ -134,7 +146,12 @@ pub async fn get_backup(&self, user_id: &UserId, version: &str) -> Result<Raw<Ba
#[implement(Service)]
pub async fn add_key(
&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, key_data: &Raw<KeyBackupData>,
&self,
user_id: &UserId,
version: &str,
room_id: &RoomId,
session_id: &str,
key_data: &Raw<KeyBackupData>,
) -> Result<()> {
let key = (user_id, version);
if self.db.backupid_algorithm.qry(&key).await.is_err() {
@ -176,14 +193,16 @@ pub async fn get_etag(&self, user_id: &UserId, version: &str) -> String {
}
#[implement(Service)]
pub async fn get_all(&self, user_id: &UserId, version: &str) -> BTreeMap<OwnedRoomId, RoomKeyBackup> {
pub async fn get_all(
&self,
user_id: &UserId,
version: &str,
) -> BTreeMap<OwnedRoomId, RoomKeyBackup> {
type Key<'a> = (Ignore, Ignore, &'a RoomId, &'a str);
type KeyVal<'a> = (Key<'a>, Raw<KeyBackupData>);
let mut rooms = BTreeMap::<OwnedRoomId, RoomKeyBackup>::new();
let default = || RoomKeyBackup {
sessions: BTreeMap::new(),
};
let default = || RoomKeyBackup { sessions: BTreeMap::new() };
let prefix = (user_id, version, Interfix);
self.db
@ -204,7 +223,10 @@ pub async fn get_all(&self, user_id: &UserId, version: &str) -> BTreeMap<OwnedRo
#[implement(Service)]
pub async fn get_room(
&self, user_id: &UserId, version: &str, room_id: &RoomId,
&self,
user_id: &UserId,
version: &str,
room_id: &RoomId,
) -> BTreeMap<String, Raw<KeyBackupData>> {
type KeyVal<'a> = ((Ignore, Ignore, Ignore, &'a str), Raw<KeyBackupData>);
@ -213,14 +235,20 @@ pub async fn get_room(
.backupkeyid_backup
.stream_prefix(&prefix)
.ignore_err()
.map(|((.., session_id), key_backup_data): KeyVal<'_>| (session_id.to_owned(), key_backup_data))
.map(|((.., session_id), key_backup_data): KeyVal<'_>| {
(session_id.to_owned(), key_backup_data)
})
.collect()
.await
}
#[implement(Service)]
pub async fn get_session(
&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str,
&self,
user_id: &UserId,
version: &str,
room_id: &RoomId,
session_id: &str,
) -> Result<Raw<KeyBackupData>> {
let key = (user_id, version, room_id, session_id);
@ -245,17 +273,27 @@ pub async fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &
.backupkeyid_backup
.keys_prefix_raw(&key)
.ignore_err()
.ready_for_each(|outdated_key| self.db.backupkeyid_backup.remove(outdated_key))
.ready_for_each(|outdated_key| {
self.db.backupkeyid_backup.remove(outdated_key);
})
.await;
}
#[implement(Service)]
pub async fn delete_room_key(&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str) {
pub async fn delete_room_key(
&self,
user_id: &UserId,
version: &str,
room_id: &RoomId,
session_id: &str,
) {
let key = (user_id, version, room_id, session_id);
self.db
.backupkeyid_backup
.keys_prefix_raw(&key)
.ignore_err()
.ready_for_each(|outdated_key| self.db.backupkeyid_backup.remove(outdated_key))
.ready_for_each(|outdated_key| {
self.db.backupkeyid_backup.remove(outdated_key);
})
.await;
}

View file

@ -102,21 +102,32 @@ impl Manager {
unimplemented!("unexpected worker task abort {error:?}");
}
async fn handle_result(&self, workers: &mut WorkersLocked<'_>, result: WorkerResult) -> Result<()> {
async fn handle_result(
&self,
workers: &mut WorkersLocked<'_>,
result: WorkerResult,
) -> Result<()> {
let (service, result) = result;
match result {
Ok(()) => self.handle_finished(workers, &service).await,
Err(error) => self.handle_error(workers, &service, error).await,
| Ok(()) => self.handle_finished(workers, &service).await,
| Err(error) => self.handle_error(workers, &service, error).await,
}
}
async fn handle_finished(&self, _workers: &mut WorkersLocked<'_>, service: &Arc<dyn Service>) -> Result<()> {
async fn handle_finished(
&self,
_workers: &mut WorkersLocked<'_>,
service: &Arc<dyn Service>,
) -> Result<()> {
debug!("service {:?} worker finished", service.name());
Ok(())
}
async fn handle_error(
&self, workers: &mut WorkersLocked<'_>, service: &Arc<dyn Service>, error: Error,
&self,
workers: &mut WorkersLocked<'_>,
service: &Arc<dyn Service>,
error: Error,
) -> Result<()> {
let name = service.name();
error!("service {name:?} aborted: {error}");
@ -138,9 +149,16 @@ impl Manager {
}
/// Start the worker in a task for the service.
async fn start_worker(&self, workers: &mut WorkersLocked<'_>, service: &Arc<dyn Service>) -> Result<()> {
async fn start_worker(
&self,
workers: &mut WorkersLocked<'_>,
service: &Arc<dyn Service>,
) -> Result<()> {
if !self.server.running() {
return Err!("Service {:?} worker not starting during server shutdown.", service.name());
return Err!(
"Service {:?} worker not starting during server shutdown.",
service.name()
);
}
debug!("Service {:?} worker starting...", service.name());

View file

@ -34,7 +34,11 @@ impl Data {
}
pub(super) fn create_file_metadata(
&self, mxc: &Mxc<'_>, user: Option<&UserId>, dim: &Dim, content_disposition: Option<&ContentDisposition>,
&self,
mxc: &Mxc<'_>,
user: Option<&UserId>,
dim: &Dim,
content_disposition: Option<&ContentDisposition>,
content_type: Option<&str>,
) -> Result<Vec<u8>> {
let dim: &[u32] = &[dim.width, dim.height];
@ -63,7 +67,10 @@ impl Data {
.stream_prefix_raw(&prefix)
.ignore_err()
.ready_for_each(|(key, val)| {
debug_assert!(key.starts_with(mxc.to_string().as_bytes()), "key should start with the mxc");
debug_assert!(
key.starts_with(mxc.to_string().as_bytes()),
"key should start with the mxc"
);
let user = str_from_bytes(val).unwrap_or_default();
debug_info!("Deleting key {key:?} which was uploaded by user {user}");
@ -95,7 +102,11 @@ impl Data {
Ok(keys)
}
pub(super) async fn search_file_metadata(&self, mxc: &Mxc<'_>, dim: &Dim) -> Result<Metadata> {
pub(super) async fn search_file_metadata(
&self,
mxc: &Mxc<'_>,
dim: &Dim,
) -> Result<Metadata> {
let dim: &[u32] = &[dim.width, dim.height];
let prefix = (mxc, dim, Interfix);
@ -113,8 +124,9 @@ impl Data {
let content_type = parts
.next()
.map(|bytes| {
string_from_bytes(bytes)
.map_err(|_| Error::bad_database("Content type in mediaid_file is invalid unicode."))
string_from_bytes(bytes).map_err(|_| {
Error::bad_database("Content type in mediaid_file is invalid unicode.")
})
})
.transpose()?;
@ -127,16 +139,16 @@ impl Data {
} else {
Some(
string_from_bytes(content_disposition_bytes)
.map_err(|_| Error::bad_database("Content Disposition in mediaid_file is invalid unicode."))?
.map_err(|_| {
Error::bad_database(
"Content Disposition in mediaid_file is invalid unicode.",
)
})?
.parse()?,
)
};
Ok(Metadata {
content_disposition,
content_type,
key,
})
Ok(Metadata { content_disposition, content_type, key })
}
/// Gets all the MXCs associated with a user
@ -144,7 +156,9 @@ impl Data {
self.mediaid_user
.stream()
.ignore_err()
.ready_filter_map(|(key, user): (&str, &UserId)| (user == user_id).then(|| key.into()))
.ready_filter_map(|(key, user): (&str, &UserId)| {
(user == user_id).then(|| key.into())
})
.collect()
.await
}
@ -166,7 +180,12 @@ impl Data {
Ok(())
}
pub(super) fn set_url_preview(&self, url: &str, data: &UrlPreviewData, timestamp: Duration) -> Result<()> {
pub(super) fn set_url_preview(
&self,
url: &str,
data: &UrlPreviewData,
timestamp: Duration,
) -> Result<()> {
let mut value = Vec::<u8>::new();
value.extend_from_slice(&timestamp.as_secs().to_be_bytes());
value.push(0xFF);
@ -218,43 +237,43 @@ impl Data {
.next()
.and_then(|b| String::from_utf8(b.to_vec()).ok())
{
Some(s) if s.is_empty() => None,
x => x,
| Some(s) if s.is_empty() => None,
| x => x,
};
let description = match values
.next()
.and_then(|b| String::from_utf8(b.to_vec()).ok())
{
Some(s) if s.is_empty() => None,
x => x,
| Some(s) if s.is_empty() => None,
| x => x,
};
let image = match values
.next()
.and_then(|b| String::from_utf8(b.to_vec()).ok())
{
Some(s) if s.is_empty() => None,
x => x,
| Some(s) if s.is_empty() => None,
| x => x,
};
let image_size = match values
.next()
.map(|b| usize::from_be_bytes(b.try_into().unwrap_or_default()))
{
Some(0) => None,
x => x,
| Some(0) => None,
| x => x,
};
let image_width = match values
.next()
.map(|b| u32::from_be_bytes(b.try_into().unwrap_or_default()))
{
Some(0) => None,
x => x,
| Some(0) => None,
| x => x,
};
let image_height = match values
.next()
.map(|b| u32::from_be_bytes(b.try_into().unwrap_or_default()))
{
Some(0) => None,
x => x,
| Some(0) => None,
| x => x,
};
Ok(UrlPreviewData {

View file

@ -83,7 +83,8 @@ pub(crate) async fn checkup_sha256_media(services: &Services) -> Result<()> {
for key in media.db.get_all_media_keys().await {
let new_path = media.get_media_file_sha256(&key).into_os_string();
let old_path = media.get_media_file_b64(&key).into_os_string();
if let Err(e) = handle_media_check(&dbs, config, &files, &key, &new_path, &old_path).await {
if let Err(e) = handle_media_check(&dbs, config, &files, &key, &new_path, &old_path).await
{
error!(
media_id = ?encode_key(&key), ?new_path, ?old_path,
"Failed to resolve media check failure: {e}"
@ -100,8 +101,12 @@ pub(crate) async fn checkup_sha256_media(services: &Services) -> Result<()> {
}
async fn handle_media_check(
dbs: &(&Arc<database::Map>, &Arc<database::Map>), config: &Config, files: &HashSet<OsString>, key: &[u8],
new_path: &OsStr, old_path: &OsStr,
dbs: &(&Arc<database::Map>, &Arc<database::Map>),
config: &Config,
files: &HashSet<OsString>,
key: &[u8],
new_path: &OsStr,
old_path: &OsStr,
) -> Result<()> {
use crate::media::encode_key;

View file

@ -80,13 +80,21 @@ impl crate::Service for Service {
impl Service {
/// Uploads a file.
pub async fn create(
&self, mxc: &Mxc<'_>, user: Option<&UserId>, content_disposition: Option<&ContentDisposition>,
content_type: Option<&str>, file: &[u8],
&self,
mxc: &Mxc<'_>,
user: Option<&UserId>,
content_disposition: Option<&ContentDisposition>,
content_type: Option<&str>,
file: &[u8],
) -> Result<()> {
// Width, Height = 0 if it's not a thumbnail
let key = self
.db
.create_file_metadata(mxc, user, &Dim::default(), content_disposition, content_type)?;
let key = self.db.create_file_metadata(
mxc,
user,
&Dim::default(),
content_disposition,
content_type,
)?;
//TODO: Dangling metadata in database if creation fails
let mut f = self.create_media_file(&key).await?;
@ -132,10 +140,10 @@ impl Service {
debug_info!(%deletion_count, "Deleting MXC {mxc} by user {user} from database and filesystem");
match self.delete(&mxc).await {
Ok(()) => {
| Ok(()) => {
deletion_count = deletion_count.saturating_add(1);
},
Err(e) => {
| Err(e) => {
debug_error!(%deletion_count, "Failed to delete {mxc} from user {user}, ignoring error: {e}");
},
}
@ -146,11 +154,8 @@ impl Service {
/// Downloads a file.
pub async fn get(&self, mxc: &Mxc<'_>) -> Result<Option<FileMeta>> {
if let Ok(Metadata {
content_disposition,
content_type,
key,
}) = self.db.search_file_metadata(mxc, &Dim::default()).await
if let Ok(Metadata { content_disposition, content_type, key }) =
self.db.search_file_metadata(mxc, &Dim::default()).await
{
let mut content = Vec::new();
let path = self.get_media_file(&key);
@ -181,13 +186,19 @@ impl Service {
let mxc = parts
.next()
.map(|bytes| {
utils::string_from_bytes(bytes)
.map_err(|e| err!(Database(error!("Failed to parse MXC unicode bytes from our database: {e}"))))
utils::string_from_bytes(bytes).map_err(|e| {
err!(Database(error!(
"Failed to parse MXC unicode bytes from our database: {e}"
)))
})
})
.transpose()?;
let Some(mxc_s) = mxc else {
debug_warn!(?mxc, "Parsed MXC URL unicode bytes from database but is still invalid");
debug_warn!(
?mxc,
"Parsed MXC URL unicode bytes from database but is still invalid"
);
continue;
};
@ -207,7 +218,11 @@ impl Service {
/// Deletes all remote only media files in the given at or after
/// time/duration. Returns a usize with the amount of media files deleted.
pub async fn delete_all_remote_media_at_after_time(
&self, time: SystemTime, before: bool, after: bool, yes_i_want_to_delete_local_media: bool,
&self,
time: SystemTime,
before: bool,
after: bool,
yes_i_want_to_delete_local_media: bool,
) -> Result<usize> {
let all_keys = self.db.get_all_media_keys().await;
let mut remote_mxcs = Vec::with_capacity(all_keys.len());
@ -218,19 +233,26 @@ impl Service {
let mxc = parts
.next()
.map(|bytes| {
utils::string_from_bytes(bytes)
.map_err(|e| err!(Database(error!("Failed to parse MXC unicode bytes from our database: {e}"))))
utils::string_from_bytes(bytes).map_err(|e| {
err!(Database(error!(
"Failed to parse MXC unicode bytes from our database: {e}"
)))
})
})
.transpose()?;
let Some(mxc_s) = mxc else {
debug_warn!(?mxc, "Parsed MXC URL unicode bytes from database but is still invalid");
debug_warn!(
?mxc,
"Parsed MXC URL unicode bytes from database but is still invalid"
);
continue;
};
trace!("Parsed MXC key to URL: {mxc_s}");
let mxc = OwnedMxcUri::from(mxc_s);
if (mxc.server_name() == Ok(self.services.globals.server_name()) && !yes_i_want_to_delete_local_media)
if (mxc.server_name() == Ok(self.services.globals.server_name())
&& !yes_i_want_to_delete_local_media)
|| !mxc.is_valid()
{
debug!("Ignoring local or broken media MXC: {mxc}");
@ -240,9 +262,12 @@ impl Service {
let path = self.get_media_file(&key);
let file_metadata = match fs::metadata(path.clone()).await {
Ok(file_metadata) => file_metadata,
Err(e) => {
error!("Failed to obtain file metadata for MXC {mxc} at file path \"{path:?}\", skipping: {e}");
| Ok(file_metadata) => file_metadata,
| Err(e) => {
error!(
"Failed to obtain file metadata for MXC {mxc} at file path \
\"{path:?}\", skipping: {e}"
);
continue;
},
};
@ -250,12 +275,12 @@ impl Service {
trace!(%mxc, ?path, "File metadata: {file_metadata:?}");
let file_created_at = match file_metadata.created() {
Ok(value) => value,
Err(err) if err.kind() == std::io::ErrorKind::Unsupported => {
| Ok(value) => value,
| Err(err) if err.kind() == std::io::ErrorKind::Unsupported => {
debug!("btime is unsupported, using mtime instead");
file_metadata.modified()?
},
Err(err) => {
| Err(err) => {
error!("Could not delete MXC {mxc} at path {path:?}: {err:?}. Skipping...");
continue;
},
@ -264,10 +289,16 @@ impl Service {
debug!("File created at: {file_created_at:?}");
if file_created_at >= time && before {
debug!("File is within (before) user duration, pushing to list of file paths and keys to delete.");
debug!(
"File is within (before) user duration, pushing to list of file paths and \
keys to delete."
);
remote_mxcs.push(mxc.to_string());
} else if file_created_at <= time && after {
debug!("File is not within (after) user duration, pushing to list of file paths and keys to delete.");
debug!(
"File is not within (after) user duration, pushing to list of file paths \
and keys to delete."
);
remote_mxcs.push(mxc.to_string());
}
}
@ -289,10 +320,10 @@ impl Service {
debug_info!("Deleting MXC {mxc} from database and filesystem");
match self.delete(&mxc).await {
Ok(()) => {
| Ok(()) => {
deletion_count = deletion_count.saturating_add(1);
},
Err(e) => {
| Err(e) => {
warn!("Failed to delete {mxc}, ignoring error and skipping: {e}");
continue;
},

View file

@ -53,10 +53,10 @@ pub async fn download_image(&self, url: &str) -> Result<UrlPreviewData> {
self.create(&mxc, None, None, None, &image).await?;
let (width, height) = match ImgReader::new(Cursor::new(&image)).with_guessed_format() {
Err(_) => (None, None),
Ok(reader) => match reader.into_dimensions() {
Err(_) => (None, None),
Ok((width, height)) => (Some(width), Some(height)),
| Err(_) => (None, None),
| Ok(reader) => match reader.into_dimensions() {
| Err(_) => (None, None),
| Ok((width, height)) => (Some(width), Some(height)),
},
};
@ -79,8 +79,8 @@ pub async fn get_url_preview(&self, url: &Url) -> Result<UrlPreviewData> {
let _request_lock = self.url_preview_mutex.lock(url.as_str()).await;
match self.db.get_url_preview(url.as_str()).await {
Ok(preview) => Ok(preview),
Err(_) => self.request_url_preview(url).await,
| Ok(preview) => Ok(preview),
| Err(_) => self.request_url_preview(url).await,
}
}
@ -111,9 +111,9 @@ async fn request_url_preview(&self, url: &Url) -> Result<UrlPreviewData> {
return Err!(Request(Unknown("Unknown Content-Type")));
};
let data = match content_type {
html if html.starts_with("text/html") => self.download_html(url.as_str()).await?,
img if img.starts_with("image/") => self.download_image(url.as_str()).await?,
_ => return Err!(Request(Unknown("Unsupported Content-Type"))),
| html if html.starts_with("text/html") => self.download_html(url.as_str()).await?,
| img if img.starts_with("image/") => self.download_image(url.as_str()).await?,
| _ => return Err!(Request(Unknown("Unsupported Content-Type"))),
};
self.set_url_preview(url.as_str(), &data).await?;
@ -131,8 +131,9 @@ async fn download_html(&self, url: &str) -> Result<UrlPreviewData> {
bytes.extend_from_slice(&chunk);
if bytes.len() > self.services.globals.url_preview_max_spider_size() {
debug!(
"Response body from URL {} exceeds url_preview_max_spider_size ({}), not processing the rest of the \
response body and assuming our necessary data is in this range.",
"Response body from URL {} exceeds url_preview_max_spider_size ({}), not \
processing the rest of the response body and assuming our necessary data is in \
this range.",
url,
self.services.globals.url_preview_max_spider_size()
);
@ -145,8 +146,8 @@ async fn download_html(&self, url: &str) -> Result<UrlPreviewData> {
};
let mut data = match html.opengraph.images.first() {
None => UrlPreviewData::default(),
Some(obj) => self.download_image(&obj.url).await?,
| None => UrlPreviewData::default(),
| Some(obj) => self.download_image(&obj.url).await?,
};
let props = html.opengraph.properties;
@ -169,11 +170,11 @@ pub fn url_preview_allowed(&self, url: &Url) -> bool {
}
let host = match url.host_str() {
None => {
| None => {
debug!("Ignoring URL preview for a URL that does not have a host (?): {}", url);
return false;
},
Some(h) => h.to_owned(),
| Some(h) => h.to_owned(),
};
let allowlist_domain_contains = self
@ -205,7 +206,10 @@ pub fn url_preview_allowed(&self, url: &Url) -> bool {
}
if allowlist_domain_explicit.contains(&host) {
debug!("Host {} is allowed by url_preview_domain_explicit_allowlist (check 2/4)", &host);
debug!(
"Host {} is allowed by url_preview_domain_explicit_allowlist (check 2/4)",
&host
);
return true;
}
@ -213,7 +217,10 @@ pub fn url_preview_allowed(&self, url: &Url) -> bool {
.iter()
.any(|domain_s| domain_s.contains(&host.clone()))
{
debug!("Host {} is allowed by url_preview_domain_contains_allowlist (check 3/4)", &host);
debug!(
"Host {} is allowed by url_preview_domain_contains_allowlist (check 3/4)",
&host
);
return true;
}
@ -229,11 +236,12 @@ pub fn url_preview_allowed(&self, url: &Url) -> bool {
if self.services.globals.url_preview_check_root_domain() {
debug!("Checking root domain");
match host.split_once('.') {
None => return false,
Some((_, root_domain)) => {
| None => return false,
| Some((_, root_domain)) => {
if denylist_domain_explicit.contains(&root_domain.to_owned()) {
debug!(
"Root domain {} is not allowed by url_preview_domain_explicit_denylist (check 1/3)",
"Root domain {} is not allowed by \
url_preview_domain_explicit_denylist (check 1/3)",
&root_domain
);
return true;
@ -241,7 +249,8 @@ pub fn url_preview_allowed(&self, url: &Url) -> bool {
if allowlist_domain_explicit.contains(&root_domain.to_owned()) {
debug!(
"Root domain {} is allowed by url_preview_domain_explicit_allowlist (check 2/3)",
"Root domain {} is allowed by url_preview_domain_explicit_allowlist \
(check 2/3)",
&root_domain
);
return true;
@ -252,7 +261,8 @@ pub fn url_preview_allowed(&self, url: &Url) -> bool {
.any(|domain_s| domain_s.contains(&root_domain.to_owned()))
{
debug!(
"Root domain {} is allowed by url_preview_domain_contains_allowlist (check 3/3)",
"Root domain {} is allowed by url_preview_domain_contains_allowlist \
(check 3/3)",
&root_domain
);
return true;

View file

@ -1,6 +1,9 @@
use std::{fmt::Debug, time::Duration};
use conduwuit::{debug_warn, err, implement, utils::content_disposition::make_content_disposition, Err, Error, Result};
use conduwuit::{
debug_warn, err, implement, utils::content_disposition::make_content_disposition, Err, Error,
Result,
};
use http::header::{HeaderValue, CONTENT_DISPOSITION, CONTENT_TYPE};
use ruma::{
api::{
@ -19,7 +22,12 @@ use super::{Dim, FileMeta};
#[implement(super::Service)]
pub async fn fetch_remote_thumbnail(
&self, mxc: &Mxc<'_>, user: Option<&UserId>, server: Option<&ServerName>, timeout_ms: Duration, dim: &Dim,
&self,
mxc: &Mxc<'_>,
user: Option<&UserId>,
server: Option<&ServerName>,
timeout_ms: Duration,
dim: &Dim,
) -> Result<FileMeta> {
self.check_fetch_authorized(mxc)?;
@ -38,7 +46,11 @@ pub async fn fetch_remote_thumbnail(
#[implement(super::Service)]
pub async fn fetch_remote_content(
&self, mxc: &Mxc<'_>, user: Option<&UserId>, server: Option<&ServerName>, timeout_ms: Duration,
&self,
mxc: &Mxc<'_>,
user: Option<&UserId>,
server: Option<&ServerName>,
timeout_ms: Duration,
) -> Result<FileMeta> {
self.check_fetch_authorized(mxc)?;
@ -57,7 +69,12 @@ pub async fn fetch_remote_content(
#[implement(super::Service)]
async fn fetch_thumbnail_authenticated(
&self, mxc: &Mxc<'_>, user: Option<&UserId>, server: Option<&ServerName>, timeout_ms: Duration, dim: &Dim,
&self,
mxc: &Mxc<'_>,
user: Option<&UserId>,
server: Option<&ServerName>,
timeout_ms: Duration,
dim: &Dim,
) -> Result<FileMeta> {
use federation::authenticated_media::get_content_thumbnail::v1::{Request, Response};
@ -70,20 +87,22 @@ async fn fetch_thumbnail_authenticated(
timeout_ms,
};
let Response {
content,
..
} = self.federation_request(mxc, user, server, request).await?;
let Response { content, .. } = self.federation_request(mxc, user, server, request).await?;
match content {
FileOrLocation::File(content) => self.handle_thumbnail_file(mxc, user, dim, content).await,
FileOrLocation::Location(location) => self.handle_location(mxc, user, &location).await,
| FileOrLocation::File(content) =>
self.handle_thumbnail_file(mxc, user, dim, content).await,
| FileOrLocation::Location(location) => self.handle_location(mxc, user, &location).await,
}
}
#[implement(super::Service)]
async fn fetch_content_authenticated(
&self, mxc: &Mxc<'_>, user: Option<&UserId>, server: Option<&ServerName>, timeout_ms: Duration,
&self,
mxc: &Mxc<'_>,
user: Option<&UserId>,
server: Option<&ServerName>,
timeout_ms: Duration,
) -> Result<FileMeta> {
use federation::authenticated_media::get_content::v1::{Request, Response};
@ -92,21 +111,23 @@ async fn fetch_content_authenticated(
timeout_ms,
};
let Response {
content,
..
} = self.federation_request(mxc, user, server, request).await?;
let Response { content, .. } = self.federation_request(mxc, user, server, request).await?;
match content {
FileOrLocation::File(content) => self.handle_content_file(mxc, user, content).await,
FileOrLocation::Location(location) => self.handle_location(mxc, user, &location).await,
| FileOrLocation::File(content) => self.handle_content_file(mxc, user, content).await,
| FileOrLocation::Location(location) => self.handle_location(mxc, user, &location).await,
}
}
#[allow(deprecated)]
#[implement(super::Service)]
async fn fetch_thumbnail_unauthenticated(
&self, mxc: &Mxc<'_>, user: Option<&UserId>, server: Option<&ServerName>, timeout_ms: Duration, dim: &Dim,
&self,
mxc: &Mxc<'_>,
user: Option<&UserId>,
server: Option<&ServerName>,
timeout_ms: Duration,
dim: &Dim,
) -> Result<FileMeta> {
use media::get_content_thumbnail::v3::{Request, Response};
@ -123,17 +144,10 @@ async fn fetch_thumbnail_unauthenticated(
};
let Response {
file,
content_type,
content_disposition,
..
file, content_type, content_disposition, ..
} = self.federation_request(mxc, user, server, request).await?;
let content = Content {
file,
content_type,
content_disposition,
};
let content = Content { file, content_type, content_disposition };
self.handle_thumbnail_file(mxc, user, dim, content).await
}
@ -141,7 +155,11 @@ async fn fetch_thumbnail_unauthenticated(
#[allow(deprecated)]
#[implement(super::Service)]
async fn fetch_content_unauthenticated(
&self, mxc: &Mxc<'_>, user: Option<&UserId>, server: Option<&ServerName>, timeout_ms: Duration,
&self,
mxc: &Mxc<'_>,
user: Option<&UserId>,
server: Option<&ServerName>,
timeout_ms: Duration,
) -> Result<FileMeta> {
use media::get_content::v3::{Request, Response};
@ -154,27 +172,27 @@ async fn fetch_content_unauthenticated(
};
let Response {
file,
content_type,
content_disposition,
..
file, content_type, content_disposition, ..
} = self.federation_request(mxc, user, server, request).await?;
let content = Content {
file,
content_type,
content_disposition,
};
let content = Content { file, content_type, content_disposition };
self.handle_content_file(mxc, user, content).await
}
#[implement(super::Service)]
async fn handle_thumbnail_file(
&self, mxc: &Mxc<'_>, user: Option<&UserId>, dim: &Dim, content: Content,
&self,
mxc: &Mxc<'_>,
user: Option<&UserId>,
dim: &Dim,
content: Content,
) -> Result<FileMeta> {
let content_disposition =
make_content_disposition(content.content_disposition.as_ref(), content.content_type.as_deref(), None);
let content_disposition = make_content_disposition(
content.content_disposition.as_ref(),
content.content_type.as_deref(),
None,
);
self.upload_thumbnail(
mxc,
@ -193,9 +211,17 @@ async fn handle_thumbnail_file(
}
#[implement(super::Service)]
async fn handle_content_file(&self, mxc: &Mxc<'_>, user: Option<&UserId>, content: Content) -> Result<FileMeta> {
let content_disposition =
make_content_disposition(content.content_disposition.as_ref(), content.content_type.as_deref(), None);
async fn handle_content_file(
&self,
mxc: &Mxc<'_>,
user: Option<&UserId>,
content: Content,
) -> Result<FileMeta> {
let content_disposition = make_content_disposition(
content.content_disposition.as_ref(),
content.content_type.as_deref(),
None,
);
self.create(
mxc,
@ -213,7 +239,12 @@ async fn handle_content_file(&self, mxc: &Mxc<'_>, user: Option<&UserId>, conten
}
#[implement(super::Service)]
async fn handle_location(&self, mxc: &Mxc<'_>, user: Option<&UserId>, location: &str) -> Result<FileMeta> {
async fn handle_location(
&self,
mxc: &Mxc<'_>,
user: Option<&UserId>,
location: &str,
) -> Result<FileMeta> {
self.location_request(location).await.map_err(|error| {
err!(Request(NotFound(
debug_warn!(%mxc, ?user, ?location, ?error, "Fetching media from location failed")
@ -263,7 +294,11 @@ async fn location_request(&self, location: &str) -> Result<FileMeta> {
#[implement(super::Service)]
async fn federation_request<Request>(
&self, mxc: &Mxc<'_>, user: Option<&UserId>, server: Option<&ServerName>, request: Request,
&self,
mxc: &Mxc<'_>,
user: Option<&UserId>,
server: Option<&ServerName>,
request: Request,
) -> Result<Request::IncomingResponse>
where
Request: OutgoingRequest + Send + Debug,
@ -277,7 +312,12 @@ where
// Handles and adjusts the error for the caller to determine if they should
// request the fallback endpoint or give up.
fn handle_federation_error(mxc: &Mxc<'_>, user: Option<&UserId>, server: Option<&ServerName>, error: Error) -> Error {
fn handle_federation_error(
mxc: &Mxc<'_>,
user: Option<&UserId>,
server: Option<&ServerName>,
error: Error,
) -> Error {
let fallback = || {
err!(Request(NotFound(
debug_error!(%mxc, ?user, ?server, ?error, "Remote media not found")
@ -303,7 +343,8 @@ fn handle_federation_error(mxc: &Mxc<'_>, user: Option<&UserId>, server: Option<
#[implement(super::Service)]
#[allow(deprecated)]
pub async fn fetch_remote_thumbnail_legacy(
&self, body: &media::get_content_thumbnail::v3::Request,
&self,
body: &media::get_content_thumbnail::v3::Request,
) -> Result<media::get_content_thumbnail::v3::Response> {
let mxc = Mxc {
server_name: &body.server_name,
@ -315,20 +356,17 @@ pub async fn fetch_remote_thumbnail_legacy(
let reponse = self
.services
.sending
.send_federation_request(
mxc.server_name,
media::get_content_thumbnail::v3::Request {
allow_remote: body.allow_remote,
height: body.height,
width: body.width,
method: body.method.clone(),
server_name: body.server_name.clone(),
media_id: body.media_id.clone(),
timeout_ms: body.timeout_ms,
allow_redirect: body.allow_redirect,
animated: body.animated,
},
)
.send_federation_request(mxc.server_name, media::get_content_thumbnail::v3::Request {
allow_remote: body.allow_remote,
height: body.height,
width: body.width,
method: body.method.clone(),
server_name: body.server_name.clone(),
media_id: body.media_id.clone(),
timeout_ms: body.timeout_ms,
allow_redirect: body.allow_redirect,
animated: body.animated,
})
.await?;
let dim = Dim::from_ruma(body.width, body.height, body.method.clone())?;
@ -341,27 +379,30 @@ pub async fn fetch_remote_thumbnail_legacy(
#[implement(super::Service)]
#[allow(deprecated)]
pub async fn fetch_remote_content_legacy(
&self, mxc: &Mxc<'_>, allow_redirect: bool, timeout_ms: Duration,
&self,
mxc: &Mxc<'_>,
allow_redirect: bool,
timeout_ms: Duration,
) -> Result<media::get_content::v3::Response, Error> {
self.check_legacy_freeze()?;
self.check_fetch_authorized(mxc)?;
let response = self
.services
.sending
.send_federation_request(
mxc.server_name,
media::get_content::v3::Request {
allow_remote: true,
server_name: mxc.server_name.into(),
media_id: mxc.media_id.into(),
timeout_ms,
allow_redirect,
},
)
.send_federation_request(mxc.server_name, media::get_content::v3::Request {
allow_remote: true,
server_name: mxc.server_name.into(),
media_id: mxc.media_id.into(),
timeout_ms,
allow_redirect,
})
.await?;
let content_disposition =
make_content_disposition(response.content_disposition.as_ref(), response.content_type.as_deref(), None);
let content_disposition = make_content_disposition(
response.content_disposition.as_ref(),
response.content_type.as_deref(),
None,
);
self.create(
mxc,

View file

@ -13,7 +13,12 @@ async fn long_file_names_works() {
impl Data for MockedKVDatabase {
fn create_file_metadata(
&self, _sender_user: Option<&str>, mxc: String, width: u32, height: u32, content_disposition: Option<&str>,
&self,
_sender_user: Option<&str>,
mxc: String,
width: u32,
height: u32,
content_disposition: Option<&str>,
content_type: Option<&str>,
) -> Result<Vec<u8>> {
// copied from src/database/key_value/media.rs
@ -46,14 +51,22 @@ async fn long_file_names_works() {
fn get_all_media_keys(&self) -> Vec<Vec<u8>> { todo!() }
fn search_file_metadata(
&self, _mxc: String, _width: u32, _height: u32,
&self,
_mxc: String,
_width: u32,
_height: u32,
) -> Result<(Option<String>, Option<String>, Vec<u8>)> {
todo!()
}
fn remove_url_preview(&self, _url: &str) -> Result<()> { todo!() }
fn set_url_preview(&self, _url: &str, _data: &UrlPreviewData, _timestamp: std::time::Duration) -> Result<()> {
fn set_url_preview(
&self,
_url: &str,
_data: &UrlPreviewData,
_timestamp: std::time::Duration,
) -> Result<()> {
todo!()
}
@ -64,11 +77,18 @@ async fn long_file_names_works() {
let mxc = "mxc://example.com/ascERGshawAWawugaAcauga".to_owned();
let width = 100;
let height = 100;
let content_disposition = "attachment; filename=\"this is a very long file name with spaces and special \
characters like äöüß and even emoji like 🦀.png\"";
let content_disposition = "attachment; filename=\"this is a very long file name with spaces \
and special characters like äöüß and even emoji like 🦀.png\"";
let content_type = "image/png";
let key = db
.create_file_metadata(None, mxc, width, height, Some(content_disposition), Some(content_type))
.create_file_metadata(
None,
mxc,
width,
height,
Some(content_disposition),
Some(content_type),
)
.unwrap();
let mut r = PathBuf::from("/tmp/media");
// r.push(base64::encode_config(key, base64::URL_SAFE_NO_PAD));

View file

@ -22,12 +22,17 @@ impl super::Service {
/// Uploads or replaces a file thumbnail.
#[allow(clippy::too_many_arguments)]
pub async fn upload_thumbnail(
&self, mxc: &Mxc<'_>, user: Option<&UserId>, content_disposition: Option<&ContentDisposition>,
content_type: Option<&str>, dim: &Dim, file: &[u8],
&self,
mxc: &Mxc<'_>,
user: Option<&UserId>,
content_disposition: Option<&ContentDisposition>,
content_type: Option<&str>,
dim: &Dim,
file: &[u8],
) -> Result<()> {
let key = self
.db
.create_file_metadata(mxc, user, dim, content_disposition, content_type)?;
let key =
self.db
.create_file_metadata(mxc, user, dim, content_disposition, content_type)?;
//TODO: Dangling metadata in database if creation fails
let mut f = self.create_media_file(&key).await?;
@ -78,7 +83,12 @@ impl super::Service {
/// Generate a thumbnail
#[tracing::instrument(skip(self), name = "generate", level = "debug")]
async fn get_thumbnail_generate(&self, mxc: &Mxc<'_>, dim: &Dim, data: Metadata) -> Result<Option<FileMeta>> {
async fn get_thumbnail_generate(
&self,
mxc: &Mxc<'_>,
dim: &Dim,
data: Metadata,
) -> Result<Option<FileMeta>> {
let mut content = Vec::new();
let path = self.get_media_file(&data.key);
fs::File::open(path)
@ -117,11 +127,7 @@ impl super::Service {
fn thumbnail_generate(image: &DynamicImage, requested: &Dim) -> Result<DynamicImage> {
let thumbnail = if !requested.crop() {
let Dim {
width,
height,
..
} = requested.scaled(&Dim {
let Dim { width, height, .. } = requested.scaled(&Dim {
width: image.width(),
height: image.height(),
..Dim::default()
@ -202,12 +208,12 @@ impl Dim {
#[must_use]
pub fn normalized(&self) -> Self {
match (self.width, self.height) {
(0..=32, 0..=32) => Self::new(32, 32, Some(Method::Crop)),
(0..=96, 0..=96) => Self::new(96, 96, Some(Method::Crop)),
(0..=320, 0..=240) => Self::new(320, 240, Some(Method::Scale)),
(0..=640, 0..=480) => Self::new(640, 480, Some(Method::Scale)),
(0..=800, 0..=600) => Self::new(800, 600, Some(Method::Scale)),
_ => Self::default(),
| (0..=32, 0..=32) => Self::new(32, 32, Some(Method::Crop)),
| (0..=96, 0..=96) => Self::new(96, 96, Some(Method::Crop)),
| (0..=320, 0..=240) => Self::new(320, 240, Some(Method::Scale)),
| (0..=640, 0..=480) => Self::new(640, 480, Some(Method::Scale)),
| (0..=800, 0..=600) => Self::new(800, 600, Some(Method::Scale)),
| _ => Self::default(),
}
}

View file

@ -12,7 +12,9 @@ use conduwuit::{
use futures::{FutureExt, StreamExt};
use itertools::Itertools;
use ruma::{
events::{push_rules::PushRulesEvent, room::member::MembershipState, GlobalAccountDataEventType},
events::{
push_rules::PushRulesEvent, room::member::MembershipState, GlobalAccountDataEventType,
},
push::Ruleset,
OwnedUserId, RoomId, UserId,
};
@ -45,7 +47,8 @@ pub(crate) async fn migrations(services: &Services) -> Result<()> {
if !services.users.exists(server_user).await {
error!("The {server_user} server user does not exist, and the database is not new.");
return Err!(Database(
"Cannot reuse an existing database after changing the server name, please delete the old one first.",
"Cannot reuse an existing database after changing the server name, please \
delete the old one first.",
));
}
}
@ -144,7 +147,8 @@ async fn migrate(services: &Services) -> Result<()> {
assert!(
version_match,
"Failed asserting local database version {} is equal to known latest conduwuit database version {}",
"Failed asserting local database version {} is equal to known latest conduwuit database \
version {}",
services.globals.db.database_version().await,
DATABASE_VERSION,
);
@ -192,7 +196,8 @@ async fn migrate(services: &Services) -> Result<()> {
let matches = patterns.matches(room_alias.alias());
if matches.matched_any() {
warn!(
"Room with alias {} ({}) matches the following forbidden room name patterns: {}",
"Room with alias {} ({}) matches the following forbidden room \
name patterns: {}",
room_alias,
&room_id,
matches
@ -223,8 +228,8 @@ async fn db_lt_12(services: &Services) -> Result<()> {
.await
{
let user = match UserId::parse_with_server_name(username.as_str(), &config.server_name) {
Ok(u) => u,
Err(e) => {
| Ok(u) => u,
| Err(e) => {
warn!("Invalid username {username}: {e}");
continue;
},
@ -240,7 +245,8 @@ async fn db_lt_12(services: &Services) -> Result<()> {
//content rule
{
let content_rule_transformation = [".m.rules.contains_user_name", ".m.rule.contains_user_name"];
let content_rule_transformation =
[".m.rules.contains_user_name", ".m.rule.contains_user_name"];
let rule = rules_list.content.get(content_rule_transformation[0]);
if rule.is_some() {
@ -301,8 +307,8 @@ async fn db_lt_13(services: &Services) -> Result<()> {
.await
{
let user = match UserId::parse_with_server_name(username.as_str(), &config.server_name) {
Ok(u) => u,
Err(e) => {
| Ok(u) => u,
| Err(e) => {
warn!("Invalid username {username}: {e}");
continue;
},
@ -413,7 +419,9 @@ async fn retroactively_fix_bad_data_from_roomuserid_joined(services: &Services)
.rooms
.state_accessor
.get_member(room_id, user_id)
.map(|member| member.is_ok_and(|member| member.membership == MembershipState::Join))
.map(|member| {
member.is_ok_and(|member| member.membership == MembershipState::Join)
})
})
.collect::<Vec<_>>()
.await;
@ -426,7 +434,9 @@ async fn retroactively_fix_bad_data_from_roomuserid_joined(services: &Services)
.rooms
.state_accessor
.get_member(room_id, user_id)
.map(|member| member.is_ok_and(|member| member.membership == MembershipState::Join))
.map(|member| {
member.is_ok_and(|member| member.membership == MembershipState::Join)
})
})
.collect::<Vec<_>>()
.await;
@ -444,7 +454,8 @@ async fn retroactively_fix_bad_data_from_roomuserid_joined(services: &Services)
for room_id in &room_ids {
debug_info!(
"Updating joined count for room {room_id} to fix servers in room after correcting membership states"
"Updating joined count for room {room_id} to fix servers in room after correcting \
membership states"
);
services

View file

@ -53,18 +53,22 @@ impl Data {
}
pub(super) async fn set_presence(
&self, user_id: &UserId, presence_state: &PresenceState, currently_active: Option<bool>,
last_active_ago: Option<UInt>, status_msg: Option<String>,
&self,
user_id: &UserId,
presence_state: &PresenceState,
currently_active: Option<bool>,
last_active_ago: Option<UInt>,
status_msg: Option<String>,
) -> Result<()> {
let last_presence = self.get_presence(user_id).await;
let state_changed = match last_presence {
Err(_) => true,
Ok(ref presence) => presence.1.content.presence != *presence_state,
| Err(_) => true,
| Ok(ref presence) => presence.1.content.presence != *presence_state,
};
let status_msg_changed = match last_presence {
Err(_) => true,
Ok(ref last_presence) => {
| Err(_) => true,
| Ok(ref last_presence) => {
let old_msg = last_presence
.1
.content
@ -80,18 +84,22 @@ impl Data {
let now = utils::millis_since_unix_epoch();
let last_last_active_ts = match last_presence {
Err(_) => 0,
Ok((_, ref presence)) => now.saturating_sub(presence.content.last_active_ago.unwrap_or_default().into()),
| Err(_) => 0,
| Ok((_, ref presence)) =>
now.saturating_sub(presence.content.last_active_ago.unwrap_or_default().into()),
};
let last_active_ts = match last_active_ago {
None => now,
Some(last_active_ago) => now.saturating_sub(last_active_ago.into()),
| None => now,
| Some(last_active_ago) => now.saturating_sub(last_active_ago.into()),
};
// TODO: tighten for state flicker?
if !status_msg_changed && !state_changed && last_active_ts < last_last_active_ts {
debug_warn!("presence spam {user_id:?} last_active_ts:{last_active_ts:?} < {last_last_active_ts:?}",);
debug_warn!(
"presence spam {user_id:?} last_active_ts:{last_active_ts:?} < \
{last_last_active_ts:?}",
);
return Ok(());
}
@ -138,7 +146,10 @@ impl Data {
}
#[inline]
pub(super) fn presence_since(&self, since: u64) -> impl Stream<Item = (&UserId, u64, &[u8])> + Send + '_ {
pub(super) fn presence_since(
&self,
since: u64,
) -> impl Stream<Item = (&UserId, u64, &[u8])> + Send + '_ {
self.presenceid_presence
.raw_stream()
.ignore_err()

View file

@ -99,13 +99,14 @@ impl Service {
let last_presence = self.db.get_presence(user_id).await;
let state_changed = match last_presence {
Err(_) => true,
Ok((_, ref presence)) => presence.content.presence != *new_state,
| Err(_) => true,
| Ok((_, ref presence)) => presence.content.presence != *new_state,
};
let last_last_active_ago = match last_presence {
Err(_) => 0_u64,
Ok((_, ref presence)) => presence.content.last_active_ago.unwrap_or_default().into(),
| Err(_) => 0_u64,
| Ok((_, ref presence)) =>
presence.content.last_active_ago.unwrap_or_default().into(),
};
if !state_changed && last_last_active_ago < REFRESH_TIMEOUT {
@ -113,8 +114,8 @@ impl Service {
}
let status_msg = match last_presence {
Ok((_, ref presence)) => presence.content.status_msg.clone(),
Err(_) => Some(String::new()),
| Ok((_, ref presence)) => presence.content.status_msg.clone(),
| Err(_) => Some(String::new()),
};
let last_active_ago = UInt::new(0);
@ -125,12 +126,16 @@ impl Service {
/// Adds a presence event which will be saved until a new event replaces it.
pub async fn set_presence(
&self, user_id: &UserId, state: &PresenceState, currently_active: Option<bool>, last_active_ago: Option<UInt>,
&self,
user_id: &UserId,
state: &PresenceState,
currently_active: Option<bool>,
last_active_ago: Option<UInt>,
status_msg: Option<String>,
) -> Result<()> {
let presence_state = match state.as_str() {
"" => &PresenceState::Offline, // default an empty string to 'offline'
&_ => state,
| "" => &PresenceState::Offline, // default an empty string to 'offline'
| &_ => state,
};
self.db
@ -141,8 +146,8 @@ impl Service {
&& user_id != self.services.globals.server_user
{
let timeout = match presence_state {
PresenceState::Online => self.services.server.config.presence_idle_timeout_s,
_ => self.services.server.config.presence_offline_timeout_s,
| PresenceState::Online => self.services.server.config.presence_idle_timeout_s,
| _ => self.services.server.config.presence_offline_timeout_s,
};
self.timer_sender
@ -160,16 +165,25 @@ impl Service {
///
/// TODO: Why is this not used?
#[allow(dead_code)]
pub async fn remove_presence(&self, user_id: &UserId) { self.db.remove_presence(user_id).await }
pub async fn remove_presence(&self, user_id: &UserId) {
self.db.remove_presence(user_id).await;
}
/// Returns the most recent presence updates that happened after the event
/// with id `since`.
pub fn presence_since(&self, since: u64) -> impl Stream<Item = (&UserId, u64, &[u8])> + Send + '_ {
pub fn presence_since(
&self,
since: u64,
) -> impl Stream<Item = (&UserId, u64, &[u8])> + Send + '_ {
self.db.presence_since(since)
}
#[inline]
pub async fn from_json_bytes_to_event(&self, bytes: &[u8], user_id: &UserId) -> Result<PresenceEvent> {
pub async fn from_json_bytes_to_event(
&self,
bytes: &[u8],
user_id: &UserId,
) -> Result<PresenceEvent> {
let presence = Presence::from_json_bytes(bytes)?;
let event = presence
.to_presence_event(user_id, &self.services.users)
@ -192,13 +206,16 @@ impl Service {
}
let new_state = match (&presence_state, last_active_ago.map(u64::from)) {
(PresenceState::Online, Some(ago)) if ago >= self.idle_timeout => Some(PresenceState::Unavailable),
(PresenceState::Unavailable, Some(ago)) if ago >= self.offline_timeout => Some(PresenceState::Offline),
_ => None,
| (PresenceState::Online, Some(ago)) if ago >= self.idle_timeout =>
Some(PresenceState::Unavailable),
| (PresenceState::Unavailable, Some(ago)) if ago >= self.offline_timeout =>
Some(PresenceState::Offline),
| _ => None,
};
debug!(
"Processed presence timer for user '{user_id}': Old state = {presence_state}, New state = {new_state:?}"
"Processed presence timer for user '{user_id}': Old state = {presence_state}, New \
state = {new_state:?}"
);
if let Some(new_state) = new_state {

View file

@ -21,7 +21,10 @@ pub(super) struct Presence {
impl Presence {
#[must_use]
pub(super) fn new(
state: PresenceState, currently_active: bool, last_active_ts: u64, status_msg: Option<String>,
state: PresenceState,
currently_active: bool,
last_active_ts: u64,
status_msg: Option<String>,
) -> Self {
Self {
state,
@ -32,11 +35,16 @@ impl Presence {
}
pub(super) fn from_json_bytes(bytes: &[u8]) -> Result<Self> {
serde_json::from_slice(bytes).map_err(|_| Error::bad_database("Invalid presence data in database"))
serde_json::from_slice(bytes)
.map_err(|_| Error::bad_database("Invalid presence data in database"))
}
/// Creates a PresenceEvent from available data.
pub(super) async fn to_presence_event(&self, user_id: &UserId, users: &users::Service) -> PresenceEvent {
pub(super) async fn to_presence_event(
&self,
user_id: &UserId,
users: &users::Service,
) -> PresenceEvent {
let now = utils::millis_since_unix_epoch();
let last_active_ago = if self.currently_active {
None

View file

@ -19,9 +19,12 @@ use ruma::{
IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken,
},
events::{
room::power_levels::RoomPowerLevelsEventContent, AnySyncTimelineEvent, StateEventType, TimelineEventType,
room::power_levels::RoomPowerLevelsEventContent, AnySyncTimelineEvent, StateEventType,
TimelineEventType,
},
push::{
Action, PushConditionPowerLevelsCtx, PushConditionRoomCtx, PushFormat, Ruleset, Tweak,
},
push::{Action, PushConditionPowerLevelsCtx, PushConditionRoomCtx, PushFormat, Ruleset, Tweak},
serde::Raw,
uint, RoomId, UInt, UserId,
};
@ -55,7 +58,8 @@ impl crate::Service for Service {
services: Services {
globals: args.depend::<globals::Service>("globals"),
client: args.depend::<client::Service>("client"),
state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
state_accessor: args
.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"),
users: args.depend::<users::Service>("users"),
sending: args.depend::<sending::Service>("sending"),
@ -67,23 +71,31 @@ impl crate::Service for Service {
}
impl Service {
pub async fn set_pusher(&self, sender: &UserId, pusher: &set_pusher::v3::PusherAction) -> Result {
pub async fn set_pusher(
&self,
sender: &UserId,
pusher: &set_pusher::v3::PusherAction,
) -> Result {
match pusher {
set_pusher::v3::PusherAction::Post(data) => {
| set_pusher::v3::PusherAction::Post(data) => {
let pushkey = data.pusher.ids.pushkey.as_str();
if pushkey.len() > 512 {
return Err!(Request(InvalidParam("Push key length cannot be greater than 512 bytes.")));
return Err!(Request(InvalidParam(
"Push key length cannot be greater than 512 bytes."
)));
}
if data.pusher.ids.app_id.as_str().len() > 64 {
return Err!(Request(InvalidParam("App ID length cannot be greater than 64 bytes.")));
return Err!(Request(InvalidParam(
"App ID length cannot be greater than 64 bytes."
)));
}
let key = (sender, data.pusher.ids.pushkey.as_str());
self.db.senderkey_pusher.put(key, Json(pusher));
},
set_pusher::v3::PusherAction::Delete(ids) => {
| set_pusher::v3::PusherAction::Delete(ids) => {
let key = (sender, ids.pushkey.as_str());
self.db.senderkey_pusher.del(key);
@ -118,7 +130,10 @@ impl Service {
.await
}
pub fn get_pushkeys<'a>(&'a self, sender: &'a UserId) -> impl Stream<Item = &str> + Send + 'a {
pub fn get_pushkeys<'a>(
&'a self,
sender: &'a UserId,
) -> impl Stream<Item = &str> + Send + 'a {
let prefix = (sender, Interfix);
self.db
.senderkey_pusher
@ -160,14 +175,16 @@ impl Service {
let response = self.services.client.pusher.execute(reqwest_request).await;
match response {
Ok(mut response) => {
| Ok(mut response) => {
// reqwest::Response -> http::Response conversion
trace!("Checking response destination's IP");
if let Some(remote_addr) = response.remote_addr() {
if let Ok(ip) = IPAddress::parse(remote_addr.ip().to_string()) {
if !self.services.client.valid_cidr_range(&ip) {
return Err!(BadServerResponse("Not allowed to send requests to this IP"));
return Err!(BadServerResponse(
"Not allowed to send requests to this IP"
));
}
}
}
@ -197,10 +214,13 @@ impl Service {
.body(body)
.expect("reqwest body is valid http body"),
);
response
.map_err(|e| err!(BadServerResponse(warn!("Push gateway {dest} returned invalid response: {e}"))))
response.map_err(|e| {
err!(BadServerResponse(warn!(
"Push gateway {dest} returned invalid response: {e}"
)))
})
},
Err(e) => {
| Err(e) => {
warn!("Could not send request to pusher {dest}: {e}");
Err(e.into())
},
@ -209,7 +229,12 @@ impl Service {
#[tracing::instrument(skip(self, user, unread, pusher, ruleset, pdu))]
pub async fn send_push_notice(
&self, user: &UserId, unread: UInt, pusher: &Pusher, ruleset: Ruleset, pdu: &PduEvent,
&self,
user: &UserId,
unread: UInt,
pusher: &Pusher,
ruleset: Ruleset,
pdu: &PduEvent,
) -> Result<()> {
let mut notify = None;
let mut tweaks = Vec::new();
@ -220,8 +245,9 @@ impl Service {
.room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")
.await
.and_then(|ev| {
serde_json::from_str(ev.content.get())
.map_err(|e| err!(Database(error!("invalid m.room.power_levels event: {e:?}"))))
serde_json::from_str(ev.content.get()).map_err(|e| {
err!(Database(error!("invalid m.room.power_levels event: {e:?}")))
})
})
.unwrap_or_default();
@ -230,12 +256,12 @@ impl Service {
.await
{
let n = match action {
Action::Notify => true,
Action::SetTweak(tweak) => {
| Action::Notify => true,
| Action::SetTweak(tweak) => {
tweaks.push(tweak.clone());
continue;
},
_ => false,
| _ => false,
};
if notify.is_some() {
@ -257,8 +283,12 @@ impl Service {
#[tracing::instrument(skip(self, user, ruleset, pdu), level = "debug")]
pub async fn get_actions<'a>(
&self, user: &UserId, ruleset: &'a Ruleset, power_levels: &RoomPowerLevelsEventContent,
pdu: &Raw<AnySyncTimelineEvent>, room_id: &RoomId,
&self,
user: &UserId,
ruleset: &'a Ruleset,
power_levels: &RoomPowerLevelsEventContent,
pdu: &Raw<AnySyncTimelineEvent>,
room_id: &RoomId,
) -> &'a [Action] {
let power_levels = PushConditionPowerLevelsCtx {
users: power_levels.users.clone(),
@ -294,14 +324,21 @@ impl Service {
}
#[tracing::instrument(skip(self, unread, pusher, tweaks, event))]
async fn send_notice(&self, unread: UInt, pusher: &Pusher, tweaks: Vec<Tweak>, event: &PduEvent) -> Result<()> {
async fn send_notice(
&self,
unread: UInt,
pusher: &Pusher,
tweaks: Vec<Tweak>,
event: &PduEvent,
) -> Result<()> {
// TODO: email
match &pusher.kind {
PusherKind::Http(http) => {
| PusherKind::Http(http) => {
// TODO (timo): can pusher/devices have conflicting formats
let event_id_only = http.format == Some(PushFormat::EventIdOnly);
let mut device = Device::new(pusher.ids.app_id.clone(), pusher.ids.pushkey.clone());
let mut device =
Device::new(pusher.ids.app_id.clone(), pusher.ids.pushkey.clone());
device.data.default_payload = http.default_payload.clone();
device.data.format.clone_from(&http.format);
@ -319,8 +356,11 @@ impl Service {
notifi.counts = NotificationCounts::new(unread, uint!(0));
if event_id_only {
self.send_request(&http.url, send_event_notification::v1::Request::new(notifi))
.await?;
self.send_request(
&http.url,
send_event_notification::v1::Request::new(notifi),
)
.await?;
} else {
if event.kind == TimelineEventType::RoomEncrypted
|| tweaks
@ -336,10 +376,12 @@ impl Service {
notifi.content = serde_json::value::to_raw_value(&event.content).ok();
if event.kind == TimelineEventType::RoomMember {
notifi.user_is_target = event.state_key.as_deref() == Some(event.sender.as_str());
notifi.user_is_target =
event.state_key.as_deref() == Some(event.sender.as_str());
}
notifi.sender_display_name = self.services.users.displayname(&event.sender).await.ok();
notifi.sender_display_name =
self.services.users.displayname(&event.sender).await.ok();
notifi.room_name = self
.services
@ -355,15 +397,18 @@ impl Service {
.await
.ok();
self.send_request(&http.url, send_event_notification::v1::Request::new(notifi))
.await?;
self.send_request(
&http.url,
send_event_notification::v1::Request::new(notifi),
)
.await?;
}
Ok(())
},
// TODO: Handle email
//PusherKind::Email(_) => Ok(()),
_ => Ok(()),
| _ => Ok(()),
}
}
}

View file

@ -35,17 +35,9 @@ impl super::Service {
(self.resolve_actual_dest(server_name, true).await?, false)
};
let CachedDest {
dest,
host,
..
} = result;
let CachedDest { dest, host, .. } = result;
Ok(ActualDest {
dest,
host,
cached,
})
Ok(ActualDest { dest, host, cached })
}
/// Returns: `actual_destination`, host header
@ -53,12 +45,16 @@ impl super::Service {
/// Numbers in comments below refer to bullet points in linked section of
/// specification
#[tracing::instrument(skip_all, name = "actual")]
pub async fn resolve_actual_dest(&self, dest: &ServerName, cache: bool) -> Result<CachedDest> {
pub async fn resolve_actual_dest(
&self,
dest: &ServerName,
cache: bool,
) -> Result<CachedDest> {
trace!("Finding actual destination for {dest}");
let mut host = dest.as_str().to_owned();
let actual_dest = match get_ip_with_port(dest.as_str()) {
Some(host_port) => Self::actual_dest_1(host_port)?,
None => {
| Some(host_port) => Self::actual_dest_1(host_port)?,
| None =>
if let Some(pos) = dest.as_str().find(':') {
self.actual_dest_2(dest, cache, pos).await?
} else if let Some(delegated) = self.request_well_known(dest.as_str()).await? {
@ -67,8 +63,7 @@ impl super::Service {
self.actual_dest_4(&host, cache, overrider).await?
} else {
self.actual_dest_5(dest, cache).await?
}
},
},
};
// Can't use get_ip_with_port here because we don't want to add a port
@ -79,7 +74,10 @@ impl super::Service {
FedDest::Named(addr.to_string(), FedDest::default_port())
} else if let Some(pos) = host.find(':') {
let (host, port) = host.split_at(pos);
FedDest::Named(host.to_owned(), port.try_into().unwrap_or_else(|_| FedDest::default_port()))
FedDest::Named(
host.to_owned(),
port.try_into().unwrap_or_else(|_| FedDest::default_port()),
)
} else {
FedDest::Named(host, FedDest::default_port())
};
@ -100,20 +98,30 @@ impl super::Service {
async fn actual_dest_2(&self, dest: &ServerName, cache: bool, pos: usize) -> Result<FedDest> {
debug!("2: Hostname with included port");
let (host, port) = dest.as_str().split_at(pos);
self.conditional_query_and_cache_override(host, host, port.parse::<u16>().unwrap_or(8448), cache)
.await?;
self.conditional_query_and_cache_override(
host,
host,
port.parse::<u16>().unwrap_or(8448),
cache,
)
.await?;
Ok(FedDest::Named(
host.to_owned(),
port.try_into().unwrap_or_else(|_| FedDest::default_port()),
))
}
async fn actual_dest_3(&self, host: &mut String, cache: bool, delegated: String) -> Result<FedDest> {
async fn actual_dest_3(
&self,
host: &mut String,
cache: bool,
delegated: String,
) -> Result<FedDest> {
debug!("3: A .well-known file is available");
*host = add_port_to_hostname(&delegated).uri_string();
match get_ip_with_port(&delegated) {
Some(host_and_port) => Self::actual_dest_3_1(host_and_port),
None => {
| Some(host_and_port) => Self::actual_dest_3_1(host_and_port),
| None =>
if let Some(pos) = delegated.find(':') {
self.actual_dest_3_2(cache, delegated, pos).await
} else {
@ -123,8 +131,7 @@ impl super::Service {
} else {
self.actual_dest_3_4(cache, delegated).await
}
}
},
},
}
}
@ -133,22 +140,42 @@ impl super::Service {
Ok(host_and_port)
}
async fn actual_dest_3_2(&self, cache: bool, delegated: String, pos: usize) -> Result<FedDest> {
async fn actual_dest_3_2(
&self,
cache: bool,
delegated: String,
pos: usize,
) -> Result<FedDest> {
debug!("3.2: Hostname with port in .well-known file");
let (host, port) = delegated.split_at(pos);
self.conditional_query_and_cache_override(host, host, port.parse::<u16>().unwrap_or(8448), cache)
.await?;
self.conditional_query_and_cache_override(
host,
host,
port.parse::<u16>().unwrap_or(8448),
cache,
)
.await?;
Ok(FedDest::Named(
host.to_owned(),
port.try_into().unwrap_or_else(|_| FedDest::default_port()),
))
}
async fn actual_dest_3_3(&self, cache: bool, delegated: String, overrider: FedDest) -> Result<FedDest> {
async fn actual_dest_3_3(
&self,
cache: bool,
delegated: String,
overrider: FedDest,
) -> Result<FedDest> {
debug!("3.3: SRV lookup successful");
let force_port = overrider.port();
self.conditional_query_and_cache_override(&delegated, &overrider.hostname(), force_port.unwrap_or(8448), cache)
.await?;
self.conditional_query_and_cache_override(
&delegated,
&overrider.hostname(),
force_port.unwrap_or(8448),
cache,
)
.await?;
if let Some(port) = force_port {
Ok(FedDest::Named(
delegated,
@ -169,11 +196,21 @@ impl super::Service {
Ok(add_port_to_hostname(&delegated))
}
async fn actual_dest_4(&self, host: &str, cache: bool, overrider: FedDest) -> Result<FedDest> {
async fn actual_dest_4(
&self,
host: &str,
cache: bool,
overrider: FedDest,
) -> Result<FedDest> {
debug!("4: No .well-known; SRV record found");
let force_port = overrider.port();
self.conditional_query_and_cache_override(host, &overrider.hostname(), force_port.unwrap_or(8448), cache)
.await?;
self.conditional_query_and_cache_override(
host,
&overrider.hostname(),
force_port.unwrap_or(8448),
cache,
)
.await?;
if let Some(port) = force_port {
let port = format!(":{port}");
Ok(FedDest::Named(
@ -245,7 +282,11 @@ impl super::Service {
#[inline]
async fn conditional_query_and_cache_override(
&self, overname: &str, hostname: &str, port: u16, cache: bool,
&self,
overname: &str,
hostname: &str,
port: u16,
cache: bool,
) -> Result<()> {
if cache {
self.query_and_cache_override(overname, hostname, port)
@ -256,22 +297,24 @@ impl super::Service {
}
#[tracing::instrument(skip_all, name = "ip")]
async fn query_and_cache_override(&self, overname: &'_ str, hostname: &'_ str, port: u16) -> Result<()> {
async fn query_and_cache_override(
&self,
overname: &'_ str,
hostname: &'_ str,
port: u16,
) -> Result<()> {
match self.resolver.resolver.lookup_ip(hostname.to_owned()).await {
Err(e) => Self::handle_resolve_error(&e, hostname),
Ok(override_ip) => {
| Err(e) => Self::handle_resolve_error(&e, hostname),
| Ok(override_ip) => {
if hostname != overname {
debug_info!("{overname:?} overriden by {hostname:?}");
}
self.set_cached_override(
overname,
CachedOverride {
ips: override_ip.into_iter().take(MAX_IPS).collect(),
port,
expire: CachedOverride::default_expire(),
},
);
self.set_cached_override(overname, CachedOverride {
ips: override_ip.into_iter().take(MAX_IPS).collect(),
port,
expire: CachedOverride::default_expire(),
});
Ok(())
},
@ -280,14 +323,15 @@ impl super::Service {
#[tracing::instrument(skip_all, name = "srv")]
async fn query_srv_record(&self, hostname: &'_ str) -> Result<Option<FedDest>> {
let hostnames = [format!("_matrix-fed._tcp.{hostname}."), format!("_matrix._tcp.{hostname}.")];
let hostnames =
[format!("_matrix-fed._tcp.{hostname}."), format!("_matrix._tcp.{hostname}.")];
for hostname in hostnames {
debug!("querying SRV for {hostname:?}");
let hostname = hostname.trim_end_matches('.');
match self.resolver.resolver.srv_lookup(hostname).await {
Err(e) => Self::handle_resolve_error(&e, hostname)?,
Ok(result) => {
| Err(e) => Self::handle_resolve_error(&e, hostname)?,
| Ok(result) =>
return Ok(result.iter().next().map(|result| {
FedDest::Named(
result.target().to_string().trim_end_matches('.').to_owned(),
@ -296,8 +340,7 @@ impl super::Service {
.try_into()
.unwrap_or_else(|_| FedDest::default_port()),
)
}))
},
})),
}
}
@ -308,25 +351,24 @@ impl super::Service {
use hickory_resolver::error::ResolveErrorKind;
match *e.kind() {
ResolveErrorKind::NoRecordsFound {
..
} => {
| ResolveErrorKind::NoRecordsFound { .. } => {
// Raise to debug_warn if we can find out the result wasn't from cache
debug!(%host, "No DNS records found: {e}");
Ok(())
},
ResolveErrorKind::Timeout => {
| ResolveErrorKind::Timeout => {
Err!(warn!(%host, "DNS {e}"))
},
ResolveErrorKind::NoConnections => {
| ResolveErrorKind::NoConnections => {
error!(
"Your DNS server is overloaded and has ran out of connections. It is strongly recommended you \
remediate this issue to ensure proper federation connectivity."
"Your DNS server is overloaded and has ran out of connections. It is \
strongly recommended you remediate this issue to ensure proper federation \
connectivity."
);
Err!(error!(%host, "DNS error: {e}"))
},
_ => Err!(error!(%host, "DNS error: {e}")),
| _ => Err!(error!(%host, "DNS error: {e}")),
}
}
@ -349,8 +391,9 @@ impl super::Service {
dest.is_ip_literal() || !IPAddress::is_valid(dest.host()),
"Destination is not an IP literal."
);
let ip = IPAddress::parse(dest.host())
.map_err(|e| err!(BadServerResponse(debug_error!("Failed to parse IP literal from string: {e}"))))?;
let ip = IPAddress::parse(dest.host()).map_err(|e| {
err!(BadServerResponse(debug_error!("Failed to parse IP literal from string: {e}")))
})?;
self.validate_ip(&ip)?;

View file

@ -46,7 +46,11 @@ impl Cache {
}
impl super::Service {
pub fn set_cached_destination(&self, name: OwnedServerName, dest: CachedDest) -> Option<CachedDest> {
pub fn set_cached_destination(
&self,
name: OwnedServerName,
dest: CachedDest,
) -> Option<CachedDest> {
trace!(?name, ?dest, "set cached destination");
self.cache
.destinations
@ -65,7 +69,11 @@ impl super::Service {
.cloned()
}
pub fn set_cached_override(&self, name: &str, over: CachedOverride) -> Option<CachedOverride> {
pub fn set_cached_override(
&self,
name: &str,
over: CachedOverride,
) -> Option<CachedOverride> {
trace!(?name, ?over, "set cached override");
self.cache
.overrides
@ -102,7 +110,9 @@ impl CachedDest {
//pub fn valid(&self) -> bool { self.expire > SystemTime::now() }
#[must_use]
pub(crate) fn default_expire() -> SystemTime { rand::timepoint_secs(60 * 60 * 18..60 * 60 * 36) }
pub(crate) fn default_expire() -> SystemTime {
rand::timepoint_secs(60 * 60 * 18..60 * 60 * 36)
}
}
impl CachedOverride {
@ -113,5 +123,7 @@ impl CachedOverride {
//pub fn valid(&self) -> bool { self.expire > SystemTime::now() }
#[must_use]
pub(crate) fn default_expire() -> SystemTime { rand::timepoint_secs(60 * 60 * 6..60 * 60 * 12) }
pub(crate) fn default_expire() -> SystemTime {
rand::timepoint_secs(60 * 60 * 6..60 * 60 * 12)
}
}

View file

@ -60,27 +60,26 @@ impl Resolver {
opts.shuffle_dns_servers = true;
opts.rotate = true;
opts.ip_strategy = match config.ip_lookup_strategy {
1 => hickory_resolver::config::LookupIpStrategy::Ipv4Only,
2 => hickory_resolver::config::LookupIpStrategy::Ipv6Only,
3 => hickory_resolver::config::LookupIpStrategy::Ipv4AndIpv6,
4 => hickory_resolver::config::LookupIpStrategy::Ipv6thenIpv4,
_ => hickory_resolver::config::LookupIpStrategy::Ipv4thenIpv6,
| 1 => hickory_resolver::config::LookupIpStrategy::Ipv4Only,
| 2 => hickory_resolver::config::LookupIpStrategy::Ipv6Only,
| 3 => hickory_resolver::config::LookupIpStrategy::Ipv4AndIpv6,
| 4 => hickory_resolver::config::LookupIpStrategy::Ipv6thenIpv4,
| _ => hickory_resolver::config::LookupIpStrategy::Ipv4thenIpv6,
};
opts.authentic_data = false;
let resolver = Arc::new(TokioAsyncResolver::tokio(conf, opts));
Ok(Arc::new(Self {
resolver: resolver.clone(),
hooked: Arc::new(Hooked {
resolver,
cache,
}),
hooked: Arc::new(Hooked { resolver, cache }),
}))
}
}
impl Resolve for Resolver {
fn resolve(&self, name: Name) -> Resolving { resolve_to_reqwest(self.resolver.clone(), name).boxed() }
fn resolve(&self, name: Name) -> Resolving {
resolve_to_reqwest(self.resolver.clone(), name).boxed()
}
}
impl Resolve for Hooked {

View file

@ -29,8 +29,8 @@ pub(crate) fn get_ip_with_port(dest_str: &str) -> Option<FedDest> {
pub(crate) fn add_port_to_hostname(dest: &str) -> FedDest {
let (host, port) = match dest.find(':') {
None => (dest, DEFAULT_PORT),
Some(pos) => dest.split_at(pos),
| None => (dest, DEFAULT_PORT),
| Some(pos) => dest.split_at(pos),
};
FedDest::Named(
@ -42,23 +42,23 @@ pub(crate) fn add_port_to_hostname(dest: &str) -> FedDest {
impl FedDest {
pub(crate) fn https_string(&self) -> String {
match self {
Self::Literal(addr) => format!("https://{addr}"),
Self::Named(host, port) => format!("https://{host}{port}"),
| Self::Literal(addr) => format!("https://{addr}"),
| Self::Named(host, port) => format!("https://{host}{port}"),
}
}
pub(crate) fn uri_string(&self) -> String {
match self {
Self::Literal(addr) => addr.to_string(),
Self::Named(host, port) => format!("{host}{port}"),
| Self::Literal(addr) => addr.to_string(),
| Self::Named(host, port) => format!("{host}{port}"),
}
}
#[inline]
pub(crate) fn hostname(&self) -> Cow<'_, str> {
match &self {
Self::Literal(addr) => addr.ip().to_string().into(),
Self::Named(host, _) => host.into(),
| Self::Literal(addr) => addr.ip().to_string().into(),
| Self::Named(host, _) => host.into(),
}
}
@ -66,16 +66,20 @@ impl FedDest {
#[allow(clippy::string_slice)]
pub(crate) fn port(&self) -> Option<u16> {
match &self {
Self::Literal(addr) => Some(addr.port()),
Self::Named(_, port) => port[1..].parse().ok(),
| Self::Literal(addr) => Some(addr.port()),
| Self::Named(_, port) => port[1..].parse().ok(),
}
}
#[inline]
#[must_use]
pub fn default_port() -> PortString { PortString::from(DEFAULT_PORT).expect("default port string") }
pub fn default_port() -> PortString {
PortString::from(DEFAULT_PORT).expect("default port string")
}
}
impl fmt::Display for FedDest {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str(self.uri_string().as_str()) }
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.uri_string().as_str())
}
}

View file

@ -52,7 +52,8 @@ impl crate::Service for Service {
appservice: args.depend::<appservice::Service>("appservice"),
globals: args.depend::<globals::Service>("globals"),
sending: args.depend::<sending::Service>("sending"),
state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
state_accessor: args
.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
},
}))
}
@ -62,8 +63,15 @@ impl crate::Service for Service {
impl Service {
#[tracing::instrument(skip(self))]
pub fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId, user_id: &UserId) -> Result<()> {
if alias == self.services.globals.admin_alias && user_id != self.services.globals.server_user {
pub fn set_alias(
&self,
alias: &RoomAliasId,
room_id: &RoomId,
user_id: &UserId,
) -> Result<()> {
if alias == self.services.globals.admin_alias
&& user_id != self.services.globals.server_user
{
return Err(Error::BadRequest(
ErrorKind::forbidden(),
"Only the server user can set this alias",
@ -120,7 +128,9 @@ impl Service {
}
pub async fn resolve_with_servers(
&self, room: &RoomOrAliasId, servers: Option<Vec<OwnedServerName>>,
&self,
room: &RoomOrAliasId,
servers: Option<Vec<OwnedServerName>>,
) -> Result<(OwnedRoomId, Vec<OwnedServerName>)> {
if room.is_room_id() {
let room_id = RoomId::parse(room).expect("valid RoomId");
@ -133,14 +143,16 @@ impl Service {
#[tracing::instrument(skip(self), name = "resolve")]
pub async fn resolve_alias(
&self, room_alias: &RoomAliasId, servers: Option<Vec<OwnedServerName>>,
&self,
room_alias: &RoomAliasId,
servers: Option<Vec<OwnedServerName>>,
) -> Result<(OwnedRoomId, Vec<OwnedServerName>)> {
let server_name = room_alias.server_name();
let server_is_ours = self.services.globals.server_is_ours(server_name);
let servers_contains_ours = || {
servers
.as_ref()
.is_some_and(|servers| servers.contains(&self.services.globals.config.server_name))
servers.as_ref().is_some_and(|servers| {
servers.contains(&self.services.globals.config.server_name)
})
};
if !server_is_ours && !servers_contains_ours() {
@ -150,8 +162,8 @@ impl Service {
}
let room_id = match self.resolve_local_alias(room_alias).await {
Ok(r) => Some(r),
Err(_) => self.resolve_appservice_alias(room_alias).await?,
| Ok(r) => Some(r),
| Err(_) => self.resolve_appservice_alias(room_alias).await?,
};
room_id.map_or_else(
@ -166,7 +178,10 @@ impl Service {
}
#[tracing::instrument(skip(self), level = "debug")]
pub fn local_aliases_for_room<'a>(&'a self, room_id: &'a RoomId) -> impl Stream<Item = &RoomAliasId> + Send + 'a {
pub fn local_aliases_for_room<'a>(
&'a self,
room_id: &'a RoomId,
) -> impl Stream<Item = &RoomAliasId> + Send + 'a {
let prefix = (room_id, Interfix);
self.db
.aliasid_alias
@ -208,10 +223,15 @@ impl Service {
if let Ok(content) = self
.services
.state_accessor
.room_state_get_content::<RoomPowerLevelsEventContent>(&room_id, &StateEventType::RoomPowerLevels, "")
.room_state_get_content::<RoomPowerLevelsEventContent>(
&room_id,
&StateEventType::RoomPowerLevels,
"",
)
.await
{
return Ok(RoomPowerLevels::from(content).user_can_send_state(user_id, StateEventType::RoomCanonicalAlias));
return Ok(RoomPowerLevels::from(content)
.user_can_send_state(user_id, StateEventType::RoomCanonicalAlias));
}
// If there is no power levels event, only the room creator can change
@ -232,7 +252,10 @@ impl Service {
self.db.alias_userid.get(alias.alias()).await.deserialized()
}
async fn resolve_appservice_alias(&self, room_alias: &RoomAliasId) -> Result<Option<OwnedRoomId>> {
async fn resolve_appservice_alias(
&self,
room_alias: &RoomAliasId,
) -> Result<Option<OwnedRoomId>> {
use ruma::api::appservice::query::query_room_alias;
for appservice in self.services.appservice.read().await.values() {
@ -242,9 +265,7 @@ impl Service {
.sending
.send_appservice_request(
appservice.registration.clone(),
query_room_alias::v1::Request {
room_alias: room_alias.to_owned(),
},
query_room_alias::v1::Request { room_alias: room_alias.to_owned() },
)
.await,
Ok(Some(_opt_result))
@ -261,19 +282,27 @@ impl Service {
}
pub async fn appservice_checks(
&self, room_alias: &RoomAliasId, appservice_info: &Option<RegistrationInfo>,
&self,
room_alias: &RoomAliasId,
appservice_info: &Option<RegistrationInfo>,
) -> Result<()> {
if !self
.services
.globals
.server_is_ours(room_alias.server_name())
{
return Err(Error::BadRequest(ErrorKind::InvalidParam, "Alias is from another server."));
return Err(Error::BadRequest(
ErrorKind::InvalidParam,
"Alias is from another server.",
));
}
if let Some(ref info) = appservice_info {
if !info.aliases.is_match(room_alias.as_str()) {
return Err(Error::BadRequest(ErrorKind::Exclusive, "Room alias is not in namespace."));
return Err(Error::BadRequest(
ErrorKind::Exclusive,
"Room alias is not in namespace.",
));
}
} else if self
.services
@ -281,7 +310,10 @@ impl Service {
.is_exclusive_alias(room_alias)
.await
{
return Err(Error::BadRequest(ErrorKind::Exclusive, "Room alias reserved by appservice."));
return Err(Error::BadRequest(
ErrorKind::Exclusive,
"Room alias reserved by appservice.",
));
}
Ok(())

View file

@ -6,7 +6,9 @@ use ruma::{api::federation, OwnedRoomId, OwnedServerName, RoomAliasId, ServerNam
#[implement(super::Service)]
pub(super) async fn remote_resolve(
&self, room_alias: &RoomAliasId, servers: Vec<OwnedServerName>,
&self,
room_alias: &RoomAliasId,
servers: Vec<OwnedServerName>,
) -> Result<(OwnedRoomId, Vec<OwnedServerName>)> {
debug!(?room_alias, servers = ?servers, "resolve");
let servers = once(room_alias.server_name())
@ -17,12 +19,12 @@ pub(super) async fn remote_resolve(
let mut resolved_room_id: Option<OwnedRoomId> = None;
for server in servers {
match self.remote_request(room_alias, &server).await {
Err(e) => debug_error!("Failed to query for {room_alias:?} from {server}: {e}"),
Ok(Response {
room_id,
servers,
}) => {
debug!("Server {server} answered with {room_id:?} for {room_alias:?} servers: {servers:?}");
| Err(e) => debug_error!("Failed to query for {room_alias:?} from {server}: {e}"),
| Ok(Response { room_id, servers }) => {
debug!(
"Server {server} answered with {room_id:?} for {room_alias:?} servers: \
{servers:?}"
);
resolved_room_id.get_or_insert(room_id);
add_server(&mut resolved_servers, server);
@ -37,16 +39,20 @@ pub(super) async fn remote_resolve(
resolved_room_id
.map(|room_id| (room_id, resolved_servers))
.ok_or_else(|| err!(Request(NotFound("No servers could assist in resolving the room alias"))))
.ok_or_else(|| {
err!(Request(NotFound("No servers could assist in resolving the room alias")))
})
}
#[implement(super::Service)]
async fn remote_request(&self, room_alias: &RoomAliasId, server: &ServerName) -> Result<Response> {
async fn remote_request(
&self,
room_alias: &RoomAliasId,
server: &ServerName,
) -> Result<Response> {
use federation::query::get_room_information::v1::Request;
let request = Request {
room_alias: room_alias.to_owned(),
};
let request = Request { room_alias: room_alias.to_owned() };
self.services
.sending

View file

@ -19,14 +19,18 @@ impl Data {
let db = &args.db;
let config = &args.server.config;
let cache_size = f64::from(config.auth_chain_cache_capacity);
let cache_size = usize_from_f64(cache_size * config.cache_capacity_modifier).expect("valid cache size");
let cache_size = usize_from_f64(cache_size * config.cache_capacity_modifier)
.expect("valid cache size");
Self {
shorteventid_authchain: db["shorteventid_authchain"].clone(),
auth_chain_cache: Mutex::new(LruCache::new(cache_size)),
}
}
pub(super) async fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Arc<[ShortEventId]>> {
pub(super) async fn get_cached_eventid_authchain(
&self,
key: &[u64],
) -> Result<Arc<[ShortEventId]>> {
debug_assert!(!key.is_empty(), "auth_chain key must not be empty");
// Check RAM cache

View file

@ -43,7 +43,9 @@ impl crate::Service for Service {
impl Service {
pub async fn event_ids_iter<'a, I>(
&'a self, room_id: &RoomId, starting_events: I,
&'a self,
room_id: &RoomId,
starting_events: I,
) -> Result<impl Stream<Item = Arc<EventId>> + Send + '_>
where
I: Iterator<Item = &'a EventId> + Clone + Debug + ExactSizeIterator + Send + 'a,
@ -57,7 +59,11 @@ impl Service {
Ok(stream)
}
pub async fn get_event_ids<'a, I>(&'a self, room_id: &RoomId, starting_events: I) -> Result<Vec<Arc<EventId>>>
pub async fn get_event_ids<'a, I>(
&'a self,
room_id: &RoomId,
starting_events: I,
) -> Result<Vec<Arc<EventId>>>
where
I: Iterator<Item = &'a EventId> + Clone + Debug + ExactSizeIterator + Send + 'a,
{
@ -74,7 +80,11 @@ impl Service {
}
#[tracing::instrument(skip_all, name = "auth_chain")]
pub async fn get_auth_chain<'a, I>(&'a self, room_id: &RoomId, starting_events: I) -> Result<Vec<ShortEventId>>
pub async fn get_auth_chain<'a, I>(
&'a self,
room_id: &RoomId,
starting_events: I,
) -> Result<Vec<ShortEventId>>
where
I: Iterator<Item = &'a EventId> + Clone + Debug + ExactSizeIterator + Send + 'a,
{
@ -110,7 +120,8 @@ impl Service {
continue;
}
let chunk_key: Vec<ShortEventId> = chunk.iter().map(|(short, _)| short).copied().collect();
let chunk_key: Vec<ShortEventId> =
chunk.iter().map(|(short, _)| short).copied().collect();
if let Ok(cached) = self.get_cached_eventid_authchain(&chunk_key).await {
trace!("Found cache entry for whole chunk");
full_auth_chain.extend(cached.iter().copied());
@ -169,7 +180,11 @@ impl Service {
}
#[tracing::instrument(skip(self, room_id))]
async fn get_auth_chain_inner(&self, room_id: &RoomId, event_id: &EventId) -> Result<HashSet<ShortEventId>> {
async fn get_auth_chain_inner(
&self,
room_id: &RoomId,
event_id: &EventId,
) -> Result<HashSet<ShortEventId>> {
let mut todo = vec![Arc::from(event_id)];
let mut found = HashSet::new();
@ -177,8 +192,10 @@ impl Service {
trace!(?event_id, "processing auth event");
match self.services.timeline.get_pdu(&event_id).await {
Err(e) => debug_error!(?event_id, ?e, "Could not find pdu mentioned in auth events"),
Ok(pdu) => {
| Err(e) => {
debug_error!(?event_id, ?e, "Could not find pdu mentioned in auth events");
},
| Ok(pdu) => {
if pdu.room_id != room_id {
return Err!(Request(Forbidden(error!(
?event_id,
@ -196,7 +213,11 @@ impl Service {
.await;
if found.insert(sauthevent) {
trace!(?event_id, ?auth_event, "adding auth event to processing queue");
trace!(
?event_id,
?auth_event,
"adding auth event to processing queue"
);
todo.push(auth_event.clone());
}
}

View file

@ -32,10 +32,14 @@ pub fn set_public(&self, room_id: &RoomId) { self.db.publicroomids.insert(room_i
pub fn set_not_public(&self, room_id: &RoomId) { self.db.publicroomids.remove(room_id); }
#[implement(Service)]
pub fn public_rooms(&self) -> impl Stream<Item = &RoomId> + Send { self.db.publicroomids.keys().ignore_err() }
pub fn public_rooms(&self) -> impl Stream<Item = &RoomId> + Send {
self.db.publicroomids.keys().ignore_err()
}
#[implement(Service)]
pub async fn is_public_room(&self, room_id: &RoomId) -> bool { self.visibility(room_id).await == Visibility::Public }
pub async fn is_public_room(&self, room_id: &RoomId) -> bool {
self.visibility(room_id).await == Visibility::Public
}
#[implement(Service)]
pub async fn visibility(&self, room_id: &RoomId) -> Visibility {

View file

@ -5,10 +5,14 @@ use std::{
};
use conduwuit::{
debug, debug_error, implement, info, pdu, trace, utils::math::continue_exponential_backoff_secs, warn, PduEvent,
debug, debug_error, implement, info, pdu, trace,
utils::math::continue_exponential_backoff_secs, warn, PduEvent,
};
use futures::TryFutureExt;
use ruma::{api::federation::event::get_event, CanonicalJsonValue, EventId, RoomId, RoomVersionId, ServerName};
use ruma::{
api::federation::event::get_event, CanonicalJsonValue, EventId, RoomId, RoomVersionId,
ServerName,
};
/// Find the event and auth it. Once the event is validated (steps 1 - 8)
/// it is appended to the outliers Tree.
@ -21,7 +25,11 @@ use ruma::{api::federation::event::get_event, CanonicalJsonValue, EventId, RoomI
/// d. TODO: Ask other servers over federation?
#[implement(super::Service)]
pub(super) async fn fetch_and_handle_outliers<'a>(
&self, origin: &'a ServerName, events: &'a [Arc<EventId>], create_event: &'a PduEvent, room_id: &'a RoomId,
&self,
origin: &'a ServerName,
events: &'a [Arc<EventId>],
create_event: &'a PduEvent,
room_id: &'a RoomId,
room_version_id: &'a RoomVersionId,
) -> Vec<(Arc<PduEvent>, Option<BTreeMap<String, CanonicalJsonValue>>)> {
let back_off = |id| match self
@ -32,10 +40,12 @@ pub(super) async fn fetch_and_handle_outliers<'a>(
.expect("locked")
.entry(id)
{
hash_map::Entry::Vacant(e) => {
| hash_map::Entry::Vacant(e) => {
e.insert((Instant::now(), 1));
},
hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1.saturating_add(1)),
| hash_map::Entry::Occupied(mut e) => {
*e.get_mut() = (Instant::now(), e.get().1.saturating_add(1));
},
};
let mut events_with_auth_events = Vec::with_capacity(events.len());
@ -67,7 +77,12 @@ pub(super) async fn fetch_and_handle_outliers<'a>(
// Exponential backoff
const MIN_DURATION: u64 = 5 * 60;
const MAX_DURATION: u64 = 60 * 60 * 24;
if continue_exponential_backoff_secs(MIN_DURATION, MAX_DURATION, time.elapsed(), *tries) {
if continue_exponential_backoff_secs(
MIN_DURATION,
MAX_DURATION,
time.elapsed(),
*tries,
) {
info!("Backing off from {next_id}");
continue;
}
@ -86,18 +101,16 @@ pub(super) async fn fetch_and_handle_outliers<'a>(
match self
.services
.sending
.send_federation_request(
origin,
get_event::v1::Request {
event_id: (*next_id).to_owned(),
include_unredacted_content: None,
},
)
.send_federation_request(origin, get_event::v1::Request {
event_id: (*next_id).to_owned(),
include_unredacted_content: None,
})
.await
{
Ok(res) => {
| Ok(res) => {
debug!("Got {next_id} over federation");
let Ok((calculated_event_id, value)) = pdu::gen_event_id_canonical_json(&res.pdu, room_version_id)
let Ok((calculated_event_id, value)) =
pdu::gen_event_id_canonical_json(&res.pdu, room_version_id)
else {
back_off((*next_id).to_owned());
continue;
@ -105,15 +118,18 @@ pub(super) async fn fetch_and_handle_outliers<'a>(
if calculated_event_id != *next_id {
warn!(
"Server didn't return event id we requested: requested: {next_id}, we got \
{calculated_event_id}. Event: {:?}",
"Server didn't return event id we requested: requested: {next_id}, \
we got {calculated_event_id}. Event: {:?}",
&res.pdu
);
}
if let Some(auth_events) = value.get("auth_events").and_then(|c| c.as_array()) {
if let Some(auth_events) = value.get("auth_events").and_then(|c| c.as_array())
{
for auth_event in auth_events {
if let Ok(auth_event) = serde_json::from_value(auth_event.clone().into()) {
if let Ok(auth_event) =
serde_json::from_value(auth_event.clone().into())
{
let a: Arc<EventId> = auth_event;
todo_auth_events.push(a);
} else {
@ -127,7 +143,7 @@ pub(super) async fn fetch_and_handle_outliers<'a>(
events_in_reverse_order.push((next_id.clone(), value));
events_all.insert(next_id);
},
Err(e) => {
| Err(e) => {
debug_error!("Failed to fetch event {next_id}: {e}");
back_off((*next_id).to_owned());
},
@ -158,20 +174,32 @@ pub(super) async fn fetch_and_handle_outliers<'a>(
// Exponential backoff
const MIN_DURATION: u64 = 5 * 60;
const MAX_DURATION: u64 = 60 * 60 * 24;
if continue_exponential_backoff_secs(MIN_DURATION, MAX_DURATION, time.elapsed(), *tries) {
if continue_exponential_backoff_secs(
MIN_DURATION,
MAX_DURATION,
time.elapsed(),
*tries,
) {
debug!("Backing off from {next_id}");
continue;
}
}
match Box::pin(self.handle_outlier_pdu(origin, create_event, &next_id, room_id, value.clone(), true)).await
match Box::pin(self.handle_outlier_pdu(
origin,
create_event,
&next_id,
room_id,
value.clone(),
true,
))
.await
{
Ok((pdu, json)) => {
| Ok((pdu, json)) =>
if next_id == *id {
pdus.push((pdu, Some(json)));
}
},
Err(e) => {
},
| Err(e) => {
warn!("Authentication of event {next_id} failed: {e:?}");
back_off(next_id.into());
},

View file

@ -8,7 +8,8 @@ use futures::{future, FutureExt};
use ruma::{
int,
state_res::{self},
uint, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, RoomId, RoomVersionId, ServerName,
uint, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, RoomId, RoomVersionId,
ServerName,
};
use super::check_room_id;
@ -17,7 +18,11 @@ use super::check_room_id;
#[allow(clippy::type_complexity)]
#[tracing::instrument(skip_all)]
pub(super) async fn fetch_prev(
&self, origin: &ServerName, create_event: &PduEvent, room_id: &RoomId, room_version_id: &RoomVersionId,
&self,
origin: &ServerName,
create_event: &PduEvent,
room_id: &RoomId,
room_version_id: &RoomVersionId,
initial_set: Vec<Arc<EventId>>,
) -> Result<(
Vec<Arc<EventId>>,
@ -35,7 +40,13 @@ pub(super) async fn fetch_prev(
self.services.server.check_running()?;
if let Some((pdu, mut json_opt)) = self
.fetch_and_handle_outliers(origin, &[prev_event_id.clone()], create_event, room_id, room_version_id)
.fetch_and_handle_outliers(
origin,
&[prev_event_id.clone()],
create_event,
room_id,
room_version_id,
)
.boxed()
.await
.pop()
@ -67,7 +78,8 @@ pub(super) async fn fetch_prev(
}
}
graph.insert(prev_event_id.clone(), pdu.prev_events.iter().cloned().collect());
graph
.insert(prev_event_id.clone(), pdu.prev_events.iter().cloned().collect());
} else {
// Time based check failed
graph.insert(prev_event_id.clone(), HashSet::new());

View file

@ -6,7 +6,8 @@ use std::{
use conduwuit::{debug, implement, warn, Err, Error, PduEvent, Result};
use futures::FutureExt;
use ruma::{
api::federation::event::get_room_state_ids, events::StateEventType, EventId, RoomId, RoomVersionId, ServerName,
api::federation::event::get_room_state_ids, events::StateEventType, EventId, RoomId,
RoomVersionId, ServerName,
};
/// Call /state_ids to find out what the state at this pdu is. We trust the
@ -15,20 +16,21 @@ use ruma::{
#[implement(super::Service)]
#[tracing::instrument(skip(self, create_event, room_version_id))]
pub(super) async fn fetch_state(
&self, origin: &ServerName, create_event: &PduEvent, room_id: &RoomId, room_version_id: &RoomVersionId,
&self,
origin: &ServerName,
create_event: &PduEvent,
room_id: &RoomId,
room_version_id: &RoomVersionId,
event_id: &EventId,
) -> Result<Option<HashMap<u64, Arc<EventId>>>> {
debug!("Fetching state ids");
let res = self
.services
.sending
.send_federation_request(
origin,
get_room_state_ids::v1::Request {
room_id: room_id.to_owned(),
event_id: (*event_id).to_owned(),
},
)
.send_federation_request(origin, get_room_state_ids::v1::Request {
room_id: room_id.to_owned(),
event_id: (*event_id).to_owned(),
})
.await
.inspect_err(|e| warn!("Fetching state for event failed: {e}"))?;
@ -58,14 +60,13 @@ pub(super) async fn fetch_state(
.await;
match state.entry(shortstatekey) {
hash_map::Entry::Vacant(v) => {
| hash_map::Entry::Vacant(v) => {
v.insert(Arc::from(&*pdu.event_id));
},
hash_map::Entry::Occupied(_) => {
| hash_map::Entry::Occupied(_) =>
return Err(Error::bad_database(
"State event's type and state_key combination exists multiple times.",
))
},
)),
}
}

View file

@ -7,7 +7,8 @@ use std::{
use conduwuit::{debug, err, implement, warn, Error, Result};
use futures::{FutureExt, TryFutureExt};
use ruma::{
api::client::error::ErrorKind, events::StateEventType, CanonicalJsonValue, EventId, RoomId, ServerName, UserId,
api::client::error::ErrorKind, events::StateEventType, CanonicalJsonValue, EventId, RoomId,
ServerName, UserId,
};
use super::{check_room_id, get_room_version_id};
@ -43,8 +44,12 @@ use crate::rooms::timeline::RawPduId;
#[implement(super::Service)]
#[tracing::instrument(skip(self, origin, value, is_timeline_event), name = "pdu")]
pub async fn handle_incoming_pdu<'a>(
&self, origin: &'a ServerName, room_id: &'a RoomId, event_id: &'a EventId,
value: BTreeMap<String, CanonicalJsonValue>, is_timeline_event: bool,
&self,
origin: &'a ServerName,
room_id: &'a RoomId,
event_id: &'a EventId,
value: BTreeMap<String, CanonicalJsonValue>,
is_timeline_event: bool,
) -> Result<Option<RawPduId>> {
// 1. Skip the PDU if we already have it as a timeline event
if let Ok(pdu_id) = self.services.timeline.get_pdu_id(event_id).await {
@ -144,10 +149,10 @@ pub async fn handle_incoming_pdu<'a>(
.expect("locked")
.entry(prev_id.into())
{
Entry::Vacant(e) => {
| Entry::Vacant(e) => {
e.insert((now, 1));
},
Entry::Occupied(mut e) => {
| Entry::Occupied(mut e) => {
*e.get_mut() = (now, e.get().1.saturating_add(1));
},
};

View file

@ -17,8 +17,13 @@ use super::{check_room_id, get_room_version_id, to_room_version};
#[implement(super::Service)]
#[allow(clippy::too_many_arguments)]
pub(super) async fn handle_outlier_pdu<'a>(
&self, origin: &'a ServerName, create_event: &'a PduEvent, event_id: &'a EventId, room_id: &'a RoomId,
mut value: CanonicalJsonObject, auth_events_known: bool,
&self,
origin: &'a ServerName,
create_event: &'a PduEvent,
event_id: &'a EventId,
room_id: &'a RoomId,
mut value: CanonicalJsonObject,
auth_events_known: bool,
) -> Result<(Arc<PduEvent>, BTreeMap<String, CanonicalJsonValue>)> {
// 1. Remove unsigned field
value.remove("unsigned");
@ -34,8 +39,8 @@ pub(super) async fn handle_outlier_pdu<'a>(
.verify_event(&value, Some(&room_version_id))
.await
{
Ok(ruma::signatures::Verified::All) => value,
Ok(ruma::signatures::Verified::Signatures) => {
| Ok(ruma::signatures::Verified::All) => value,
| Ok(ruma::signatures::Verified::Signatures) => {
// Redact
debug_info!("Calculated hash does not match (redaction): {event_id}");
let Ok(obj) = ruma::canonical_json::redact(value, &room_version_id, None) else {
@ -44,24 +49,26 @@ pub(super) async fn handle_outlier_pdu<'a>(
// Skip the PDU if it is redacted and we already have it as an outlier event
if self.services.timeline.pdu_exists(event_id).await {
return Err!(Request(InvalidParam("Event was redacted and we already knew about it")));
return Err!(Request(InvalidParam(
"Event was redacted and we already knew about it"
)));
}
obj
},
Err(e) => {
| Err(e) =>
return Err!(Request(InvalidParam(debug_error!(
"Signature verification failed for {event_id}: {e}"
))))
},
)))),
};
// Now that we have checked the signature and hashes we can add the eventID and
// convert to our PduEvent type
val.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.as_str().to_owned()));
let incoming_pdu =
serde_json::from_value::<PduEvent>(serde_json::to_value(&val).expect("CanonicalJsonObj is a valid JsonValue"))
.map_err(|_| Error::bad_database("Event is not a valid PDU."))?;
let incoming_pdu = serde_json::from_value::<PduEvent>(
serde_json::to_value(&val).expect("CanonicalJsonObj is a valid JsonValue"),
)
.map_err(|_| Error::bad_database("Event is not a valid PDU."))?;
check_room_id(room_id, &incoming_pdu)?;
@ -108,10 +115,10 @@ pub(super) async fn handle_outlier_pdu<'a>(
.clone()
.expect("all auth events have state keys"),
)) {
hash_map::Entry::Vacant(v) => {
| hash_map::Entry::Vacant(v) => {
v.insert(auth_event);
},
hash_map::Entry::Occupied(_) => {
| hash_map::Entry::Occupied(_) => {
return Err(Error::BadRequest(
ErrorKind::InvalidParam,
"Auth event's type and state_key combination exists multiple times.",

View file

@ -4,7 +4,9 @@ use std::{
time::Instant,
};
use conduwuit::{debug, implement, utils::math::continue_exponential_backoff_secs, Error, PduEvent, Result};
use conduwuit::{
debug, implement, utils::math::continue_exponential_backoff_secs, Error, PduEvent, Result,
};
use ruma::{api::client::error::ErrorKind, CanonicalJsonValue, EventId, RoomId, ServerName};
#[implement(super::Service)]
@ -15,15 +17,23 @@ use ruma::{api::client::error::ErrorKind, CanonicalJsonValue, EventId, RoomId, S
name = "prev"
)]
pub(super) async fn handle_prev_pdu<'a>(
&self, origin: &'a ServerName, event_id: &'a EventId, room_id: &'a RoomId,
eventid_info: &mut HashMap<Arc<EventId>, (Arc<PduEvent>, BTreeMap<String, CanonicalJsonValue>)>,
create_event: &Arc<PduEvent>, first_pdu_in_room: &Arc<PduEvent>, prev_id: &EventId,
&self,
origin: &'a ServerName,
event_id: &'a EventId,
room_id: &'a RoomId,
eventid_info: &mut HashMap<
Arc<EventId>,
(Arc<PduEvent>, BTreeMap<String, CanonicalJsonValue>),
>,
create_event: &Arc<PduEvent>,
first_pdu_in_room: &Arc<PduEvent>,
prev_id: &EventId,
) -> Result {
// Check for disabled again because it might have changed
if self.services.metadata.is_disabled(room_id).await {
debug!(
"Federaton of room {room_id} is currently disabled on this server. Request by origin {origin} and event \
ID {event_id}"
"Federaton of room {room_id} is currently disabled on this server. Request by \
origin {origin} and event ID {event_id}"
);
return Err(Error::BadRequest(
ErrorKind::forbidden(),

View file

@ -23,8 +23,8 @@ use conduwuit::{
};
use futures::TryFutureExt;
use ruma::{
events::room::create::RoomCreateEventContent, state_res::RoomVersion, EventId, OwnedEventId, OwnedRoomId, RoomId,
RoomVersionId,
events::room::create::RoomCreateEventContent, state_res::RoomVersion, EventId, OwnedEventId,
OwnedRoomId, RoomId, RoomVersionId,
};
use crate::{globals, rooms, sending, server_keys, Dep};
@ -69,8 +69,10 @@ impl crate::Service for Service {
pdu_metadata: args.depend::<rooms::pdu_metadata::Service>("rooms::pdu_metadata"),
short: args.depend::<rooms::short::Service>("rooms::short"),
state: args.depend::<rooms::state::Service>("rooms::state"),
state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
state_compressor: args.depend::<rooms::state_compressor::Service>("rooms::state_compressor"),
state_accessor: args
.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
state_compressor: args
.depend::<rooms::state_compressor::Service>("rooms::state_compressor"),
timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"),
server: args.server.clone(),
},
@ -95,7 +97,9 @@ impl crate::Service for Service {
}
impl Service {
async fn event_exists(&self, event_id: Arc<EventId>) -> bool { self.services.timeline.pdu_exists(&event_id).await }
async fn event_exists(&self, event_id: Arc<EventId>) -> bool {
self.services.timeline.pdu_exists(&event_id).await
}
async fn event_fetch(&self, event_id: Arc<EventId>) -> Option<Arc<PduEvent>> {
self.services

View file

@ -3,9 +3,13 @@ use ruma::{CanonicalJsonObject, CanonicalJsonValue, OwnedEventId, OwnedRoomId, R
use serde_json::value::RawValue as RawJsonValue;
#[implement(super::Service)]
pub async fn parse_incoming_pdu(&self, pdu: &RawJsonValue) -> Result<(OwnedEventId, CanonicalJsonObject, OwnedRoomId)> {
let value = serde_json::from_str::<CanonicalJsonObject>(pdu.get())
.map_err(|e| err!(BadServerResponse(debug_warn!("Error parsing incoming event {e:?}"))))?;
pub async fn parse_incoming_pdu(
&self,
pdu: &RawJsonValue,
) -> Result<(OwnedEventId, CanonicalJsonObject, OwnedRoomId)> {
let value = serde_json::from_str::<CanonicalJsonObject>(pdu.get()).map_err(|e| {
err!(BadServerResponse(debug_warn!("Error parsing incoming event {e:?}")))
})?;
let room_id: OwnedRoomId = value
.get("room_id")
@ -20,8 +24,9 @@ pub async fn parse_incoming_pdu(&self, pdu: &RawJsonValue) -> Result<(OwnedEvent
.await
.map_err(|_| err!("Server is not in room {room_id}"))?;
let (event_id, value) = gen_event_id_canonical_json(pdu, &room_version_id)
.map_err(|e| err!(Request(InvalidParam("Could not convert event to canonical json: {e}"))))?;
let (event_id, value) = gen_event_id_canonical_json(pdu, &room_version_id).map_err(|e| {
err!(Request(InvalidParam("Could not convert event to canonical json: {e}")))
})?;
Ok((event_id, value, room_id))
}

View file

@ -20,7 +20,10 @@ use crate::rooms::state_compressor::CompressedStateEvent;
#[implement(super::Service)]
#[tracing::instrument(skip_all, name = "resolve")]
pub async fn resolve_state(
&self, room_id: &RoomId, room_version_id: &RoomVersionId, incoming_state: HashMap<u64, Arc<EventId>>,
&self,
room_id: &RoomId,
room_version_id: &RoomVersionId,
incoming_state: HashMap<u64, Arc<EventId>>,
) -> Result<Arc<HashSet<CompressedStateEvent>>> {
debug!("Loading current room state ids");
let current_sstatehash = self
@ -76,10 +79,16 @@ pub async fn resolve_state(
let event_fetch = |event_id| self.event_fetch(event_id);
let event_exists = |event_id| self.event_exists(event_id);
let state = state_res::resolve(room_version_id, &fork_states, &auth_chain_sets, &event_fetch, &event_exists)
.boxed()
.await
.map_err(|e| err!(Database(error!("State resolution failed: {e:?}"))))?;
let state = state_res::resolve(
room_version_id,
&fork_states,
&auth_chain_sets,
&event_fetch,
&event_exists,
)
.boxed()
.await
.map_err(|e| err!(Database(error!("State resolution failed: {e:?}"))))?;
drop(lock);

View file

@ -21,7 +21,8 @@ use ruma::{
// request and build the state from a known point and resolve if > 1 prev_event
#[tracing::instrument(skip_all, name = "state")]
pub(super) async fn state_at_incoming_degree_one(
&self, incoming_pdu: &Arc<PduEvent>,
&self,
incoming_pdu: &Arc<PduEvent>,
) -> Result<Option<HashMap<u64, Arc<EventId>>>> {
let prev_event = &*incoming_pdu.prev_events[0];
let Ok(prev_event_sstatehash) = self
@ -70,7 +71,10 @@ pub(super) async fn state_at_incoming_degree_one(
#[implement(super::Service)]
#[tracing::instrument(skip_all, name = "state")]
pub(super) async fn state_at_incoming_resolved(
&self, incoming_pdu: &Arc<PduEvent>, room_id: &RoomId, room_version_id: &RoomVersionId,
&self,
incoming_pdu: &Arc<PduEvent>,
room_id: &RoomId,
room_version_id: &RoomVersionId,
) -> Result<Option<HashMap<u64, Arc<EventId>>>> {
debug!("Calculating state at event using state res");
let mut extremity_sstatehashes = HashMap::with_capacity(incoming_pdu.prev_events.len());
@ -157,10 +161,16 @@ pub(super) async fn state_at_incoming_resolved(
let event_fetch = |event_id| self.event_fetch(event_id);
let event_exists = |event_id| self.event_exists(event_id);
let result = state_res::resolve(room_version_id, &fork_states, &auth_chain_sets, &event_fetch, &event_exists)
.boxed()
.await
.map_err(|e| err!(Database(warn!(?e, "State resolution on prev events failed."))));
let result = state_res::resolve(
room_version_id,
&fork_states,
&auth_chain_sets,
&event_fetch,
&event_exists,
)
.boxed()
.await
.map_err(|e| err!(Database(warn!(?e, "State resolution on prev events failed."))));
drop(lock);

View file

@ -19,8 +19,12 @@ use crate::rooms::{state_compressor::HashSetCompressStateEvent, timeline::RawPdu
#[implement(super::Service)]
pub(super) async fn upgrade_outlier_to_timeline_pdu(
&self, incoming_pdu: Arc<PduEvent>, val: BTreeMap<String, CanonicalJsonValue>, create_event: &PduEvent,
origin: &ServerName, room_id: &RoomId,
&self,
incoming_pdu: Arc<PduEvent>,
val: BTreeMap<String, CanonicalJsonValue>,
create_event: &PduEvent,
origin: &ServerName,
room_id: &RoomId,
) -> Result<Option<RawPduId>> {
// Skip the PDU if we already have it as a timeline event
if let Ok(pduid) = self
@ -63,7 +67,8 @@ pub(super) async fn upgrade_outlier_to_timeline_pdu(
.await?;
}
let state_at_incoming_event = state_at_incoming_event.expect("we always set this to some above");
let state_at_incoming_event =
state_at_incoming_event.expect("we always set this to some above");
let room_version = to_room_version(&room_version_id);
debug!("Performing auth check");
@ -124,24 +129,34 @@ pub(super) async fn upgrade_outlier_to_timeline_pdu(
!auth_check
|| incoming_pdu.kind == TimelineEventType::RoomRedaction
&& match room_version_id {
V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => {
| V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => {
if let Some(redact_id) = &incoming_pdu.redacts {
!self
.services
.state_accessor
.user_can_redact(redact_id, &incoming_pdu.sender, &incoming_pdu.room_id, true)
.user_can_redact(
redact_id,
&incoming_pdu.sender,
&incoming_pdu.room_id,
true,
)
.await?
} else {
false
}
},
_ => {
| _ => {
let content: RoomRedactionEventContent = incoming_pdu.get_content()?;
if let Some(redact_id) = &content.redacts {
!self
.services
.state_accessor
.user_can_redact(redact_id, &incoming_pdu.sender, &incoming_pdu.room_id, true)
.user_can_redact(
redact_id,
&incoming_pdu.sender,
&incoming_pdu.room_id,
true,
)
.await?
} else {
false
@ -229,11 +244,7 @@ pub(super) async fn upgrade_outlier_to_timeline_pdu(
// Set the new room state to the resolved state
debug!("Forcing new room state");
let HashSetCompressStateEvent {
shortstatehash,
added,
removed,
} = self
let HashSetCompressStateEvent { shortstatehash, added, removed } = self
.services
.state_compressor
.save_state(room_id, new_room_state)

View file

@ -51,7 +51,11 @@ impl crate::Service for Service {
#[tracing::instrument(skip(self), level = "debug")]
#[inline]
pub async fn lazy_load_was_sent_before(
&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, ll_user: &UserId,
&self,
user_id: &UserId,
device_id: &DeviceId,
room_id: &RoomId,
ll_user: &UserId,
) -> bool {
let key = (user_id, device_id, room_id, ll_user);
self.db.lazyloadedids.qry(&key).await.is_ok()
@ -60,7 +64,12 @@ pub async fn lazy_load_was_sent_before(
#[implement(Service)]
#[tracing::instrument(skip(self), level = "debug")]
pub fn lazy_load_mark_sent(
&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, lazy_load: HashSet<OwnedUserId>, count: PduCount,
&self,
user_id: &UserId,
device_id: &DeviceId,
room_id: &RoomId,
lazy_load: HashSet<OwnedUserId>,
count: PduCount,
) {
let key = (user_id.to_owned(), device_id.to_owned(), room_id.to_owned(), count);
@ -72,7 +81,13 @@ pub fn lazy_load_mark_sent(
#[implement(Service)]
#[tracing::instrument(skip(self), level = "debug")]
pub fn lazy_load_confirm_delivery(&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, since: PduCount) {
pub fn lazy_load_confirm_delivery(
&self,
user_id: &UserId,
device_id: &DeviceId,
room_id: &RoomId,
since: PduCount,
) {
let key = (user_id.to_owned(), device_id.to_owned(), room_id.to_owned(), since);
let Some(user_ids) = self.lazy_load_waiting.lock().expect("locked").remove(&key) else {

View file

@ -58,7 +58,9 @@ pub async fn exists(&self, room_id: &RoomId) -> bool {
}
#[implement(Service)]
pub fn iter_ids(&self) -> impl Stream<Item = &RoomId> + Send + '_ { self.db.roomid_shortroomid.keys().ignore_err() }
pub fn iter_ids(&self) -> impl Stream<Item = &RoomId> + Send + '_ {
self.db.roomid_shortroomid.keys().ignore_err()
}
#[implement(Service)]
#[inline]
@ -81,12 +83,18 @@ pub fn ban_room(&self, room_id: &RoomId, banned: bool) {
}
#[implement(Service)]
pub fn list_banned_rooms(&self) -> impl Stream<Item = &RoomId> + Send + '_ { self.db.bannedroomids.keys().ignore_err() }
pub fn list_banned_rooms(&self) -> impl Stream<Item = &RoomId> + Send + '_ {
self.db.bannedroomids.keys().ignore_err()
}
#[implement(Service)]
#[inline]
pub async fn is_disabled(&self, room_id: &RoomId) -> bool { self.db.disabledroomids.get(room_id).await.is_ok() }
pub async fn is_disabled(&self, room_id: &RoomId) -> bool {
self.db.disabledroomids.get(room_id).await.is_ok()
}
#[implement(Service)]
#[inline]
pub async fn is_banned(&self, room_id: &RoomId) -> bool { self.db.bannedroomids.get(room_id).await.is_ok() }
pub async fn is_banned(&self, room_id: &RoomId) -> bool {
self.db.bannedroomids.get(room_id).await.is_ok()
}

View file

@ -56,26 +56,27 @@ impl Data {
}
pub(super) fn get_relations<'a>(
&'a self, user_id: &'a UserId, shortroomid: ShortRoomId, target: ShortEventId, from: PduCount, dir: Direction,
&'a self,
user_id: &'a UserId,
shortroomid: ShortRoomId,
target: ShortEventId,
from: PduCount,
dir: Direction,
) -> impl Stream<Item = PdusIterItem> + Send + '_ {
let mut current = ArrayVec::<u8, 16>::new();
current.extend(target.to_be_bytes());
current.extend(from.saturating_inc(dir).into_unsigned().to_be_bytes());
let current = current.as_slice();
match dir {
Direction::Forward => self.tofrom_relation.raw_keys_from(current).boxed(),
Direction::Backward => self.tofrom_relation.rev_raw_keys_from(current).boxed(),
| Direction::Forward => self.tofrom_relation.raw_keys_from(current).boxed(),
| Direction::Backward => self.tofrom_relation.rev_raw_keys_from(current).boxed(),
}
.ignore_err()
.ready_take_while(move |key| key.starts_with(&target.to_be_bytes()))
.map(|to_from| u64_from_u8(&to_from[8..16]))
.map(PduCount::from_unsigned)
.wide_filter_map(move |shorteventid| async move {
let pdu_id: RawPduId = PduId {
shortroomid,
shorteventid,
}
.into();
let pdu_id: RawPduId = PduId { shortroomid, shorteventid }.into();
let mut pdu = self.services.timeline.get_pdu_from_id(&pdu_id).await.ok()?;
@ -99,7 +100,9 @@ impl Data {
self.referencedevents.qry(&key).await.is_ok()
}
pub(super) fn mark_event_soft_failed(&self, event_id: &EventId) { self.softfailedeventids.insert(event_id, []); }
pub(super) fn mark_event_soft_failed(&self, event_id: &EventId) {
self.softfailedeventids.insert(event_id, []);
}
pub(super) async fn is_event_soft_failed(&self, event_id: &EventId) -> bool {
self.softfailedeventids.get(event_id).await.is_ok()

View file

@ -36,8 +36,8 @@ impl Service {
#[tracing::instrument(skip(self, from, to), level = "debug")]
pub fn add_relation(&self, from: PduCount, to: PduCount) {
match (from, to) {
(PduCount::Normal(f), PduCount::Normal(t)) => self.db.add_relation(f, t),
_ => {
| (PduCount::Normal(f), PduCount::Normal(t)) => self.db.add_relation(f, t),
| _ => {
// TODO: Relations with backfilled pdus
},
}
@ -45,15 +45,21 @@ impl Service {
#[allow(clippy::too_many_arguments)]
pub async fn get_relations(
&self, user_id: &UserId, room_id: &RoomId, target: &EventId, from: PduCount, limit: usize, max_depth: u8,
&self,
user_id: &UserId,
room_id: &RoomId,
target: &EventId,
from: PduCount,
limit: usize,
max_depth: u8,
dir: Direction,
) -> Vec<PdusIterItem> {
let room_id = self.services.short.get_or_create_shortroomid(room_id).await;
let target = match self.services.timeline.get_pdu_count(target).await {
Ok(PduCount::Normal(c)) => c,
| Ok(PduCount::Normal(c)) => c,
// TODO: Support backfilled relations
_ => 0, // This will result in an empty iterator
| _ => 0, // This will result in an empty iterator
};
let mut pdus: Vec<_> = self
@ -66,9 +72,9 @@ impl Service {
'limit: while let Some(stack_pdu) = stack.pop() {
let target = match stack_pdu.0 .0 {
PduCount::Normal(c) => c,
| PduCount::Normal(c) => c,
// TODO: Support backfilled relations
PduCount::Backfilled(_) => 0, // This will result in an empty iterator
| PduCount::Backfilled(_) => 0, // This will result in an empty iterator
};
let relations: Vec<_> = self
@ -106,7 +112,9 @@ impl Service {
#[inline]
#[tracing::instrument(skip(self), level = "debug")]
pub fn mark_event_soft_failed(&self, event_id: &EventId) { self.db.mark_event_soft_failed(event_id) }
pub fn mark_event_soft_failed(&self, event_id: &EventId) {
self.db.mark_event_soft_failed(event_id);
}
#[inline]
#[tracing::instrument(skip(self), level = "debug")]

View file

@ -40,7 +40,12 @@ impl Data {
}
}
pub(super) async fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: &ReceiptEvent) {
pub(super) async fn readreceipt_update(
&self,
user_id: &UserId,
room_id: &RoomId,
event: &ReceiptEvent,
) {
// Remove old entry
let last_possible_key = (room_id, u64::MAX);
self.readreceiptid_readreceipt
@ -57,7 +62,9 @@ impl Data {
}
pub(super) fn readreceipts_since<'a>(
&'a self, room_id: &'a RoomId, since: u64,
&'a self,
room_id: &'a RoomId,
since: u64,
) -> impl Stream<Item = ReceiptItem<'_>> + Send + 'a {
type Key<'a> = (&'a RoomId, u64, &'a UserId);
type KeyVal<'a> = (Key<'a>, CanonicalJsonObject);
@ -87,12 +94,20 @@ impl Data {
self.roomuserid_lastprivatereadupdate.put(key, next_count);
}
pub(super) async fn private_read_get_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<u64> {
pub(super) async fn private_read_get_count(
&self,
room_id: &RoomId,
user_id: &UserId,
) -> Result<u64> {
let key = (room_id, user_id);
self.roomuserid_privateread.qry(&key).await.deserialized()
}
pub(super) async fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> u64 {
pub(super) async fn last_privateread_update(
&self,
user_id: &UserId,
room_id: &RoomId,
) -> u64 {
let key = (room_id, user_id);
self.roomuserid_lastprivatereadupdate
.qry(&key)

View file

@ -44,7 +44,12 @@ impl crate::Service for Service {
impl Service {
/// Replaces the previous read receipt.
pub async fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: &ReceiptEvent) {
pub async fn readreceipt_update(
&self,
user_id: &UserId,
room_id: &RoomId,
event: &ReceiptEvent,
) {
self.db.readreceipt_update(user_id, room_id, event).await;
self.services
.sending
@ -54,23 +59,21 @@ impl Service {
}
/// Gets the latest private read receipt from the user in the room
pub async fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result<Raw<AnySyncEphemeralRoomEvent>> {
let pdu_count = self
.private_read_get_count(room_id, user_id)
.map_err(|e| err!(Database(warn!("No private read receipt was set in {room_id}: {e}"))));
let shortroomid = self
.services
.short
.get_shortroomid(room_id)
.map_err(|e| err!(Database(warn!("Short room ID does not exist in database for {room_id}: {e}"))));
pub async fn private_read_get(
&self,
room_id: &RoomId,
user_id: &UserId,
) -> Result<Raw<AnySyncEphemeralRoomEvent>> {
let pdu_count = self.private_read_get_count(room_id, user_id).map_err(|e| {
err!(Database(warn!("No private read receipt was set in {room_id}: {e}")))
});
let shortroomid = self.services.short.get_shortroomid(room_id).map_err(|e| {
err!(Database(warn!("Short room ID does not exist in database for {room_id}: {e}")))
});
let (pdu_count, shortroomid) = try_join!(pdu_count, shortroomid)?;
let shorteventid = PduCount::Normal(pdu_count);
let pdu_id: RawPduId = PduId {
shortroomid,
shorteventid,
}
.into();
let pdu_id: RawPduId = PduId { shortroomid, shorteventid }.into();
let pdu = self.services.timeline.get_pdu_from_id(&pdu_id).await?;
@ -80,21 +83,17 @@ impl Service {
event_id,
BTreeMap::from_iter([(
ruma::events::receipt::ReceiptType::ReadPrivate,
BTreeMap::from_iter([(
user_id,
ruma::events::receipt::Receipt {
ts: None, // TODO: start storing the timestamp so we can return one
thread: ruma::events::receipt::ReceiptThread::Unthreaded,
},
)]),
BTreeMap::from_iter([(user_id, ruma::events::receipt::Receipt {
ts: None, // TODO: start storing the timestamp so we can return one
thread: ruma::events::receipt::ReceiptThread::Unthreaded,
})]),
)]),
)]);
let receipt_event_content = ReceiptEventContent(content);
let receipt_sync_event = SyncEphemeralRoomEvent {
content: receipt_event_content,
};
let receipt_sync_event = SyncEphemeralRoomEvent { content: receipt_event_content };
let event = serde_json::value::to_raw_value(&receipt_sync_event).expect("receipt created manually");
let event = serde_json::value::to_raw_value(&receipt_sync_event)
.expect("receipt created manually");
Ok(Raw::from_json(event))
}
@ -104,7 +103,9 @@ impl Service {
#[inline]
#[tracing::instrument(skip(self), level = "debug")]
pub fn readreceipts_since<'a>(
&'a self, room_id: &'a RoomId, since: u64,
&'a self,
room_id: &'a RoomId,
since: u64,
) -> impl Stream<Item = ReceiptItem<'_>> + Send + 'a {
self.db.readreceipts_since(room_id, since)
}
@ -119,7 +120,11 @@ impl Service {
/// Returns the private read marker PDU count.
#[inline]
#[tracing::instrument(skip(self), level = "debug")]
pub async fn private_read_get_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<u64> {
pub async fn private_read_get_count(
&self,
room_id: &RoomId,
user_id: &UserId,
) -> Result<u64> {
self.db.private_read_get_count(room_id, user_id).await
}
@ -137,7 +142,9 @@ where
{
let mut json = BTreeMap::new();
for value in receipts {
let receipt = serde_json::from_str::<SyncEphemeralRoomEvent<ReceiptEventContent>>(value.json().get());
let receipt = serde_json::from_str::<SyncEphemeralRoomEvent<ReceiptEventContent>>(
value.json().get(),
);
if let Ok(value) = receipt {
for (event, receipt) in value.content {
json.insert(event, receipt);
@ -149,9 +156,7 @@ where
let content = ReceiptEventContent::from_iter(json);
Raw::from_json(
serde_json::value::to_raw_value(&SyncEphemeralRoomEvent {
content,
})
.expect("received valid json"),
serde_json::value::to_raw_value(&SyncEphemeralRoomEvent { content })
.expect("received valid json"),
)
}

View file

@ -49,18 +49,18 @@ pub struct RoomQuery<'a> {
type TokenId = ArrayVec<u8, TOKEN_ID_MAX_LEN>;
const TOKEN_ID_MAX_LEN: usize = size_of::<ShortRoomId>() + WORD_MAX_LEN + 1 + size_of::<RawPduId>();
const TOKEN_ID_MAX_LEN: usize =
size_of::<ShortRoomId>() + WORD_MAX_LEN + 1 + size_of::<RawPduId>();
const WORD_MAX_LEN: usize = 50;
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
db: Data {
tokenids: args.db["tokenids"].clone(),
},
db: Data { tokenids: args.db["tokenids"].clone() },
services: Services {
short: args.depend::<rooms::short::Service>("rooms::short"),
state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
state_accessor: args
.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"),
},
}))
@ -103,7 +103,8 @@ pub fn deindex_pdu(&self, shortroomid: ShortRoomId, pdu_id: &RawPduId, message_b
#[implement(Service)]
pub async fn search_pdus<'a>(
&'a self, query: &'a RoomQuery<'a>,
&'a self,
query: &'a RoomQuery<'a>,
) -> Result<(usize, impl Stream<Item = PduEvent> + Send + 'a)> {
let pdu_ids: Vec<_> = self.search_pdu_ids(query).await?.collect().await;
@ -136,7 +137,10 @@ pub async fn search_pdus<'a>(
// result is modeled as a stream such that callers don't have to be refactored
// though an additional async/wrap still exists for now
#[implement(Service)]
pub async fn search_pdu_ids(&self, query: &RoomQuery<'_>) -> Result<impl Stream<Item = RawPduId> + Send + '_> {
pub async fn search_pdu_ids(
&self,
query: &RoomQuery<'_>,
) -> Result<impl Stream<Item = RawPduId> + Send + '_> {
let shortroomid = self.services.short.get_shortroomid(query.room_id).await?;
let pdu_ids = self.search_pdu_ids_query_room(query, shortroomid).await;
@ -147,7 +151,11 @@ pub async fn search_pdu_ids(&self, query: &RoomQuery<'_>) -> Result<impl Stream<
}
#[implement(Service)]
async fn search_pdu_ids_query_room(&self, query: &RoomQuery<'_>, shortroomid: ShortRoomId) -> Vec<Vec<RawPduId>> {
async fn search_pdu_ids_query_room(
&self,
query: &RoomQuery<'_>,
shortroomid: ShortRoomId,
) -> Vec<Vec<RawPduId>> {
tokenize(&query.criteria.search_term)
.stream()
.wide_then(|word| async move {
@ -162,7 +170,9 @@ async fn search_pdu_ids_query_room(&self, query: &RoomQuery<'_>, shortroomid: Sh
/// Iterate over PduId's containing a word
#[implement(Service)]
fn search_pdu_ids_query_words<'a>(
&'a self, shortroomid: ShortRoomId, word: &'a str,
&'a self,
shortroomid: ShortRoomId,
word: &'a str,
) -> impl Stream<Item = RawPduId> + Send + '_ {
self.search_pdu_ids_query_word(shortroomid, word)
.map(move |key| -> RawPduId {
@ -173,7 +183,11 @@ fn search_pdu_ids_query_words<'a>(
/// Iterate over raw database results for a word
#[implement(Service)]
fn search_pdu_ids_query_word(&self, shortroomid: ShortRoomId, word: &str) -> impl Stream<Item = Val<'_>> + Send + '_ {
fn search_pdu_ids_query_word(
&self,
shortroomid: ShortRoomId,
word: &str,
) -> impl Stream<Item = Val<'_>> + Send + '_ {
// rustc says const'ing this not yet stable
let end_id: RawPduId = PduId {
shortroomid,

View file

@ -60,7 +60,10 @@ pub async fn get_or_create_shorteventid(&self, event_id: &EventId) -> ShortEvent
}
#[implement(Service)]
pub fn multi_get_or_create_shorteventid<'a, I>(&'a self, event_ids: I) -> impl Stream<Item = ShortEventId> + Send + '_
pub fn multi_get_or_create_shorteventid<'a, I>(
&'a self,
event_ids: I,
) -> impl Stream<Item = ShortEventId> + Send + '_
where
I: Iterator<Item = &'a EventId> + Clone + Debug + ExactSizeIterator + Send + 'a,
<I as Iterator>::Item: AsRef<[u8]> + Send + Sync + 'a,
@ -72,8 +75,8 @@ where
event_ids.next().map(|event_id| (event_id, result))
})
.map(|(event_id, result)| match result {
Ok(ref short) => utils::u64_from_u8(short),
Err(_) => self.create_shorteventid(event_id),
| Ok(ref short) => utils::u64_from_u8(short),
| Err(_) => self.create_shorteventid(event_id),
})
}
@ -104,7 +107,11 @@ pub async fn get_shorteventid(&self, event_id: &EventId) -> Result<ShortEventId>
}
#[implement(Service)]
pub async fn get_or_create_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> ShortStateKey {
pub async fn get_or_create_shortstatekey(
&self,
event_type: &StateEventType,
state_key: &str,
) -> ShortStateKey {
const BUFSIZE: usize = size_of::<ShortStateKey>();
if let Ok(shortstatekey) = self.get_shortstatekey(event_type, state_key).await {
@ -127,7 +134,11 @@ pub async fn get_or_create_shortstatekey(&self, event_type: &StateEventType, sta
}
#[implement(Service)]
pub async fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result<ShortStateKey> {
pub async fn get_shortstatekey(
&self,
event_type: &StateEventType,
state_key: &str,
) -> Result<ShortStateKey> {
let key = (event_type, state_key);
self.db
.statekey_shortstatekey
@ -153,7 +164,10 @@ where
}
#[implement(Service)]
pub fn multi_get_eventid_from_short<'a, Id, I>(&'a self, shorteventid: I) -> impl Stream<Item = Result<Id>> + Send + 'a
pub fn multi_get_eventid_from_short<'a, Id, I>(
&'a self,
shorteventid: I,
) -> impl Stream<Item = Result<Id>> + Send + 'a
where
I: Iterator<Item = &'a ShortEventId> + Send + 'a,
Id: for<'de> Deserialize<'de> + Sized + ToOwned + 'a,
@ -168,7 +182,10 @@ where
}
#[implement(Service)]
pub async fn get_statekey_from_short(&self, shortstatekey: ShortStateKey) -> Result<(StateEventType, String)> {
pub async fn get_statekey_from_short(
&self,
shortstatekey: ShortStateKey,
) -> Result<(StateEventType, String)> {
const BUFSIZE: usize = size_of::<ShortStateKey>();
self.db

View file

@ -125,7 +125,8 @@ enum Identifier<'a> {
pub struct Service {
services: Services,
pub roomid_spacehierarchy_cache: Mutex<LruCache<OwnedRoomId, Option<CachedSpaceHierarchySummary>>>,
pub roomid_spacehierarchy_cache:
Mutex<LruCache<OwnedRoomId, Option<CachedSpaceHierarchySummary>>>,
}
struct Services {
@ -145,11 +146,13 @@ impl crate::Service for Service {
let cache_size = cache_size * config.cache_capacity_modifier;
Ok(Arc::new(Self {
services: Services {
state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
state_accessor: args
.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"),
state: args.depend::<rooms::state::Service>("rooms::state"),
short: args.depend::<rooms::short::Service>("rooms::short"),
event_handler: args.depend::<rooms::event_handler::Service>("rooms::event_handler"),
event_handler: args
.depend::<rooms::event_handler::Service>("rooms::event_handler"),
timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"),
sending: args.depend::<sending::Service>("sending"),
},
@ -166,28 +169,37 @@ impl Service {
/// Errors if the room does not exist, so a check if the room exists should
/// be done
pub async fn get_federation_hierarchy(
&self, room_id: &RoomId, server_name: &ServerName, suggested_only: bool,
&self,
room_id: &RoomId,
server_name: &ServerName,
suggested_only: bool,
) -> Result<federation::space::get_hierarchy::v1::Response> {
match self
.get_summary_and_children_local(&room_id.to_owned(), Identifier::ServerName(server_name))
.get_summary_and_children_local(
&room_id.to_owned(),
Identifier::ServerName(server_name),
)
.await?
{
Some(SummaryAccessibility::Accessible(room)) => {
| Some(SummaryAccessibility::Accessible(room)) => {
let mut children = Vec::new();
let mut inaccessible_children = Vec::new();
for (child, _via) in get_parent_children_via(&room, suggested_only) {
match self
.get_summary_and_children_local(&child, Identifier::ServerName(server_name))
.get_summary_and_children_local(
&child,
Identifier::ServerName(server_name),
)
.await?
{
Some(SummaryAccessibility::Accessible(summary)) => {
| Some(SummaryAccessibility::Accessible(summary)) => {
children.push((*summary).into());
},
Some(SummaryAccessibility::Inaccessible) => {
| Some(SummaryAccessibility::Inaccessible) => {
inaccessible_children.push(child);
},
None => (),
| None => (),
}
}
@ -197,16 +209,18 @@ impl Service {
inaccessible_children,
})
},
Some(SummaryAccessibility::Inaccessible) => {
Err(Error::BadRequest(ErrorKind::NotFound, "The requested room is inaccessible"))
},
None => Err(Error::BadRequest(ErrorKind::NotFound, "The requested room was not found")),
| Some(SummaryAccessibility::Inaccessible) =>
Err(Error::BadRequest(ErrorKind::NotFound, "The requested room is inaccessible")),
| None =>
Err(Error::BadRequest(ErrorKind::NotFound, "The requested room was not found")),
}
}
/// Gets the summary of a space using solely local information
async fn get_summary_and_children_local(
&self, current_room: &OwnedRoomId, identifier: Identifier<'_>,
&self,
current_room: &OwnedRoomId,
identifier: Identifier<'_>,
) -> Result<Option<SummaryAccessibility>> {
if let Some(cached) = self
.roomid_spacehierarchy_cache
@ -241,9 +255,7 @@ impl Service {
if let Ok(summary) = summary {
self.roomid_spacehierarchy_cache.lock().await.insert(
current_room.clone(),
Some(CachedSpaceHierarchySummary {
summary: summary.clone(),
}),
Some(CachedSpaceHierarchySummary { summary: summary.clone() }),
);
Ok(Some(SummaryAccessibility::Accessible(Box::new(summary))))
@ -258,20 +270,21 @@ impl Service {
/// Gets the summary of a space using solely federation
#[tracing::instrument(skip(self))]
async fn get_summary_and_children_federation(
&self, current_room: &OwnedRoomId, suggested_only: bool, user_id: &UserId, via: &[OwnedServerName],
&self,
current_room: &OwnedRoomId,
suggested_only: bool,
user_id: &UserId,
via: &[OwnedServerName],
) -> Result<Option<SummaryAccessibility>> {
for server in via {
debug_info!("Asking {server} for /hierarchy");
let Ok(response) = self
.services
.sending
.send_federation_request(
server,
federation::space::get_hierarchy::v1::Request {
room_id: current_room.to_owned(),
suggested_only,
},
)
.send_federation_request(server, federation::space::get_hierarchy::v1::Request {
room_id: current_room.to_owned(),
suggested_only,
})
.await
else {
continue;
@ -282,9 +295,7 @@ impl Service {
self.roomid_spacehierarchy_cache.lock().await.insert(
current_room.clone(),
Some(CachedSpaceHierarchySummary {
summary: summary.clone(),
}),
Some(CachedSpaceHierarchySummary { summary: summary.clone() }),
);
for child in response.children {
@ -356,7 +367,11 @@ impl Service {
/// Gets the summary of a space using either local or remote (federation)
/// sources
async fn get_summary_and_children_client(
&self, current_room: &OwnedRoomId, suggested_only: bool, user_id: &UserId, via: &[OwnedServerName],
&self,
current_room: &OwnedRoomId,
suggested_only: bool,
user_id: &UserId,
via: &[OwnedServerName],
) -> Result<Option<SummaryAccessibility>> {
if let Ok(Some(response)) = self
.get_summary_and_children_local(current_room, Identifier::UserId(user_id))
@ -370,7 +385,9 @@ impl Service {
}
async fn get_room_summary(
&self, current_room: &OwnedRoomId, children_state: Vec<Raw<HierarchySpaceChildEvent>>,
&self,
current_room: &OwnedRoomId,
children_state: Vec<Raw<HierarchySpaceChildEvent>>,
identifier: &Identifier<'_>,
) -> Result<SpaceHierarchyParentSummary, Error> {
let room_id: &RoomId = current_room;
@ -388,12 +405,20 @@ impl Service {
.allowed_room_ids(join_rule.clone());
if !self
.is_accessible_child(current_room, &join_rule.clone().into(), identifier, &allowed_room_ids)
.is_accessible_child(
current_room,
&join_rule.clone().into(),
identifier,
&allowed_room_ids,
)
.await
{
debug_info!("User is not allowed to see room {room_id}");
// This error will be caught later
return Err(Error::BadRequest(ErrorKind::forbidden(), "User is not allowed to see the room"));
return Err(Error::BadRequest(
ErrorKind::forbidden(),
"User is not allowed to see the room",
));
}
Ok(SpaceHierarchyParentSummary {
@ -446,7 +471,12 @@ impl Service {
}
pub async fn get_client_hierarchy(
&self, sender_user: &UserId, room_id: &RoomId, limit: usize, short_room_ids: Vec<ShortRoomId>, max_depth: u64,
&self,
sender_user: &UserId,
room_id: &RoomId,
limit: usize,
short_room_ids: Vec<ShortRoomId>,
max_depth: u64,
suggested_only: bool,
) -> Result<client::space::get_hierarchy::v1::Response> {
let mut parents = VecDeque::new();
@ -454,27 +484,30 @@ impl Service {
// Don't start populating the results if we have to start at a specific room.
let mut populate_results = short_room_ids.is_empty();
let mut stack = vec![vec![(
room_id.to_owned(),
match room_id.server_name() {
Some(server_name) => vec![server_name.into()],
None => vec![],
},
)]];
let mut stack = vec![vec![(room_id.to_owned(), match room_id.server_name() {
| Some(server_name) => vec![server_name.into()],
| None => vec![],
})]];
let mut results = Vec::with_capacity(limit);
while let Some((current_room, via)) = { next_room_to_traverse(&mut stack, &mut parents) } {
while let Some((current_room, via)) = { next_room_to_traverse(&mut stack, &mut parents) }
{
if results.len() >= limit {
break;
}
match (
self.get_summary_and_children_client(&current_room, suggested_only, sender_user, &via)
.await?,
self.get_summary_and_children_client(
&current_room,
suggested_only,
sender_user,
&via,
)
.await?,
current_room == room_id,
) {
(Some(SummaryAccessibility::Accessible(summary)), _) => {
| (Some(SummaryAccessibility::Accessible(summary)), _) => {
let mut children: Vec<(OwnedRoomId, Vec<OwnedServerName>)> =
get_parent_children_via(&summary, suggested_only)
.into_iter()
@ -493,7 +526,9 @@ impl Service {
self.services
.short
.get_shortroomid(room)
.map_ok(|short| Some(&short) != short_room_ids.get(parents.len()))
.map_ok(|short| {
Some(&short) != short_room_ids.get(parents.len())
})
.unwrap_or_else(|_| false)
})
.map(Clone::clone)
@ -525,14 +560,20 @@ impl Service {
// Root room in the space hierarchy, we return an error
// if this one fails.
},
(Some(SummaryAccessibility::Inaccessible), true) => {
return Err(Error::BadRequest(ErrorKind::forbidden(), "The requested room is inaccessible"));
| (Some(SummaryAccessibility::Inaccessible), true) => {
return Err(Error::BadRequest(
ErrorKind::forbidden(),
"The requested room is inaccessible",
));
},
(None, true) => {
return Err(Error::BadRequest(ErrorKind::forbidden(), "The requested room was not found"));
| (None, true) => {
return Err(Error::BadRequest(
ErrorKind::forbidden(),
"The requested room was not found",
));
},
// Just ignore other unavailable rooms
(None | Some(SummaryAccessibility::Inaccessible), false) => (),
| (None | Some(SummaryAccessibility::Inaccessible), false) => (),
}
}
@ -544,15 +585,19 @@ impl Service {
let short_room_ids: Vec<_> = parents
.iter()
.stream()
.filter_map(|room_id| async move { self.services.short.get_shortroomid(room_id).await.ok() })
.filter_map(|room_id| async move {
self.services.short.get_shortroomid(room_id).await.ok()
})
.collect()
.await;
Some(
PaginationToken {
short_room_ids,
limit: UInt::new(max_depth).expect("When sent in request it must have been valid UInt"),
max_depth: UInt::new(max_depth).expect("When sent in request it must have been valid UInt"),
limit: UInt::new(max_depth)
.expect("When sent in request it must have been valid UInt"),
max_depth: UInt::new(max_depth)
.expect("When sent in request it must have been valid UInt"),
suggested_only,
}
.to_string(),
@ -566,9 +611,12 @@ impl Service {
/// Simply returns the stripped m.space.child events of a room
async fn get_stripped_space_child_events(
&self, room_id: &RoomId,
&self,
room_id: &RoomId,
) -> Result<Option<Vec<Raw<HierarchySpaceChildEvent>>>, Error> {
let Ok(current_shortstatehash) = self.services.state.get_room_shortstatehash(room_id).await else {
let Ok(current_shortstatehash) =
self.services.state.get_room_shortstatehash(room_id).await
else {
return Ok(None);
};
@ -581,18 +629,17 @@ impl Service {
let mut children_pdus = Vec::with_capacity(state.len());
for (key, id) in state {
let (event_type, state_key) = self.services.short.get_statekey_from_short(key).await?;
let (event_type, state_key) =
self.services.short.get_statekey_from_short(key).await?;
if event_type != StateEventType::SpaceChild {
continue;
}
let pdu = self
.services
.timeline
.get_pdu(&id)
.await
.map_err(|e| err!(Database("Event {id:?} in space state not found: {e:?}")))?;
let pdu =
self.services.timeline.get_pdu(&id).await.map_err(|e| {
err!(Database("Event {id:?} in space state not found: {e:?}"))
})?;
if let Ok(content) = pdu.get_content::<SpaceChildEventContent>() {
if content.via.is_empty() {
@ -610,11 +657,14 @@ impl Service {
/// With the given identifier, checks if a room is accessable
async fn is_accessible_child(
&self, current_room: &OwnedRoomId, join_rule: &SpaceRoomJoinRule, identifier: &Identifier<'_>,
&self,
current_room: &OwnedRoomId,
join_rule: &SpaceRoomJoinRule,
identifier: &Identifier<'_>,
allowed_room_ids: &Vec<OwnedRoomId>,
) -> bool {
match identifier {
Identifier::ServerName(server_name) => {
| Identifier::ServerName(server_name) => {
// Checks if ACLs allow for the server to participate
if self
.services
@ -626,7 +676,7 @@ impl Service {
return false;
}
},
Identifier::UserId(user_id) => {
| Identifier::UserId(user_id) => {
if self
.services
.state_cache
@ -642,16 +692,18 @@ impl Service {
},
}
match &join_rule {
SpaceRoomJoinRule::Public | SpaceRoomJoinRule::Knock | SpaceRoomJoinRule::KnockRestricted => true,
SpaceRoomJoinRule::Restricted => {
| SpaceRoomJoinRule::Public
| SpaceRoomJoinRule::Knock
| SpaceRoomJoinRule::KnockRestricted => true,
| SpaceRoomJoinRule::Restricted => {
for room in allowed_room_ids {
match identifier {
Identifier::UserId(user) => {
| Identifier::UserId(user) => {
if self.services.state_cache.is_joined(user, room).await {
return true;
}
},
Identifier::ServerName(server) => {
| Identifier::ServerName(server) => {
if self.services.state_cache.server_in_room(server, room).await {
return true;
}
@ -661,7 +713,7 @@ impl Service {
false
},
// Invite only, Private, or Custom join rule
_ => false,
| _ => false,
}
}
}
@ -737,7 +789,8 @@ fn summary_to_chunk(summary: SpaceHierarchyParentSummary) -> SpaceHierarchyRooms
/// Returns the children of a SpaceHierarchyParentSummary, making use of the
/// children_state field
fn get_parent_children_via(
parent: &SpaceHierarchyParentSummary, suggested_only: bool,
parent: &SpaceHierarchyParentSummary,
suggested_only: bool,
) -> Vec<(OwnedRoomId, Vec<OwnedServerName>)> {
parent
.children_state
@ -755,7 +808,8 @@ fn get_parent_children_via(
}
fn next_room_to_traverse(
stack: &mut Vec<Vec<(OwnedRoomId, Vec<OwnedServerName>)>>, parents: &mut VecDeque<OwnedRoomId>,
stack: &mut Vec<Vec<(OwnedRoomId, Vec<OwnedServerName>)>>,
parents: &mut VecDeque<OwnedRoomId>,
) -> Option<(OwnedRoomId, Vec<OwnedServerName>)> {
while stack.last().is_some_and(Vec::is_empty) {
stack.pop();

View file

@ -69,18 +69,15 @@ fn get_summary_children() {
}
.into();
assert_eq!(
get_parent_children_via(&summary, false),
vec![
(owned_room_id!("!foo:example.org"), vec![owned_server_name!("example.org")]),
(owned_room_id!("!bar:example.org"), vec![owned_server_name!("example.org")]),
(owned_room_id!("!baz:example.org"), vec![owned_server_name!("example.org")])
]
);
assert_eq!(
get_parent_children_via(&summary, true),
vec![(owned_room_id!("!bar:example.org"), vec![owned_server_name!("example.org")])]
);
assert_eq!(get_parent_children_via(&summary, false), vec![
(owned_room_id!("!foo:example.org"), vec![owned_server_name!("example.org")]),
(owned_room_id!("!bar:example.org"), vec![owned_server_name!("example.org")]),
(owned_room_id!("!baz:example.org"), vec![owned_server_name!("example.org")])
]);
assert_eq!(get_parent_children_via(&summary, true), vec![(
owned_room_id!("!bar:example.org"),
vec![owned_server_name!("example.org")]
)]);
}
#[test]

View file

@ -16,7 +16,9 @@ use conduwuit::{
warn, PduEvent, Result,
};
use database::{Deserialized, Ignore, Interfix, Map};
use futures::{future::join_all, pin_mut, FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt};
use futures::{
future::join_all, pin_mut, FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt,
};
use ruma::{
events::{
room::{create::RoomCreateEventContent, member::RoomMemberEventContent},
@ -70,8 +72,10 @@ impl crate::Service for Service {
short: args.depend::<rooms::short::Service>("rooms::short"),
spaces: args.depend::<rooms::spaces::Service>("rooms::spaces"),
state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"),
state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
state_compressor: args.depend::<rooms::state_compressor::Service>("rooms::state_compressor"),
state_accessor: args
.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
state_compressor: args
.depend::<rooms::state_compressor::Service>("rooms::state_compressor"),
timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"),
},
db: Data {
@ -100,7 +104,8 @@ impl Service {
shortstatehash: u64,
statediffnew: Arc<HashSet<CompressedStateEvent>>,
_statediffremoved: Arc<HashSet<CompressedStateEvent>>,
state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex
state_lock: &RoomMutexGuard, /* Take mutex guard to make sure users get the room state
* mutex */
) -> Result {
let event_ids = statediffnew
.iter()
@ -120,8 +125,9 @@ impl Service {
};
match pdu.kind {
TimelineEventType::RoomMember => {
let Some(user_id) = pdu.state_key.as_ref().map(UserId::parse).flat_ok() else {
| TimelineEventType::RoomMember => {
let Some(user_id) = pdu.state_key.as_ref().map(UserId::parse).flat_ok()
else {
continue;
};
@ -131,10 +137,18 @@ impl Service {
self.services
.state_cache
.update_membership(room_id, &user_id, membership_event, &pdu.sender, None, None, false)
.update_membership(
room_id,
&user_id,
membership_event,
&pdu.sender,
None,
None,
false,
)
.await?;
},
TimelineEventType::SpaceChild => {
| TimelineEventType::SpaceChild => {
self.services
.spaces
.roomid_spacehierarchy_cache
@ -142,7 +156,7 @@ impl Service {
.await
.remove(&pdu.room_id);
},
_ => continue,
| _ => continue,
}
}
@ -159,7 +173,10 @@ impl Service {
/// to `stateid_pduid` and adds the incoming event to `eventid_statehash`.
#[tracing::instrument(skip(self, state_ids_compressed), level = "debug")]
pub async fn set_event_state(
&self, event_id: &EventId, room_id: &RoomId, state_ids_compressed: Arc<HashSet<CompressedStateEvent>>,
&self,
event_id: &EventId,
room_id: &RoomId,
state_ids_compressed: Arc<HashSet<CompressedStateEvent>>,
) -> Result<ShortStateHash> {
const KEY_LEN: usize = size_of::<ShortEventId>();
const VAL_LEN: usize = size_of::<ShortStateHash>();
@ -190,22 +207,23 @@ impl Service {
Vec::new()
};
let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() {
let statediffnew: HashSet<_> = state_ids_compressed
.difference(&parent_stateinfo.full_state)
.copied()
.collect();
let (statediffnew, statediffremoved) =
if let Some(parent_stateinfo) = states_parents.last() {
let statediffnew: HashSet<_> = state_ids_compressed
.difference(&parent_stateinfo.full_state)
.copied()
.collect();
let statediffremoved: HashSet<_> = parent_stateinfo
.full_state
.difference(&state_ids_compressed)
.copied()
.collect();
let statediffremoved: HashSet<_> = parent_stateinfo
.full_state
.difference(&state_ids_compressed)
.copied()
.collect();
(Arc::new(statediffnew), Arc::new(statediffremoved))
} else {
(state_ids_compressed, Arc::new(HashSet::new()))
};
(Arc::new(statediffnew), Arc::new(statediffremoved))
} else {
(state_ids_compressed, Arc::new(HashSet::new()))
};
self.services.state_compressor.save_state_from_diff(
shortstatehash,
statediffnew,
@ -338,7 +356,8 @@ impl Service {
&self,
room_id: &RoomId,
shortstatehash: u64,
_mutex_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex
_mutex_lock: &RoomMutexGuard, /* Take mutex guard to make sure users get the room
* state mutex */
) {
const BUFSIZE: usize = size_of::<u64>();
@ -366,7 +385,10 @@ impl Service {
.deserialized()
}
pub fn get_forward_extremities<'a>(&'a self, room_id: &'a RoomId) -> impl Stream<Item = &EventId> + Send + '_ {
pub fn get_forward_extremities<'a>(
&'a self,
room_id: &'a RoomId,
) -> impl Stream<Item = &EventId> + Send + '_ {
let prefix = (room_id, Interfix);
self.db
@ -380,7 +402,8 @@ impl Service {
&self,
room_id: &RoomId,
event_ids: Vec<OwnedEventId>,
_state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex
_state_lock: &RoomMutexGuard, /* Take mutex guard to make sure users get the room
* state mutex */
) {
let prefix = (room_id, Interfix);
self.db
@ -399,26 +422,33 @@ impl Service {
/// This fetches auth events from the current state.
#[tracing::instrument(skip(self, content), level = "debug")]
pub async fn get_auth_events(
&self, room_id: &RoomId, kind: &TimelineEventType, sender: &UserId, state_key: Option<&str>,
&self,
room_id: &RoomId,
kind: &TimelineEventType,
sender: &UserId,
state_key: Option<&str>,
content: &serde_json::value::RawValue,
) -> Result<StateMap<Arc<PduEvent>>> {
let Ok(shortstatehash) = self.get_room_shortstatehash(room_id).await else {
return Ok(HashMap::new());
};
let mut sauthevents: HashMap<_, _> = state_res::auth_types_for_event(kind, sender, state_key, content)?
.iter()
.stream()
.broad_filter_map(|(event_type, state_key)| {
self.services
.short
.get_shortstatekey(event_type, state_key)
.map_ok(move |ssk| (ssk, (event_type, state_key)))
.map(Result::ok)
})
.map(|(ssk, (event_type, state_key))| (ssk, (event_type.to_owned(), state_key.to_owned())))
.collect()
.await;
let mut sauthevents: HashMap<_, _> =
state_res::auth_types_for_event(kind, sender, state_key, content)?
.iter()
.stream()
.broad_filter_map(|(event_type, state_key)| {
self.services
.short
.get_shortstatekey(event_type, state_key)
.map_ok(move |ssk| (ssk, (event_type, state_key)))
.map(Result::ok)
})
.map(|(ssk, (event_type, state_key))| {
(ssk, (event_type.to_owned(), state_key.to_owned()))
})
.collect()
.await;
let (state_keys, event_ids): (Vec<_>, Vec<_>) = self
.services

View file

@ -39,14 +39,16 @@ impl Data {
services: Services {
short: args.depend::<rooms::short::Service>("rooms::short"),
state: args.depend::<rooms::state::Service>("rooms::state"),
state_compressor: args.depend::<rooms::state_compressor::Service>("rooms::state_compressor"),
state_compressor: args
.depend::<rooms::state_compressor::Service>("rooms::state_compressor"),
timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"),
},
}
}
pub(super) async fn state_full(
&self, shortstatehash: ShortStateHash,
&self,
shortstatehash: ShortStateHash,
) -> Result<HashMap<(StateEventType, String), PduEvent>> {
let state = self
.state_full_pdus(shortstatehash)
@ -58,7 +60,10 @@ impl Data {
Ok(state)
}
pub(super) async fn state_full_pdus(&self, shortstatehash: ShortStateHash) -> Result<Vec<PduEvent>> {
pub(super) async fn state_full_pdus(
&self,
shortstatehash: ShortStateHash,
) -> Result<Vec<PduEvent>> {
let short_ids = self.state_full_shortids(shortstatehash).await?;
let full_pdus = self
@ -66,16 +71,19 @@ impl Data {
.short
.multi_get_eventid_from_short(short_ids.iter().map(ref_at!(1)))
.ready_filter_map(Result::ok)
.broad_filter_map(
|event_id: OwnedEventId| async move { self.services.timeline.get_pdu(&event_id).await.ok() },
)
.broad_filter_map(|event_id: OwnedEventId| async move {
self.services.timeline.get_pdu(&event_id).await.ok()
})
.collect()
.await;
Ok(full_pdus)
}
pub(super) async fn state_full_ids<Id>(&self, shortstatehash: ShortStateHash) -> Result<HashMap<ShortStateKey, Id>>
pub(super) async fn state_full_ids<Id>(
&self,
shortstatehash: ShortStateHash,
) -> Result<HashMap<ShortStateKey, Id>>
where
Id: for<'de> Deserialize<'de> + Send + Sized + ToOwned,
<Id as ToOwned>::Owned: Borrow<EventId>,
@ -96,7 +104,8 @@ impl Data {
}
pub(super) async fn state_full_shortids(
&self, shortstatehash: ShortStateHash,
&self,
shortstatehash: ShortStateHash,
) -> Result<Vec<(ShortStateKey, ShortEventId)>> {
let shortids = self
.services
@ -118,7 +127,10 @@ impl Data {
/// Returns a single EventId from `room_id` with key
/// (`event_type`,`state_key`).
pub(super) async fn state_get_id<Id>(
&self, shortstatehash: ShortStateHash, event_type: &StateEventType, state_key: &str,
&self,
shortstatehash: ShortStateHash,
event_type: &StateEventType,
state_key: &str,
) -> Result<Id>
where
Id: for<'de> Deserialize<'de> + Sized + ToOwned,
@ -155,10 +167,15 @@ impl Data {
/// Returns a single PDU from `room_id` with key (`event_type`,`state_key`).
pub(super) async fn state_get(
&self, shortstatehash: ShortStateHash, event_type: &StateEventType, state_key: &str,
&self,
shortstatehash: ShortStateHash,
event_type: &StateEventType,
state_key: &str,
) -> Result<PduEvent> {
self.state_get_id(shortstatehash, event_type, state_key)
.and_then(|event_id: OwnedEventId| async move { self.services.timeline.get_pdu(&event_id).await })
.and_then(|event_id: OwnedEventId| async move {
self.services.timeline.get_pdu(&event_id).await
})
.await
}
@ -179,7 +196,8 @@ impl Data {
/// Returns the full room state.
pub(super) async fn room_state_full(
&self, room_id: &RoomId,
&self,
room_id: &RoomId,
) -> Result<HashMap<(StateEventType, String), PduEvent>> {
self.services
.state
@ -203,7 +221,10 @@ impl Data {
/// Returns a single EventId from `room_id` with key
/// (`event_type`,`state_key`).
pub(super) async fn room_state_get_id<Id>(
&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
&self,
room_id: &RoomId,
event_type: &StateEventType,
state_key: &str,
) -> Result<Id>
where
Id: for<'de> Deserialize<'de> + Sized + ToOwned,
@ -218,7 +239,10 @@ impl Data {
/// Returns a single PDU from `room_id` with key (`event_type`,`state_key`).
pub(super) async fn room_state_get(
&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
&self,
room_id: &RoomId,
event_type: &StateEventType,
state_key: &str,
) -> Result<PduEvent> {
self.services
.state

View file

@ -34,8 +34,8 @@ use ruma::{
},
room::RoomType,
space::SpaceRoomJoinRule,
EventEncryptionAlgorithm, EventId, JsOption, OwnedRoomAliasId, OwnedRoomId, OwnedServerName, OwnedUserId, RoomId,
ServerName, UserId,
EventEncryptionAlgorithm, EventId, JsOption, OwnedRoomAliasId, OwnedRoomId, OwnedServerName,
OwnedUserId, RoomId, ServerName, UserId,
};
use serde::Deserialize;
@ -75,8 +75,12 @@ impl crate::Service for Service {
timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"),
},
db: Data::new(&args),
server_visibility_cache: StdMutex::new(LruCache::new(usize_from_f64(server_visibility_cache_capacity)?)),
user_visibility_cache: StdMutex::new(LruCache::new(usize_from_f64(user_visibility_cache_capacity)?)),
server_visibility_cache: StdMutex::new(LruCache::new(usize_from_f64(
server_visibility_cache_capacity,
)?)),
user_visibility_cache: StdMutex::new(LruCache::new(usize_from_f64(
user_visibility_cache_capacity,
)?)),
}))
}
@ -102,7 +106,10 @@ impl Service {
/// Builds a StateMap by iterating over all keys that start
/// with state_hash, this gives the full state for the given state_hash.
#[tracing::instrument(skip(self), level = "debug")]
pub async fn state_full_ids<Id>(&self, shortstatehash: ShortStateHash) -> Result<HashMap<ShortStateKey, Id>>
pub async fn state_full_ids<Id>(
&self,
shortstatehash: ShortStateHash,
) -> Result<HashMap<ShortStateKey, Id>>
where
Id: for<'de> Deserialize<'de> + Send + Sized + ToOwned,
<Id as ToOwned>::Owned: Borrow<EventId>,
@ -112,13 +119,15 @@ impl Service {
#[inline]
pub async fn state_full_shortids(
&self, shortstatehash: ShortStateHash,
&self,
shortstatehash: ShortStateHash,
) -> Result<Vec<(ShortStateKey, ShortEventId)>> {
self.db.state_full_shortids(shortstatehash).await
}
pub async fn state_full(
&self, shortstatehash: ShortStateHash,
&self,
shortstatehash: ShortStateHash,
) -> Result<HashMap<(StateEventType, String), PduEvent>> {
self.db.state_full(shortstatehash).await
}
@ -127,7 +136,10 @@ impl Service {
/// `state_key`).
#[tracing::instrument(skip(self), level = "debug")]
pub async fn state_get_id<Id>(
&self, shortstatehash: ShortStateHash, event_type: &StateEventType, state_key: &str,
&self,
shortstatehash: ShortStateHash,
event_type: &StateEventType,
state_key: &str,
) -> Result<Id>
where
Id: for<'de> Deserialize<'de> + Sized + ToOwned,
@ -142,7 +154,10 @@ impl Service {
/// `state_key`).
#[inline]
pub async fn state_get(
&self, shortstatehash: ShortStateHash, event_type: &StateEventType, state_key: &str,
&self,
shortstatehash: ShortStateHash,
event_type: &StateEventType,
state_key: &str,
) -> Result<PduEvent> {
self.db
.state_get(shortstatehash, event_type, state_key)
@ -151,7 +166,10 @@ impl Service {
/// Returns a single PDU from `room_id` with key (`event_type`,`state_key`).
pub async fn state_get_content<T>(
&self, shortstatehash: ShortStateHash, event_type: &StateEventType, state_key: &str,
&self,
shortstatehash: ShortStateHash,
event_type: &StateEventType,
state_key: &str,
) -> Result<T>
where
T: for<'de> Deserialize<'de>,
@ -162,7 +180,11 @@ impl Service {
}
/// Get membership for given user in state
async fn user_membership(&self, shortstatehash: ShortStateHash, user_id: &UserId) -> MembershipState {
async fn user_membership(
&self,
shortstatehash: ShortStateHash,
user_id: &UserId,
) -> MembershipState {
self.state_get_content(shortstatehash, &StateEventType::RoomMember, user_id.as_str())
.await
.map_or(MembershipState::Leave, |c: RoomMemberEventContent| c.membership)
@ -185,7 +207,12 @@ impl Service {
/// Whether a server is allowed to see an event through federation, based on
/// the room's history_visibility at that event's state.
#[tracing::instrument(skip_all, level = "trace")]
pub async fn server_can_see_event(&self, origin: &ServerName, room_id: &RoomId, event_id: &EventId) -> bool {
pub async fn server_can_see_event(
&self,
origin: &ServerName,
room_id: &RoomId,
event_id: &EventId,
) -> bool {
let Ok(shortstatehash) = self.pdu_shortstatehash(event_id).await else {
return true;
};
@ -213,20 +240,20 @@ impl Service {
.ready_filter(|member| member.server_name() == origin);
let visibility = match history_visibility {
HistoryVisibility::WorldReadable | HistoryVisibility::Shared => true,
HistoryVisibility::Invited => {
| HistoryVisibility::WorldReadable | HistoryVisibility::Shared => true,
| HistoryVisibility::Invited => {
// Allow if any member on requesting server was AT LEAST invited, else deny
current_server_members
.any(|member| self.user_was_invited(shortstatehash, member))
.await
},
HistoryVisibility::Joined => {
| HistoryVisibility::Joined => {
// Allow if any member on requested server was joined, else deny
current_server_members
.any(|member| self.user_was_joined(shortstatehash, member))
.await
},
_ => {
| _ => {
error!("Unknown history visibility {history_visibility}");
false
},
@ -243,7 +270,12 @@ impl Service {
/// Whether a user is allowed to see an event, based on
/// the room's history_visibility at that event's state.
#[tracing::instrument(skip_all, level = "trace")]
pub async fn user_can_see_event(&self, user_id: &UserId, room_id: &RoomId, event_id: &EventId) -> bool {
pub async fn user_can_see_event(
&self,
user_id: &UserId,
room_id: &RoomId,
event_id: &EventId,
) -> bool {
let Ok(shortstatehash) = self.pdu_shortstatehash(event_id).await else {
return true;
};
@ -267,17 +299,17 @@ impl Service {
});
let visibility = match history_visibility {
HistoryVisibility::WorldReadable => true,
HistoryVisibility::Shared => currently_member,
HistoryVisibility::Invited => {
| HistoryVisibility::WorldReadable => true,
| HistoryVisibility::Shared => currently_member,
| HistoryVisibility::Invited => {
// Allow if any member on requesting server was AT LEAST invited, else deny
self.user_was_invited(shortstatehash, user_id).await
},
HistoryVisibility::Joined => {
| HistoryVisibility::Joined => {
// Allow if any member on requested server was joined, else deny
self.user_was_joined(shortstatehash, user_id).await
},
_ => {
| _ => {
error!("Unknown history visibility {history_visibility}");
false
},
@ -307,9 +339,10 @@ impl Service {
});
match history_visibility {
HistoryVisibility::Invited => self.services.state_cache.is_invited(user_id, room_id).await,
HistoryVisibility::WorldReadable => true,
_ => false,
| HistoryVisibility::Invited =>
self.services.state_cache.is_invited(user_id, room_id).await,
| HistoryVisibility::WorldReadable => true,
| _ => false,
}
}
@ -320,7 +353,10 @@ impl Service {
/// Returns the full room state.
#[tracing::instrument(skip(self), level = "debug")]
pub async fn room_state_full(&self, room_id: &RoomId) -> Result<HashMap<(StateEventType, String), PduEvent>> {
pub async fn room_state_full(
&self,
room_id: &RoomId,
) -> Result<HashMap<(StateEventType, String), PduEvent>> {
self.db.room_state_full(room_id).await
}
@ -334,7 +370,10 @@ impl Service {
/// `state_key`).
#[tracing::instrument(skip(self), level = "debug")]
pub async fn room_state_get_id<Id>(
&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
&self,
room_id: &RoomId,
event_type: &StateEventType,
state_key: &str,
) -> Result<Id>
where
Id: for<'de> Deserialize<'de> + Sized + ToOwned,
@ -349,14 +388,20 @@ impl Service {
/// `state_key`).
#[tracing::instrument(skip(self), level = "debug")]
pub async fn room_state_get(
&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
&self,
room_id: &RoomId,
event_type: &StateEventType,
state_key: &str,
) -> Result<PduEvent> {
self.db.room_state_get(room_id, event_type, state_key).await
}
/// Returns a single PDU from `room_id` with key (`event_type`,`state_key`).
pub async fn room_state_get_content<T>(
&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
&self,
room_id: &RoomId,
event_type: &StateEventType,
state_key: &str,
) -> Result<T>
where
T: for<'de> Deserialize<'de>,
@ -381,18 +426,29 @@ impl Service {
JsOption::from_option(content)
}
pub async fn get_member(&self, room_id: &RoomId, user_id: &UserId) -> Result<RoomMemberEventContent> {
pub async fn get_member(
&self,
room_id: &RoomId,
user_id: &UserId,
) -> Result<RoomMemberEventContent> {
self.room_state_get_content(room_id, &StateEventType::RoomMember, user_id.as_str())
.await
}
pub async fn user_can_invite(
&self, room_id: &RoomId, sender: &UserId, target_user: &UserId, state_lock: &RoomMutexGuard,
&self,
room_id: &RoomId,
sender: &UserId,
target_user: &UserId,
state_lock: &RoomMutexGuard,
) -> bool {
self.services
.timeline
.create_hash_and_sign_event(
PduBuilder::state(target_user.into(), &RoomMemberEventContent::new(MembershipState::Invite)),
PduBuilder::state(
target_user.into(),
&RoomMemberEventContent::new(MembershipState::Invite),
),
sender,
room_id,
state_lock,
@ -405,7 +461,9 @@ impl Service {
pub async fn is_world_readable(&self, room_id: &RoomId) -> bool {
self.room_state_get_content(room_id, &StateEventType::RoomHistoryVisibility, "")
.await
.map(|c: RoomHistoryVisibilityEventContent| c.history_visibility == HistoryVisibility::WorldReadable)
.map(|c: RoomHistoryVisibilityEventContent| {
c.history_visibility == HistoryVisibility::WorldReadable
})
.unwrap_or(false)
}
@ -439,7 +497,11 @@ impl Service {
/// If federation is true, it allows redaction events from any user of the
/// same server as the original event sender
pub async fn user_can_redact(
&self, redacts: &EventId, sender: &UserId, room_id: &RoomId, federation: bool,
&self,
redacts: &EventId,
sender: &UserId,
room_id: &RoomId,
federation: bool,
) -> Result<bool> {
let redacting_event = self.services.timeline.get_pdu(redacts).await;
@ -451,7 +513,11 @@ impl Service {
}
if let Ok(pl_event_content) = self
.room_state_get_content::<RoomPowerLevelsEventContent>(room_id, &StateEventType::RoomPowerLevels, "")
.room_state_get_content::<RoomPowerLevelsEventContent>(
room_id,
&StateEventType::RoomPowerLevels,
"",
)
.await
{
let pl_event: RoomPowerLevels = pl_event_content.into();
@ -485,10 +551,15 @@ impl Service {
}
/// Returns the join rule (`SpaceRoomJoinRule`) for a given room
pub async fn get_join_rule(&self, room_id: &RoomId) -> Result<(SpaceRoomJoinRule, Vec<OwnedRoomId>)> {
pub async fn get_join_rule(
&self,
room_id: &RoomId,
) -> Result<(SpaceRoomJoinRule, Vec<OwnedRoomId>)> {
self.room_state_get_content(room_id, &StateEventType::RoomJoinRules, "")
.await
.map(|c: RoomJoinRulesEventContent| (c.join_rule.clone().into(), self.allowed_room_ids(c.join_rule)))
.map(|c: RoomJoinRulesEventContent| {
(c.join_rule.clone().into(), self.allowed_room_ids(c.join_rule))
})
.or_else(|_| Ok((SpaceRoomJoinRule::Invite, vec![])))
}
@ -497,10 +568,7 @@ impl Service {
let mut room_ids = Vec::with_capacity(1);
if let JoinRule::Restricted(r) | JoinRule::KnockRestricted(r) = join_rule {
for rule in r.allow {
if let AllowRule::RoomMembership(RoomMembership {
room_id: membership,
}) = rule
{
if let AllowRule::RoomMembership(RoomMembership { room_id: membership }) = rule {
room_ids.push(membership.clone());
}
}
@ -520,7 +588,10 @@ impl Service {
/// Gets the room's encryption algorithm if `m.room.encryption` state event
/// is found
pub async fn get_room_encryption(&self, room_id: &RoomId) -> Result<EventEncryptionAlgorithm> {
pub async fn get_room_encryption(
&self,
room_id: &RoomId,
) -> Result<EventEncryptionAlgorithm> {
self.room_state_get_content(room_id, &StateEventType::RoomEncryption, "")
.await
.map(|content: RoomEncryptionEventContent| content.algorithm)

View file

@ -20,7 +20,8 @@ use ruma::{
member::{MembershipState, RoomMemberEventContent},
power_levels::RoomPowerLevelsEventContent,
},
AnyStrippedStateEvent, AnySyncStateEvent, GlobalAccountDataEventType, RoomAccountDataEventType, StateEventType,
AnyStrippedStateEvent, AnySyncStateEvent, GlobalAccountDataEventType,
RoomAccountDataEventType, StateEventType,
},
int,
serde::Raw,
@ -68,7 +69,8 @@ impl crate::Service for Service {
services: Services {
account_data: args.depend::<account_data::Service>("account_data"),
globals: args.depend::<globals::Service>("globals"),
state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
state_accessor: args
.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
users: args.depend::<users::Service>("users"),
},
db: Data {
@ -96,8 +98,13 @@ impl Service {
#[tracing::instrument(skip(self, last_state))]
#[allow(clippy::too_many_arguments)]
pub async fn update_membership(
&self, room_id: &RoomId, user_id: &UserId, membership_event: RoomMemberEventContent, sender: &UserId,
last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>, invite_via: Option<Vec<OwnedServerName>>,
&self,
room_id: &RoomId,
user_id: &UserId,
membership_event: RoomMemberEventContent,
sender: &UserId,
last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
invite_via: Option<Vec<OwnedServerName>>,
update_joined_count: bool,
) -> Result<()> {
let membership = membership_event.membership;
@ -138,7 +145,7 @@ impl Service {
}
match &membership {
MembershipState::Join => {
| MembershipState::Join => {
// Check if the user never joined this room
if !self.once_joined(user_id, room_id).await {
// Add the user ID to the join list then
@ -181,12 +188,21 @@ impl Service {
if let Ok(tag_event) = self
.services
.account_data
.get_room(&predecessor.room_id, user_id, RoomAccountDataEventType::Tag)
.get_room(
&predecessor.room_id,
user_id,
RoomAccountDataEventType::Tag,
)
.await
{
self.services
.account_data
.update(Some(room_id), user_id, RoomAccountDataEventType::Tag, &tag_event)
.update(
Some(room_id),
user_id,
RoomAccountDataEventType::Tag,
&tag_event,
)
.await
.ok();
};
@ -195,7 +211,10 @@ impl Service {
if let Ok(mut direct_event) = self
.services
.account_data
.get_global::<DirectEvent>(user_id, GlobalAccountDataEventType::Direct)
.get_global::<DirectEvent>(
user_id,
GlobalAccountDataEventType::Direct,
)
.await
{
let mut room_ids_updated = false;
@ -213,7 +232,8 @@ impl Service {
None,
user_id,
GlobalAccountDataEventType::Direct.to_string().into(),
&serde_json::to_value(&direct_event).expect("to json always works"),
&serde_json::to_value(&direct_event)
.expect("to json always works"),
)
.await?;
}
@ -223,7 +243,7 @@ impl Service {
self.mark_as_joined(user_id, room_id);
},
MembershipState::Invite => {
| MembershipState::Invite => {
// We want to know if the sender is ignored by the receiver
if self.services.users.user_is_ignored(sender, user_id).await {
return Ok(());
@ -232,10 +252,10 @@ impl Service {
self.mark_as_invited(user_id, room_id, last_state, invite_via)
.await;
},
MembershipState::Leave | MembershipState::Ban => {
| MembershipState::Leave | MembershipState::Ban => {
self.mark_as_left(user_id, room_id);
},
_ => {},
| _ => {},
}
if update_joined_count {
@ -246,7 +266,11 @@ impl Service {
}
#[tracing::instrument(skip(self, room_id, appservice), level = "debug")]
pub async fn appservice_in_room(&self, room_id: &RoomId, appservice: &RegistrationInfo) -> bool {
pub async fn appservice_in_room(
&self,
room_id: &RoomId,
appservice: &RegistrationInfo,
) -> bool {
if let Some(cached) = self
.appservice_in_room_cache
.read()
@ -347,7 +371,10 @@ impl Service {
/// Returns an iterator of all servers participating in this room.
#[tracing::instrument(skip(self), level = "debug")]
pub fn room_servers<'a>(&'a self, room_id: &'a RoomId) -> impl Stream<Item = &ServerName> + Send + 'a {
pub fn room_servers<'a>(
&'a self,
room_id: &'a RoomId,
) -> impl Stream<Item = &ServerName> + Send + 'a {
let prefix = (room_id, Interfix);
self.db
.roomserverids
@ -357,7 +384,11 @@ impl Service {
}
#[tracing::instrument(skip(self), level = "debug")]
pub async fn server_in_room<'a>(&'a self, server: &'a ServerName, room_id: &'a RoomId) -> bool {
pub async fn server_in_room<'a>(
&'a self,
server: &'a ServerName,
room_id: &'a RoomId,
) -> bool {
let key = (server, room_id);
self.db.serverroomids.qry(&key).await.is_ok()
}
@ -365,7 +396,10 @@ impl Service {
/// Returns an iterator of all rooms a server participates in (as far as we
/// know).
#[tracing::instrument(skip(self), level = "debug")]
pub fn server_rooms<'a>(&'a self, server: &'a ServerName) -> impl Stream<Item = &RoomId> + Send + 'a {
pub fn server_rooms<'a>(
&'a self,
server: &'a ServerName,
) -> impl Stream<Item = &RoomId> + Send + 'a {
let prefix = (server, Interfix);
self.db
.serverroomids
@ -393,7 +427,9 @@ impl Service {
/// List the rooms common between two users
pub fn get_shared_rooms<'a>(
&'a self, user_a: &'a UserId, user_b: &'a UserId,
&'a self,
user_a: &'a UserId,
user_b: &'a UserId,
) -> impl Stream<Item = &RoomId> + Send + 'a {
use conduwuit::utils::set;
@ -404,7 +440,10 @@ impl Service {
/// Returns an iterator of all joined members of a room.
#[tracing::instrument(skip(self), level = "debug")]
pub fn room_members<'a>(&'a self, room_id: &'a RoomId) -> impl Stream<Item = &UserId> + Send + 'a {
pub fn room_members<'a>(
&'a self,
room_id: &'a RoomId,
) -> impl Stream<Item = &UserId> + Send + 'a {
let prefix = (room_id, Interfix);
self.db
.roomuserid_joined
@ -422,7 +461,10 @@ impl Service {
#[tracing::instrument(skip(self), level = "debug")]
/// Returns an iterator of all our local users in the room, even if they're
/// deactivated/guests
pub fn local_users_in_room<'a>(&'a self, room_id: &'a RoomId) -> impl Stream<Item = &UserId> + Send + 'a {
pub fn local_users_in_room<'a>(
&'a self,
room_id: &'a RoomId,
) -> impl Stream<Item = &UserId> + Send + 'a {
self.room_members(room_id)
.ready_filter(|user| self.services.globals.user_is_local(user))
}
@ -430,7 +472,10 @@ impl Service {
#[tracing::instrument(skip(self), level = "debug")]
/// Returns an iterator of all our local joined users in a room who are
/// active (not deactivated, not guest)
pub fn active_local_users_in_room<'a>(&'a self, room_id: &'a RoomId) -> impl Stream<Item = &UserId> + Send + 'a {
pub fn active_local_users_in_room<'a>(
&'a self,
room_id: &'a RoomId,
) -> impl Stream<Item = &UserId> + Send + 'a {
self.local_users_in_room(room_id)
.filter(|user| self.services.users.is_active(user))
}
@ -447,7 +492,10 @@ impl Service {
/// Returns an iterator over all User IDs who ever joined a room.
#[tracing::instrument(skip(self), level = "debug")]
pub fn room_useroncejoined<'a>(&'a self, room_id: &'a RoomId) -> impl Stream<Item = &UserId> + Send + 'a {
pub fn room_useroncejoined<'a>(
&'a self,
room_id: &'a RoomId,
) -> impl Stream<Item = &UserId> + Send + 'a {
let prefix = (room_id, Interfix);
self.db
.roomuseroncejoinedids
@ -458,7 +506,10 @@ impl Service {
/// Returns an iterator over all invited members of a room.
#[tracing::instrument(skip(self), level = "debug")]
pub fn room_members_invited<'a>(&'a self, room_id: &'a RoomId) -> impl Stream<Item = &UserId> + Send + 'a {
pub fn room_members_invited<'a>(
&'a self,
room_id: &'a RoomId,
) -> impl Stream<Item = &UserId> + Send + 'a {
let prefix = (room_id, Interfix);
self.db
.roomuserid_invitecount
@ -485,7 +536,10 @@ impl Service {
/// Returns an iterator over all rooms this user joined.
#[tracing::instrument(skip(self), level = "debug")]
pub fn rooms_joined<'a>(&'a self, user_id: &'a UserId) -> impl Stream<Item = &RoomId> + Send + 'a {
pub fn rooms_joined<'a>(
&'a self,
user_id: &'a UserId,
) -> impl Stream<Item = &RoomId> + Send + 'a {
self.db
.userroomid_joined
.keys_raw_prefix(user_id)
@ -495,7 +549,10 @@ impl Service {
/// Returns an iterator over all rooms a user was invited to.
#[tracing::instrument(skip(self), level = "debug")]
pub fn rooms_invited<'a>(&'a self, user_id: &'a UserId) -> impl Stream<Item = StrippedStateEventItem> + Send + 'a {
pub fn rooms_invited<'a>(
&'a self,
user_id: &'a UserId,
) -> impl Stream<Item = StrippedStateEventItem> + Send + 'a {
type KeyVal<'a> = (Key<'a>, Raw<Vec<AnyStrippedStateEvent>>);
type Key<'a> = (&'a UserId, &'a RoomId);
@ -510,30 +567,45 @@ impl Service {
}
#[tracing::instrument(skip(self), level = "debug")]
pub async fn invite_state(&self, user_id: &UserId, room_id: &RoomId) -> Result<Vec<Raw<AnyStrippedStateEvent>>> {
pub async fn invite_state(
&self,
user_id: &UserId,
room_id: &RoomId,
) -> Result<Vec<Raw<AnyStrippedStateEvent>>> {
let key = (user_id, room_id);
self.db
.userroomid_invitestate
.qry(&key)
.await
.deserialized()
.and_then(|val: Raw<Vec<AnyStrippedStateEvent>>| val.deserialize_as().map_err(Into::into))
.and_then(|val: Raw<Vec<AnyStrippedStateEvent>>| {
val.deserialize_as().map_err(Into::into)
})
}
#[tracing::instrument(skip(self), level = "debug")]
pub async fn left_state(&self, user_id: &UserId, room_id: &RoomId) -> Result<Vec<Raw<AnyStrippedStateEvent>>> {
pub async fn left_state(
&self,
user_id: &UserId,
room_id: &RoomId,
) -> Result<Vec<Raw<AnyStrippedStateEvent>>> {
let key = (user_id, room_id);
self.db
.userroomid_leftstate
.qry(&key)
.await
.deserialized()
.and_then(|val: Raw<Vec<AnyStrippedStateEvent>>| val.deserialize_as().map_err(Into::into))
.and_then(|val: Raw<Vec<AnyStrippedStateEvent>>| {
val.deserialize_as().map_err(Into::into)
})
}
/// Returns an iterator over all rooms a user left.
#[tracing::instrument(skip(self), level = "debug")]
pub fn rooms_left<'a>(&'a self, user_id: &'a UserId) -> impl Stream<Item = SyncStateEventItem> + Send + 'a {
pub fn rooms_left<'a>(
&'a self,
user_id: &'a UserId,
) -> impl Stream<Item = SyncStateEventItem> + Send + 'a {
type KeyVal<'a> = (Key<'a>, Raw<Vec<Raw<AnySyncStateEvent>>>);
type Key<'a> = (&'a UserId, &'a RoomId);
@ -571,7 +643,11 @@ impl Service {
self.db.userroomid_leftstate.qry(&key).await.is_ok()
}
pub async fn user_membership(&self, user_id: &UserId, room_id: &RoomId) -> Option<MembershipState> {
pub async fn user_membership(
&self,
user_id: &UserId,
room_id: &RoomId,
) -> Option<MembershipState> {
let states = join4(
self.is_joined(user_id, room_id),
self.is_left(user_id, room_id),
@ -581,16 +657,19 @@ impl Service {
.await;
match states {
(true, ..) => Some(MembershipState::Join),
(_, true, ..) => Some(MembershipState::Leave),
(_, _, true, ..) => Some(MembershipState::Invite),
(false, false, false, true) => Some(MembershipState::Ban),
_ => None,
| (true, ..) => Some(MembershipState::Join),
| (_, true, ..) => Some(MembershipState::Leave),
| (_, _, true, ..) => Some(MembershipState::Invite),
| (false, false, false, true) => Some(MembershipState::Ban),
| _ => None,
}
}
#[tracing::instrument(skip(self), level = "debug")]
pub fn servers_invite_via<'a>(&'a self, room_id: &'a RoomId) -> impl Stream<Item = &ServerName> + Send + 'a {
pub fn servers_invite_via<'a>(
&'a self,
room_id: &'a RoomId,
) -> impl Stream<Item = &ServerName> + Send + 'a {
type KeyVal<'a> = (Ignore, Vec<&'a ServerName>);
self.db
@ -711,7 +790,10 @@ impl Service {
}
pub async fn mark_as_invited(
&self, user_id: &UserId, room_id: &RoomId, last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
&self,
user_id: &UserId,
room_id: &RoomId,
last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
invite_via: Option<Vec<OwnedServerName>>,
) {
let roomuser_id = (room_id, user_id);

View file

@ -69,7 +69,8 @@ pub(crate) type CompressedStateEvent = [u8; 2 * size_of::<ShortId>()];
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
let config = &args.server.config;
let cache_capacity = f64::from(config.stateinfo_cache_capacity) * config.cache_capacity_modifier;
let cache_capacity =
f64::from(config.stateinfo_cache_capacity) * config.cache_capacity_modifier;
Ok(Arc::new(Self {
stateinfo_cache: LruCache::new(usize_from_f64(cache_capacity)?).into(),
db: Data {
@ -85,17 +86,16 @@ impl crate::Service for Service {
fn memory_usage(&self, out: &mut dyn Write) -> Result {
let (cache_len, ents) = {
let cache = self.stateinfo_cache.lock().expect("locked");
let ents = cache
.iter()
.map(at!(1))
.flat_map(|vec| vec.iter())
.fold(HashMap::new(), |mut ents, ssi| {
let ents = cache.iter().map(at!(1)).flat_map(|vec| vec.iter()).fold(
HashMap::new(),
|mut ents, ssi| {
for cs in &[&ssi.added, &ssi.removed, &ssi.full_state] {
ents.insert(Arc::as_ptr(cs), compressed_state_size(cs));
}
ents
});
},
);
(cache.len(), ents)
};
@ -117,7 +117,10 @@ impl crate::Service for Service {
impl Service {
/// Returns a stack with info on shortstatehash, full state, added diff and
/// removed diff for the selected shortstatehash and each parent layer.
pub async fn load_shortstatehash_info(&self, shortstatehash: ShortStateHash) -> Result<ShortStateInfoVec> {
pub async fn load_shortstatehash_info(
&self,
shortstatehash: ShortStateHash,
) -> Result<ShortStateInfoVec> {
if let Some(r) = self
.stateinfo_cache
.lock()
@ -143,12 +146,11 @@ impl Service {
Ok(stack)
}
async fn new_shortstatehash_info(&self, shortstatehash: ShortStateHash) -> Result<ShortStateInfoVec> {
let StateDiff {
parent,
added,
removed,
} = self.get_statediff(shortstatehash).await?;
async fn new_shortstatehash_info(
&self,
shortstatehash: ShortStateHash,
) -> Result<ShortStateInfoVec> {
let StateDiff { parent, added, removed } = self.get_statediff(shortstatehash).await?;
let Some(parent) = parent else {
return Ok(vec![ShortStateInfo {
@ -180,9 +182,17 @@ impl Service {
Ok(stack)
}
pub fn compress_state_events<'a, I>(&'a self, state: I) -> impl Stream<Item = CompressedStateEvent> + Send + 'a
pub fn compress_state_events<'a, I>(
&'a self,
state: I,
) -> impl Stream<Item = CompressedStateEvent> + Send + 'a
where
I: Iterator<Item = (&'a ShortStateKey, &'a EventId)> + Clone + Debug + ExactSizeIterator + Send + 'a,
I: Iterator<Item = (&'a ShortStateKey, &'a EventId)>
+ Clone
+ Debug
+ ExactSizeIterator
+ Send
+ 'a,
{
let event_ids = state.clone().map(at!(1));
@ -195,10 +205,16 @@ impl Service {
.stream()
.map(at!(0))
.zip(short_event_ids)
.map(|(shortstatekey, shorteventid)| compress_state_event(*shortstatekey, shorteventid))
.map(|(shortstatekey, shorteventid)| {
compress_state_event(*shortstatekey, shorteventid)
})
}
pub async fn compress_state_event(&self, shortstatekey: ShortStateKey, event_id: &EventId) -> CompressedStateEvent {
pub async fn compress_state_event(
&self,
shortstatekey: ShortStateKey,
event_id: &EventId,
) -> CompressedStateEvent {
let shorteventid = self
.services
.short
@ -227,8 +243,11 @@ impl Service {
/// * `parent_states` - A stack with info on shortstatehash, full state,
/// added diff and removed diff for each parent layer
pub fn save_state_from_diff(
&self, shortstatehash: ShortStateHash, statediffnew: Arc<HashSet<CompressedStateEvent>>,
statediffremoved: Arc<HashSet<CompressedStateEvent>>, diff_to_sibling: usize,
&self,
shortstatehash: ShortStateHash,
statediffnew: Arc<HashSet<CompressedStateEvent>>,
statediffremoved: Arc<HashSet<CompressedStateEvent>>,
diff_to_sibling: usize,
mut parent_states: ParentStatesVec,
) -> Result {
let statediffnew_len = statediffnew.len();
@ -274,14 +293,11 @@ impl Service {
if parent_states.is_empty() {
// There is no parent layer, create a new state
self.save_statediff(
shortstatehash,
&StateDiff {
parent: None,
added: statediffnew,
removed: statediffremoved,
},
);
self.save_statediff(shortstatehash, &StateDiff {
parent: None,
added: statediffnew,
removed: statediffremoved,
});
return Ok(());
};
@ -327,14 +343,11 @@ impl Service {
)?;
} else {
// Diff small enough, we add diff as layer on top of parent
self.save_statediff(
shortstatehash,
&StateDiff {
parent: Some(parent.shortstatehash),
added: statediffnew,
removed: statediffremoved,
},
);
self.save_statediff(shortstatehash, &StateDiff {
parent: Some(parent.shortstatehash),
added: statediffnew,
removed: statediffremoved,
});
}
Ok(())
@ -344,7 +357,9 @@ impl Service {
/// room state
#[tracing::instrument(skip(self, new_state_ids_compressed), level = "debug")]
pub async fn save_state(
&self, room_id: &RoomId, new_state_ids_compressed: Arc<HashSet<CompressedStateEvent>>,
&self,
room_id: &RoomId,
new_state_ids_compressed: Arc<HashSet<CompressedStateEvent>>,
) -> Result<HashSetCompressStateEvent> {
let previous_shortstatehash = self
.services
@ -353,7 +368,8 @@ impl Service {
.await
.ok();
let state_hash = utils::calculate_hash(new_state_ids_compressed.iter().map(|bytes| &bytes[..]));
let state_hash =
utils::calculate_hash(new_state_ids_compressed.iter().map(|bytes| &bytes[..]));
let (new_shortstatehash, already_existed) = self
.services
@ -374,22 +390,23 @@ impl Service {
ShortStateInfoVec::new()
};
let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() {
let statediffnew: HashSet<_> = new_state_ids_compressed
.difference(&parent_stateinfo.full_state)
.copied()
.collect();
let (statediffnew, statediffremoved) =
if let Some(parent_stateinfo) = states_parents.last() {
let statediffnew: HashSet<_> = new_state_ids_compressed
.difference(&parent_stateinfo.full_state)
.copied()
.collect();
let statediffremoved: HashSet<_> = parent_stateinfo
.full_state
.difference(&new_state_ids_compressed)
.copied()
.collect();
let statediffremoved: HashSet<_> = parent_stateinfo
.full_state
.difference(&new_state_ids_compressed)
.copied()
.collect();
(Arc::new(statediffnew), Arc::new(statediffremoved))
} else {
(new_state_ids_compressed, Arc::new(HashSet::new()))
};
(Arc::new(statediffnew), Arc::new(statediffremoved))
} else {
(new_state_ids_compressed, Arc::new(HashSet::new()))
};
if !already_existed {
self.save_state_from_diff(
@ -418,7 +435,9 @@ impl Service {
.shortstatehash_statediff
.aqry::<BUFSIZE, _>(&shortstatehash)
.await
.map_err(|e| err!(Database("Failed to find StateDiff from short {shortstatehash:?}: {e}")))?;
.map_err(|e| {
err!(Database("Failed to find StateDiff from short {shortstatehash:?}: {e}"))
})?;
let parent = utils::u64_from_bytes(&value[0..size_of::<u64>()])
.ok()
@ -484,7 +503,10 @@ impl Service {
#[inline]
#[must_use]
fn compress_state_event(shortstatekey: ShortStateKey, shorteventid: ShortEventId) -> CompressedStateEvent {
fn compress_state_event(
shortstatekey: ShortStateKey,
shorteventid: ShortEventId,
) -> CompressedStateEvent {
const SIZE: usize = size_of::<CompressedStateEvent>();
let mut v = ArrayVec::<u8, SIZE>::new();
@ -497,7 +519,9 @@ fn compress_state_event(shortstatekey: ShortStateKey, shorteventid: ShortEventId
#[inline]
#[must_use]
pub fn parse_compressed_state_event(compressed_event: CompressedStateEvent) -> (ShortStateKey, ShortEventId) {
pub fn parse_compressed_state_event(
compressed_event: CompressedStateEvent,
) -> (ShortStateKey, ShortEventId) {
use utils::u64_from_u8;
let shortstatekey = u64_from_u8(&compressed_event[0..size_of::<ShortStateKey>()]);

View file

@ -11,8 +11,8 @@ use conduwuit::{
use database::{Deserialized, Map};
use futures::{Stream, StreamExt};
use ruma::{
api::client::threads::get_threads::v1::IncludeThreads, events::relation::BundledThread, uint, CanonicalJsonValue,
EventId, OwnedUserId, RoomId, UserId,
api::client::threads::get_threads::v1::IncludeThreads, events::relation::BundledThread, uint,
CanonicalJsonValue, EventId, OwnedUserId, RoomId, UserId,
};
use serde_json::json;
@ -55,7 +55,9 @@ impl Service {
.timeline
.get_pdu_id(root_event_id)
.await
.map_err(|e| err!(Request(InvalidParam("Invalid event_id in thread message: {e:?}"))))?;
.map_err(|e| {
err!(Request(InvalidParam("Invalid event_id in thread message: {e:?}")))
})?;
let root_pdu = self
.services
@ -79,8 +81,9 @@ impl Service {
.get("m.relations")
.and_then(|r| r.as_object())
.and_then(|r| r.get("m.thread"))
.and_then(|relations| serde_json::from_value::<BundledThread>(relations.clone().into()).ok())
{
.and_then(|relations| {
serde_json::from_value::<BundledThread>(relations.clone().into()).ok()
}) {
// Thread already existed
relations.count = relations.count.saturating_add(uint!(1));
relations.latest_event = pdu.to_message_like_event();
@ -129,7 +132,11 @@ impl Service {
}
pub async fn threads_until<'a>(
&'a self, user_id: &'a UserId, room_id: &'a RoomId, shorteventid: PduCount, _inc: &'a IncludeThreads,
&'a self,
user_id: &'a UserId,
room_id: &'a RoomId,
shorteventid: PduCount,
_inc: &'a IncludeThreads,
) -> Result<impl Stream<Item = (PduCount, PduEvent)> + Send + 'a> {
let shortroomid: ShortRoomId = self.services.short.get_shortroomid(room_id).await?;
@ -160,7 +167,11 @@ impl Service {
Ok(stream)
}
pub(super) fn update_participants(&self, root_id: &RawPduId, participants: &[OwnedUserId]) -> Result {
pub(super) fn update_participants(
&self,
root_id: &RawPduId,
participants: &[OwnedUserId],
) -> Result {
let users = participants
.iter()
.map(|user| user.as_bytes())

View file

@ -13,7 +13,9 @@ use conduwuit::{
};
use database::{Database, Deserialized, Json, KeyVal, Map};
use futures::{future::select_ok, FutureExt, Stream, StreamExt};
use ruma::{api::Direction, CanonicalJsonObject, EventId, OwnedRoomId, OwnedUserId, RoomId, UserId};
use ruma::{
api::Direction, CanonicalJsonObject, EventId, OwnedRoomId, OwnedUserId, RoomId, UserId,
};
use tokio::sync::Mutex;
use super::{PduId, RawPduId};
@ -54,15 +56,19 @@ impl Data {
}
}
pub(super) async fn last_timeline_count(&self, sender_user: Option<&UserId>, room_id: &RoomId) -> Result<PduCount> {
pub(super) async fn last_timeline_count(
&self,
sender_user: Option<&UserId>,
room_id: &RoomId,
) -> Result<PduCount> {
match self
.lasttimelinecount_cache
.lock()
.await
.entry(room_id.into())
{
hash_map::Entry::Occupied(o) => Ok(*o.get()),
hash_map::Entry::Vacant(v) => Ok(self
| hash_map::Entry::Occupied(o) => Ok(*o.get()),
| hash_map::Entry::Vacant(v) => Ok(self
.pdus_rev(sender_user, room_id, PduCount::max())
.await?
.next()
@ -93,7 +99,10 @@ impl Data {
}
/// Returns the json of a pdu.
pub(super) async fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result<CanonicalJsonObject> {
pub(super) async fn get_non_outlier_pdu_json(
&self,
event_id: &EventId,
) -> Result<CanonicalJsonObject> {
let pduid = self.get_pdu_id(event_id).await?;
self.pduid_pdu.get(&pduid).await.deserialized()
@ -160,12 +169,19 @@ impl Data {
}
/// Returns the pdu as a `BTreeMap<String, CanonicalJsonValue>`.
pub(super) async fn get_pdu_json_from_id(&self, pdu_id: &RawPduId) -> Result<CanonicalJsonObject> {
pub(super) async fn get_pdu_json_from_id(
&self,
pdu_id: &RawPduId,
) -> Result<CanonicalJsonObject> {
self.pduid_pdu.get(pdu_id).await.deserialized()
}
pub(super) async fn append_pdu(
&self, pdu_id: &RawPduId, pdu: &PduEvent, json: &CanonicalJsonObject, count: PduCount,
&self,
pdu_id: &RawPduId,
pdu: &PduEvent,
json: &CanonicalJsonObject,
count: PduCount,
) {
debug_assert!(matches!(count, PduCount::Normal(_)), "PduCount not Normal");
@ -179,7 +195,12 @@ impl Data {
self.eventid_outlierpdu.remove(pdu.event_id.as_bytes());
}
pub(super) fn prepend_backfill_pdu(&self, pdu_id: &RawPduId, event_id: &EventId, json: &CanonicalJsonObject) {
pub(super) fn prepend_backfill_pdu(
&self,
pdu_id: &RawPduId,
event_id: &EventId,
json: &CanonicalJsonObject,
) {
self.pduid_pdu.raw_put(pdu_id, Json(json));
self.eventid_pduid.insert(event_id, pdu_id);
self.eventid_outlierpdu.remove(event_id);
@ -187,7 +208,10 @@ impl Data {
/// Removes a pdu and creates a new one with the same id.
pub(super) async fn replace_pdu(
&self, pdu_id: &RawPduId, pdu_json: &CanonicalJsonObject, _pdu: &PduEvent,
&self,
pdu_id: &RawPduId,
pdu_json: &CanonicalJsonObject,
_pdu: &PduEvent,
) -> Result {
if self.pduid_pdu.get(pdu_id).await.is_not_found() {
return Err!(Request(NotFound("PDU does not exist.")));
@ -202,7 +226,10 @@ impl Data {
/// happened before the event with id `until` in reverse-chronological
/// order.
pub(super) async fn pdus_rev<'a>(
&'a self, user_id: Option<&'a UserId>, room_id: &'a RoomId, until: PduCount,
&'a self,
user_id: Option<&'a UserId>,
room_id: &'a RoomId,
until: PduCount,
) -> Result<impl Stream<Item = PdusIterItem> + Send + 'a> {
let current = self
.count_to_id(room_id, until, Direction::Backward)
@ -219,7 +246,10 @@ impl Data {
}
pub(super) async fn pdus<'a>(
&'a self, user_id: Option<&'a UserId>, room_id: &'a RoomId, from: PduCount,
&'a self,
user_id: Option<&'a UserId>,
room_id: &'a RoomId,
from: PduCount,
) -> Result<impl Stream<Item = PdusIterItem> + Send + Unpin + 'a> {
let current = self.count_to_id(room_id, from, Direction::Forward).await?;
let prefix = current.shortroomid();
@ -236,8 +266,8 @@ impl Data {
fn each_pdu((pdu_id, pdu): KeyVal<'_>, user_id: Option<&UserId>) -> PdusIterItem {
let pdu_id: RawPduId = pdu_id.into();
let mut pdu =
serde_json::from_slice::<PduEvent>(pdu).expect("PduEvent in pduid_pdu database column is invalid JSON");
let mut pdu = serde_json::from_slice::<PduEvent>(pdu)
.expect("PduEvent in pduid_pdu database column is invalid JSON");
if Some(pdu.sender.borrow()) != user_id {
pdu.remove_transaction_id().log_err().ok();
@ -249,7 +279,10 @@ impl Data {
}
pub(super) fn increment_notification_counts(
&self, room_id: &RoomId, notifies: Vec<OwnedUserId>, highlights: Vec<OwnedUserId>,
&self,
room_id: &RoomId,
notifies: Vec<OwnedUserId>,
highlights: Vec<OwnedUserId>,
) {
let _cork = self.db.cork();
@ -268,7 +301,12 @@ impl Data {
}
}
async fn count_to_id(&self, room_id: &RoomId, shorteventid: PduCount, dir: Direction) -> Result<RawPduId> {
async fn count_to_id(
&self,
room_id: &RoomId,
shorteventid: PduCount,
dir: Direction,
) -> Result<RawPduId> {
let shortroomid: ShortRoomId = self
.services
.short

View file

@ -15,7 +15,9 @@ use conduwuit::{
validated, warn, Err, Error, Result, Server,
};
pub use conduwuit::{PduId, RawPduId};
use futures::{future, future::ready, Future, FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt};
use futures::{
future, future::ready, Future, FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt,
};
use ruma::{
api::federation,
canonical_json::to_canonical_value,
@ -32,8 +34,8 @@ use ruma::{
},
push::{Action, Ruleset, Tweak},
state_res::{self, Event, RoomVersion},
uint, user_id, CanonicalJsonObject, CanonicalJsonValue, EventId, OwnedEventId, OwnedRoomId, OwnedServerName,
RoomId, RoomVersionId, ServerName, UserId,
uint, user_id, CanonicalJsonObject, CanonicalJsonValue, EventId, OwnedEventId, OwnedRoomId,
OwnedServerName, RoomId, RoomVersionId, ServerName, UserId,
};
use serde::Deserialize;
use serde_json::value::{to_raw_value, RawValue as RawJsonValue};
@ -116,7 +118,8 @@ impl crate::Service for Service {
short: args.depend::<rooms::short::Service>("rooms::short"),
state: args.depend::<rooms::state::Service>("rooms::state"),
state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"),
state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
state_accessor: args
.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
pdu_metadata: args.depend::<rooms::pdu_metadata::Service>("rooms::pdu_metadata"),
read_receipt: args.depend::<rooms::read_receipt::Service>("rooms::read_receipt"),
sending: args.depend::<sending::Service>("sending"),
@ -127,7 +130,8 @@ impl crate::Service for Service {
threads: args.depend::<rooms::threads::Service>("rooms::threads"),
search: args.depend::<rooms::search::Service>("rooms::search"),
spaces: args.depend::<rooms::spaces::Service>("rooms::spaces"),
event_handler: args.depend::<rooms::event_handler::Service>("rooms::event_handler"),
event_handler: args
.depend::<rooms::event_handler::Service>("rooms::event_handler"),
},
db: Data::new(&args),
mutex_insert: RoomMutexMap::new(),
@ -185,12 +189,18 @@ impl Service {
}
#[tracing::instrument(skip(self), level = "debug")]
pub async fn last_timeline_count(&self, sender_user: Option<&UserId>, room_id: &RoomId) -> Result<PduCount> {
pub async fn last_timeline_count(
&self,
sender_user: Option<&UserId>,
room_id: &RoomId,
) -> Result<PduCount> {
self.db.last_timeline_count(sender_user, room_id).await
}
/// Returns the `count` of this pdu's id.
pub async fn get_pdu_count(&self, event_id: &EventId) -> Result<PduCount> { self.db.get_pdu_count(event_id).await }
pub async fn get_pdu_count(&self, event_id: &EventId) -> Result<PduCount> {
self.db.get_pdu_count(event_id).await
}
// TODO Is this the same as the function above?
/*
@ -222,13 +232,18 @@ impl Service {
/// Returns the json of a pdu.
#[inline]
pub async fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result<CanonicalJsonObject> {
pub async fn get_non_outlier_pdu_json(
&self,
event_id: &EventId,
) -> Result<CanonicalJsonObject> {
self.db.get_non_outlier_pdu_json(event_id).await
}
/// Returns the pdu's id.
#[inline]
pub async fn get_pdu_id(&self, event_id: &EventId) -> Result<RawPduId> { self.db.get_pdu_id(event_id).await }
pub async fn get_pdu_id(&self, event_id: &EventId) -> Result<RawPduId> {
self.db.get_pdu_id(event_id).await
}
/// Returns the pdu.
///
@ -241,19 +256,26 @@ impl Service {
/// Returns the pdu.
///
/// Checks the `eventid_outlierpdu` Tree if not found in the timeline.
pub async fn get_pdu(&self, event_id: &EventId) -> Result<PduEvent> { self.db.get_pdu(event_id).await }
pub async fn get_pdu(&self, event_id: &EventId) -> Result<PduEvent> {
self.db.get_pdu(event_id).await
}
/// Checks if pdu exists
///
/// Checks the `eventid_outlierpdu` Tree if not found in the timeline.
pub fn pdu_exists<'a>(&'a self, event_id: &'a EventId) -> impl Future<Output = bool> + Send + 'a {
pub fn pdu_exists<'a>(
&'a self,
event_id: &'a EventId,
) -> impl Future<Output = bool> + Send + 'a {
self.db.pdu_exists(event_id)
}
/// Returns the pdu.
///
/// This does __NOT__ check the outliers `Tree`.
pub async fn get_pdu_from_id(&self, pdu_id: &RawPduId) -> Result<PduEvent> { self.db.get_pdu_from_id(pdu_id).await }
pub async fn get_pdu_from_id(&self, pdu_id: &RawPduId) -> Result<PduEvent> {
self.db.get_pdu_from_id(pdu_id).await
}
/// Returns the pdu as a `BTreeMap<String, CanonicalJsonValue>`.
pub async fn get_pdu_json_from_id(&self, pdu_id: &RawPduId) -> Result<CanonicalJsonObject> {
@ -262,7 +284,12 @@ impl Service {
/// Removes a pdu and creates a new one with the same id.
#[tracing::instrument(skip(self), level = "debug")]
pub async fn replace_pdu(&self, pdu_id: &RawPduId, pdu_json: &CanonicalJsonObject, pdu: &PduEvent) -> Result<()> {
pub async fn replace_pdu(
&self,
pdu_id: &RawPduId,
pdu_json: &CanonicalJsonObject,
pdu: &PduEvent,
) -> Result<()> {
self.db.replace_pdu(pdu_id, pdu_json, pdu).await
}
@ -278,7 +305,8 @@ impl Service {
pdu: &PduEvent,
mut pdu_json: CanonicalJsonObject,
leaves: Vec<OwnedEventId>,
state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex
state_lock: &RoomMutexGuard, /* Take mutex guard to make sure users get the room state
* mutex */
) -> Result<RawPduId> {
// Coalesce database writes for the remainder of this scope.
let _cork = self.db.db.cork_and_flush();
@ -313,10 +341,16 @@ impl Service {
unsigned.insert(
"prev_content".to_owned(),
CanonicalJsonValue::Object(
utils::to_canonical_object(prev_state.content.clone()).map_err(|e| {
error!("Failed to convert prev_state to canonical JSON: {e}");
Error::bad_database("Failed to convert prev_state to canonical JSON.")
})?,
utils::to_canonical_object(prev_state.content.clone()).map_err(
|e| {
error!(
"Failed to convert prev_state to canonical JSON: {e}"
);
Error::bad_database(
"Failed to convert prev_state to canonical JSON.",
)
},
)?,
),
);
unsigned.insert(
@ -357,11 +391,7 @@ impl Service {
.reset_notification_counts(&pdu.sender, &pdu.room_id);
let count2 = PduCount::Normal(self.services.globals.next_count().unwrap());
let pdu_id: RawPduId = PduId {
shortroomid,
shorteventid: count2,
}
.into();
let pdu_id: RawPduId = PduId { shortroomid, shorteventid: count2 }.into();
// Insert pdu
self.db.append_pdu(&pdu_id, pdu, &pdu_json, count2).await;
@ -408,7 +438,10 @@ impl Service {
.account_data
.get_global(user, GlobalAccountDataEventType::PushRules)
.await
.map_or_else(|_| Ruleset::server_default(user), |ev: PushRulesEvent| ev.content.global);
.map_or_else(
|_| Ruleset::server_default(user),
|ev: PushRulesEvent| ev.content.global,
);
let mut highlight = false;
let mut notify = false;
@ -420,11 +453,11 @@ impl Service {
.await
{
match action {
Action::Notify => notify = true,
Action::SetTweak(Tweak::Highlight(true)) => {
| Action::Notify => notify = true,
| Action::SetTweak(Tweak::Highlight(true)) => {
highlight = true;
},
_ => {},
| _ => {},
};
// Break early if both conditions are true
@ -457,12 +490,12 @@ impl Service {
.increment_notification_counts(&pdu.room_id, notifies, highlights);
match pdu.kind {
TimelineEventType::RoomRedaction => {
| TimelineEventType::RoomRedaction => {
use RoomVersionId::*;
let room_version_id = self.services.state.get_room_version(&pdu.room_id).await?;
match room_version_id {
V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => {
| V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => {
if let Some(redact_id) = &pdu.redacts {
if self
.services
@ -474,7 +507,7 @@ impl Service {
}
}
},
_ => {
| _ => {
let content: RoomRedactionEventContent = pdu.get_content()?;
if let Some(redact_id) = &content.redacts {
if self
@ -489,7 +522,7 @@ impl Service {
},
};
},
TimelineEventType::SpaceChild => {
| TimelineEventType::SpaceChild =>
if let Some(_state_key) = &pdu.state_key {
self.services
.spaces
@ -497,18 +530,18 @@ impl Service {
.lock()
.await
.remove(&pdu.room_id);
}
},
TimelineEventType::RoomMember => {
},
| TimelineEventType::RoomMember => {
if let Some(state_key) = &pdu.state_key {
// if the state_key fails
let target_user_id =
UserId::parse(state_key.clone()).expect("This state_key was previously validated");
let target_user_id = UserId::parse(state_key.clone())
.expect("This state_key was previously validated");
let content: RoomMemberEventContent = pdu.get_content()?;
let invite_state = match content.membership {
MembershipState::Invite => self.services.state.summary_stripped(pdu).await.into(),
_ => None,
| MembershipState::Invite =>
self.services.state.summary_stripped(pdu).await.into(),
| _ => None,
};
// Update our membership info, we do this here incase a user is invited
@ -527,7 +560,7 @@ impl Service {
.await?;
}
},
TimelineEventType::RoomMessage => {
| TimelineEventType::RoomMessage => {
let content: ExtractBody = pdu.get_content()?;
if let Some(body) = content.body {
self.services.search.index_pdu(shortroomid, &pdu_id, &body);
@ -539,7 +572,7 @@ impl Service {
}
}
},
_ => {},
| _ => {},
}
if let Ok(content) = pdu.get_content::<ExtractRelatesToEventId>() {
@ -552,24 +585,23 @@ impl Service {
if let Ok(content) = pdu.get_content::<ExtractRelatesTo>() {
match content.relates_to {
Relation::Reply {
in_reply_to,
} => {
| Relation::Reply { in_reply_to } => {
// We need to do it again here, because replies don't have
// event_id as a top level field
if let Ok(related_pducount) = self.get_pdu_count(&in_reply_to.event_id).await {
if let Ok(related_pducount) = self.get_pdu_count(&in_reply_to.event_id).await
{
self.services
.pdu_metadata
.add_relation(count2, related_pducount);
}
},
Relation::Thread(thread) => {
| Relation::Thread(thread) => {
self.services
.threads
.add_to_thread(&thread.event_id, pdu)
.await?;
},
_ => {}, // TODO: Aggregate other types
| _ => {}, // TODO: Aggregate other types
}
}
@ -637,7 +669,8 @@ impl Service {
pdu_builder: PduBuilder,
sender: &UserId,
room_id: &RoomId,
_mutex_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex
_mutex_lock: &RoomMutexGuard, /* Take mutex guard to make sure users get the room
* state mutex */
) -> Result<(PduEvent, CanonicalJsonObject)> {
let PduBuilder {
event_type,
@ -707,7 +740,8 @@ impl Service {
unsigned.insert("prev_content".to_owned(), prev_pdu.get_content_as_value());
unsigned.insert(
"prev_sender".to_owned(),
serde_json::to_value(&prev_pdu.sender).expect("UserId::to_value always works"),
serde_json::to_value(&prev_pdu.sender)
.expect("UserId::to_value always works"),
);
unsigned.insert(
"replaces_state".to_owned(),
@ -744,9 +778,7 @@ impl Service {
} else {
Some(to_raw_value(&unsigned).expect("to_raw_value always works"))
},
hashes: EventHash {
sha256: "aaa".to_owned(),
},
hashes: EventHash { sha256: "aaa".to_owned() },
signatures: None,
};
@ -769,13 +801,14 @@ impl Service {
}
// Hash and sign
let mut pdu_json = utils::to_canonical_object(&pdu)
.map_err(|e| err!(Request(BadJson(warn!("Failed to convert PDU to canonical JSON: {e}")))))?;
let mut pdu_json = utils::to_canonical_object(&pdu).map_err(|e| {
err!(Request(BadJson(warn!("Failed to convert PDU to canonical JSON: {e}"))))
})?;
// room v3 and above removed the "event_id" field from remote PDU format
match room_version_id {
RoomVersionId::V1 | RoomVersionId::V2 => {},
_ => {
| RoomVersionId::V1 | RoomVersionId::V2 => {},
| _ => {
pdu_json.remove("event_id");
},
};
@ -783,7 +816,8 @@ impl Service {
// Add origin because synapse likes that (and it's required in the spec)
pdu_json.insert(
"origin".to_owned(),
to_canonical_value(self.services.globals.server_name()).expect("server name is a valid CanonicalJsonValue"),
to_canonical_value(self.services.globals.server_name())
.expect("server name is a valid CanonicalJsonValue"),
);
if let Err(e) = self
@ -792,17 +826,18 @@ impl Service {
.hash_and_sign_event(&mut pdu_json, &room_version_id)
{
return match e {
Error::Signatures(ruma::signatures::Error::PduSize) => {
| Error::Signatures(ruma::signatures::Error::PduSize) => {
Err!(Request(TooLarge("Message/PDU is too long (exceeds 65535 bytes)")))
},
_ => Err!(Request(Unknown(warn!("Signing event failed: {e}")))),
| _ => Err!(Request(Unknown(warn!("Signing event failed: {e}")))),
};
}
// Generate event id
pdu.event_id = EventId::parse_arc(format!(
"${}",
ruma::signatures::reference_hash(&pdu_json, &room_version_id).expect("ruma can calculate reference hashes")
ruma::signatures::reference_hash(&pdu_json, &room_version_id)
.expect("ruma can calculate reference hashes")
))
.expect("ruma's reference hashes are valid event ids");
@ -830,7 +865,8 @@ impl Service {
pdu_builder: PduBuilder,
sender: &UserId,
room_id: &RoomId,
state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex
state_lock: &RoomMutexGuard, /* Take mutex guard to make sure users get the room state
* mutex */
) -> Result<Arc<EventId>> {
let (pdu, pdu_json) = self
.create_hash_and_sign_event(pdu_builder, sender, room_id, state_lock)
@ -844,7 +880,7 @@ impl Service {
if pdu.kind == TimelineEventType::RoomRedaction {
use RoomVersionId::*;
match self.services.state.get_room_version(&pdu.room_id).await? {
V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => {
| V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => {
if let Some(redact_id) = &pdu.redacts {
if !self
.services
@ -856,7 +892,7 @@ impl Service {
}
};
},
_ => {
| _ => {
let content: RoomRedactionEventContent = pdu.get_content()?;
if let Some(redact_id) = &content.redacts {
if !self
@ -937,7 +973,8 @@ impl Service {
new_room_leaves: Vec<OwnedEventId>,
state_ids_compressed: Arc<HashSet<CompressedStateEvent>>,
soft_fail: bool,
state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex
state_lock: &RoomMutexGuard, /* Take mutex guard to make sure users get the room state
* mutex */
) -> Result<Option<RawPduId>> {
// We append to state before appending the pdu, so we don't have a moment in
// time with the pdu without it's state. This is okay because append_pdu can't
@ -971,7 +1008,9 @@ impl Service {
/// items.
#[inline]
pub fn all_pdus<'a>(
&'a self, user_id: &'a UserId, room_id: &'a RoomId,
&'a self,
user_id: &'a UserId,
room_id: &'a RoomId,
) -> impl Stream<Item = PdusIterItem> + Send + Unpin + 'a {
self.pdus(Some(user_id), room_id, None)
.map_ok(|stream| stream.map(Ok))
@ -983,7 +1022,10 @@ impl Service {
/// Reverse iteration starting at from.
#[tracing::instrument(skip(self), level = "debug")]
pub async fn pdus_rev<'a>(
&'a self, user_id: Option<&'a UserId>, room_id: &'a RoomId, until: Option<PduCount>,
&'a self,
user_id: Option<&'a UserId>,
room_id: &'a RoomId,
until: Option<PduCount>,
) -> Result<impl Stream<Item = PdusIterItem> + Send + 'a> {
self.db
.pdus_rev(user_id, room_id, until.unwrap_or_else(PduCount::max))
@ -993,7 +1035,10 @@ impl Service {
/// Forward iteration starting at from.
#[tracing::instrument(skip(self), level = "debug")]
pub async fn pdus<'a>(
&'a self, user_id: Option<&'a UserId>, room_id: &'a RoomId, from: Option<PduCount>,
&'a self,
user_id: Option<&'a UserId>,
room_id: &'a RoomId,
from: Option<PduCount>,
) -> Result<impl Stream<Item = PdusIterItem> + Send + 'a> {
self.db
.pdus(user_id, room_id, from.unwrap_or_else(PduCount::min))
@ -1002,17 +1047,21 @@ impl Service {
/// Replace a PDU with the redacted form.
#[tracing::instrument(skip(self, reason))]
pub async fn redact_pdu(&self, event_id: &EventId, reason: &PduEvent, shortroomid: ShortRoomId) -> Result {
pub async fn redact_pdu(
&self,
event_id: &EventId,
reason: &PduEvent,
shortroomid: ShortRoomId,
) -> Result {
// TODO: Don't reserialize, keep original json
let Ok(pdu_id) = self.get_pdu_id(event_id).await else {
// If event does not exist, just noop
return Ok(());
};
let mut pdu = self
.get_pdu_from_id(&pdu_id)
.await
.map_err(|e| err!(Database(error!(?pdu_id, ?event_id, ?e, "PDU ID points to invalid PDU."))))?;
let mut pdu = self.get_pdu_from_id(&pdu_id).await.map_err(|e| {
err!(Database(error!(?pdu_id, ?event_id, ?e, "PDU ID points to invalid PDU.")))
})?;
if let Ok(content) = pdu.get_content::<ExtractBody>() {
if let Some(body) = content.body {
@ -1026,8 +1075,9 @@ impl Service {
pdu.redact(&room_version_id, reason)?;
let obj = utils::to_canonical_object(&pdu)
.map_err(|e| err!(Database(error!(?event_id, ?e, "Failed to convert PDU to canonical JSON"))))?;
let obj = utils::to_canonical_object(&pdu).map_err(|e| {
err!(Database(error!(?event_id, ?e, "Failed to convert PDU to canonical JSON")))
})?;
self.replace_pdu(&pdu_id, &obj, &pdu).await
}
@ -1069,7 +1119,9 @@ impl Service {
.unwrap_or_default();
let room_mods = power_levels.users.iter().filter_map(|(user_id, level)| {
if level > &power_levels.users_default && !self.services.globals.user_is_local(user_id) {
if level > &power_levels.users_default
&& !self.services.globals.user_is_local(user_id)
{
Some(user_id.server_name())
} else {
None
@ -1124,7 +1176,7 @@ impl Service {
)
.await;
match response {
Ok(response) => {
| Ok(response) => {
for pdu in response.pdus {
if let Err(e) = self.backfill_pdu(backfill_server, pdu).boxed().await {
debug_warn!("Failed to add backfilled pdu in room {room_id}: {e}");
@ -1132,7 +1184,7 @@ impl Service {
}
return Ok(());
},
Err(e) => {
| Err(e) => {
warn!("{backfill_server} failed to provide backfill for room {room_id}: {e}");
},
}
@ -1144,7 +1196,8 @@ impl Service {
#[tracing::instrument(skip(self, pdu))]
pub async fn backfill_pdu(&self, origin: &ServerName, pdu: Box<RawJsonValue>) -> Result<()> {
let (event_id, value, room_id) = self.services.event_handler.parse_incoming_pdu(&pdu).await?;
let (event_id, value, room_id) =
self.services.event_handler.parse_incoming_pdu(&pdu).await?;
// Lock so we cannot backfill the same pdu twice at the same time
let mutex_lock = self
@ -1210,10 +1263,10 @@ impl Service {
#[tracing::instrument(skip_all, level = "debug")]
async fn check_pdu_for_admin_room(&self, pdu: &PduEvent, sender: &UserId) -> Result<()> {
match pdu.event_type() {
TimelineEventType::RoomEncryption => {
| TimelineEventType::RoomEncryption => {
return Err!(Request(Forbidden(error!("Encryption not supported in admins room."))));
},
TimelineEventType::RoomMember => {
| TimelineEventType::RoomMember => {
let target = pdu
.state_key()
.filter(|v| v.starts_with('@'))
@ -1223,9 +1276,11 @@ async fn check_pdu_for_admin_room(&self, pdu: &PduEvent, sender: &UserId) -> Res
let content: RoomMemberEventContent = pdu.get_content()?;
match content.membership {
MembershipState::Leave => {
| MembershipState::Leave => {
if target == server_user {
return Err!(Request(Forbidden(error!("Server user cannot leave the admins room."))));
return Err!(Request(Forbidden(error!(
"Server user cannot leave the admins room."
))));
}
let count = self
@ -1239,13 +1294,17 @@ async fn check_pdu_for_admin_room(&self, pdu: &PduEvent, sender: &UserId) -> Res
.await;
if count < 2 {
return Err!(Request(Forbidden(error!("Last admin cannot leave the admins room."))));
return Err!(Request(Forbidden(error!(
"Last admin cannot leave the admins room."
))));
}
},
MembershipState::Ban if pdu.state_key().is_some() => {
| MembershipState::Ban if pdu.state_key().is_some() => {
if target == server_user {
return Err!(Request(Forbidden(error!("Server cannot be banned from admins room."))));
return Err!(Request(Forbidden(error!(
"Server cannot be banned from admins room."
))));
}
let count = self
@ -1259,13 +1318,15 @@ async fn check_pdu_for_admin_room(&self, pdu: &PduEvent, sender: &UserId) -> Res
.await;
if count < 2 {
return Err!(Request(Forbidden(error!("Last admin cannot be banned from admins room."))));
return Err!(Request(Forbidden(error!(
"Last admin cannot be banned from admins room."
))));
}
},
_ => {},
| _ => {},
};
},
_ => {},
| _ => {},
};
Ok(())

View file

@ -52,7 +52,12 @@ impl crate::Service for Service {
impl Service {
/// Sets a user as typing until the timeout timestamp is reached or
/// roomtyping_remove is called.
pub async fn typing_add(&self, user_id: &UserId, room_id: &RoomId, timeout: u64) -> Result<()> {
pub async fn typing_add(
&self,
user_id: &UserId,
room_id: &RoomId,
timeout: u64,
) -> Result<()> {
debug_info!("typing started {user_id:?} in {room_id:?} timeout:{timeout:?}");
// update clients
self.typing
@ -177,15 +182,15 @@ impl Service {
/// Returns a new typing EDU.
pub async fn typings_all(
&self, room_id: &RoomId, sender_user: &UserId,
&self,
room_id: &RoomId,
sender_user: &UserId,
) -> Result<SyncEphemeralRoomEvent<ruma::events::typing::TypingEventContent>> {
let room_typing_indicators = self.typing.read().await.get(room_id).cloned();
let Some(typing_indicators) = room_typing_indicators else {
return Ok(SyncEphemeralRoomEvent {
content: ruma::events::typing::TypingEventContent {
user_ids: Vec::new(),
},
content: ruma::events::typing::TypingEventContent { user_ids: Vec::new() },
});
};
@ -204,13 +209,16 @@ impl Service {
.await;
Ok(SyncEphemeralRoomEvent {
content: ruma::events::typing::TypingEventContent {
user_ids,
},
content: ruma::events::typing::TypingEventContent { user_ids },
})
}
async fn federation_send(&self, room_id: &RoomId, user_id: &UserId, typing: bool) -> Result<()> {
async fn federation_send(
&self,
room_id: &RoomId,
user_id: &UserId,
typing: bool,
) -> Result<()> {
debug_assert!(
self.services.globals.user_is_local(user_id),
"tried to broadcast typing status of remote user",

View file

@ -92,7 +92,12 @@ pub async fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -
}
#[implement(Service)]
pub async fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64, shortstatehash: ShortStateHash) {
pub async fn associate_token_shortstatehash(
&self,
room_id: &RoomId,
token: u64,
shortstatehash: ShortStateHash,
) {
let shortroomid = self
.services
.short
@ -108,7 +113,11 @@ pub async fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64,
}
#[implement(Service)]
pub async fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<ShortStateHash> {
pub async fn get_token_shortstatehash(
&self,
room_id: &RoomId,
token: u64,
) -> Result<ShortStateHash> {
let shortroomid = self.services.short.get_shortroomid(room_id).await?;
let key: &[u64] = &[shortroomid, token];

View file

@ -3,14 +3,18 @@ use std::{fmt::Debug, mem};
use bytes::BytesMut;
use conduwuit::{debug_error, err, trace, utils, warn, Err, Result};
use reqwest::Client;
use ruma::api::{appservice::Registration, IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken};
use ruma::api::{
appservice::Registration, IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken,
};
/// Sends a request to an appservice
///
/// Only returns Ok(None) if there is no url specified in the appservice
/// registration file
pub(crate) async fn send_request<T>(
client: &Client, registration: Registration, request: T,
client: &Client,
registration: Registration,
request: T,
) -> Result<Option<T::IncomingResponse>>
where
T: OutgoingRequest + Debug + Send,
@ -25,17 +29,17 @@ where
let hs_token = registration.hs_token.as_str();
let mut http_request = request
.try_into_http_request::<BytesMut>(&dest, SendAccessToken::IfRequired(hs_token), &VERSIONS)
.try_into_http_request::<BytesMut>(
&dest,
SendAccessToken::IfRequired(hs_token),
&VERSIONS,
)
.map_err(|e| err!(BadServerResponse(warn!("Failed to find destination {dest}: {e}"))))?
.map(BytesMut::freeze);
let mut parts = http_request.uri().clone().into_parts();
let old_path_and_query = parts.path_and_query.unwrap().as_str().to_owned();
let symbol = if old_path_and_query.contains('?') {
"&"
} else {
"?"
};
let symbol = if old_path_and_query.contains('?') { "&" } else { "?" };
parts.path_and_query = Some(
(old_path_and_query + symbol + "access_token=" + hs_token)

View file

@ -43,7 +43,9 @@ impl Data {
}
}
pub(super) fn delete_active_request(&self, key: &[u8]) { self.servercurrentevent_data.remove(key); }
pub(super) fn delete_active_request(&self, key: &[u8]) {
self.servercurrentevent_data.remove(key);
}
pub(super) async fn delete_all_active_requests_for(&self, destination: &Destination) {
let prefix = destination.get_prefix();
@ -76,11 +78,7 @@ impl Data {
events
.filter(|(key, _)| !key.is_empty())
.for_each(|(key, val)| {
let val = if let SendingEvent::Edu(val) = &val {
&**val
} else {
&[]
};
let val = if let SendingEvent::Edu(val) = &val { &**val } else { &[] };
self.servercurrentevent_data.insert(key, val);
self.servernameevent_data.remove(key);
@ -93,21 +91,26 @@ impl Data {
.raw_stream()
.ignore_err()
.map(|(key, val)| {
let (dest, event) = parse_servercurrentevent(key, val).expect("invalid servercurrentevent");
let (dest, event) =
parse_servercurrentevent(key, val).expect("invalid servercurrentevent");
(key.to_vec(), event, dest)
})
}
#[inline]
pub fn active_requests_for(&self, destination: &Destination) -> impl Stream<Item = SendingItem> + Send + '_ {
pub fn active_requests_for(
&self,
destination: &Destination,
) -> impl Stream<Item = SendingItem> + Send + '_ {
let prefix = destination.get_prefix();
self.servercurrentevent_data
.raw_stream_from(&prefix)
.ignore_err()
.ready_take_while(move |(key, _)| key.starts_with(&prefix))
.map(|(key, val)| {
let (_, event) = parse_servercurrentevent(key, val).expect("invalid servercurrentevent");
let (_, event) =
parse_servercurrentevent(key, val).expect("invalid servercurrentevent");
(key.to_vec(), event)
})
@ -150,14 +153,18 @@ impl Data {
keys
}
pub fn queued_requests(&self, destination: &Destination) -> impl Stream<Item = QueueItem> + Send + '_ {
pub fn queued_requests(
&self,
destination: &Destination,
) -> impl Stream<Item = QueueItem> + Send + '_ {
let prefix = destination.get_prefix();
self.servernameevent_data
.raw_stream_from(&prefix)
.ignore_err()
.ready_take_while(move |(key, _)| key.starts_with(&prefix))
.map(|(key, val)| {
let (_, event) = parse_servercurrentevent(key, val).expect("invalid servercurrentevent");
let (_, event) =
parse_servercurrentevent(key, val).expect("invalid servercurrentevent");
(key.to_vec(), event)
})
@ -186,8 +193,9 @@ fn parse_servercurrentevent(key: &[u8], value: &[u8]) -> Result<(Destination, Se
.next()
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
let server = utils::string_from_bytes(server)
.map_err(|_| Error::bad_database("Invalid server bytes in server_currenttransaction"))?;
let server = utils::string_from_bytes(server).map_err(|_| {
Error::bad_database("Invalid server bytes in server_currenttransaction")
})?;
(
Destination::Appservice(server),
@ -203,8 +211,8 @@ fn parse_servercurrentevent(key: &[u8], value: &[u8]) -> Result<(Destination, Se
let user = parts.next().expect("splitn always returns one element");
let user_string = utils::string_from_bytes(user)
.map_err(|_| Error::bad_database("Invalid user string in servercurrentevent"))?;
let user_id =
UserId::parse(user_string).map_err(|_| Error::bad_database("Invalid user id in servercurrentevent"))?;
let user_id = UserId::parse(user_string)
.map_err(|_| Error::bad_database("Invalid user id in servercurrentevent"))?;
let pushkey = parts
.next()
@ -233,14 +241,14 @@ fn parse_servercurrentevent(key: &[u8], value: &[u8]) -> Result<(Destination, Se
.next()
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
let server = utils::string_from_bytes(server)
.map_err(|_| Error::bad_database("Invalid server bytes in server_currenttransaction"))?;
let server = utils::string_from_bytes(server).map_err(|_| {
Error::bad_database("Invalid server bytes in server_currenttransaction")
})?;
(
Destination::Normal(
ServerName::parse(server)
.map_err(|_| Error::bad_database("Invalid server string in server_currenttransaction"))?,
),
Destination::Normal(ServerName::parse(server).map_err(|_| {
Error::bad_database("Invalid server string in server_currenttransaction")
})?),
if value.is_empty() {
SendingEvent::Pdu(event.into())
} else {

View file

@ -14,7 +14,7 @@ pub enum Destination {
#[must_use]
pub(super) fn get_prefix(&self) -> Vec<u8> {
match self {
Self::Normal(server) => {
| Self::Normal(server) => {
let len = server.as_bytes().len().saturating_add(1);
let mut p = Vec::with_capacity(len);
@ -22,7 +22,7 @@ pub(super) fn get_prefix(&self) -> Vec<u8> {
p.push(0xFF);
p
},
Self::Appservice(server) => {
| Self::Appservice(server) => {
let sigil = b"+";
let len = sigil.len().saturating_add(server.len()).saturating_add(1);
@ -32,7 +32,7 @@ pub(super) fn get_prefix(&self) -> Vec<u8> {
p.push(0xFF);
p
},
Self::Push(user, pushkey) => {
| Self::Push(user, pushkey) => {
let sigil = b"$";
let len = sigil
.len()

View file

@ -25,8 +25,8 @@ pub use self::{
sender::{EDU_LIMIT, PDU_LIMIT},
};
use crate::{
account_data, client, globals, presence, pusher, resolver, rooms, rooms::timeline::RawPduId, server_keys, users,
Dep,
account_data, client, globals, presence, pusher, resolver, rooms, rooms::timeline::RawPduId,
server_keys, users, Dep,
};
pub struct Service {
@ -156,18 +156,16 @@ impl Service {
{
let _cork = self.db.db.cork();
let requests = servers
.map(|server| (Destination::Normal(server.into()), SendingEvent::Pdu(pdu_id.to_owned())))
.map(|server| {
(Destination::Normal(server.into()), SendingEvent::Pdu(pdu_id.to_owned()))
})
.collect::<Vec<_>>()
.await;
let keys = self.db.queue_requests(requests.iter().map(|(o, e)| (e, o)));
for ((dest, event), queue_id) in requests.into_iter().zip(keys) {
self.dispatch(Msg {
dest,
event,
queue_id,
})?;
self.dispatch(Msg { dest, event, queue_id })?;
}
Ok(())
@ -204,18 +202,16 @@ impl Service {
{
let _cork = self.db.db.cork();
let requests = servers
.map(|server| (Destination::Normal(server.to_owned()), SendingEvent::Edu(serialized.clone())))
.map(|server| {
(Destination::Normal(server.to_owned()), SendingEvent::Edu(serialized.clone()))
})
.collect::<Vec<_>>()
.await;
let keys = self.db.queue_requests(requests.iter().map(|(o, e)| (e, o)));
for ((dest, event), queue_id) in requests.into_iter().zip(keys) {
self.dispatch(Msg {
dest,
event,
queue_id,
})?;
self.dispatch(Msg { dest, event, queue_id })?;
}
Ok(())
@ -253,7 +249,11 @@ impl Service {
/// Sends a request to a federation server
#[tracing::instrument(skip_all, name = "request")]
pub async fn send_federation_request<T>(&self, dest: &ServerName, request: T) -> Result<T::IncomingResponse>
pub async fn send_federation_request<T>(
&self,
dest: &ServerName,
request: T,
) -> Result<T::IncomingResponse>
where
T: OutgoingRequest + Debug + Send,
{
@ -263,7 +263,11 @@ impl Service {
/// Like send_federation_request() but with a very large timeout
#[tracing::instrument(skip_all, name = "synapse")]
pub async fn send_synapse_request<T>(&self, dest: &ServerName, request: T) -> Result<T::IncomingResponse>
pub async fn send_synapse_request<T>(
&self,
dest: &ServerName,
request: T,
) -> Result<T::IncomingResponse>
where
T: OutgoingRequest + Debug + Send,
{
@ -276,7 +280,9 @@ impl Service {
/// Only returns None if there is no url specified in the appservice
/// registration file
pub async fn send_appservice_request<T>(
&self, registration: Registration, request: T,
&self,
registration: Registration,
request: T,
) -> Result<Option<T::IncomingResponse>>
where
T: OutgoingRequest + Debug + Send,
@ -291,24 +297,30 @@ impl Service {
/// key
#[tracing::instrument(skip(self), level = "debug")]
pub async fn cleanup_events(
&self, appservice_id: Option<&str>, user_id: Option<&UserId>, push_key: Option<&str>,
&self,
appservice_id: Option<&str>,
user_id: Option<&UserId>,
push_key: Option<&str>,
) -> Result {
match (appservice_id, user_id, push_key) {
(None, Some(user_id), Some(push_key)) => {
| (None, Some(user_id), Some(push_key)) => {
self.db
.delete_all_requests_for(&Destination::Push(user_id.to_owned(), push_key.to_owned()))
.delete_all_requests_for(&Destination::Push(
user_id.to_owned(),
push_key.to_owned(),
))
.await;
Ok(())
},
(Some(appservice_id), None, None) => {
| (Some(appservice_id), None, None) => {
self.db
.delete_all_requests_for(&Destination::Appservice(appservice_id.to_owned()))
.await;
Ok(())
},
_ => {
| _ => {
debug_warn!("cleanup_events called with too many or too few arguments");
Ok(())
},

View file

@ -2,16 +2,16 @@ use std::mem;
use bytes::Bytes;
use conduwuit::{
debug, debug_error, debug_warn, err, error::inspect_debug_log, implement, trace, utils::string::EMPTY, Err, Error,
Result,
debug, debug_error, debug_warn, err, error::inspect_debug_log, implement, trace,
utils::string::EMPTY, Err, Error, Result,
};
use http::{header::AUTHORIZATION, HeaderValue};
use ipaddress::IPAddress;
use reqwest::{Client, Method, Request, Response, Url};
use ruma::{
api::{
client::error::Error as RumaError, EndpointError, IncomingResponse, MatrixVersion, OutgoingRequest,
SendAccessToken,
client::error::Error as RumaError, EndpointError, IncomingResponse, MatrixVersion,
OutgoingRequest, SendAccessToken,
},
serde::Base64,
server_util::authorization::XMatrix,
@ -25,7 +25,12 @@ use crate::{
impl super::Service {
#[tracing::instrument(skip_all, level = "debug")]
pub async fn send<T>(&self, client: &Client, dest: &ServerName, request: T) -> Result<T::IncomingResponse>
pub async fn send<T>(
&self,
client: &Client,
dest: &ServerName,
request: T,
) -> Result<T::IncomingResponse>
where
T: OutgoingRequest + Send,
{
@ -39,7 +44,9 @@ impl super::Service {
.forbidden_remote_server_names
.contains(dest)
{
return Err!(Request(Forbidden(debug_warn!("Federation with {dest} is not allowed."))));
return Err!(Request(Forbidden(debug_warn!(
"Federation with {dest} is not allowed."
))));
}
let actual = self.services.resolver.get_actual_dest(dest).await?;
@ -49,7 +56,11 @@ impl super::Service {
}
async fn execute<T>(
&self, dest: &ServerName, actual: &ActualDest, request: Request, client: &Client,
&self,
dest: &ServerName,
actual: &ActualDest,
request: Request,
client: &Client,
) -> Result<T::IncomingResponse>
where
T: OutgoingRequest + Send,
@ -59,8 +70,18 @@ impl super::Service {
debug!(?method, ?url, "Sending request");
match client.execute(request).await {
Ok(response) => handle_response::<T>(&self.services.resolver, dest, actual, &method, &url, response).await,
Err(error) => Err(handle_error(actual, &method, &url, error).expect_err("always returns error")),
| Ok(response) =>
handle_response::<T>(
&self.services.resolver,
dest,
actual,
&method,
&url,
response,
)
.await,
| Err(error) =>
Err(handle_error(actual, &method, &url, error).expect_err("always returns error")),
}
}
@ -86,7 +107,11 @@ impl super::Service {
}
async fn handle_response<T>(
resolver: &resolver::Service, dest: &ServerName, actual: &ActualDest, method: &Method, url: &Url,
resolver: &resolver::Service,
dest: &ServerName,
actual: &ActualDest,
method: &Method,
url: &Url,
response: Response,
) -> Result<T::IncomingResponse>
where
@ -96,21 +121,22 @@ where
let result = T::IncomingResponse::try_from_http_response(response);
if result.is_ok() && !actual.cached {
resolver.set_cached_destination(
dest.to_owned(),
CachedDest {
dest: actual.dest.clone(),
host: actual.host.clone(),
expire: CachedDest::default_expire(),
},
);
resolver.set_cached_destination(dest.to_owned(), CachedDest {
dest: actual.dest.clone(),
host: actual.host.clone(),
expire: CachedDest::default_expire(),
});
}
result.map_err(|e| err!(BadServerResponse("Server returned bad 200 response: {e:?}")))
}
async fn into_http_response(
dest: &ServerName, actual: &ActualDest, method: &Method, url: &Url, mut response: Response,
dest: &ServerName,
actual: &ActualDest,
method: &Method,
url: &Url,
mut response: Response,
) -> Result<http::Response<Bytes>> {
let status = response.status();
trace!(
@ -146,13 +172,21 @@ async fn into_http_response(
debug!("Got {status:?} for {method} {url}");
if !status.is_success() {
return Err(Error::Federation(dest.to_owned(), RumaError::from_http_response(http_response)));
return Err(Error::Federation(
dest.to_owned(),
RumaError::from_http_response(http_response),
));
}
Ok(http_response)
}
fn handle_error(actual: &ActualDest, method: &Method, url: &Url, mut e: reqwest::Error) -> Result {
fn handle_error(
actual: &ActualDest,
method: &Method,
url: &Url,
mut e: reqwest::Error,
) -> Result {
if e.is_timeout() || e.is_connect() {
e = e.without_url();
debug_warn!("{e:?}");
@ -186,7 +220,8 @@ fn sign_request(&self, http_request: &mut http::Request<Vec<u8>>, dest: &ServerN
.expect("http::Request missing path_and_query");
let mut req: Object = if !body.is_empty() {
let content: CanonicalJsonValue = serde_json::from_slice(body).expect("failed to serialize body");
let content: CanonicalJsonValue =
serde_json::from_slice(body).expect("failed to serialize body");
let authorization: [Member; 5] = [
("content".into(), content),

View file

@ -24,15 +24,19 @@ use ruma::{
appservice::event::push_events::v1::Edu as RumaEdu,
federation::transactions::{
edu::{
DeviceListUpdateContent, Edu, PresenceContent, PresenceUpdate, ReceiptContent, ReceiptData, ReceiptMap,
DeviceListUpdateContent, Edu, PresenceContent, PresenceUpdate, ReceiptContent,
ReceiptData, ReceiptMap,
},
send_transaction_message,
},
},
device_id,
events::{push_rules::PushRulesEvent, receipt::ReceiptType, AnySyncEphemeralRoomEvent, GlobalAccountDataEventType},
push, uint, CanonicalJsonObject, MilliSecondsSinceUnixEpoch, OwnedRoomId, OwnedServerName, OwnedUserId, RoomId,
RoomVersionId, ServerName, UInt,
events::{
push_rules::PushRulesEvent, receipt::ReceiptType, AnySyncEphemeralRoomEvent,
GlobalAccountDataEventType,
},
push, uint, CanonicalJsonObject, MilliSecondsSinceUnixEpoch, OwnedRoomId, OwnedServerName,
OwnedUserId, RoomId, RoomVersionId, ServerName, UInt,
};
use serde_json::value::{to_raw_value, RawValue as RawJsonValue};
@ -86,11 +90,14 @@ impl Service {
}
async fn handle_response<'a>(
&'a self, response: SendingResult, futures: &mut SendingFutures<'a>, statuses: &mut CurTransactionStatus,
&'a self,
response: SendingResult,
futures: &mut SendingFutures<'a>,
statuses: &mut CurTransactionStatus,
) {
match response {
Ok(dest) => self.handle_response_ok(&dest, futures, statuses).await,
Err((dest, e)) => Self::handle_response_err(dest, statuses, &e),
| Ok(dest) => self.handle_response_ok(&dest, futures, statuses).await,
| Err((dest, e)) => Self::handle_response_err(dest, statuses, &e),
};
}
@ -98,16 +105,22 @@ impl Service {
debug!(dest = ?dest, "{e:?}");
statuses.entry(dest).and_modify(|e| {
*e = match e {
TransactionStatus::Running => TransactionStatus::Failed(1, Instant::now()),
TransactionStatus::Retrying(ref n) => TransactionStatus::Failed(n.saturating_add(1), Instant::now()),
TransactionStatus::Failed(..) => panic!("Request that was not even running failed?!"),
| TransactionStatus::Running => TransactionStatus::Failed(1, Instant::now()),
| TransactionStatus::Retrying(ref n) =>
TransactionStatus::Failed(n.saturating_add(1), Instant::now()),
| TransactionStatus::Failed(..) => {
panic!("Request that was not even running failed?!")
},
}
});
}
#[allow(clippy::needless_pass_by_ref_mut)]
async fn handle_response_ok<'a>(
&'a self, dest: &Destination, futures: &mut SendingFutures<'a>, statuses: &mut CurTransactionStatus,
&'a self,
dest: &Destination,
futures: &mut SendingFutures<'a>,
statuses: &mut CurTransactionStatus,
) {
let _cork = self.db.db.cork();
self.db.delete_all_active_requests_for(dest).await;
@ -133,7 +146,10 @@ impl Service {
#[allow(clippy::needless_pass_by_ref_mut)]
async fn handle_request<'a>(
&'a self, msg: Msg, futures: &mut SendingFutures<'a>, statuses: &mut CurTransactionStatus,
&'a self,
msg: Msg,
futures: &mut SendingFutures<'a>,
statuses: &mut CurTransactionStatus,
) {
let iv = vec![(msg.queue_id, msg.event)];
if let Ok(Some(events)) = self.select_events(&msg.dest, iv, statuses).await {
@ -168,8 +184,13 @@ impl Service {
}
#[allow(clippy::needless_pass_by_ref_mut)]
async fn initial_requests<'a>(&'a self, futures: &mut SendingFutures<'a>, statuses: &mut CurTransactionStatus) {
let keep = usize::try_from(self.server.config.startup_netburst_keep).unwrap_or(usize::MAX);
async fn initial_requests<'a>(
&'a self,
futures: &mut SendingFutures<'a>,
statuses: &mut CurTransactionStatus,
) {
let keep =
usize::try_from(self.server.config.startup_netburst_keep).unwrap_or(usize::MAX);
let mut txns = HashMap::<Destination, Vec<SendingEvent>>::new();
let mut active = self.db.active_requests().boxed();
@ -240,7 +261,11 @@ impl Service {
}
#[tracing::instrument(skip_all, level = "debug")]
fn select_events_current(&self, dest: Destination, statuses: &mut CurTransactionStatus) -> Result<(bool, bool)> {
fn select_events_current(
&self,
dest: Destination,
statuses: &mut CurTransactionStatus,
) -> Result<(bool, bool)> {
let (mut allow, mut retry) = (true, false);
statuses
.entry(dest.clone()) // TODO: can we avoid cloning?
@ -278,7 +303,8 @@ impl Service {
let events_len = AtomicUsize::default();
let max_edu_count = AtomicU64::new(since);
let device_changes = self.select_edus_device_changes(server_name, batch, &max_edu_count, &events_len);
let device_changes =
self.select_edus_device_changes(server_name, batch, &max_edu_count, &events_len);
let receipts: OptionFuture<_> = self
.server
@ -305,7 +331,11 @@ impl Service {
/// Look for presence
async fn select_edus_device_changes(
&self, server_name: &ServerName, since: (u64, u64), max_edu_count: &AtomicU64, events_len: &AtomicUsize,
&self,
server_name: &ServerName,
since: (u64, u64),
max_edu_count: &AtomicU64,
events_len: &AtomicUsize,
) -> Vec<Vec<u8>> {
let mut events = Vec::new();
let server_rooms = self.services.state_cache.server_rooms(server_name);
@ -342,7 +372,8 @@ impl Service {
keys: None,
});
let edu = serde_json::to_vec(&edu).expect("failed to serialize device list update to JSON");
let edu = serde_json::to_vec(&edu)
.expect("failed to serialize device list update to JSON");
events.push(edu);
if events_len.fetch_add(1, Ordering::Relaxed) >= SELECT_EDU_LIMIT - 1 {
@ -356,7 +387,10 @@ impl Service {
/// Look for read receipts in this room
async fn select_edus_receipts(
&self, server_name: &ServerName, since: (u64, u64), max_edu_count: &AtomicU64,
&self,
server_name: &ServerName,
since: (u64, u64),
max_edu_count: &AtomicU64,
) -> Option<Vec<u8>> {
let server_rooms = self.services.state_cache.server_rooms(server_name);
@ -377,19 +411,21 @@ impl Service {
return None;
}
let receipt_content = Edu::Receipt(ReceiptContent {
receipts,
});
let receipt_content = Edu::Receipt(ReceiptContent { receipts });
let receipt_content =
serde_json::to_vec(&receipt_content).expect("Failed to serialize Receipt EDU to JSON vec");
let receipt_content = serde_json::to_vec(&receipt_content)
.expect("Failed to serialize Receipt EDU to JSON vec");
Some(receipt_content)
}
/// Look for read receipts in this room
async fn select_edus_receipts_room(
&self, room_id: &RoomId, since: (u64, u64), max_edu_count: &AtomicU64, num: &mut usize,
&self,
room_id: &RoomId,
since: (u64, u64),
max_edu_count: &AtomicU64,
num: &mut usize,
) -> ReceiptMap {
let receipts = self
.services
@ -444,14 +480,15 @@ impl Service {
}
}
ReceiptMap {
read,
}
ReceiptMap { read }
}
/// Look for presence
async fn select_edus_presence(
&self, server_name: &ServerName, since: (u64, u64), max_edu_count: &AtomicU64,
&self,
server_name: &ServerName,
since: (u64, u64),
max_edu_count: &AtomicU64,
) -> Option<Vec<u8>> {
let presence_since = self.services.presence.presence_since(since.0);
@ -511,7 +548,8 @@ impl Service {
push: presence_updates.into_values().collect(),
});
let presence_content = serde_json::to_vec(&presence_content).expect("failed to serialize Presence EDU to JSON");
let presence_content = serde_json::to_vec(&presence_content)
.expect("failed to serialize Presence EDU to JSON");
Some(presence_content)
}
@ -519,21 +557,28 @@ impl Service {
async fn send_events(&self, dest: Destination, events: Vec<SendingEvent>) -> SendingResult {
//debug_assert!(!events.is_empty(), "sending empty transaction");
match dest {
Destination::Normal(ref server) => self.send_events_dest_normal(&dest, server, events).await,
Destination::Appservice(ref id) => self.send_events_dest_appservice(&dest, id, events).await,
Destination::Push(ref userid, ref pushkey) => {
| Destination::Normal(ref server) =>
self.send_events_dest_normal(&dest, server, events).await,
| Destination::Appservice(ref id) =>
self.send_events_dest_appservice(&dest, id, events).await,
| Destination::Push(ref userid, ref pushkey) =>
self.send_events_dest_push(&dest, userid, pushkey, events)
.await
},
.await,
}
}
#[tracing::instrument(skip(self, dest, events), name = "appservice")]
async fn send_events_dest_appservice(
&self, dest: &Destination, id: &str, events: Vec<SendingEvent>,
&self,
dest: &Destination,
id: &str,
events: Vec<SendingEvent>,
) -> SendingResult {
let Some(appservice) = self.services.appservice.get_registration(id).await else {
return Err((dest.clone(), err!(Database(warn!(?id, "Missing appservice registration")))));
return Err((
dest.clone(),
err!(Database(warn!(?id, "Missing appservice registration"))),
));
};
let mut pdu_jsons = Vec::with_capacity(
@ -550,12 +595,12 @@ impl Service {
);
for event in &events {
match event {
SendingEvent::Pdu(pdu_id) => {
| SendingEvent::Pdu(pdu_id) => {
if let Ok(pdu) = self.services.timeline.get_pdu_from_id(pdu_id).await {
pdu_jsons.push(pdu.to_room_event());
}
},
SendingEvent::Edu(edu) => {
| SendingEvent::Edu(edu) => {
if appservice
.receive_ephemeral
.is_some_and(|receive_edus| receive_edus)
@ -565,14 +610,14 @@ impl Service {
}
}
},
SendingEvent::Flush => {}, // flush only; no new content
| SendingEvent::Flush => {}, // flush only; no new content
}
}
let txn_hash = calculate_hash(events.iter().filter_map(|e| match e {
SendingEvent::Edu(b) => Some(&**b),
SendingEvent::Pdu(b) => Some(b.as_ref()),
SendingEvent::Flush => None,
| SendingEvent::Edu(b) => Some(&**b),
| SendingEvent::Pdu(b) => Some(b.as_ref()),
| SendingEvent::Flush => None,
}));
let txn_id = &*general_purpose::URL_SAFE_NO_PAD.encode(txn_hash);
@ -592,28 +637,35 @@ impl Service {
)
.await
{
Ok(_) => Ok(dest.clone()),
Err(e) => Err((dest.clone(), e)),
| Ok(_) => Ok(dest.clone()),
| Err(e) => Err((dest.clone(), e)),
}
}
#[tracing::instrument(skip(self, dest, events), name = "push")]
async fn send_events_dest_push(
&self, dest: &Destination, userid: &OwnedUserId, pushkey: &str, events: Vec<SendingEvent>,
&self,
dest: &Destination,
userid: &OwnedUserId,
pushkey: &str,
events: Vec<SendingEvent>,
) -> SendingResult {
let Ok(pusher) = self.services.pusher.get_pusher(userid, pushkey).await else {
return Err((dest.clone(), err!(Database(error!(?userid, ?pushkey, "Missing pusher")))));
return Err((
dest.clone(),
err!(Database(error!(?userid, ?pushkey, "Missing pusher"))),
));
};
let mut pdus = Vec::new();
for event in &events {
match event {
SendingEvent::Pdu(pdu_id) => {
| SendingEvent::Pdu(pdu_id) => {
if let Ok(pdu) = self.services.timeline.get_pdu_from_id(pdu_id).await {
pdus.push(pdu);
}
},
SendingEvent::Edu(_) | SendingEvent::Flush => {
| SendingEvent::Edu(_) | SendingEvent::Flush => {
// Push gateways don't need EDUs (?) and flush only;
// no new content
},
@ -657,7 +709,10 @@ impl Service {
#[tracing::instrument(skip(self, dest, events), name = "", level = "debug")]
async fn send_events_dest_normal(
&self, dest: &Destination, server: &OwnedServerName, events: Vec<SendingEvent>,
&self,
dest: &Destination,
server: &OwnedServerName,
events: Vec<SendingEvent>,
) -> SendingResult {
let mut pdu_jsons = Vec::with_capacity(
events
@ -675,17 +730,16 @@ impl Service {
for event in &events {
match event {
// TODO: check room version and remove event_id if needed
SendingEvent::Pdu(pdu_id) => {
| SendingEvent::Pdu(pdu_id) => {
if let Ok(pdu) = self.services.timeline.get_pdu_json_from_id(pdu_id).await {
pdu_jsons.push(self.convert_to_outgoing_federation_event(pdu).await);
}
},
SendingEvent::Edu(edu) => {
| SendingEvent::Edu(edu) =>
if let Ok(raw) = serde_json::from_slice(edu) {
edu_jsons.push(raw);
}
},
SendingEvent::Flush => {}, // flush only; no new content
},
| SendingEvent::Flush => {}, // flush only; no new content
}
}
@ -693,9 +747,9 @@ impl Service {
// transaction");
let txn_hash = calculate_hash(events.iter().filter_map(|e| match e {
SendingEvent::Edu(b) => Some(&**b),
SendingEvent::Pdu(b) => Some(b.as_ref()),
SendingEvent::Flush => None,
| SendingEvent::Edu(b) => Some(&**b),
| SendingEvent::Pdu(b) => Some(b.as_ref()),
| SendingEvent::Flush => None,
}));
let txn_id = &*general_purpose::URL_SAFE_NO_PAD.encode(txn_hash);
@ -725,7 +779,10 @@ impl Service {
}
/// This does not return a full `Pdu` it is only to satisfy ruma's types.
pub async fn convert_to_outgoing_federation_event(&self, mut pdu_json: CanonicalJsonObject) -> Box<RawJsonValue> {
pub async fn convert_to_outgoing_federation_event(
&self,
mut pdu_json: CanonicalJsonObject,
) -> Box<RawJsonValue> {
if let Some(unsigned) = pdu_json
.get_mut("unsigned")
.and_then(|val| val.as_object_mut())
@ -739,11 +796,11 @@ impl Service {
.and_then(|val| RoomId::parse(val.as_str()?).ok())
{
match self.services.state.get_room_version(&room_id).await {
Ok(room_version_id) => match room_version_id {
RoomVersionId::V1 | RoomVersionId::V2 => {},
_ => _ = pdu_json.remove("event_id"),
| Ok(room_version_id) => match room_version_id {
| RoomVersionId::V1 | RoomVersionId::V2 => {},
| _ => _ = pdu_json.remove("event_id"),
},
Err(_) => _ = pdu_json.remove("event_id"),
| Err(_) => _ = pdu_json.remove("event_id"),
}
} else {
pdu_json.remove("event_id");

View file

@ -4,11 +4,13 @@ use std::{
time::Duration,
};
use conduwuit::{debug, debug_error, debug_warn, error, implement, info, result::FlatOk, trace, warn};
use conduwuit::{
debug, debug_error, debug_warn, error, implement, info, result::FlatOk, trace, warn,
};
use futures::{stream::FuturesUnordered, StreamExt};
use ruma::{
api::federation::discovery::ServerSigningKeys, serde::Raw, CanonicalJsonObject, OwnedServerName,
OwnedServerSigningKeyId, ServerName, ServerSigningKeyId,
api::federation::discovery::ServerSigningKeys, serde::Raw, CanonicalJsonObject,
OwnedServerName, OwnedServerSigningKeyId, ServerName, ServerSigningKeyId,
};
use serde_json::value::RawValue as RawJsonValue;
use tokio::time::{timeout_at, Instant};
@ -79,7 +81,9 @@ where
return;
}
warn!("missing {missing_keys} keys for {missing_servers} servers from all notaries first");
warn!(
"missing {missing_keys} keys for {missing_servers} servers from all notaries first"
);
}
if !notary_only {
@ -101,13 +105,15 @@ where
return;
}
debug_warn!("still missing {missing_keys} keys for {missing_servers} servers from all notaries.");
debug_warn!(
"still missing {missing_keys} keys for {missing_servers} servers from all notaries."
);
}
if missing_keys > 0 {
warn!(
"did not obtain {missing_keys} keys for {missing_servers} servers out of {requested_keys} total keys for \
{requested_servers} total servers."
"did not obtain {missing_keys} keys for {missing_servers} servers out of \
{requested_keys} total keys for {requested_servers} total servers."
);
}
@ -162,12 +168,15 @@ where
#[implement(super::Service)]
async fn acquire_origin(
&self, origin: OwnedServerName, mut key_ids: Vec<OwnedServerSigningKeyId>, timeout: Instant,
&self,
origin: OwnedServerName,
mut key_ids: Vec<OwnedServerSigningKeyId>,
timeout: Instant,
) -> (OwnedServerName, Vec<OwnedServerSigningKeyId>) {
match timeout_at(timeout, self.server_request(&origin)).await {
Err(e) => debug_warn!(?origin, "timed out: {e}"),
Ok(Err(e)) => debug_error!(?origin, "{e}"),
Ok(Ok(server_keys)) => {
| Err(e) => debug_warn!(?origin, "timed out: {e}"),
| Ok(Err(e)) => debug_error!(?origin, "{e}"),
| Ok(Ok(server_keys)) => {
trace!(
%origin,
?key_ids,
@ -192,19 +201,21 @@ where
for notary in self.services.globals.trusted_servers() {
let missing_keys = keys_count(&missing);
let missing_servers = missing.len();
debug!("Asking notary {notary} for {missing_keys} missing keys from {missing_servers} servers");
debug!(
"Asking notary {notary} for {missing_keys} missing keys from {missing_servers} \
servers"
);
let batch = missing
.iter()
.map(|(server, keys)| (server.borrow(), keys.iter().map(Borrow::borrow)));
match self.batch_notary_request(notary, batch).await {
Err(e) => error!("Failed to contact notary {notary:?}: {e}"),
Ok(results) => {
| Err(e) => error!("Failed to contact notary {notary:?}: {e}"),
| Ok(results) =>
for server_keys in results {
self.acquire_notary_result(&mut missing, server_keys).await;
}
},
},
}
}
@ -224,4 +235,6 @@ async fn acquire_notary_result(&self, missing: &mut Batch, server_keys: ServerSi
}
}
fn keys_count(batch: &Batch) -> usize { batch.iter().flat_map(|(_, key_ids)| key_ids.iter()).count() }
fn keys_count(batch: &Batch) -> usize {
batch.iter().flat_map(|(_, key_ids)| key_ids.iter()).count()
}

View file

@ -1,17 +1,25 @@
use std::borrow::Borrow;
use conduwuit::{implement, Err, Result};
use ruma::{api::federation::discovery::VerifyKey, CanonicalJsonObject, RoomVersionId, ServerName, ServerSigningKeyId};
use ruma::{
api::federation::discovery::VerifyKey, CanonicalJsonObject, RoomVersionId, ServerName,
ServerSigningKeyId,
};
use super::{extract_key, PubKeyMap, PubKeys};
#[implement(super::Service)]
pub async fn get_event_keys(&self, object: &CanonicalJsonObject, version: &RoomVersionId) -> Result<PubKeyMap> {
pub async fn get_event_keys(
&self,
object: &CanonicalJsonObject,
version: &RoomVersionId,
) -> Result<PubKeyMap> {
use ruma::signatures::required_keys;
let required = match required_keys(object, version) {
Ok(required) => required,
Err(e) => return Err!(BadServerResponse("Failed to determine keys required to verify: {e}")),
| Ok(required) => required,
| Err(e) =>
return Err!(BadServerResponse("Failed to determine keys required to verify: {e}")),
};
let batch = required
@ -52,7 +60,11 @@ where
}
#[implement(super::Service)]
pub async fn get_verify_key(&self, origin: &ServerName, key_id: &ServerSigningKeyId) -> Result<VerifyKey> {
pub async fn get_verify_key(
&self,
origin: &ServerName,
key_id: &ServerSigningKeyId,
) -> Result<VerifyKey> {
let notary_first = self.services.server.config.query_trusted_key_servers_first;
let notary_only = self.services.server.config.only_query_trusted_key_servers;
@ -86,7 +98,11 @@ pub async fn get_verify_key(&self, origin: &ServerName, key_id: &ServerSigningKe
}
#[implement(super::Service)]
async fn get_verify_key_from_notaries(&self, origin: &ServerName, key_id: &ServerSigningKeyId) -> Result<VerifyKey> {
async fn get_verify_key_from_notaries(
&self,
origin: &ServerName,
key_id: &ServerSigningKeyId,
) -> Result<VerifyKey> {
for notary in self.services.globals.trusted_servers() {
if let Ok(server_keys) = self.notary_request(notary, origin).await {
for server_key in server_keys.clone() {
@ -105,7 +121,11 @@ async fn get_verify_key_from_notaries(&self, origin: &ServerName, key_id: &Serve
}
#[implement(super::Service)]
async fn get_verify_key_from_origin(&self, origin: &ServerName, key_id: &ServerSigningKeyId) -> Result<VerifyKey> {
async fn get_verify_key_from_origin(
&self,
origin: &ServerName,
key_id: &ServerSigningKeyId,
) -> Result<VerifyKey> {
if let Ok(server_key) = self.server_request(origin).await {
self.add_signing_keys(server_key.clone()).await;
if let Some(result) = extract_key(server_key, key_id) {

View file

@ -39,14 +39,15 @@ fn load(db: &Arc<Database>) -> Result<Box<Ed25519KeyPair>> {
create(db)
})?;
let key =
Ed25519KeyPair::from_der(&key, version).map_err(|e| err!("Failed to load ed25519 keypair from der: {e:?}"))?;
let key = Ed25519KeyPair::from_der(&key, version)
.map_err(|e| err!("Failed to load ed25519 keypair from der: {e:?}"))?;
Ok(Box::new(key))
}
fn create(db: &Arc<Database>) -> Result<(String, Vec<u8>)> {
let keypair = Ed25519KeyPair::generate().map_err(|e| err!("Failed to generate new ed25519 keypair: {e:?}"))?;
let keypair = Ed25519KeyPair::generate()
.map_err(|e| err!("Failed to generate new ed25519 keypair: {e:?}"))?;
let id = utils::rand::string(8);
debug_info!("Generated new Ed25519 keypair: {id:?}");

View file

@ -18,8 +18,8 @@ use ruma::{
api::federation::discovery::{ServerSigningKeys, VerifyKey},
serde::Raw,
signatures::{Ed25519KeyPair, PublicKeyMap, PublicKeySet},
CanonicalJsonObject, MilliSecondsSinceUnixEpoch, OwnedServerSigningKeyId, RoomVersionId, ServerName,
ServerSigningKeyId,
CanonicalJsonObject, MilliSecondsSinceUnixEpoch, OwnedServerSigningKeyId, RoomVersionId,
ServerName, ServerSigningKeyId,
};
use serde_json::value::RawValue as RawJsonValue;
@ -113,7 +113,11 @@ async fn add_signing_keys(&self, new_keys: ServerSigningKeys) {
}
#[implement(Service)]
pub async fn required_keys_exist(&self, object: &CanonicalJsonObject, version: &RoomVersionId) -> bool {
pub async fn required_keys_exist(
&self,
object: &CanonicalJsonObject,
version: &RoomVersionId,
) -> bool {
use ruma::signatures::required_keys;
let Ok(required_keys) = required_keys(object, version) else {
@ -179,7 +183,8 @@ pub async fn signing_keys_for(&self, origin: &ServerName) -> Result<ServerSignin
#[implement(Service)]
fn minimum_valid_ts(&self) -> MilliSecondsSinceUnixEpoch {
let timepoint = timepoint_from_now(self.minimum_valid).expect("SystemTime should not overflow");
let timepoint =
timepoint_from_now(self.minimum_valid).expect("SystemTime should not overflow");
MilliSecondsSinceUnixEpoch::from_system_time(timepoint).expect("UInt should not overflow")
}

View file

@ -12,7 +12,9 @@ use ruma::{
#[implement(super::Service)]
pub(super) async fn batch_notary_request<'a, S, K>(
&self, notary: &ServerName, batch: S,
&self,
notary: &ServerName,
batch: S,
) -> Result<Vec<ServerSigningKeys>>
where
S: Iterator<Item = (&'a ServerName, K)> + Send,
@ -74,7 +76,9 @@ where
#[implement(super::Service)]
pub async fn notary_request(
&self, notary: &ServerName, target: &ServerName,
&self,
notary: &ServerName,
target: &ServerName,
) -> Result<impl Iterator<Item = ServerSigningKeys> + Clone + Debug + Send> {
use get_remote_server_keys::v2::Request;

View file

@ -10,7 +10,11 @@ pub fn sign_json(&self, object: &mut CanonicalJsonObject) -> Result {
}
#[implement(super::Service)]
pub fn hash_and_sign_event(&self, object: &mut CanonicalJsonObject, room_version: &RoomVersionId) -> Result {
pub fn hash_and_sign_event(
&self,
object: &mut CanonicalJsonObject,
room_version: &RoomVersionId,
) -> Result {
use ruma::signatures::hash_and_sign_event;
let server_name = self.services.globals.server_name().as_str();

View file

@ -1,14 +1,20 @@
use conduwuit::{implement, pdu::gen_event_id_canonical_json, Err, Result};
use ruma::{signatures::Verified, CanonicalJsonObject, CanonicalJsonValue, OwnedEventId, RoomVersionId};
use ruma::{
signatures::Verified, CanonicalJsonObject, CanonicalJsonValue, OwnedEventId, RoomVersionId,
};
use serde_json::value::RawValue as RawJsonValue;
#[implement(super::Service)]
pub async fn validate_and_add_event_id(
&self, pdu: &RawJsonValue, room_version: &RoomVersionId,
&self,
pdu: &RawJsonValue,
room_version: &RoomVersionId,
) -> Result<(OwnedEventId, CanonicalJsonObject)> {
let (event_id, mut value) = gen_event_id_canonical_json(pdu, room_version)?;
if let Err(e) = self.verify_event(&value, Some(room_version)).await {
return Err!(BadServerResponse(debug_error!("Event {event_id} failed verification: {e:?}")));
return Err!(BadServerResponse(debug_error!(
"Event {event_id} failed verification: {e:?}"
)));
}
value.insert("event_id".into(), CanonicalJsonValue::String(event_id.as_str().into()));
@ -18,7 +24,9 @@ pub async fn validate_and_add_event_id(
#[implement(super::Service)]
pub async fn validate_and_add_event_id_no_fetch(
&self, pdu: &RawJsonValue, room_version: &RoomVersionId,
&self,
pdu: &RawJsonValue,
room_version: &RoomVersionId,
) -> Result<(OwnedEventId, CanonicalJsonObject)> {
let (event_id, mut value) = gen_event_id_canonical_json(pdu, room_version)?;
if !self.required_keys_exist(&value, room_version).await {
@ -28,7 +36,9 @@ pub async fn validate_and_add_event_id_no_fetch(
}
if let Err(e) = self.verify_event(&value, Some(room_version)).await {
return Err!(BadServerResponse(debug_error!("Event {event_id} failed verification: {e:?}")));
return Err!(BadServerResponse(debug_error!(
"Event {event_id} failed verification: {e:?}"
)));
}
value.insert("event_id".into(), CanonicalJsonValue::String(event_id.as_str().into()));
@ -38,7 +48,9 @@ pub async fn validate_and_add_event_id_no_fetch(
#[implement(super::Service)]
pub async fn verify_event(
&self, event: &CanonicalJsonObject, room_version: Option<&RoomVersionId>,
&self,
event: &CanonicalJsonObject,
room_version: Option<&RoomVersionId>,
) -> Result<Verified> {
let room_version = room_version.unwrap_or(&RoomVersionId::V11);
let keys = self.get_event_keys(event, room_version).await?;
@ -46,7 +58,11 @@ pub async fn verify_event(
}
#[implement(super::Service)]
pub async fn verify_json(&self, event: &CanonicalJsonObject, room_version: Option<&RoomVersionId>) -> Result {
pub async fn verify_json(
&self,
event: &CanonicalJsonObject,
room_version: Option<&RoomVersionId>,
) -> Result {
let room_version = room_version.unwrap_or(&RoomVersionId::V11);
let keys = self.get_event_keys(event, room_version).await?;
ruma::signatures::verify_json(&keys, event.clone()).map_err(Into::into)

View file

@ -114,7 +114,9 @@ impl<'a> Args<'a> {
/// Create a reference immediately to a service when constructing another
/// Service. The other service must be constructed.
#[inline]
pub(crate) fn require<T: Service>(&'a self, name: &str) -> Arc<T> { require::<T>(self.service, name) }
pub(crate) fn require<T: Service>(&'a self, name: &str) -> Arc<T> {
require::<T>(self.service, name)
}
}
/// Reference a Service by name. Panics if the Service does not exist or was

View file

@ -47,7 +47,8 @@ struct Services {
struct SlidingSyncCache {
lists: BTreeMap<String, SyncRequestList>,
subscriptions: BTreeMap<OwnedRoomId, sync_events::v4::RoomSubscription>,
known_rooms: BTreeMap<String, BTreeMap<OwnedRoomId, u64>>, // For every room, the roomsince number
known_rooms: BTreeMap<String, BTreeMap<OwnedRoomId, u64>>, /* For every room, the
* roomsince number */
extensions: ExtensionsConfig,
}
@ -85,14 +86,24 @@ impl crate::Service for Service {
}
impl Service {
pub fn remembered(&self, user_id: OwnedUserId, device_id: OwnedDeviceId, conn_id: String) -> bool {
pub fn remembered(
&self,
user_id: OwnedUserId,
device_id: OwnedDeviceId,
conn_id: String,
) -> bool {
self.connections
.lock()
.unwrap()
.contains_key(&(user_id, device_id, conn_id))
}
pub fn forget_sync_request_connection(&self, user_id: OwnedUserId, device_id: OwnedDeviceId, conn_id: String) {
pub fn forget_sync_request_connection(
&self,
user_id: OwnedUserId,
device_id: OwnedDeviceId,
conn_id: String,
) {
self.connections
.lock()
.expect("locked")
@ -100,25 +111,26 @@ impl Service {
}
pub fn update_sync_request_with_cache(
&self, user_id: OwnedUserId, device_id: OwnedDeviceId, request: &mut sync_events::v4::Request,
&self,
user_id: OwnedUserId,
device_id: OwnedDeviceId,
request: &mut sync_events::v4::Request,
) -> BTreeMap<String, BTreeMap<OwnedRoomId, u64>> {
let Some(conn_id) = request.conn_id.clone() else {
return BTreeMap::new();
};
let mut cache = self.connections.lock().expect("locked");
let cached = Arc::clone(
cache
.entry((user_id, device_id, conn_id))
.or_insert_with(|| {
Arc::new(Mutex::new(SlidingSyncCache {
lists: BTreeMap::new(),
subscriptions: BTreeMap::new(),
known_rooms: BTreeMap::new(),
extensions: ExtensionsConfig::default(),
}))
}),
);
let cached = Arc::clone(cache.entry((user_id, device_id, conn_id)).or_insert_with(
|| {
Arc::new(Mutex::new(SlidingSyncCache {
lists: BTreeMap::new(),
subscriptions: BTreeMap::new(),
known_rooms: BTreeMap::new(),
extensions: ExtensionsConfig::default(),
}))
},
));
let cached = &mut cached.lock().expect("locked");
drop(cache);
@ -141,13 +153,15 @@ impl Service {
.clone()
.or_else(|| cached_list.include_old_rooms.clone());
match (&mut list.filters, cached_list.filters.clone()) {
(Some(list_filters), Some(cached_filters)) => {
| (Some(list_filters), Some(cached_filters)) => {
list_filters.is_dm = list_filters.is_dm.or(cached_filters.is_dm);
if list_filters.spaces.is_empty() {
list_filters.spaces = cached_filters.spaces;
}
list_filters.is_encrypted = list_filters.is_encrypted.or(cached_filters.is_encrypted);
list_filters.is_invite = list_filters.is_invite.or(cached_filters.is_invite);
list_filters.is_encrypted =
list_filters.is_encrypted.or(cached_filters.is_encrypted);
list_filters.is_invite =
list_filters.is_invite.or(cached_filters.is_invite);
if list_filters.room_types.is_empty() {
list_filters.room_types = cached_filters.room_types;
}
@ -165,9 +179,9 @@ impl Service {
list_filters.not_tags = cached_filters.not_tags;
}
},
(_, Some(cached_filters)) => list.filters = Some(cached_filters),
(Some(list_filters), _) => list.filters = Some(list_filters.clone()),
(..) => {},
| (_, Some(cached_filters)) => list.filters = Some(cached_filters),
| (Some(list_filters), _) => list.filters = Some(list_filters.clone()),
| (..) => {},
}
if list.bump_event_types.is_empty() {
list.bump_event_types
@ -220,22 +234,23 @@ impl Service {
}
pub fn update_sync_subscriptions(
&self, user_id: OwnedUserId, device_id: OwnedDeviceId, conn_id: String,
&self,
user_id: OwnedUserId,
device_id: OwnedDeviceId,
conn_id: String,
subscriptions: BTreeMap<OwnedRoomId, sync_events::v4::RoomSubscription>,
) {
let mut cache = self.connections.lock().expect("locked");
let cached = Arc::clone(
cache
.entry((user_id, device_id, conn_id))
.or_insert_with(|| {
Arc::new(Mutex::new(SlidingSyncCache {
lists: BTreeMap::new(),
subscriptions: BTreeMap::new(),
known_rooms: BTreeMap::new(),
extensions: ExtensionsConfig::default(),
}))
}),
);
let cached = Arc::clone(cache.entry((user_id, device_id, conn_id)).or_insert_with(
|| {
Arc::new(Mutex::new(SlidingSyncCache {
lists: BTreeMap::new(),
subscriptions: BTreeMap::new(),
known_rooms: BTreeMap::new(),
extensions: ExtensionsConfig::default(),
}))
},
));
let cached = &mut cached.lock().expect("locked");
drop(cache);
@ -243,22 +258,25 @@ impl Service {
}
pub fn update_sync_known_rooms(
&self, user_id: OwnedUserId, device_id: OwnedDeviceId, conn_id: String, list_id: String,
new_cached_rooms: BTreeSet<OwnedRoomId>, globalsince: u64,
&self,
user_id: OwnedUserId,
device_id: OwnedDeviceId,
conn_id: String,
list_id: String,
new_cached_rooms: BTreeSet<OwnedRoomId>,
globalsince: u64,
) {
let mut cache = self.connections.lock().expect("locked");
let cached = Arc::clone(
cache
.entry((user_id, device_id, conn_id))
.or_insert_with(|| {
Arc::new(Mutex::new(SlidingSyncCache {
lists: BTreeMap::new(),
subscriptions: BTreeMap::new(),
known_rooms: BTreeMap::new(),
extensions: ExtensionsConfig::default(),
}))
}),
);
let cached = Arc::clone(cache.entry((user_id, device_id, conn_id)).or_insert_with(
|| {
Arc::new(Mutex::new(SlidingSyncCache {
lists: BTreeMap::new(),
subscriptions: BTreeMap::new(),
known_rooms: BTreeMap::new(),
extensions: ExtensionsConfig::default(),
}))
},
));
let cached = &mut cached.lock().expect("locked");
drop(cache);

View file

@ -25,7 +25,13 @@ impl crate::Service for Service {
}
#[implement(Service)]
pub fn add_txnid(&self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, data: &[u8]) {
pub fn add_txnid(
&self,
user_id: &UserId,
device_id: Option<&DeviceId>,
txn_id: &TransactionId,
data: &[u8],
) {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default());
@ -38,7 +44,10 @@ pub fn add_txnid(&self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id:
// If there's no entry, this is a new transaction
#[implement(Service)]
pub async fn existing_txnid(
&self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId,
&self,
user_id: &UserId,
device_id: Option<&DeviceId>,
txn_id: &TransactionId,
) -> Result<Handle<'_>> {
let key = (user_id, device_id, txn_id);
self.db.userdevicetxnid_response.qry(&key).await

View file

@ -58,7 +58,13 @@ impl crate::Service for Service {
/// Creates a new Uiaa session. Make sure the session token is unique.
#[implement(Service)]
pub fn create(&self, user_id: &UserId, device_id: &DeviceId, uiaainfo: &UiaaInfo, json_body: &CanonicalJsonValue) {
pub fn create(
&self,
user_id: &UserId,
device_id: &DeviceId,
uiaainfo: &UiaaInfo,
json_body: &CanonicalJsonValue,
) {
// TODO: better session error handling (why is uiaainfo.session optional in
// ruma?)
self.set_uiaa_request(
@ -78,7 +84,11 @@ pub fn create(&self, user_id: &UserId, device_id: &DeviceId, uiaainfo: &UiaaInfo
#[implement(Service)]
pub async fn try_auth(
&self, user_id: &UserId, device_id: &DeviceId, auth: &AuthData, uiaainfo: &UiaaInfo,
&self,
user_id: &UserId,
device_id: &DeviceId,
auth: &AuthData,
uiaainfo: &UiaaInfo,
) -> Result<(bool, UiaaInfo)> {
let mut uiaainfo = if let Some(session) = auth.session() {
self.get_uiaa_session(user_id, device_id, session).await?
@ -92,7 +102,7 @@ pub async fn try_auth(
match auth {
// Find out what the user completed
AuthData::Password(Password {
| AuthData::Password(Password {
identifier,
password,
#[cfg(feature = "element_hacks")]
@ -105,17 +115,26 @@ pub async fn try_auth(
} else if let Some(username) = user {
username
} else {
return Err(Error::BadRequest(ErrorKind::Unrecognized, "Identifier type not recognized."));
return Err(Error::BadRequest(
ErrorKind::Unrecognized,
"Identifier type not recognized.",
));
};
#[cfg(not(feature = "element_hacks"))]
let Some(UserIdentifier::UserIdOrLocalpart(username)) = identifier
else {
return Err(Error::BadRequest(ErrorKind::Unrecognized, "Identifier type not recognized."));
return Err(Error::BadRequest(
ErrorKind::Unrecognized,
"Identifier type not recognized.",
));
};
let user_id = UserId::parse_with_server_name(username.clone(), self.services.globals.server_name())
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "User ID is invalid."))?;
let user_id = UserId::parse_with_server_name(
username.clone(),
self.services.globals.server_name(),
)
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "User ID is invalid."))?;
// Check if password is correct
if let Ok(hash) = self.services.users.password_hash(&user_id).await {
@ -132,7 +151,7 @@ pub async fn try_auth(
// Password was correct! Let's add it to `completed`
uiaainfo.completed.push(AuthType::Password);
},
AuthData::RegistrationToken(t) => {
| AuthData::RegistrationToken(t) => {
if self
.services
.globals
@ -149,10 +168,10 @@ pub async fn try_auth(
return Ok((false, uiaainfo));
}
},
AuthData::Dummy(_) => {
| AuthData::Dummy(_) => {
uiaainfo.completed.push(AuthType::Dummy);
},
k => error!("type not supported: {:?}", k),
| k => error!("type not supported: {:?}", k),
}
// Check if a flow now succeeds
@ -190,7 +209,13 @@ pub async fn try_auth(
}
#[implement(Service)]
fn set_uiaa_request(&self, user_id: &UserId, device_id: &DeviceId, session: &str, request: &CanonicalJsonValue) {
fn set_uiaa_request(
&self,
user_id: &UserId,
device_id: &DeviceId,
session: &str,
request: &CanonicalJsonValue,
) {
let key = (user_id.to_owned(), device_id.to_owned(), session.to_owned());
self.userdevicesessionid_uiaarequest
.write()
@ -200,7 +225,10 @@ fn set_uiaa_request(&self, user_id: &UserId, device_id: &DeviceId, session: &str
#[implement(Service)]
pub fn get_uiaa_request(
&self, user_id: &UserId, device_id: Option<&DeviceId>, session: &str,
&self,
user_id: &UserId,
device_id: Option<&DeviceId>,
session: &str,
) -> Option<CanonicalJsonValue> {
let key = (
user_id.to_owned(),
@ -216,7 +244,13 @@ pub fn get_uiaa_request(
}
#[implement(Service)]
fn update_uiaa_session(&self, user_id: &UserId, device_id: &DeviceId, session: &str, uiaainfo: Option<&UiaaInfo>) {
fn update_uiaa_session(
&self,
user_id: &UserId,
device_id: &DeviceId,
session: &str,
uiaainfo: Option<&UiaaInfo>,
) {
let key = (user_id, device_id, session);
if let Some(uiaainfo) = uiaainfo {
@ -229,7 +263,12 @@ fn update_uiaa_session(&self, user_id: &UserId, device_id: &DeviceId, session: &
}
#[implement(Service)]
async fn get_uiaa_session(&self, user_id: &UserId, device_id: &DeviceId, session: &str) -> Result<UiaaInfo> {
async fn get_uiaa_session(
&self,
user_id: &UserId,
device_id: &DeviceId,
session: &str,
) -> Result<UiaaInfo> {
let key = (user_id, device_id, session);
self.db
.userdevicesessionid_uiaainfo

View file

@ -119,7 +119,8 @@ impl Service {
self.services
.admin
.send_message(RoomMessageEventContent::text_markdown(format!(
"### the following is a message from the conduwuit puppy\n\nit was sent on `{}`:\n\n@room: {}",
"### the following is a message from the conduwuit puppy\n\nit was sent on \
`{}`:\n\n@room: {}",
update.date, update.message
)))
.await
@ -127,7 +128,9 @@ impl Service {
}
#[inline]
pub fn update_check_for_updates_id(&self, id: u64) { self.db.raw_put(LAST_CHECK_FOR_UPDATES_COUNT, id); }
pub fn update_check_for_updates_id(&self, id: u64) {
self.db.raw_put(LAST_CHECK_FOR_UPDATES_COUNT, id);
}
pub async fn last_check_for_updates_id(&self) -> u64 {
self.db

View file

@ -10,10 +10,12 @@ use futures::{FutureExt, Stream, StreamExt, TryFutureExt};
use ruma::{
api::client::{device::Device, error::ErrorKind, filter::FilterDefinition},
encryption::{CrossSigningKey, DeviceKeys, OneTimeKey},
events::{ignored_user_list::IgnoredUserListEvent, AnyToDeviceEvent, GlobalAccountDataEventType},
events::{
ignored_user_list::IgnoredUserListEvent, AnyToDeviceEvent, GlobalAccountDataEventType,
},
serde::Raw,
DeviceId, KeyId, MilliSecondsSinceUnixEpoch, OneTimeKeyAlgorithm, OneTimeKeyId, OneTimeKeyName, OwnedDeviceId,
OwnedKeyId, OwnedMxcUri, OwnedUserId, RoomId, UInt, UserId,
DeviceId, KeyId, MilliSecondsSinceUnixEpoch, OneTimeKeyAlgorithm, OneTimeKeyId,
OneTimeKeyName, OwnedDeviceId, OwnedKeyId, OwnedMxcUri, OwnedUserId, RoomId, UInt, UserId,
};
use serde_json::json;
@ -65,7 +67,8 @@ impl crate::Service for Service {
account_data: args.depend::<account_data::Service>("account_data"),
admin: args.depend::<admin::Service>("admin"),
globals: args.depend::<globals::Service>("globals"),
state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
state_accessor: args
.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"),
},
db: Data {
@ -114,7 +117,9 @@ impl Service {
/// Check if a user is an admin
#[inline]
pub async fn is_admin(&self, user_id: &UserId) -> bool { self.services.admin.user_is_admin(user_id).await }
pub async fn is_admin(&self, user_id: &UserId) -> bool {
self.services.admin.user_is_admin(user_id).await
}
/// Create a new user account on this homeserver.
#[inline]
@ -141,7 +146,9 @@ impl Service {
/// Check if a user has an account on this homeserver.
#[inline]
pub async fn exists(&self, user_id: &UserId) -> bool { self.db.userid_password.get(user_id).await.is_ok() }
pub async fn exists(&self, user_id: &UserId) -> bool {
self.db.userid_password.get(user_id).await.is_ok()
}
/// Check if account is deactivated
pub async fn is_deactivated(&self, user_id: &UserId) -> Result<bool> {
@ -154,7 +161,9 @@ impl Service {
}
/// Check if account is active, infallible
pub async fn is_active(&self, user_id: &UserId) -> bool { !self.is_deactivated(user_id).await.unwrap_or(true) }
pub async fn is_active(&self, user_id: &UserId) -> bool {
!self.is_deactivated(user_id).await.unwrap_or(true)
}
/// Check if account is active, infallible
pub async fn is_active_local(&self, user_id: &UserId) -> bool {
@ -173,10 +182,14 @@ impl Service {
/// Returns an iterator over all users on this homeserver (offered for
/// compatibility)
#[allow(clippy::iter_without_into_iter, clippy::iter_not_returning_iterator)]
pub fn iter(&self) -> impl Stream<Item = OwnedUserId> + Send + '_ { self.stream().map(ToOwned::to_owned) }
pub fn iter(&self) -> impl Stream<Item = OwnedUserId> + Send + '_ {
self.stream().map(ToOwned::to_owned)
}
/// Returns an iterator over all users on this homeserver.
pub fn stream(&self) -> impl Stream<Item = &UserId> + Send { self.db.userid_password.keys().ignore_err() }
pub fn stream(&self) -> impl Stream<Item = &UserId> + Send {
self.db.userid_password.keys().ignore_err()
}
/// Returns a list of local users as list of usernames.
///
@ -200,7 +213,9 @@ impl Service {
password
.map(utils::hash::password)
.transpose()
.map_err(|e| err!(Request(InvalidParam("Password does not meet the requirements: {e}"))))?
.map_err(|e| {
err!(Request(InvalidParam("Password does not meet the requirements: {e}")))
})?
.map_or_else(
|| self.db.userid_password.insert(user_id, b""),
|hash| self.db.userid_password.insert(user_id, hash),
@ -254,13 +269,19 @@ impl Service {
/// Adds a new device to a user.
pub async fn create_device(
&self, user_id: &UserId, device_id: &DeviceId, token: &str, initial_device_display_name: Option<String>,
&self,
user_id: &UserId,
device_id: &DeviceId,
token: &str,
initial_device_display_name: Option<String>,
client_ip: Option<String>,
) -> Result<()> {
// This method should never be called for nonexistent users. We shouldn't assert
// though...
if !self.exists(user_id).await {
return Err!(Request(InvalidParam(error!("Called create_device for non-existent {user_id}"))));
return Err!(Request(InvalidParam(error!(
"Called create_device for non-existent {user_id}"
))));
}
let key = (user_id, device_id);
@ -304,7 +325,10 @@ impl Service {
}
/// Returns an iterator over all device ids of this user.
pub fn all_device_ids<'a>(&'a self, user_id: &'a UserId) -> impl Stream<Item = &DeviceId> + Send + 'a {
pub fn all_device_ids<'a>(
&'a self,
user_id: &'a UserId,
) -> impl Stream<Item = &DeviceId> + Send + 'a {
let prefix = (user_id, Interfix);
self.db
.userdeviceid_metadata
@ -319,7 +343,12 @@ impl Service {
}
/// Replaces the access token of one device.
pub async fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> {
pub async fn set_token(
&self,
user_id: &UserId,
device_id: &DeviceId,
token: &str,
) -> Result<()> {
let key = (user_id, device_id);
// should not be None, but we shouldn't assert either lol...
if self.db.userdeviceid_metadata.qry(&key).await.is_err() {
@ -344,7 +373,10 @@ impl Service {
}
pub async fn add_one_time_key(
&self, user_id: &UserId, device_id: &DeviceId, one_time_key_key: &KeyId<OneTimeKeyAlgorithm, OneTimeKeyName>,
&self,
user_id: &UserId,
device_id: &DeviceId,
one_time_key_key: &KeyId<OneTimeKeyAlgorithm, OneTimeKeyName>,
one_time_key_value: &Raw<OneTimeKey>,
) -> Result {
// All devices have metadata
@ -391,7 +423,10 @@ impl Service {
}
pub async fn take_one_time_key(
&self, user_id: &UserId, device_id: &DeviceId, key_algorithm: &OneTimeKeyAlgorithm,
&self,
user_id: &UserId,
device_id: &DeviceId,
key_algorithm: &OneTimeKeyAlgorithm,
) -> Result<(OwnedKeyId<OneTimeKeyAlgorithm, OneTimeKeyName>, Raw<OneTimeKey>)> {
let count = self.services.globals.next_count()?.to_be_bytes();
self.db.userid_lastonetimekeyupdate.insert(user_id, count);
@ -435,7 +470,9 @@ impl Service {
}
pub async fn count_one_time_keys(
&self, user_id: &UserId, device_id: &DeviceId,
&self,
user_id: &UserId,
device_id: &DeviceId,
) -> BTreeMap<OneTimeKeyAlgorithm, UInt> {
type KeyVal<'a> = ((Ignore, Ignore, &'a Unquoted), Ignore);
@ -462,7 +499,12 @@ impl Service {
algorithm_counts
}
pub async fn add_device_keys(&self, user_id: &UserId, device_id: &DeviceId, device_keys: &Raw<DeviceKeys>) {
pub async fn add_device_keys(
&self,
user_id: &UserId,
device_id: &DeviceId,
device_keys: &Raw<DeviceKeys>,
) {
let key = (user_id, device_id);
self.db.keyid_key.put(key, Json(device_keys));
@ -470,8 +512,12 @@ impl Service {
}
pub async fn add_cross_signing_keys(
&self, user_id: &UserId, master_key: &Raw<CrossSigningKey>, self_signing_key: &Option<Raw<CrossSigningKey>>,
user_signing_key: &Option<Raw<CrossSigningKey>>, notify: bool,
&self,
user_id: &UserId,
master_key: &Raw<CrossSigningKey>,
self_signing_key: &Option<Raw<CrossSigningKey>>,
user_signing_key: &Option<Raw<CrossSigningKey>>,
notify: bool,
) -> Result<()> {
// TODO: Check signatures
let mut prefix = user_id.as_bytes().to_vec();
@ -495,9 +541,10 @@ impl Service {
.keys
.into_values();
let self_signing_key_id = self_signing_key_ids
.next()
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Self signing key contained no key."))?;
let self_signing_key_id = self_signing_key_ids.next().ok_or(Error::BadRequest(
ErrorKind::InvalidParam,
"Self signing key contained no key.",
))?;
if self_signing_key_ids.next().is_some() {
return Err(Error::BadRequest(
@ -531,7 +578,9 @@ impl Service {
.ok_or(err!(Request(InvalidParam("User signing key contained no key."))))?;
if user_signing_key_ids.next().is_some() {
return Err!(Request(InvalidParam("User signing key contained more than one key.")));
return Err!(Request(InvalidParam(
"User signing key contained more than one key."
)));
}
let mut user_signing_key_key = prefix;
@ -554,7 +603,11 @@ impl Service {
}
pub async fn sign_key(
&self, target_id: &UserId, key_id: &str, signature: (String, String), sender_id: &UserId,
&self,
target_id: &UserId,
key_id: &str,
signature: (String, String),
sender_id: &UserId,
) -> Result<()> {
let key = (target_id, key_id);
@ -590,7 +643,10 @@ impl Service {
#[inline]
pub fn keys_changed<'a>(
&'a self, user_id: &'a UserId, from: u64, to: Option<u64>,
&'a self,
user_id: &'a UserId,
from: u64,
to: Option<u64>,
) -> impl Stream<Item = &UserId> + Send + 'a {
self.keys_changed_user_or_room(user_id.as_str(), from, to)
.map(|(user_id, ..)| user_id)
@ -598,13 +654,19 @@ impl Service {
#[inline]
pub fn room_keys_changed<'a>(
&'a self, room_id: &'a RoomId, from: u64, to: Option<u64>,
&'a self,
room_id: &'a RoomId,
from: u64,
to: Option<u64>,
) -> impl Stream<Item = (&UserId, u64)> + Send + 'a {
self.keys_changed_user_or_room(room_id.as_str(), from, to)
}
fn keys_changed_user_or_room<'a>(
&'a self, user_or_room_id: &'a str, from: u64, to: Option<u64>,
&'a self,
user_or_room_id: &'a str,
from: u64,
to: Option<u64>,
) -> impl Stream<Item = (&UserId, u64)> + Send + 'a {
type KeyVal<'a> = ((&'a str, u64), &'a UserId);
@ -614,7 +676,9 @@ impl Service {
.keychangeid_userid
.stream_from(&start)
.ignore_err()
.ready_take_while(move |((prefix, count), _): &KeyVal<'_>| *prefix == user_or_room_id && *count <= to)
.ready_take_while(move |((prefix, count), _): &KeyVal<'_>| {
*prefix == user_or_room_id && *count <= to
})
.map(|((_, count), user_id): KeyVal<'_>| (user_id, count))
}
@ -636,13 +700,21 @@ impl Service {
self.db.keychangeid_userid.put_raw(key, user_id);
}
pub async fn get_device_keys<'a>(&'a self, user_id: &'a UserId, device_id: &DeviceId) -> Result<Raw<DeviceKeys>> {
pub async fn get_device_keys<'a>(
&'a self,
user_id: &'a UserId,
device_id: &DeviceId,
) -> Result<Raw<DeviceKeys>> {
let key_id = (user_id, device_id);
self.db.keyid_key.qry(&key_id).await.deserialized()
}
pub async fn get_key<F>(
&self, key_id: &[u8], sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &F,
&self,
key_id: &[u8],
sender_user: Option<&UserId>,
user_id: &UserId,
allowed_signatures: &F,
) -> Result<Raw<CrossSigningKey>>
where
F: Fn(&UserId) -> bool + Send + Sync,
@ -655,7 +727,10 @@ impl Service {
}
pub async fn get_master_key<F>(
&self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &F,
&self,
sender_user: Option<&UserId>,
user_id: &UserId,
allowed_signatures: &F,
) -> Result<Raw<CrossSigningKey>>
where
F: Fn(&UserId) -> bool + Send + Sync,
@ -667,7 +742,10 @@ impl Service {
}
pub async fn get_self_signing_key<F>(
&self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &F,
&self,
sender_user: Option<&UserId>,
user_id: &UserId,
allowed_signatures: &F,
) -> Result<Raw<CrossSigningKey>>
where
F: Fn(&UserId) -> bool + Send + Sync,
@ -688,7 +766,11 @@ impl Service {
}
pub async fn add_to_device_event(
&self, sender: &UserId, target_user_id: &UserId, target_device_id: &DeviceId, event_type: &str,
&self,
sender: &UserId,
target_user_id: &UserId,
target_device_id: &DeviceId,
event_type: &str,
content: serde_json::Value,
) {
let count = self.services.globals.next_count().unwrap();
@ -705,7 +787,9 @@ impl Service {
}
pub fn get_to_device_events<'a>(
&'a self, user_id: &'a UserId, device_id: &'a DeviceId,
&'a self,
user_id: &'a UserId,
device_id: &'a DeviceId,
) -> impl Stream<Item = Raw<AnyToDeviceEvent>> + Send + 'a {
let prefix = (user_id, device_id, Interfix);
self.db
@ -715,7 +799,12 @@ impl Service {
.map(|(_, val): (Ignore, Raw<AnyToDeviceEvent>)| val)
}
pub async fn remove_to_device_events(&self, user_id: &UserId, device_id: &DeviceId, until: u64) {
pub async fn remove_to_device_events(
&self,
user_id: &UserId,
device_id: &DeviceId,
until: u64,
) {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(device_id.as_bytes());
@ -742,7 +831,12 @@ impl Service {
.await;
}
pub async fn update_device_metadata(&self, user_id: &UserId, device_id: &DeviceId, device: &Device) -> Result<()> {
pub async fn update_device_metadata(
&self,
user_id: &UserId,
device_id: &DeviceId,
device: &Device,
) -> Result<()> {
increment(&self.db.userid_devicelistversion, user_id.as_bytes());
let key = (user_id, device_id);
@ -752,7 +846,11 @@ impl Service {
}
/// Get device metadata.
pub async fn get_device_metadata(&self, user_id: &UserId, device_id: &DeviceId) -> Result<Device> {
pub async fn get_device_metadata(
&self,
user_id: &UserId,
device_id: &DeviceId,
) -> Result<Device> {
self.db
.userdeviceid_metadata
.qry(&(user_id, device_id))
@ -768,7 +866,10 @@ impl Service {
.deserialized()
}
pub fn all_devices_metadata<'a>(&'a self, user_id: &'a UserId) -> impl Stream<Item = Device> + Send + 'a {
pub fn all_devices_metadata<'a>(
&'a self,
user_id: &'a UserId,
) -> impl Stream<Item = Device> + Send + 'a {
let key = (user_id, Interfix);
self.db
.userdeviceid_metadata
@ -787,7 +888,11 @@ impl Service {
filter_id
}
pub async fn get_filter(&self, user_id: &UserId, filter_id: &str) -> Result<FilterDefinition> {
pub async fn get_filter(
&self,
user_id: &UserId,
filter_id: &str,
) -> Result<FilterDefinition> {
let key = (user_id, filter_id);
self.db.userfilterid_filter.qry(&key).await.deserialized()
}
@ -817,11 +922,10 @@ impl Service {
};
let (expires_at_bytes, user_bytes) = value.split_at(0_u64.to_be_bytes().len());
let expires_at = u64::from_be_bytes(
expires_at_bytes
.try_into()
.map_err(|e| err!(Database("expires_at in openid_userid is invalid u64. {e}")))?,
);
let expires_at =
u64::from_be_bytes(expires_at_bytes.try_into().map_err(|e| {
err!(Database("expires_at in openid_userid is invalid u64. {e}"))
})?);
if expires_at < utils::millis_since_unix_epoch() {
debug_warn!("OpenID token is expired, removing");
@ -833,11 +937,16 @@ impl Service {
let user_string = utils::string_from_bytes(user_bytes)
.map_err(|e| err!(Database("User ID in openid_userid is invalid unicode. {e}")))?;
UserId::parse(user_string).map_err(|e| err!(Database("User ID in openid_userid is invalid. {e}")))
UserId::parse(user_string)
.map_err(|e| err!(Database("User ID in openid_userid is invalid. {e}")))
}
/// Gets a specific user profile key
pub async fn profile_key(&self, user_id: &UserId, profile_key: &str) -> Result<serde_json::Value> {
pub async fn profile_key(
&self,
user_id: &UserId,
profile_key: &str,
) -> Result<serde_json::Value> {
let key = (user_id, profile_key);
self.db
.useridprofilekey_value
@ -848,7 +957,8 @@ impl Service {
/// Gets all the user's profile keys and values in an iterator
pub fn all_profile_keys<'a>(
&'a self, user_id: &'a UserId,
&'a self,
user_id: &'a UserId,
) -> impl Stream<Item = (String, serde_json::Value)> + 'a + Send {
type KeyVal = ((Ignore, String), serde_json::Value);
@ -861,7 +971,12 @@ impl Service {
}
/// Sets a new profile key value, removes the key if value is None
pub fn set_profile_key(&self, user_id: &UserId, profile_key: &str, profile_key_value: Option<serde_json::Value>) {
pub fn set_profile_key(
&self,
user_id: &UserId,
profile_key: &str,
profile_key_value: Option<serde_json::Value>,
) {
// TODO: insert to the stable MSC4175 key when it's stable
let key = (user_id, profile_key);
@ -901,7 +1016,10 @@ impl Service {
}
}
pub fn parse_master_key(user_id: &UserId, master_key: &Raw<CrossSigningKey>) -> Result<(Vec<u8>, CrossSigningKey)> {
pub fn parse_master_key(
user_id: &UserId,
master_key: &Raw<CrossSigningKey>,
) -> Result<(Vec<u8>, CrossSigningKey)> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
@ -925,7 +1043,10 @@ pub fn parse_master_key(user_id: &UserId, master_key: &Raw<CrossSigningKey>) ->
/// Ensure that a user only sees signatures from themselves and the target user
fn clean_signatures<F>(
mut cross_signing_key: serde_json::Value, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &F,
mut cross_signing_key: serde_json::Value,
sender_user: Option<&UserId>,
user_id: &UserId,
allowed_signatures: &F,
) -> Result<serde_json::Value>
where
F: Fn(&UserId) -> bool + Send + Sync,
@ -937,9 +1058,11 @@ where
// Don't allocate for the full size of the current signatures, but require
// at most one resize if nothing is dropped
let new_capacity = signatures.len() / 2;
for (user, signature) in mem::replace(signatures, serde_json::Map::with_capacity(new_capacity)) {
let sid =
<&UserId>::try_from(user.as_str()).map_err(|_| Error::bad_database("Invalid user ID in database."))?;
for (user, signature) in
mem::replace(signatures, serde_json::Map::with_capacity(new_capacity))
{
let sid = <&UserId>::try_from(user.as_str())
.map_err(|_| Error::bad_database("Invalid user ID in database."))?;
if sender_user == Some(user_id) || sid == user_id || allowed_signatures(sid) {
signatures.insert(user, signature);
}