mirror of
https://forgejo.ellis.link/continuwuation/continuwuity.git
synced 2025-09-11 00:52:49 +02:00
apply new rustfmt.toml changes, fix some clippy lints
Signed-off-by: strawberry <strawberry@puppygock.gay>
This commit is contained in:
parent
0317cc8cc5
commit
77e0b76408
296 changed files with 7147 additions and 4300 deletions
|
@ -9,8 +9,8 @@ use database::{Deserialized, Handle, Interfix, Json, Map};
|
|||
use futures::{Stream, StreamExt, TryFutureExt};
|
||||
use ruma::{
|
||||
events::{
|
||||
AnyGlobalAccountDataEvent, AnyRawAccountDataEvent, AnyRoomAccountDataEvent, GlobalAccountDataEventType,
|
||||
RoomAccountDataEventType,
|
||||
AnyGlobalAccountDataEvent, AnyRawAccountDataEvent, AnyRoomAccountDataEvent,
|
||||
GlobalAccountDataEventType, RoomAccountDataEventType,
|
||||
},
|
||||
serde::Raw,
|
||||
RoomId, UserId,
|
||||
|
@ -54,7 +54,11 @@ impl crate::Service for Service {
|
|||
#[allow(clippy::needless_pass_by_value)]
|
||||
#[implement(Service)]
|
||||
pub async fn update(
|
||||
&self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType, data: &serde_json::Value,
|
||||
&self,
|
||||
room_id: Option<&RoomId>,
|
||||
user_id: &UserId,
|
||||
event_type: RoomAccountDataEventType,
|
||||
data: &serde_json::Value,
|
||||
) -> Result<()> {
|
||||
if data.get("type").is_none() || data.get("content").is_none() {
|
||||
return Err!(Request(InvalidParam("Account data doesn't have all required fields.")));
|
||||
|
@ -91,7 +95,12 @@ where
|
|||
|
||||
/// Searches the global account data for a specific kind.
|
||||
#[implement(Service)]
|
||||
pub async fn get_room<T>(&self, room_id: &RoomId, user_id: &UserId, kind: RoomAccountDataEventType) -> Result<T>
|
||||
pub async fn get_room<T>(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
user_id: &UserId,
|
||||
kind: RoomAccountDataEventType,
|
||||
) -> Result<T>
|
||||
where
|
||||
T: for<'de> Deserialize<'de>,
|
||||
{
|
||||
|
@ -101,7 +110,12 @@ where
|
|||
}
|
||||
|
||||
#[implement(Service)]
|
||||
pub async fn get_raw(&self, room_id: Option<&RoomId>, user_id: &UserId, kind: &str) -> Result<Handle<'_>> {
|
||||
pub async fn get_raw(
|
||||
&self,
|
||||
room_id: Option<&RoomId>,
|
||||
user_id: &UserId,
|
||||
kind: &str,
|
||||
) -> Result<Handle<'_>> {
|
||||
let key = (room_id, user_id, kind.to_owned());
|
||||
self.db
|
||||
.roomusertype_roomuserdataid
|
||||
|
@ -113,7 +127,10 @@ pub async fn get_raw(&self, room_id: Option<&RoomId>, user_id: &UserId, kind: &s
|
|||
/// Returns all changes to the account data that happened after `since`.
|
||||
#[implement(Service)]
|
||||
pub fn changes_since<'a>(
|
||||
&'a self, room_id: Option<&'a RoomId>, user_id: &'a UserId, since: u64,
|
||||
&'a self,
|
||||
room_id: Option<&'a RoomId>,
|
||||
user_id: &'a UserId,
|
||||
since: u64,
|
||||
) -> impl Stream<Item = AnyRawAccountDataEvent> + Send + 'a {
|
||||
let prefix = (room_id, user_id, Interfix);
|
||||
let prefix = database::serialize_key(prefix).expect("failed to serialize prefix");
|
||||
|
@ -128,8 +145,10 @@ pub fn changes_since<'a>(
|
|||
.ready_take_while(move |(k, _)| k.starts_with(&prefix))
|
||||
.map(move |(_, v)| {
|
||||
match room_id {
|
||||
Some(_) => serde_json::from_slice::<Raw<AnyRoomAccountDataEvent>>(v).map(AnyRawAccountDataEvent::Room),
|
||||
None => serde_json::from_slice::<Raw<AnyGlobalAccountDataEvent>>(v).map(AnyRawAccountDataEvent::Global),
|
||||
| Some(_) => serde_json::from_slice::<Raw<AnyRoomAccountDataEvent>>(v)
|
||||
.map(AnyRawAccountDataEvent::Room),
|
||||
| None => serde_json::from_slice::<Raw<AnyGlobalAccountDataEvent>>(v)
|
||||
.map(AnyRawAccountDataEvent::Global),
|
||||
}
|
||||
.map_err(|e| err!(Database("Database contains invalid account data: {e}")))
|
||||
.log_err()
|
||||
|
|
|
@ -94,15 +94,16 @@ impl Console {
|
|||
debug!("session starting");
|
||||
while self.server.running() {
|
||||
match self.readline().await {
|
||||
Ok(event) => match event {
|
||||
ReadlineEvent::Line(string) => self.clone().handle(string).await,
|
||||
ReadlineEvent::Interrupted => continue,
|
||||
ReadlineEvent::Eof => break,
|
||||
ReadlineEvent::Quit => self.server.shutdown().unwrap_or_else(error::default_log),
|
||||
| Ok(event) => match event {
|
||||
| ReadlineEvent::Line(string) => self.clone().handle(string).await,
|
||||
| ReadlineEvent::Interrupted => continue,
|
||||
| ReadlineEvent::Eof => break,
|
||||
| ReadlineEvent::Quit =>
|
||||
self.server.shutdown().unwrap_or_else(error::default_log),
|
||||
},
|
||||
Err(error) => match error {
|
||||
ReadlineError::Closed => break,
|
||||
ReadlineError::IO(error) => {
|
||||
| Err(error) => match error {
|
||||
| ReadlineError::Closed => break,
|
||||
| ReadlineError::IO(error) => {
|
||||
error!("console I/O: {error:?}");
|
||||
break;
|
||||
},
|
||||
|
@ -158,9 +159,9 @@ impl Console {
|
|||
|
||||
async fn process(self: Arc<Self>, line: String) {
|
||||
match self.admin.command_in_place(line, None).await {
|
||||
Ok(Some(ref content)) => self.output(content),
|
||||
Err(ref content) => self.output_err(content),
|
||||
_ => unreachable!(),
|
||||
| Ok(Some(ref content)) => self.output(content),
|
||||
| Err(ref content) => self.output_err(content),
|
||||
| _ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -42,8 +42,9 @@ pub async fn create_admin_room(services: &Services) -> Result<()> {
|
|||
let create_content = {
|
||||
use RoomVersionId::*;
|
||||
match room_version {
|
||||
V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => RoomCreateEventContent::new_v1(server_user.clone()),
|
||||
_ => RoomCreateEventContent::new_v11(),
|
||||
| V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 =>
|
||||
RoomCreateEventContent::new_v1(server_user.clone()),
|
||||
| _ => RoomCreateEventContent::new_v11(),
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -52,15 +53,12 @@ pub async fn create_admin_room(services: &Services) -> Result<()> {
|
|||
.rooms
|
||||
.timeline
|
||||
.build_and_append_pdu(
|
||||
PduBuilder::state(
|
||||
String::new(),
|
||||
&RoomCreateEventContent {
|
||||
federate: true,
|
||||
predecessor: None,
|
||||
room_version: room_version.clone(),
|
||||
..create_content
|
||||
},
|
||||
),
|
||||
PduBuilder::state(String::new(), &RoomCreateEventContent {
|
||||
federate: true,
|
||||
predecessor: None,
|
||||
room_version: room_version.clone(),
|
||||
..create_content
|
||||
}),
|
||||
server_user,
|
||||
&room_id,
|
||||
&state_lock,
|
||||
|
@ -72,7 +70,10 @@ pub async fn create_admin_room(services: &Services) -> Result<()> {
|
|||
.rooms
|
||||
.timeline
|
||||
.build_and_append_pdu(
|
||||
PduBuilder::state(server_user.to_string(), &RoomMemberEventContent::new(MembershipState::Join)),
|
||||
PduBuilder::state(
|
||||
server_user.to_string(),
|
||||
&RoomMemberEventContent::new(MembershipState::Join),
|
||||
),
|
||||
server_user,
|
||||
&room_id,
|
||||
&state_lock,
|
||||
|
@ -86,13 +87,10 @@ pub async fn create_admin_room(services: &Services) -> Result<()> {
|
|||
.rooms
|
||||
.timeline
|
||||
.build_and_append_pdu(
|
||||
PduBuilder::state(
|
||||
String::new(),
|
||||
&RoomPowerLevelsEventContent {
|
||||
users,
|
||||
..Default::default()
|
||||
},
|
||||
),
|
||||
PduBuilder::state(String::new(), &RoomPowerLevelsEventContent {
|
||||
users,
|
||||
..Default::default()
|
||||
}),
|
||||
server_user,
|
||||
&room_id,
|
||||
&state_lock,
|
||||
|
@ -131,7 +129,10 @@ pub async fn create_admin_room(services: &Services) -> Result<()> {
|
|||
.rooms
|
||||
.timeline
|
||||
.build_and_append_pdu(
|
||||
PduBuilder::state(String::new(), &RoomGuestAccessEventContent::new(GuestAccess::Forbidden)),
|
||||
PduBuilder::state(
|
||||
String::new(),
|
||||
&RoomGuestAccessEventContent::new(GuestAccess::Forbidden),
|
||||
),
|
||||
server_user,
|
||||
&room_id,
|
||||
&state_lock,
|
||||
|
@ -155,12 +156,9 @@ pub async fn create_admin_room(services: &Services) -> Result<()> {
|
|||
.rooms
|
||||
.timeline
|
||||
.build_and_append_pdu(
|
||||
PduBuilder::state(
|
||||
String::new(),
|
||||
&RoomTopicEventContent {
|
||||
topic: format!("Manage {}", services.globals.server_name()),
|
||||
},
|
||||
),
|
||||
PduBuilder::state(String::new(), &RoomTopicEventContent {
|
||||
topic: format!("Manage {}", services.globals.server_name()),
|
||||
}),
|
||||
server_user,
|
||||
&room_id,
|
||||
&state_lock,
|
||||
|
@ -174,13 +172,10 @@ pub async fn create_admin_room(services: &Services) -> Result<()> {
|
|||
.rooms
|
||||
.timeline
|
||||
.build_and_append_pdu(
|
||||
PduBuilder::state(
|
||||
String::new(),
|
||||
&RoomCanonicalAliasEventContent {
|
||||
alias: Some(alias.clone()),
|
||||
alt_aliases: Vec::new(),
|
||||
},
|
||||
),
|
||||
PduBuilder::state(String::new(), &RoomCanonicalAliasEventContent {
|
||||
alias: Some(alias.clone()),
|
||||
alt_aliases: Vec::new(),
|
||||
}),
|
||||
server_user,
|
||||
&room_id,
|
||||
&state_lock,
|
||||
|
@ -197,12 +192,7 @@ pub async fn create_admin_room(services: &Services) -> Result<()> {
|
|||
.rooms
|
||||
.timeline
|
||||
.build_and_append_pdu(
|
||||
PduBuilder::state(
|
||||
String::new(),
|
||||
&RoomPreviewUrlsEventContent {
|
||||
disabled: true,
|
||||
},
|
||||
),
|
||||
PduBuilder::state(String::new(), &RoomPreviewUrlsEventContent { disabled: true }),
|
||||
server_user,
|
||||
&room_id,
|
||||
&state_lock,
|
||||
|
|
|
@ -34,7 +34,10 @@ pub async fn make_user_admin(&self, user_id: &UserId) -> Result<()> {
|
|||
self.services
|
||||
.timeline
|
||||
.build_and_append_pdu(
|
||||
PduBuilder::state(user_id.to_string(), &RoomMemberEventContent::new(MembershipState::Invite)),
|
||||
PduBuilder::state(
|
||||
user_id.to_string(),
|
||||
&RoomMemberEventContent::new(MembershipState::Invite),
|
||||
),
|
||||
server_user,
|
||||
&room_id,
|
||||
&state_lock,
|
||||
|
@ -43,7 +46,10 @@ pub async fn make_user_admin(&self, user_id: &UserId) -> Result<()> {
|
|||
self.services
|
||||
.timeline
|
||||
.build_and_append_pdu(
|
||||
PduBuilder::state(user_id.to_string(), &RoomMemberEventContent::new(MembershipState::Join)),
|
||||
PduBuilder::state(
|
||||
user_id.to_string(),
|
||||
&RoomMemberEventContent::new(MembershipState::Join),
|
||||
),
|
||||
user_id,
|
||||
&room_id,
|
||||
&state_lock,
|
||||
|
@ -51,18 +57,18 @@ pub async fn make_user_admin(&self, user_id: &UserId) -> Result<()> {
|
|||
.await?;
|
||||
|
||||
// Set power level
|
||||
let users = BTreeMap::from_iter([(server_user.clone(), 100.into()), (user_id.to_owned(), 100.into())]);
|
||||
let users = BTreeMap::from_iter([
|
||||
(server_user.clone(), 100.into()),
|
||||
(user_id.to_owned(), 100.into()),
|
||||
]);
|
||||
|
||||
self.services
|
||||
.timeline
|
||||
.build_and_append_pdu(
|
||||
PduBuilder::state(
|
||||
String::new(),
|
||||
&RoomPowerLevelsEventContent {
|
||||
users,
|
||||
..Default::default()
|
||||
},
|
||||
),
|
||||
PduBuilder::state(String::new(), &RoomPowerLevelsEventContent {
|
||||
users,
|
||||
..Default::default()
|
||||
}),
|
||||
server_user,
|
||||
&room_id,
|
||||
&state_lock,
|
||||
|
@ -103,9 +109,7 @@ async fn set_room_tag(&self, room_id: &RoomId, user_id: &UserId, tag: &str) -> R
|
|||
.get_room(room_id, user_id, RoomAccountDataEventType::Tag)
|
||||
.await
|
||||
.unwrap_or_else(|_| TagEvent {
|
||||
content: TagEventContent {
|
||||
tags: BTreeMap::new(),
|
||||
},
|
||||
content: TagEventContent { tags: BTreeMap::new() },
|
||||
});
|
||||
|
||||
event
|
||||
|
|
|
@ -10,7 +10,9 @@ use std::{
|
|||
};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use conduwuit::{debug, err, error, error::default_log, pdu::PduBuilder, Error, PduEvent, Result, Server};
|
||||
use conduwuit::{
|
||||
debug, err, error, error::default_log, pdu::PduBuilder, Error, PduEvent, Result, Server,
|
||||
};
|
||||
pub use create::create_admin_room;
|
||||
use futures::{FutureExt, TryFutureExt};
|
||||
use loole::{Receiver, Sender};
|
||||
|
@ -158,21 +160,19 @@ impl Service {
|
|||
/// the queue is full.
|
||||
pub fn command(&self, command: String, reply_id: Option<OwnedEventId>) -> Result<()> {
|
||||
self.sender
|
||||
.send(CommandInput {
|
||||
command,
|
||||
reply_id,
|
||||
})
|
||||
.send(CommandInput { command, reply_id })
|
||||
.map_err(|e| err!("Failed to enqueue admin command: {e:?}"))
|
||||
}
|
||||
|
||||
/// Dispatches a comamnd to the processor on the current task and waits for
|
||||
/// completion.
|
||||
pub async fn command_in_place(&self, command: String, reply_id: Option<OwnedEventId>) -> ProcessorResult {
|
||||
self.process_command(CommandInput {
|
||||
command,
|
||||
reply_id,
|
||||
})
|
||||
.await
|
||||
pub async fn command_in_place(
|
||||
&self,
|
||||
command: String,
|
||||
reply_id: Option<OwnedEventId>,
|
||||
) -> ProcessorResult {
|
||||
self.process_command(CommandInput { command, reply_id })
|
||||
.await
|
||||
}
|
||||
|
||||
/// Invokes the tab-completer to complete the command. When unavailable,
|
||||
|
@ -191,8 +191,8 @@ impl Service {
|
|||
|
||||
async fn handle_command(&self, command: CommandInput) {
|
||||
match self.process_command(command).await {
|
||||
Ok(None) => debug!("Command successful with no response"),
|
||||
Ok(Some(output)) | Err(output) => self
|
||||
| Ok(None) => debug!("Command successful with no response"),
|
||||
| Ok(Some(output)) | Err(output) => self
|
||||
.handle_response(output)
|
||||
.await
|
||||
.unwrap_or_else(default_log),
|
||||
|
@ -250,10 +250,7 @@ impl Service {
|
|||
}
|
||||
|
||||
async fn handle_response(&self, content: RoomMessageEventContent) -> Result<()> {
|
||||
let Some(Relation::Reply {
|
||||
in_reply_to,
|
||||
}) = content.relates_to.as_ref()
|
||||
else {
|
||||
let Some(Relation::Reply { in_reply_to }) = content.relates_to.as_ref() else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
|
@ -277,7 +274,10 @@ impl Service {
|
|||
}
|
||||
|
||||
async fn respond_to_room(
|
||||
&self, content: RoomMessageEventContent, room_id: &RoomId, user_id: &UserId,
|
||||
&self,
|
||||
content: RoomMessageEventContent,
|
||||
room_id: &RoomId,
|
||||
user_id: &UserId,
|
||||
) -> Result<()> {
|
||||
assert!(self.user_is_admin(user_id).await, "sender is not admin");
|
||||
|
||||
|
@ -298,12 +298,16 @@ impl Service {
|
|||
}
|
||||
|
||||
async fn handle_response_error(
|
||||
&self, e: Error, room_id: &RoomId, user_id: &UserId, state_lock: &RoomMutexGuard,
|
||||
&self,
|
||||
e: Error,
|
||||
room_id: &RoomId,
|
||||
user_id: &UserId,
|
||||
state_lock: &RoomMutexGuard,
|
||||
) -> Result<()> {
|
||||
error!("Failed to build and append admin room response PDU: \"{e}\"");
|
||||
let content = RoomMessageEventContent::text_plain(format!(
|
||||
"Failed to build and append admin room PDU: \"{e}\"\n\nThe original admin command may have finished \
|
||||
successfully, but we could not return the output."
|
||||
"Failed to build and append admin room PDU: \"{e}\"\n\nThe original admin command \
|
||||
may have finished successfully, but we could not return the output."
|
||||
));
|
||||
|
||||
self.services
|
||||
|
@ -321,7 +325,8 @@ impl Service {
|
|||
|
||||
// Admin command with public echo (in admin room)
|
||||
let server_user = &self.services.globals.server_user;
|
||||
let is_public_prefix = body.starts_with("!admin") || body.starts_with(server_user.as_str());
|
||||
let is_public_prefix =
|
||||
body.starts_with("!admin") || body.starts_with(server_user.as_str());
|
||||
|
||||
// Expected backward branch
|
||||
if !is_public_escape && !is_public_prefix {
|
||||
|
|
|
@ -65,9 +65,9 @@ async fn startup_execute_command(&self, i: usize, command: String) -> Result<()>
|
|||
debug!("Startup command #{i}: executing {command:?}");
|
||||
|
||||
match self.command_in_place(command, None).await {
|
||||
Ok(Some(output)) => Self::startup_command_output(i, &output),
|
||||
Err(output) => Self::startup_command_error(i, &output),
|
||||
Ok(None) => {
|
||||
| Ok(Some(output)) => Self::startup_command_output(i, &output),
|
||||
| Err(output) => Self::startup_command_error(i, &output),
|
||||
| Ok(None) => {
|
||||
info!("Startup command #{i} completed (no output).");
|
||||
Ok(())
|
||||
},
|
||||
|
|
|
@ -61,7 +61,11 @@ impl crate::Service for Service {
|
|||
|
||||
impl Service {
|
||||
/// Registers an appservice and returns the ID to the caller
|
||||
pub async fn register_appservice(&self, registration: &Registration, appservice_config_body: &str) -> Result {
|
||||
pub async fn register_appservice(
|
||||
&self,
|
||||
registration: &Registration,
|
||||
appservice_config_body: &str,
|
||||
) -> Result {
|
||||
//TODO: Check for collisions between exclusive appservice namespaces
|
||||
self.registration_info
|
||||
.write()
|
||||
|
@ -152,7 +156,10 @@ impl Service {
|
|||
.any(|info| info.rooms.is_exclusive_match(room_id.as_str()))
|
||||
}
|
||||
|
||||
pub fn read(&self) -> impl Future<Output = tokio::sync::RwLockReadGuard<'_, BTreeMap<String, RegistrationInfo>>> {
|
||||
pub fn read(
|
||||
&self,
|
||||
) -> impl Future<Output = tokio::sync::RwLockReadGuard<'_, BTreeMap<String, RegistrationInfo>>>
|
||||
{
|
||||
self.registration_info.read()
|
||||
}
|
||||
|
||||
|
|
|
@ -15,13 +15,15 @@ pub struct RegistrationInfo {
|
|||
impl RegistrationInfo {
|
||||
#[must_use]
|
||||
pub fn is_user_match(&self, user_id: &UserId) -> bool {
|
||||
self.users.is_match(user_id.as_str()) || self.registration.sender_localpart == user_id.localpart()
|
||||
self.users.is_match(user_id.as_str())
|
||||
|| self.registration.sender_localpart == user_id.localpart()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
#[must_use]
|
||||
pub fn is_exclusive_user_match(&self, user_id: &UserId) -> bool {
|
||||
self.users.is_exclusive_match(user_id.as_str()) || self.registration.sender_localpart == user_id.localpart()
|
||||
self.users.is_exclusive_match(user_id.as_str())
|
||||
|| self.registration.sender_localpart == user_id.localpart()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -42,7 +42,9 @@ impl crate::Service for Service {
|
|||
.build()?,
|
||||
|
||||
url_preview: base(config)
|
||||
.and_then(|builder| builder_interface(builder, url_preview_bind_iface.as_deref()))?
|
||||
.and_then(|builder| {
|
||||
builder_interface(builder, url_preview_bind_iface.as_deref())
|
||||
})?
|
||||
.local_address(url_preview_bind_addr)
|
||||
.dns_resolver(resolver.resolver.clone())
|
||||
.redirect(redirect::Policy::limited(3))
|
||||
|
@ -178,7 +180,10 @@ fn base(config: &Config) -> Result<reqwest::ClientBuilder> {
|
|||
}
|
||||
|
||||
#[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
|
||||
fn builder_interface(builder: reqwest::ClientBuilder, config: Option<&str>) -> Result<reqwest::ClientBuilder> {
|
||||
fn builder_interface(
|
||||
builder: reqwest::ClientBuilder,
|
||||
config: Option<&str>,
|
||||
) -> Result<reqwest::ClientBuilder> {
|
||||
if let Some(iface) = config {
|
||||
Ok(builder.interface(iface))
|
||||
} else {
|
||||
|
@ -187,7 +192,10 @@ fn builder_interface(builder: reqwest::ClientBuilder, config: Option<&str>) -> R
|
|||
}
|
||||
|
||||
#[cfg(not(any(target_os = "android", target_os = "fuchsia", target_os = "linux")))]
|
||||
fn builder_interface(builder: reqwest::ClientBuilder, config: Option<&str>) -> Result<reqwest::ClientBuilder> {
|
||||
fn builder_interface(
|
||||
builder: reqwest::ClientBuilder,
|
||||
config: Option<&str>,
|
||||
) -> Result<reqwest::ClientBuilder> {
|
||||
use conduwuit::Err;
|
||||
|
||||
if let Some(iface) = config {
|
||||
|
|
|
@ -3,7 +3,9 @@ use std::sync::Arc;
|
|||
use async_trait::async_trait;
|
||||
use conduwuit::{error, warn, Result};
|
||||
use ruma::{
|
||||
events::{push_rules::PushRulesEventContent, GlobalAccountDataEvent, GlobalAccountDataEventType},
|
||||
events::{
|
||||
push_rules::PushRulesEventContent, GlobalAccountDataEvent, GlobalAccountDataEventType,
|
||||
},
|
||||
push::Ruleset,
|
||||
};
|
||||
|
||||
|
@ -31,16 +33,14 @@ impl crate::Service for Service {
|
|||
}))
|
||||
}
|
||||
|
||||
async fn worker(self: Arc<Self>) -> Result<()> {
|
||||
async fn worker(self: Arc<Self>) -> Result {
|
||||
if self.services.globals.is_read_only() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
self.set_emergency_access()
|
||||
.await
|
||||
.inspect_err(|e| error!("Could not set the configured emergency password for the server user: {e}"))?;
|
||||
|
||||
Ok(())
|
||||
self.set_emergency_access().await.inspect_err(|e| {
|
||||
error!("Could not set the configured emergency password for the server user: {e}");
|
||||
})
|
||||
}
|
||||
|
||||
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
|
||||
|
@ -49,7 +49,7 @@ impl crate::Service for Service {
|
|||
impl Service {
|
||||
/// Sets the emergency password and push rules for the server user account
|
||||
/// in case emergency password is set
|
||||
async fn set_emergency_access(&self) -> Result<bool> {
|
||||
async fn set_emergency_access(&self) -> Result {
|
||||
let server_user = &self.services.globals.server_user;
|
||||
|
||||
self.services
|
||||
|
@ -57,8 +57,8 @@ impl Service {
|
|||
.set_password(server_user, self.services.globals.emergency_password().as_deref())?;
|
||||
|
||||
let (ruleset, pwd_set) = match self.services.globals.emergency_password() {
|
||||
Some(_) => (Ruleset::server_default(server_user), true),
|
||||
None => (Ruleset::new(), false),
|
||||
| Some(_) => (Ruleset::server_default(server_user), true),
|
||||
| None => (Ruleset::new(), false),
|
||||
};
|
||||
|
||||
self.services
|
||||
|
@ -68,9 +68,7 @@ impl Service {
|
|||
server_user,
|
||||
GlobalAccountDataEventType::PushRules.to_string().into(),
|
||||
&serde_json::to_value(&GlobalAccountDataEvent {
|
||||
content: PushRulesEventContent {
|
||||
global: ruleset,
|
||||
},
|
||||
content: PushRulesEventContent { global: ruleset },
|
||||
})
|
||||
.expect("to json value always works"),
|
||||
)
|
||||
|
@ -78,14 +76,14 @@ impl Service {
|
|||
|
||||
if pwd_set {
|
||||
warn!(
|
||||
"The server account emergency password is set! Please unset it as soon as you finish admin account \
|
||||
recovery! You will be logged out of the server service account when you finish."
|
||||
"The server account emergency password is set! Please unset it as soon as you \
|
||||
finish admin account recovery! You will be logged out of the server service \
|
||||
account when you finish."
|
||||
);
|
||||
Ok(())
|
||||
} else {
|
||||
// logs out any users still in the server service account and removes sessions
|
||||
self.services.users.deactivate_account(server_user).await?;
|
||||
self.services.users.deactivate_account(server_user).await
|
||||
}
|
||||
|
||||
Ok(pwd_set)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,7 +16,9 @@ impl Data {
|
|||
let db = &args.db;
|
||||
Self {
|
||||
global: db["global"].clone(),
|
||||
counter: RwLock::new(Self::stored_count(&db["global"]).expect("initialized global counter")),
|
||||
counter: RwLock::new(
|
||||
Self::stored_count(&db["global"]).expect("initialized global counter"),
|
||||
),
|
||||
db: args.db.clone(),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,7 +10,9 @@ use std::{
|
|||
use conduwuit::{error, Config, Result};
|
||||
use data::Data;
|
||||
use regex::RegexSet;
|
||||
use ruma::{OwnedEventId, OwnedRoomAliasId, OwnedServerName, OwnedUserId, RoomAliasId, ServerName, UserId};
|
||||
use ruma::{
|
||||
OwnedEventId, OwnedRoomAliasId, OwnedServerName, OwnedUserId, RoomAliasId, ServerName, UserId,
|
||||
};
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use crate::service;
|
||||
|
@ -40,31 +42,31 @@ impl crate::Service for Service {
|
|||
.as_ref()
|
||||
.map(|secret| jsonwebtoken::DecodingKey::from_secret(secret.as_bytes()));
|
||||
|
||||
let turn_secret = config
|
||||
.turn_secret_file
|
||||
.as_ref()
|
||||
.map_or(config.turn_secret.clone(), |path| {
|
||||
std::fs::read_to_string(path).unwrap_or_else(|e| {
|
||||
error!("Failed to read the TURN secret file: {e}");
|
||||
|
||||
config.turn_secret.clone()
|
||||
})
|
||||
});
|
||||
|
||||
let registration_token =
|
||||
let turn_secret =
|
||||
config
|
||||
.registration_token_file
|
||||
.turn_secret_file
|
||||
.as_ref()
|
||||
.map_or(config.registration_token.clone(), |path| {
|
||||
let Ok(token) = std::fs::read_to_string(path).inspect_err(|e| {
|
||||
error!("Failed to read the registration token file: {e}");
|
||||
}) else {
|
||||
return config.registration_token.clone();
|
||||
};
|
||||
.map_or(config.turn_secret.clone(), |path| {
|
||||
std::fs::read_to_string(path).unwrap_or_else(|e| {
|
||||
error!("Failed to read the TURN secret file: {e}");
|
||||
|
||||
Some(token)
|
||||
config.turn_secret.clone()
|
||||
})
|
||||
});
|
||||
|
||||
let registration_token = config.registration_token_file.as_ref().map_or(
|
||||
config.registration_token.clone(),
|
||||
|path| {
|
||||
let Ok(token) = std::fs::read_to_string(path).inspect_err(|e| {
|
||||
error!("Failed to read the registration token file: {e}");
|
||||
}) else {
|
||||
return config.registration_token.clone();
|
||||
};
|
||||
|
||||
Some(token)
|
||||
},
|
||||
);
|
||||
|
||||
let mut s = Self {
|
||||
db,
|
||||
config: config.clone(),
|
||||
|
@ -73,8 +75,11 @@ impl crate::Service for Service {
|
|||
stateres_mutex: Arc::new(Mutex::new(())),
|
||||
admin_alias: RoomAliasId::parse(format!("#admins:{}", &config.server_name))
|
||||
.expect("#admins:server_name is valid alias name"),
|
||||
server_user: UserId::parse_with_server_name(String::from("conduit"), &config.server_name)
|
||||
.expect("@conduit:server_name is valid"),
|
||||
server_user: UserId::parse_with_server_name(
|
||||
String::from("conduit"),
|
||||
&config.server_name,
|
||||
)
|
||||
.expect("@conduit:server_name is valid"),
|
||||
turn_secret,
|
||||
registration_token,
|
||||
};
|
||||
|
@ -125,7 +130,9 @@ impl Service {
|
|||
|
||||
pub fn allow_guest_registration(&self) -> bool { self.config.allow_guest_registration }
|
||||
|
||||
pub fn allow_guests_auto_join_rooms(&self) -> bool { self.config.allow_guests_auto_join_rooms }
|
||||
pub fn allow_guests_auto_join_rooms(&self) -> bool {
|
||||
self.config.allow_guests_auto_join_rooms
|
||||
}
|
||||
|
||||
pub fn log_guest_registrations(&self) -> bool { self.config.log_guest_registrations }
|
||||
|
||||
|
@ -137,17 +144,23 @@ impl Service {
|
|||
self.config.allow_public_room_directory_over_federation
|
||||
}
|
||||
|
||||
pub fn allow_device_name_federation(&self) -> bool { self.config.allow_device_name_federation }
|
||||
pub fn allow_device_name_federation(&self) -> bool {
|
||||
self.config.allow_device_name_federation
|
||||
}
|
||||
|
||||
pub fn allow_room_creation(&self) -> bool { self.config.allow_room_creation }
|
||||
|
||||
pub fn new_user_displayname_suffix(&self) -> &String { &self.config.new_user_displayname_suffix }
|
||||
pub fn new_user_displayname_suffix(&self) -> &String {
|
||||
&self.config.new_user_displayname_suffix
|
||||
}
|
||||
|
||||
pub fn allow_check_for_updates(&self) -> bool { self.config.allow_check_for_updates }
|
||||
|
||||
pub fn trusted_servers(&self) -> &[OwnedServerName] { &self.config.trusted_servers }
|
||||
|
||||
pub fn jwt_decoding_key(&self) -> Option<&jsonwebtoken::DecodingKey> { self.jwt_decoding_key.as_ref() }
|
||||
pub fn jwt_decoding_key(&self) -> Option<&jsonwebtoken::DecodingKey> {
|
||||
self.jwt_decoding_key.as_ref()
|
||||
}
|
||||
|
||||
pub fn turn_password(&self) -> &String { &self.config.turn_password }
|
||||
|
||||
|
@ -173,11 +186,15 @@ impl Service {
|
|||
&self.config.url_preview_domain_explicit_denylist
|
||||
}
|
||||
|
||||
pub fn url_preview_url_contains_allowlist(&self) -> &Vec<String> { &self.config.url_preview_url_contains_allowlist }
|
||||
pub fn url_preview_url_contains_allowlist(&self) -> &Vec<String> {
|
||||
&self.config.url_preview_url_contains_allowlist
|
||||
}
|
||||
|
||||
pub fn url_preview_max_spider_size(&self) -> usize { self.config.url_preview_max_spider_size }
|
||||
|
||||
pub fn url_preview_check_root_domain(&self) -> bool { self.config.url_preview_check_root_domain }
|
||||
pub fn url_preview_check_root_domain(&self) -> bool {
|
||||
self.config.url_preview_check_root_domain
|
||||
}
|
||||
|
||||
pub fn forbidden_alias_names(&self) -> &RegexSet { &self.config.forbidden_alias_names }
|
||||
|
||||
|
@ -189,18 +206,26 @@ impl Service {
|
|||
|
||||
pub fn allow_outgoing_presence(&self) -> bool { self.config.allow_outgoing_presence }
|
||||
|
||||
pub fn allow_incoming_read_receipts(&self) -> bool { self.config.allow_incoming_read_receipts }
|
||||
pub fn allow_incoming_read_receipts(&self) -> bool {
|
||||
self.config.allow_incoming_read_receipts
|
||||
}
|
||||
|
||||
pub fn allow_outgoing_read_receipts(&self) -> bool { self.config.allow_outgoing_read_receipts }
|
||||
pub fn allow_outgoing_read_receipts(&self) -> bool {
|
||||
self.config.allow_outgoing_read_receipts
|
||||
}
|
||||
|
||||
pub fn block_non_admin_invites(&self) -> bool { self.config.block_non_admin_invites }
|
||||
|
||||
/// checks if `user_id` is local to us via server_name comparison
|
||||
#[inline]
|
||||
pub fn user_is_local(&self, user_id: &UserId) -> bool { self.server_is_ours(user_id.server_name()) }
|
||||
pub fn user_is_local(&self, user_id: &UserId) -> bool {
|
||||
self.server_is_ours(user_id.server_name())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn server_is_ours(&self, server_name: &ServerName) -> bool { server_name == self.config.server_name }
|
||||
pub fn server_is_ours(&self, server_name: &ServerName) -> bool {
|
||||
server_name == self.config.server_name
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn is_read_only(&self) -> bool { self.db.db.is_read_only() }
|
||||
|
|
|
@ -48,7 +48,11 @@ impl crate::Service for Service {
|
|||
}
|
||||
|
||||
#[implement(Service)]
|
||||
pub fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw<BackupAlgorithm>) -> Result<String> {
|
||||
pub fn create_backup(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
backup_metadata: &Raw<BackupAlgorithm>,
|
||||
) -> Result<String> {
|
||||
let version = self.services.globals.next_count()?.to_string();
|
||||
let count = self.services.globals.next_count()?;
|
||||
|
||||
|
@ -71,13 +75,18 @@ pub async fn delete_backup(&self, user_id: &UserId, version: &str) {
|
|||
.backupkeyid_backup
|
||||
.keys_prefix_raw(&key)
|
||||
.ignore_err()
|
||||
.ready_for_each(|outdated_key| self.db.backupkeyid_backup.remove(outdated_key))
|
||||
.ready_for_each(|outdated_key| {
|
||||
self.db.backupkeyid_backup.remove(outdated_key);
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
pub async fn update_backup<'a>(
|
||||
&self, user_id: &UserId, version: &'a str, backup_metadata: &Raw<BackupAlgorithm>,
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &'a str,
|
||||
backup_metadata: &Raw<BackupAlgorithm>,
|
||||
) -> Result<&'a str> {
|
||||
let key = (user_id, version);
|
||||
if self.db.backupid_algorithm.qry(&key).await.is_err() {
|
||||
|
@ -110,7 +119,10 @@ pub async fn get_latest_backup_version(&self, user_id: &UserId) -> Result<String
|
|||
}
|
||||
|
||||
#[implement(Service)]
|
||||
pub async fn get_latest_backup(&self, user_id: &UserId) -> Result<(String, Raw<BackupAlgorithm>)> {
|
||||
pub async fn get_latest_backup(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
) -> Result<(String, Raw<BackupAlgorithm>)> {
|
||||
type Key<'a> = (&'a UserId, &'a str);
|
||||
type KeyVal<'a> = (Key<'a>, Raw<BackupAlgorithm>);
|
||||
|
||||
|
@ -134,7 +146,12 @@ pub async fn get_backup(&self, user_id: &UserId, version: &str) -> Result<Raw<Ba
|
|||
|
||||
#[implement(Service)]
|
||||
pub async fn add_key(
|
||||
&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, key_data: &Raw<KeyBackupData>,
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
room_id: &RoomId,
|
||||
session_id: &str,
|
||||
key_data: &Raw<KeyBackupData>,
|
||||
) -> Result<()> {
|
||||
let key = (user_id, version);
|
||||
if self.db.backupid_algorithm.qry(&key).await.is_err() {
|
||||
|
@ -176,14 +193,16 @@ pub async fn get_etag(&self, user_id: &UserId, version: &str) -> String {
|
|||
}
|
||||
|
||||
#[implement(Service)]
|
||||
pub async fn get_all(&self, user_id: &UserId, version: &str) -> BTreeMap<OwnedRoomId, RoomKeyBackup> {
|
||||
pub async fn get_all(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
) -> BTreeMap<OwnedRoomId, RoomKeyBackup> {
|
||||
type Key<'a> = (Ignore, Ignore, &'a RoomId, &'a str);
|
||||
type KeyVal<'a> = (Key<'a>, Raw<KeyBackupData>);
|
||||
|
||||
let mut rooms = BTreeMap::<OwnedRoomId, RoomKeyBackup>::new();
|
||||
let default = || RoomKeyBackup {
|
||||
sessions: BTreeMap::new(),
|
||||
};
|
||||
let default = || RoomKeyBackup { sessions: BTreeMap::new() };
|
||||
|
||||
let prefix = (user_id, version, Interfix);
|
||||
self.db
|
||||
|
@ -204,7 +223,10 @@ pub async fn get_all(&self, user_id: &UserId, version: &str) -> BTreeMap<OwnedRo
|
|||
|
||||
#[implement(Service)]
|
||||
pub async fn get_room(
|
||||
&self, user_id: &UserId, version: &str, room_id: &RoomId,
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
room_id: &RoomId,
|
||||
) -> BTreeMap<String, Raw<KeyBackupData>> {
|
||||
type KeyVal<'a> = ((Ignore, Ignore, Ignore, &'a str), Raw<KeyBackupData>);
|
||||
|
||||
|
@ -213,14 +235,20 @@ pub async fn get_room(
|
|||
.backupkeyid_backup
|
||||
.stream_prefix(&prefix)
|
||||
.ignore_err()
|
||||
.map(|((.., session_id), key_backup_data): KeyVal<'_>| (session_id.to_owned(), key_backup_data))
|
||||
.map(|((.., session_id), key_backup_data): KeyVal<'_>| {
|
||||
(session_id.to_owned(), key_backup_data)
|
||||
})
|
||||
.collect()
|
||||
.await
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
pub async fn get_session(
|
||||
&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str,
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
room_id: &RoomId,
|
||||
session_id: &str,
|
||||
) -> Result<Raw<KeyBackupData>> {
|
||||
let key = (user_id, version, room_id, session_id);
|
||||
|
||||
|
@ -245,17 +273,27 @@ pub async fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &
|
|||
.backupkeyid_backup
|
||||
.keys_prefix_raw(&key)
|
||||
.ignore_err()
|
||||
.ready_for_each(|outdated_key| self.db.backupkeyid_backup.remove(outdated_key))
|
||||
.ready_for_each(|outdated_key| {
|
||||
self.db.backupkeyid_backup.remove(outdated_key);
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
pub async fn delete_room_key(&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str) {
|
||||
pub async fn delete_room_key(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
room_id: &RoomId,
|
||||
session_id: &str,
|
||||
) {
|
||||
let key = (user_id, version, room_id, session_id);
|
||||
self.db
|
||||
.backupkeyid_backup
|
||||
.keys_prefix_raw(&key)
|
||||
.ignore_err()
|
||||
.ready_for_each(|outdated_key| self.db.backupkeyid_backup.remove(outdated_key))
|
||||
.ready_for_each(|outdated_key| {
|
||||
self.db.backupkeyid_backup.remove(outdated_key);
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
|
|
@ -102,21 +102,32 @@ impl Manager {
|
|||
unimplemented!("unexpected worker task abort {error:?}");
|
||||
}
|
||||
|
||||
async fn handle_result(&self, workers: &mut WorkersLocked<'_>, result: WorkerResult) -> Result<()> {
|
||||
async fn handle_result(
|
||||
&self,
|
||||
workers: &mut WorkersLocked<'_>,
|
||||
result: WorkerResult,
|
||||
) -> Result<()> {
|
||||
let (service, result) = result;
|
||||
match result {
|
||||
Ok(()) => self.handle_finished(workers, &service).await,
|
||||
Err(error) => self.handle_error(workers, &service, error).await,
|
||||
| Ok(()) => self.handle_finished(workers, &service).await,
|
||||
| Err(error) => self.handle_error(workers, &service, error).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_finished(&self, _workers: &mut WorkersLocked<'_>, service: &Arc<dyn Service>) -> Result<()> {
|
||||
async fn handle_finished(
|
||||
&self,
|
||||
_workers: &mut WorkersLocked<'_>,
|
||||
service: &Arc<dyn Service>,
|
||||
) -> Result<()> {
|
||||
debug!("service {:?} worker finished", service.name());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_error(
|
||||
&self, workers: &mut WorkersLocked<'_>, service: &Arc<dyn Service>, error: Error,
|
||||
&self,
|
||||
workers: &mut WorkersLocked<'_>,
|
||||
service: &Arc<dyn Service>,
|
||||
error: Error,
|
||||
) -> Result<()> {
|
||||
let name = service.name();
|
||||
error!("service {name:?} aborted: {error}");
|
||||
|
@ -138,9 +149,16 @@ impl Manager {
|
|||
}
|
||||
|
||||
/// Start the worker in a task for the service.
|
||||
async fn start_worker(&self, workers: &mut WorkersLocked<'_>, service: &Arc<dyn Service>) -> Result<()> {
|
||||
async fn start_worker(
|
||||
&self,
|
||||
workers: &mut WorkersLocked<'_>,
|
||||
service: &Arc<dyn Service>,
|
||||
) -> Result<()> {
|
||||
if !self.server.running() {
|
||||
return Err!("Service {:?} worker not starting during server shutdown.", service.name());
|
||||
return Err!(
|
||||
"Service {:?} worker not starting during server shutdown.",
|
||||
service.name()
|
||||
);
|
||||
}
|
||||
|
||||
debug!("Service {:?} worker starting...", service.name());
|
||||
|
|
|
@ -34,7 +34,11 @@ impl Data {
|
|||
}
|
||||
|
||||
pub(super) fn create_file_metadata(
|
||||
&self, mxc: &Mxc<'_>, user: Option<&UserId>, dim: &Dim, content_disposition: Option<&ContentDisposition>,
|
||||
&self,
|
||||
mxc: &Mxc<'_>,
|
||||
user: Option<&UserId>,
|
||||
dim: &Dim,
|
||||
content_disposition: Option<&ContentDisposition>,
|
||||
content_type: Option<&str>,
|
||||
) -> Result<Vec<u8>> {
|
||||
let dim: &[u32] = &[dim.width, dim.height];
|
||||
|
@ -63,7 +67,10 @@ impl Data {
|
|||
.stream_prefix_raw(&prefix)
|
||||
.ignore_err()
|
||||
.ready_for_each(|(key, val)| {
|
||||
debug_assert!(key.starts_with(mxc.to_string().as_bytes()), "key should start with the mxc");
|
||||
debug_assert!(
|
||||
key.starts_with(mxc.to_string().as_bytes()),
|
||||
"key should start with the mxc"
|
||||
);
|
||||
|
||||
let user = str_from_bytes(val).unwrap_or_default();
|
||||
debug_info!("Deleting key {key:?} which was uploaded by user {user}");
|
||||
|
@ -95,7 +102,11 @@ impl Data {
|
|||
Ok(keys)
|
||||
}
|
||||
|
||||
pub(super) async fn search_file_metadata(&self, mxc: &Mxc<'_>, dim: &Dim) -> Result<Metadata> {
|
||||
pub(super) async fn search_file_metadata(
|
||||
&self,
|
||||
mxc: &Mxc<'_>,
|
||||
dim: &Dim,
|
||||
) -> Result<Metadata> {
|
||||
let dim: &[u32] = &[dim.width, dim.height];
|
||||
let prefix = (mxc, dim, Interfix);
|
||||
|
||||
|
@ -113,8 +124,9 @@ impl Data {
|
|||
let content_type = parts
|
||||
.next()
|
||||
.map(|bytes| {
|
||||
string_from_bytes(bytes)
|
||||
.map_err(|_| Error::bad_database("Content type in mediaid_file is invalid unicode."))
|
||||
string_from_bytes(bytes).map_err(|_| {
|
||||
Error::bad_database("Content type in mediaid_file is invalid unicode.")
|
||||
})
|
||||
})
|
||||
.transpose()?;
|
||||
|
||||
|
@ -127,16 +139,16 @@ impl Data {
|
|||
} else {
|
||||
Some(
|
||||
string_from_bytes(content_disposition_bytes)
|
||||
.map_err(|_| Error::bad_database("Content Disposition in mediaid_file is invalid unicode."))?
|
||||
.map_err(|_| {
|
||||
Error::bad_database(
|
||||
"Content Disposition in mediaid_file is invalid unicode.",
|
||||
)
|
||||
})?
|
||||
.parse()?,
|
||||
)
|
||||
};
|
||||
|
||||
Ok(Metadata {
|
||||
content_disposition,
|
||||
content_type,
|
||||
key,
|
||||
})
|
||||
Ok(Metadata { content_disposition, content_type, key })
|
||||
}
|
||||
|
||||
/// Gets all the MXCs associated with a user
|
||||
|
@ -144,7 +156,9 @@ impl Data {
|
|||
self.mediaid_user
|
||||
.stream()
|
||||
.ignore_err()
|
||||
.ready_filter_map(|(key, user): (&str, &UserId)| (user == user_id).then(|| key.into()))
|
||||
.ready_filter_map(|(key, user): (&str, &UserId)| {
|
||||
(user == user_id).then(|| key.into())
|
||||
})
|
||||
.collect()
|
||||
.await
|
||||
}
|
||||
|
@ -166,7 +180,12 @@ impl Data {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub(super) fn set_url_preview(&self, url: &str, data: &UrlPreviewData, timestamp: Duration) -> Result<()> {
|
||||
pub(super) fn set_url_preview(
|
||||
&self,
|
||||
url: &str,
|
||||
data: &UrlPreviewData,
|
||||
timestamp: Duration,
|
||||
) -> Result<()> {
|
||||
let mut value = Vec::<u8>::new();
|
||||
value.extend_from_slice(×tamp.as_secs().to_be_bytes());
|
||||
value.push(0xFF);
|
||||
|
@ -218,43 +237,43 @@ impl Data {
|
|||
.next()
|
||||
.and_then(|b| String::from_utf8(b.to_vec()).ok())
|
||||
{
|
||||
Some(s) if s.is_empty() => None,
|
||||
x => x,
|
||||
| Some(s) if s.is_empty() => None,
|
||||
| x => x,
|
||||
};
|
||||
let description = match values
|
||||
.next()
|
||||
.and_then(|b| String::from_utf8(b.to_vec()).ok())
|
||||
{
|
||||
Some(s) if s.is_empty() => None,
|
||||
x => x,
|
||||
| Some(s) if s.is_empty() => None,
|
||||
| x => x,
|
||||
};
|
||||
let image = match values
|
||||
.next()
|
||||
.and_then(|b| String::from_utf8(b.to_vec()).ok())
|
||||
{
|
||||
Some(s) if s.is_empty() => None,
|
||||
x => x,
|
||||
| Some(s) if s.is_empty() => None,
|
||||
| x => x,
|
||||
};
|
||||
let image_size = match values
|
||||
.next()
|
||||
.map(|b| usize::from_be_bytes(b.try_into().unwrap_or_default()))
|
||||
{
|
||||
Some(0) => None,
|
||||
x => x,
|
||||
| Some(0) => None,
|
||||
| x => x,
|
||||
};
|
||||
let image_width = match values
|
||||
.next()
|
||||
.map(|b| u32::from_be_bytes(b.try_into().unwrap_or_default()))
|
||||
{
|
||||
Some(0) => None,
|
||||
x => x,
|
||||
| Some(0) => None,
|
||||
| x => x,
|
||||
};
|
||||
let image_height = match values
|
||||
.next()
|
||||
.map(|b| u32::from_be_bytes(b.try_into().unwrap_or_default()))
|
||||
{
|
||||
Some(0) => None,
|
||||
x => x,
|
||||
| Some(0) => None,
|
||||
| x => x,
|
||||
};
|
||||
|
||||
Ok(UrlPreviewData {
|
||||
|
|
|
@ -83,7 +83,8 @@ pub(crate) async fn checkup_sha256_media(services: &Services) -> Result<()> {
|
|||
for key in media.db.get_all_media_keys().await {
|
||||
let new_path = media.get_media_file_sha256(&key).into_os_string();
|
||||
let old_path = media.get_media_file_b64(&key).into_os_string();
|
||||
if let Err(e) = handle_media_check(&dbs, config, &files, &key, &new_path, &old_path).await {
|
||||
if let Err(e) = handle_media_check(&dbs, config, &files, &key, &new_path, &old_path).await
|
||||
{
|
||||
error!(
|
||||
media_id = ?encode_key(&key), ?new_path, ?old_path,
|
||||
"Failed to resolve media check failure: {e}"
|
||||
|
@ -100,8 +101,12 @@ pub(crate) async fn checkup_sha256_media(services: &Services) -> Result<()> {
|
|||
}
|
||||
|
||||
async fn handle_media_check(
|
||||
dbs: &(&Arc<database::Map>, &Arc<database::Map>), config: &Config, files: &HashSet<OsString>, key: &[u8],
|
||||
new_path: &OsStr, old_path: &OsStr,
|
||||
dbs: &(&Arc<database::Map>, &Arc<database::Map>),
|
||||
config: &Config,
|
||||
files: &HashSet<OsString>,
|
||||
key: &[u8],
|
||||
new_path: &OsStr,
|
||||
old_path: &OsStr,
|
||||
) -> Result<()> {
|
||||
use crate::media::encode_key;
|
||||
|
||||
|
|
|
@ -80,13 +80,21 @@ impl crate::Service for Service {
|
|||
impl Service {
|
||||
/// Uploads a file.
|
||||
pub async fn create(
|
||||
&self, mxc: &Mxc<'_>, user: Option<&UserId>, content_disposition: Option<&ContentDisposition>,
|
||||
content_type: Option<&str>, file: &[u8],
|
||||
&self,
|
||||
mxc: &Mxc<'_>,
|
||||
user: Option<&UserId>,
|
||||
content_disposition: Option<&ContentDisposition>,
|
||||
content_type: Option<&str>,
|
||||
file: &[u8],
|
||||
) -> Result<()> {
|
||||
// Width, Height = 0 if it's not a thumbnail
|
||||
let key = self
|
||||
.db
|
||||
.create_file_metadata(mxc, user, &Dim::default(), content_disposition, content_type)?;
|
||||
let key = self.db.create_file_metadata(
|
||||
mxc,
|
||||
user,
|
||||
&Dim::default(),
|
||||
content_disposition,
|
||||
content_type,
|
||||
)?;
|
||||
|
||||
//TODO: Dangling metadata in database if creation fails
|
||||
let mut f = self.create_media_file(&key).await?;
|
||||
|
@ -132,10 +140,10 @@ impl Service {
|
|||
|
||||
debug_info!(%deletion_count, "Deleting MXC {mxc} by user {user} from database and filesystem");
|
||||
match self.delete(&mxc).await {
|
||||
Ok(()) => {
|
||||
| Ok(()) => {
|
||||
deletion_count = deletion_count.saturating_add(1);
|
||||
},
|
||||
Err(e) => {
|
||||
| Err(e) => {
|
||||
debug_error!(%deletion_count, "Failed to delete {mxc} from user {user}, ignoring error: {e}");
|
||||
},
|
||||
}
|
||||
|
@ -146,11 +154,8 @@ impl Service {
|
|||
|
||||
/// Downloads a file.
|
||||
pub async fn get(&self, mxc: &Mxc<'_>) -> Result<Option<FileMeta>> {
|
||||
if let Ok(Metadata {
|
||||
content_disposition,
|
||||
content_type,
|
||||
key,
|
||||
}) = self.db.search_file_metadata(mxc, &Dim::default()).await
|
||||
if let Ok(Metadata { content_disposition, content_type, key }) =
|
||||
self.db.search_file_metadata(mxc, &Dim::default()).await
|
||||
{
|
||||
let mut content = Vec::new();
|
||||
let path = self.get_media_file(&key);
|
||||
|
@ -181,13 +186,19 @@ impl Service {
|
|||
let mxc = parts
|
||||
.next()
|
||||
.map(|bytes| {
|
||||
utils::string_from_bytes(bytes)
|
||||
.map_err(|e| err!(Database(error!("Failed to parse MXC unicode bytes from our database: {e}"))))
|
||||
utils::string_from_bytes(bytes).map_err(|e| {
|
||||
err!(Database(error!(
|
||||
"Failed to parse MXC unicode bytes from our database: {e}"
|
||||
)))
|
||||
})
|
||||
})
|
||||
.transpose()?;
|
||||
|
||||
let Some(mxc_s) = mxc else {
|
||||
debug_warn!(?mxc, "Parsed MXC URL unicode bytes from database but is still invalid");
|
||||
debug_warn!(
|
||||
?mxc,
|
||||
"Parsed MXC URL unicode bytes from database but is still invalid"
|
||||
);
|
||||
continue;
|
||||
};
|
||||
|
||||
|
@ -207,7 +218,11 @@ impl Service {
|
|||
/// Deletes all remote only media files in the given at or after
|
||||
/// time/duration. Returns a usize with the amount of media files deleted.
|
||||
pub async fn delete_all_remote_media_at_after_time(
|
||||
&self, time: SystemTime, before: bool, after: bool, yes_i_want_to_delete_local_media: bool,
|
||||
&self,
|
||||
time: SystemTime,
|
||||
before: bool,
|
||||
after: bool,
|
||||
yes_i_want_to_delete_local_media: bool,
|
||||
) -> Result<usize> {
|
||||
let all_keys = self.db.get_all_media_keys().await;
|
||||
let mut remote_mxcs = Vec::with_capacity(all_keys.len());
|
||||
|
@ -218,19 +233,26 @@ impl Service {
|
|||
let mxc = parts
|
||||
.next()
|
||||
.map(|bytes| {
|
||||
utils::string_from_bytes(bytes)
|
||||
.map_err(|e| err!(Database(error!("Failed to parse MXC unicode bytes from our database: {e}"))))
|
||||
utils::string_from_bytes(bytes).map_err(|e| {
|
||||
err!(Database(error!(
|
||||
"Failed to parse MXC unicode bytes from our database: {e}"
|
||||
)))
|
||||
})
|
||||
})
|
||||
.transpose()?;
|
||||
|
||||
let Some(mxc_s) = mxc else {
|
||||
debug_warn!(?mxc, "Parsed MXC URL unicode bytes from database but is still invalid");
|
||||
debug_warn!(
|
||||
?mxc,
|
||||
"Parsed MXC URL unicode bytes from database but is still invalid"
|
||||
);
|
||||
continue;
|
||||
};
|
||||
|
||||
trace!("Parsed MXC key to URL: {mxc_s}");
|
||||
let mxc = OwnedMxcUri::from(mxc_s);
|
||||
if (mxc.server_name() == Ok(self.services.globals.server_name()) && !yes_i_want_to_delete_local_media)
|
||||
if (mxc.server_name() == Ok(self.services.globals.server_name())
|
||||
&& !yes_i_want_to_delete_local_media)
|
||||
|| !mxc.is_valid()
|
||||
{
|
||||
debug!("Ignoring local or broken media MXC: {mxc}");
|
||||
|
@ -240,9 +262,12 @@ impl Service {
|
|||
let path = self.get_media_file(&key);
|
||||
|
||||
let file_metadata = match fs::metadata(path.clone()).await {
|
||||
Ok(file_metadata) => file_metadata,
|
||||
Err(e) => {
|
||||
error!("Failed to obtain file metadata for MXC {mxc} at file path \"{path:?}\", skipping: {e}");
|
||||
| Ok(file_metadata) => file_metadata,
|
||||
| Err(e) => {
|
||||
error!(
|
||||
"Failed to obtain file metadata for MXC {mxc} at file path \
|
||||
\"{path:?}\", skipping: {e}"
|
||||
);
|
||||
continue;
|
||||
},
|
||||
};
|
||||
|
@ -250,12 +275,12 @@ impl Service {
|
|||
trace!(%mxc, ?path, "File metadata: {file_metadata:?}");
|
||||
|
||||
let file_created_at = match file_metadata.created() {
|
||||
Ok(value) => value,
|
||||
Err(err) if err.kind() == std::io::ErrorKind::Unsupported => {
|
||||
| Ok(value) => value,
|
||||
| Err(err) if err.kind() == std::io::ErrorKind::Unsupported => {
|
||||
debug!("btime is unsupported, using mtime instead");
|
||||
file_metadata.modified()?
|
||||
},
|
||||
Err(err) => {
|
||||
| Err(err) => {
|
||||
error!("Could not delete MXC {mxc} at path {path:?}: {err:?}. Skipping...");
|
||||
continue;
|
||||
},
|
||||
|
@ -264,10 +289,16 @@ impl Service {
|
|||
debug!("File created at: {file_created_at:?}");
|
||||
|
||||
if file_created_at >= time && before {
|
||||
debug!("File is within (before) user duration, pushing to list of file paths and keys to delete.");
|
||||
debug!(
|
||||
"File is within (before) user duration, pushing to list of file paths and \
|
||||
keys to delete."
|
||||
);
|
||||
remote_mxcs.push(mxc.to_string());
|
||||
} else if file_created_at <= time && after {
|
||||
debug!("File is not within (after) user duration, pushing to list of file paths and keys to delete.");
|
||||
debug!(
|
||||
"File is not within (after) user duration, pushing to list of file paths \
|
||||
and keys to delete."
|
||||
);
|
||||
remote_mxcs.push(mxc.to_string());
|
||||
}
|
||||
}
|
||||
|
@ -289,10 +320,10 @@ impl Service {
|
|||
debug_info!("Deleting MXC {mxc} from database and filesystem");
|
||||
|
||||
match self.delete(&mxc).await {
|
||||
Ok(()) => {
|
||||
| Ok(()) => {
|
||||
deletion_count = deletion_count.saturating_add(1);
|
||||
},
|
||||
Err(e) => {
|
||||
| Err(e) => {
|
||||
warn!("Failed to delete {mxc}, ignoring error and skipping: {e}");
|
||||
continue;
|
||||
},
|
||||
|
|
|
@ -53,10 +53,10 @@ pub async fn download_image(&self, url: &str) -> Result<UrlPreviewData> {
|
|||
self.create(&mxc, None, None, None, &image).await?;
|
||||
|
||||
let (width, height) = match ImgReader::new(Cursor::new(&image)).with_guessed_format() {
|
||||
Err(_) => (None, None),
|
||||
Ok(reader) => match reader.into_dimensions() {
|
||||
Err(_) => (None, None),
|
||||
Ok((width, height)) => (Some(width), Some(height)),
|
||||
| Err(_) => (None, None),
|
||||
| Ok(reader) => match reader.into_dimensions() {
|
||||
| Err(_) => (None, None),
|
||||
| Ok((width, height)) => (Some(width), Some(height)),
|
||||
},
|
||||
};
|
||||
|
||||
|
@ -79,8 +79,8 @@ pub async fn get_url_preview(&self, url: &Url) -> Result<UrlPreviewData> {
|
|||
let _request_lock = self.url_preview_mutex.lock(url.as_str()).await;
|
||||
|
||||
match self.db.get_url_preview(url.as_str()).await {
|
||||
Ok(preview) => Ok(preview),
|
||||
Err(_) => self.request_url_preview(url).await,
|
||||
| Ok(preview) => Ok(preview),
|
||||
| Err(_) => self.request_url_preview(url).await,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -111,9 +111,9 @@ async fn request_url_preview(&self, url: &Url) -> Result<UrlPreviewData> {
|
|||
return Err!(Request(Unknown("Unknown Content-Type")));
|
||||
};
|
||||
let data = match content_type {
|
||||
html if html.starts_with("text/html") => self.download_html(url.as_str()).await?,
|
||||
img if img.starts_with("image/") => self.download_image(url.as_str()).await?,
|
||||
_ => return Err!(Request(Unknown("Unsupported Content-Type"))),
|
||||
| html if html.starts_with("text/html") => self.download_html(url.as_str()).await?,
|
||||
| img if img.starts_with("image/") => self.download_image(url.as_str()).await?,
|
||||
| _ => return Err!(Request(Unknown("Unsupported Content-Type"))),
|
||||
};
|
||||
|
||||
self.set_url_preview(url.as_str(), &data).await?;
|
||||
|
@ -131,8 +131,9 @@ async fn download_html(&self, url: &str) -> Result<UrlPreviewData> {
|
|||
bytes.extend_from_slice(&chunk);
|
||||
if bytes.len() > self.services.globals.url_preview_max_spider_size() {
|
||||
debug!(
|
||||
"Response body from URL {} exceeds url_preview_max_spider_size ({}), not processing the rest of the \
|
||||
response body and assuming our necessary data is in this range.",
|
||||
"Response body from URL {} exceeds url_preview_max_spider_size ({}), not \
|
||||
processing the rest of the response body and assuming our necessary data is in \
|
||||
this range.",
|
||||
url,
|
||||
self.services.globals.url_preview_max_spider_size()
|
||||
);
|
||||
|
@ -145,8 +146,8 @@ async fn download_html(&self, url: &str) -> Result<UrlPreviewData> {
|
|||
};
|
||||
|
||||
let mut data = match html.opengraph.images.first() {
|
||||
None => UrlPreviewData::default(),
|
||||
Some(obj) => self.download_image(&obj.url).await?,
|
||||
| None => UrlPreviewData::default(),
|
||||
| Some(obj) => self.download_image(&obj.url).await?,
|
||||
};
|
||||
|
||||
let props = html.opengraph.properties;
|
||||
|
@ -169,11 +170,11 @@ pub fn url_preview_allowed(&self, url: &Url) -> bool {
|
|||
}
|
||||
|
||||
let host = match url.host_str() {
|
||||
None => {
|
||||
| None => {
|
||||
debug!("Ignoring URL preview for a URL that does not have a host (?): {}", url);
|
||||
return false;
|
||||
},
|
||||
Some(h) => h.to_owned(),
|
||||
| Some(h) => h.to_owned(),
|
||||
};
|
||||
|
||||
let allowlist_domain_contains = self
|
||||
|
@ -205,7 +206,10 @@ pub fn url_preview_allowed(&self, url: &Url) -> bool {
|
|||
}
|
||||
|
||||
if allowlist_domain_explicit.contains(&host) {
|
||||
debug!("Host {} is allowed by url_preview_domain_explicit_allowlist (check 2/4)", &host);
|
||||
debug!(
|
||||
"Host {} is allowed by url_preview_domain_explicit_allowlist (check 2/4)",
|
||||
&host
|
||||
);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -213,7 +217,10 @@ pub fn url_preview_allowed(&self, url: &Url) -> bool {
|
|||
.iter()
|
||||
.any(|domain_s| domain_s.contains(&host.clone()))
|
||||
{
|
||||
debug!("Host {} is allowed by url_preview_domain_contains_allowlist (check 3/4)", &host);
|
||||
debug!(
|
||||
"Host {} is allowed by url_preview_domain_contains_allowlist (check 3/4)",
|
||||
&host
|
||||
);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -229,11 +236,12 @@ pub fn url_preview_allowed(&self, url: &Url) -> bool {
|
|||
if self.services.globals.url_preview_check_root_domain() {
|
||||
debug!("Checking root domain");
|
||||
match host.split_once('.') {
|
||||
None => return false,
|
||||
Some((_, root_domain)) => {
|
||||
| None => return false,
|
||||
| Some((_, root_domain)) => {
|
||||
if denylist_domain_explicit.contains(&root_domain.to_owned()) {
|
||||
debug!(
|
||||
"Root domain {} is not allowed by url_preview_domain_explicit_denylist (check 1/3)",
|
||||
"Root domain {} is not allowed by \
|
||||
url_preview_domain_explicit_denylist (check 1/3)",
|
||||
&root_domain
|
||||
);
|
||||
return true;
|
||||
|
@ -241,7 +249,8 @@ pub fn url_preview_allowed(&self, url: &Url) -> bool {
|
|||
|
||||
if allowlist_domain_explicit.contains(&root_domain.to_owned()) {
|
||||
debug!(
|
||||
"Root domain {} is allowed by url_preview_domain_explicit_allowlist (check 2/3)",
|
||||
"Root domain {} is allowed by url_preview_domain_explicit_allowlist \
|
||||
(check 2/3)",
|
||||
&root_domain
|
||||
);
|
||||
return true;
|
||||
|
@ -252,7 +261,8 @@ pub fn url_preview_allowed(&self, url: &Url) -> bool {
|
|||
.any(|domain_s| domain_s.contains(&root_domain.to_owned()))
|
||||
{
|
||||
debug!(
|
||||
"Root domain {} is allowed by url_preview_domain_contains_allowlist (check 3/3)",
|
||||
"Root domain {} is allowed by url_preview_domain_contains_allowlist \
|
||||
(check 3/3)",
|
||||
&root_domain
|
||||
);
|
||||
return true;
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
use std::{fmt::Debug, time::Duration};
|
||||
|
||||
use conduwuit::{debug_warn, err, implement, utils::content_disposition::make_content_disposition, Err, Error, Result};
|
||||
use conduwuit::{
|
||||
debug_warn, err, implement, utils::content_disposition::make_content_disposition, Err, Error,
|
||||
Result,
|
||||
};
|
||||
use http::header::{HeaderValue, CONTENT_DISPOSITION, CONTENT_TYPE};
|
||||
use ruma::{
|
||||
api::{
|
||||
|
@ -19,7 +22,12 @@ use super::{Dim, FileMeta};
|
|||
|
||||
#[implement(super::Service)]
|
||||
pub async fn fetch_remote_thumbnail(
|
||||
&self, mxc: &Mxc<'_>, user: Option<&UserId>, server: Option<&ServerName>, timeout_ms: Duration, dim: &Dim,
|
||||
&self,
|
||||
mxc: &Mxc<'_>,
|
||||
user: Option<&UserId>,
|
||||
server: Option<&ServerName>,
|
||||
timeout_ms: Duration,
|
||||
dim: &Dim,
|
||||
) -> Result<FileMeta> {
|
||||
self.check_fetch_authorized(mxc)?;
|
||||
|
||||
|
@ -38,7 +46,11 @@ pub async fn fetch_remote_thumbnail(
|
|||
|
||||
#[implement(super::Service)]
|
||||
pub async fn fetch_remote_content(
|
||||
&self, mxc: &Mxc<'_>, user: Option<&UserId>, server: Option<&ServerName>, timeout_ms: Duration,
|
||||
&self,
|
||||
mxc: &Mxc<'_>,
|
||||
user: Option<&UserId>,
|
||||
server: Option<&ServerName>,
|
||||
timeout_ms: Duration,
|
||||
) -> Result<FileMeta> {
|
||||
self.check_fetch_authorized(mxc)?;
|
||||
|
||||
|
@ -57,7 +69,12 @@ pub async fn fetch_remote_content(
|
|||
|
||||
#[implement(super::Service)]
|
||||
async fn fetch_thumbnail_authenticated(
|
||||
&self, mxc: &Mxc<'_>, user: Option<&UserId>, server: Option<&ServerName>, timeout_ms: Duration, dim: &Dim,
|
||||
&self,
|
||||
mxc: &Mxc<'_>,
|
||||
user: Option<&UserId>,
|
||||
server: Option<&ServerName>,
|
||||
timeout_ms: Duration,
|
||||
dim: &Dim,
|
||||
) -> Result<FileMeta> {
|
||||
use federation::authenticated_media::get_content_thumbnail::v1::{Request, Response};
|
||||
|
||||
|
@ -70,20 +87,22 @@ async fn fetch_thumbnail_authenticated(
|
|||
timeout_ms,
|
||||
};
|
||||
|
||||
let Response {
|
||||
content,
|
||||
..
|
||||
} = self.federation_request(mxc, user, server, request).await?;
|
||||
let Response { content, .. } = self.federation_request(mxc, user, server, request).await?;
|
||||
|
||||
match content {
|
||||
FileOrLocation::File(content) => self.handle_thumbnail_file(mxc, user, dim, content).await,
|
||||
FileOrLocation::Location(location) => self.handle_location(mxc, user, &location).await,
|
||||
| FileOrLocation::File(content) =>
|
||||
self.handle_thumbnail_file(mxc, user, dim, content).await,
|
||||
| FileOrLocation::Location(location) => self.handle_location(mxc, user, &location).await,
|
||||
}
|
||||
}
|
||||
|
||||
#[implement(super::Service)]
|
||||
async fn fetch_content_authenticated(
|
||||
&self, mxc: &Mxc<'_>, user: Option<&UserId>, server: Option<&ServerName>, timeout_ms: Duration,
|
||||
&self,
|
||||
mxc: &Mxc<'_>,
|
||||
user: Option<&UserId>,
|
||||
server: Option<&ServerName>,
|
||||
timeout_ms: Duration,
|
||||
) -> Result<FileMeta> {
|
||||
use federation::authenticated_media::get_content::v1::{Request, Response};
|
||||
|
||||
|
@ -92,21 +111,23 @@ async fn fetch_content_authenticated(
|
|||
timeout_ms,
|
||||
};
|
||||
|
||||
let Response {
|
||||
content,
|
||||
..
|
||||
} = self.federation_request(mxc, user, server, request).await?;
|
||||
let Response { content, .. } = self.federation_request(mxc, user, server, request).await?;
|
||||
|
||||
match content {
|
||||
FileOrLocation::File(content) => self.handle_content_file(mxc, user, content).await,
|
||||
FileOrLocation::Location(location) => self.handle_location(mxc, user, &location).await,
|
||||
| FileOrLocation::File(content) => self.handle_content_file(mxc, user, content).await,
|
||||
| FileOrLocation::Location(location) => self.handle_location(mxc, user, &location).await,
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(deprecated)]
|
||||
#[implement(super::Service)]
|
||||
async fn fetch_thumbnail_unauthenticated(
|
||||
&self, mxc: &Mxc<'_>, user: Option<&UserId>, server: Option<&ServerName>, timeout_ms: Duration, dim: &Dim,
|
||||
&self,
|
||||
mxc: &Mxc<'_>,
|
||||
user: Option<&UserId>,
|
||||
server: Option<&ServerName>,
|
||||
timeout_ms: Duration,
|
||||
dim: &Dim,
|
||||
) -> Result<FileMeta> {
|
||||
use media::get_content_thumbnail::v3::{Request, Response};
|
||||
|
||||
|
@ -123,17 +144,10 @@ async fn fetch_thumbnail_unauthenticated(
|
|||
};
|
||||
|
||||
let Response {
|
||||
file,
|
||||
content_type,
|
||||
content_disposition,
|
||||
..
|
||||
file, content_type, content_disposition, ..
|
||||
} = self.federation_request(mxc, user, server, request).await?;
|
||||
|
||||
let content = Content {
|
||||
file,
|
||||
content_type,
|
||||
content_disposition,
|
||||
};
|
||||
let content = Content { file, content_type, content_disposition };
|
||||
|
||||
self.handle_thumbnail_file(mxc, user, dim, content).await
|
||||
}
|
||||
|
@ -141,7 +155,11 @@ async fn fetch_thumbnail_unauthenticated(
|
|||
#[allow(deprecated)]
|
||||
#[implement(super::Service)]
|
||||
async fn fetch_content_unauthenticated(
|
||||
&self, mxc: &Mxc<'_>, user: Option<&UserId>, server: Option<&ServerName>, timeout_ms: Duration,
|
||||
&self,
|
||||
mxc: &Mxc<'_>,
|
||||
user: Option<&UserId>,
|
||||
server: Option<&ServerName>,
|
||||
timeout_ms: Duration,
|
||||
) -> Result<FileMeta> {
|
||||
use media::get_content::v3::{Request, Response};
|
||||
|
||||
|
@ -154,27 +172,27 @@ async fn fetch_content_unauthenticated(
|
|||
};
|
||||
|
||||
let Response {
|
||||
file,
|
||||
content_type,
|
||||
content_disposition,
|
||||
..
|
||||
file, content_type, content_disposition, ..
|
||||
} = self.federation_request(mxc, user, server, request).await?;
|
||||
|
||||
let content = Content {
|
||||
file,
|
||||
content_type,
|
||||
content_disposition,
|
||||
};
|
||||
let content = Content { file, content_type, content_disposition };
|
||||
|
||||
self.handle_content_file(mxc, user, content).await
|
||||
}
|
||||
|
||||
#[implement(super::Service)]
|
||||
async fn handle_thumbnail_file(
|
||||
&self, mxc: &Mxc<'_>, user: Option<&UserId>, dim: &Dim, content: Content,
|
||||
&self,
|
||||
mxc: &Mxc<'_>,
|
||||
user: Option<&UserId>,
|
||||
dim: &Dim,
|
||||
content: Content,
|
||||
) -> Result<FileMeta> {
|
||||
let content_disposition =
|
||||
make_content_disposition(content.content_disposition.as_ref(), content.content_type.as_deref(), None);
|
||||
let content_disposition = make_content_disposition(
|
||||
content.content_disposition.as_ref(),
|
||||
content.content_type.as_deref(),
|
||||
None,
|
||||
);
|
||||
|
||||
self.upload_thumbnail(
|
||||
mxc,
|
||||
|
@ -193,9 +211,17 @@ async fn handle_thumbnail_file(
|
|||
}
|
||||
|
||||
#[implement(super::Service)]
|
||||
async fn handle_content_file(&self, mxc: &Mxc<'_>, user: Option<&UserId>, content: Content) -> Result<FileMeta> {
|
||||
let content_disposition =
|
||||
make_content_disposition(content.content_disposition.as_ref(), content.content_type.as_deref(), None);
|
||||
async fn handle_content_file(
|
||||
&self,
|
||||
mxc: &Mxc<'_>,
|
||||
user: Option<&UserId>,
|
||||
content: Content,
|
||||
) -> Result<FileMeta> {
|
||||
let content_disposition = make_content_disposition(
|
||||
content.content_disposition.as_ref(),
|
||||
content.content_type.as_deref(),
|
||||
None,
|
||||
);
|
||||
|
||||
self.create(
|
||||
mxc,
|
||||
|
@ -213,7 +239,12 @@ async fn handle_content_file(&self, mxc: &Mxc<'_>, user: Option<&UserId>, conten
|
|||
}
|
||||
|
||||
#[implement(super::Service)]
|
||||
async fn handle_location(&self, mxc: &Mxc<'_>, user: Option<&UserId>, location: &str) -> Result<FileMeta> {
|
||||
async fn handle_location(
|
||||
&self,
|
||||
mxc: &Mxc<'_>,
|
||||
user: Option<&UserId>,
|
||||
location: &str,
|
||||
) -> Result<FileMeta> {
|
||||
self.location_request(location).await.map_err(|error| {
|
||||
err!(Request(NotFound(
|
||||
debug_warn!(%mxc, ?user, ?location, ?error, "Fetching media from location failed")
|
||||
|
@ -263,7 +294,11 @@ async fn location_request(&self, location: &str) -> Result<FileMeta> {
|
|||
|
||||
#[implement(super::Service)]
|
||||
async fn federation_request<Request>(
|
||||
&self, mxc: &Mxc<'_>, user: Option<&UserId>, server: Option<&ServerName>, request: Request,
|
||||
&self,
|
||||
mxc: &Mxc<'_>,
|
||||
user: Option<&UserId>,
|
||||
server: Option<&ServerName>,
|
||||
request: Request,
|
||||
) -> Result<Request::IncomingResponse>
|
||||
where
|
||||
Request: OutgoingRequest + Send + Debug,
|
||||
|
@ -277,7 +312,12 @@ where
|
|||
|
||||
// Handles and adjusts the error for the caller to determine if they should
|
||||
// request the fallback endpoint or give up.
|
||||
fn handle_federation_error(mxc: &Mxc<'_>, user: Option<&UserId>, server: Option<&ServerName>, error: Error) -> Error {
|
||||
fn handle_federation_error(
|
||||
mxc: &Mxc<'_>,
|
||||
user: Option<&UserId>,
|
||||
server: Option<&ServerName>,
|
||||
error: Error,
|
||||
) -> Error {
|
||||
let fallback = || {
|
||||
err!(Request(NotFound(
|
||||
debug_error!(%mxc, ?user, ?server, ?error, "Remote media not found")
|
||||
|
@ -303,7 +343,8 @@ fn handle_federation_error(mxc: &Mxc<'_>, user: Option<&UserId>, server: Option<
|
|||
#[implement(super::Service)]
|
||||
#[allow(deprecated)]
|
||||
pub async fn fetch_remote_thumbnail_legacy(
|
||||
&self, body: &media::get_content_thumbnail::v3::Request,
|
||||
&self,
|
||||
body: &media::get_content_thumbnail::v3::Request,
|
||||
) -> Result<media::get_content_thumbnail::v3::Response> {
|
||||
let mxc = Mxc {
|
||||
server_name: &body.server_name,
|
||||
|
@ -315,20 +356,17 @@ pub async fn fetch_remote_thumbnail_legacy(
|
|||
let reponse = self
|
||||
.services
|
||||
.sending
|
||||
.send_federation_request(
|
||||
mxc.server_name,
|
||||
media::get_content_thumbnail::v3::Request {
|
||||
allow_remote: body.allow_remote,
|
||||
height: body.height,
|
||||
width: body.width,
|
||||
method: body.method.clone(),
|
||||
server_name: body.server_name.clone(),
|
||||
media_id: body.media_id.clone(),
|
||||
timeout_ms: body.timeout_ms,
|
||||
allow_redirect: body.allow_redirect,
|
||||
animated: body.animated,
|
||||
},
|
||||
)
|
||||
.send_federation_request(mxc.server_name, media::get_content_thumbnail::v3::Request {
|
||||
allow_remote: body.allow_remote,
|
||||
height: body.height,
|
||||
width: body.width,
|
||||
method: body.method.clone(),
|
||||
server_name: body.server_name.clone(),
|
||||
media_id: body.media_id.clone(),
|
||||
timeout_ms: body.timeout_ms,
|
||||
allow_redirect: body.allow_redirect,
|
||||
animated: body.animated,
|
||||
})
|
||||
.await?;
|
||||
|
||||
let dim = Dim::from_ruma(body.width, body.height, body.method.clone())?;
|
||||
|
@ -341,27 +379,30 @@ pub async fn fetch_remote_thumbnail_legacy(
|
|||
#[implement(super::Service)]
|
||||
#[allow(deprecated)]
|
||||
pub async fn fetch_remote_content_legacy(
|
||||
&self, mxc: &Mxc<'_>, allow_redirect: bool, timeout_ms: Duration,
|
||||
&self,
|
||||
mxc: &Mxc<'_>,
|
||||
allow_redirect: bool,
|
||||
timeout_ms: Duration,
|
||||
) -> Result<media::get_content::v3::Response, Error> {
|
||||
self.check_legacy_freeze()?;
|
||||
self.check_fetch_authorized(mxc)?;
|
||||
let response = self
|
||||
.services
|
||||
.sending
|
||||
.send_federation_request(
|
||||
mxc.server_name,
|
||||
media::get_content::v3::Request {
|
||||
allow_remote: true,
|
||||
server_name: mxc.server_name.into(),
|
||||
media_id: mxc.media_id.into(),
|
||||
timeout_ms,
|
||||
allow_redirect,
|
||||
},
|
||||
)
|
||||
.send_federation_request(mxc.server_name, media::get_content::v3::Request {
|
||||
allow_remote: true,
|
||||
server_name: mxc.server_name.into(),
|
||||
media_id: mxc.media_id.into(),
|
||||
timeout_ms,
|
||||
allow_redirect,
|
||||
})
|
||||
.await?;
|
||||
|
||||
let content_disposition =
|
||||
make_content_disposition(response.content_disposition.as_ref(), response.content_type.as_deref(), None);
|
||||
let content_disposition = make_content_disposition(
|
||||
response.content_disposition.as_ref(),
|
||||
response.content_type.as_deref(),
|
||||
None,
|
||||
);
|
||||
|
||||
self.create(
|
||||
mxc,
|
||||
|
|
|
@ -13,7 +13,12 @@ async fn long_file_names_works() {
|
|||
|
||||
impl Data for MockedKVDatabase {
|
||||
fn create_file_metadata(
|
||||
&self, _sender_user: Option<&str>, mxc: String, width: u32, height: u32, content_disposition: Option<&str>,
|
||||
&self,
|
||||
_sender_user: Option<&str>,
|
||||
mxc: String,
|
||||
width: u32,
|
||||
height: u32,
|
||||
content_disposition: Option<&str>,
|
||||
content_type: Option<&str>,
|
||||
) -> Result<Vec<u8>> {
|
||||
// copied from src/database/key_value/media.rs
|
||||
|
@ -46,14 +51,22 @@ async fn long_file_names_works() {
|
|||
fn get_all_media_keys(&self) -> Vec<Vec<u8>> { todo!() }
|
||||
|
||||
fn search_file_metadata(
|
||||
&self, _mxc: String, _width: u32, _height: u32,
|
||||
&self,
|
||||
_mxc: String,
|
||||
_width: u32,
|
||||
_height: u32,
|
||||
) -> Result<(Option<String>, Option<String>, Vec<u8>)> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn remove_url_preview(&self, _url: &str) -> Result<()> { todo!() }
|
||||
|
||||
fn set_url_preview(&self, _url: &str, _data: &UrlPreviewData, _timestamp: std::time::Duration) -> Result<()> {
|
||||
fn set_url_preview(
|
||||
&self,
|
||||
_url: &str,
|
||||
_data: &UrlPreviewData,
|
||||
_timestamp: std::time::Duration,
|
||||
) -> Result<()> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
|
@ -64,11 +77,18 @@ async fn long_file_names_works() {
|
|||
let mxc = "mxc://example.com/ascERGshawAWawugaAcauga".to_owned();
|
||||
let width = 100;
|
||||
let height = 100;
|
||||
let content_disposition = "attachment; filename=\"this is a very long file name with spaces and special \
|
||||
characters like äöüß and even emoji like 🦀.png\"";
|
||||
let content_disposition = "attachment; filename=\"this is a very long file name with spaces \
|
||||
and special characters like äöüß and even emoji like 🦀.png\"";
|
||||
let content_type = "image/png";
|
||||
let key = db
|
||||
.create_file_metadata(None, mxc, width, height, Some(content_disposition), Some(content_type))
|
||||
.create_file_metadata(
|
||||
None,
|
||||
mxc,
|
||||
width,
|
||||
height,
|
||||
Some(content_disposition),
|
||||
Some(content_type),
|
||||
)
|
||||
.unwrap();
|
||||
let mut r = PathBuf::from("/tmp/media");
|
||||
// r.push(base64::encode_config(key, base64::URL_SAFE_NO_PAD));
|
||||
|
|
|
@ -22,12 +22,17 @@ impl super::Service {
|
|||
/// Uploads or replaces a file thumbnail.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn upload_thumbnail(
|
||||
&self, mxc: &Mxc<'_>, user: Option<&UserId>, content_disposition: Option<&ContentDisposition>,
|
||||
content_type: Option<&str>, dim: &Dim, file: &[u8],
|
||||
&self,
|
||||
mxc: &Mxc<'_>,
|
||||
user: Option<&UserId>,
|
||||
content_disposition: Option<&ContentDisposition>,
|
||||
content_type: Option<&str>,
|
||||
dim: &Dim,
|
||||
file: &[u8],
|
||||
) -> Result<()> {
|
||||
let key = self
|
||||
.db
|
||||
.create_file_metadata(mxc, user, dim, content_disposition, content_type)?;
|
||||
let key =
|
||||
self.db
|
||||
.create_file_metadata(mxc, user, dim, content_disposition, content_type)?;
|
||||
|
||||
//TODO: Dangling metadata in database if creation fails
|
||||
let mut f = self.create_media_file(&key).await?;
|
||||
|
@ -78,7 +83,12 @@ impl super::Service {
|
|||
|
||||
/// Generate a thumbnail
|
||||
#[tracing::instrument(skip(self), name = "generate", level = "debug")]
|
||||
async fn get_thumbnail_generate(&self, mxc: &Mxc<'_>, dim: &Dim, data: Metadata) -> Result<Option<FileMeta>> {
|
||||
async fn get_thumbnail_generate(
|
||||
&self,
|
||||
mxc: &Mxc<'_>,
|
||||
dim: &Dim,
|
||||
data: Metadata,
|
||||
) -> Result<Option<FileMeta>> {
|
||||
let mut content = Vec::new();
|
||||
let path = self.get_media_file(&data.key);
|
||||
fs::File::open(path)
|
||||
|
@ -117,11 +127,7 @@ impl super::Service {
|
|||
|
||||
fn thumbnail_generate(image: &DynamicImage, requested: &Dim) -> Result<DynamicImage> {
|
||||
let thumbnail = if !requested.crop() {
|
||||
let Dim {
|
||||
width,
|
||||
height,
|
||||
..
|
||||
} = requested.scaled(&Dim {
|
||||
let Dim { width, height, .. } = requested.scaled(&Dim {
|
||||
width: image.width(),
|
||||
height: image.height(),
|
||||
..Dim::default()
|
||||
|
@ -202,12 +208,12 @@ impl Dim {
|
|||
#[must_use]
|
||||
pub fn normalized(&self) -> Self {
|
||||
match (self.width, self.height) {
|
||||
(0..=32, 0..=32) => Self::new(32, 32, Some(Method::Crop)),
|
||||
(0..=96, 0..=96) => Self::new(96, 96, Some(Method::Crop)),
|
||||
(0..=320, 0..=240) => Self::new(320, 240, Some(Method::Scale)),
|
||||
(0..=640, 0..=480) => Self::new(640, 480, Some(Method::Scale)),
|
||||
(0..=800, 0..=600) => Self::new(800, 600, Some(Method::Scale)),
|
||||
_ => Self::default(),
|
||||
| (0..=32, 0..=32) => Self::new(32, 32, Some(Method::Crop)),
|
||||
| (0..=96, 0..=96) => Self::new(96, 96, Some(Method::Crop)),
|
||||
| (0..=320, 0..=240) => Self::new(320, 240, Some(Method::Scale)),
|
||||
| (0..=640, 0..=480) => Self::new(640, 480, Some(Method::Scale)),
|
||||
| (0..=800, 0..=600) => Self::new(800, 600, Some(Method::Scale)),
|
||||
| _ => Self::default(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -12,7 +12,9 @@ use conduwuit::{
|
|||
use futures::{FutureExt, StreamExt};
|
||||
use itertools::Itertools;
|
||||
use ruma::{
|
||||
events::{push_rules::PushRulesEvent, room::member::MembershipState, GlobalAccountDataEventType},
|
||||
events::{
|
||||
push_rules::PushRulesEvent, room::member::MembershipState, GlobalAccountDataEventType,
|
||||
},
|
||||
push::Ruleset,
|
||||
OwnedUserId, RoomId, UserId,
|
||||
};
|
||||
|
@ -45,7 +47,8 @@ pub(crate) async fn migrations(services: &Services) -> Result<()> {
|
|||
if !services.users.exists(server_user).await {
|
||||
error!("The {server_user} server user does not exist, and the database is not new.");
|
||||
return Err!(Database(
|
||||
"Cannot reuse an existing database after changing the server name, please delete the old one first.",
|
||||
"Cannot reuse an existing database after changing the server name, please \
|
||||
delete the old one first.",
|
||||
));
|
||||
}
|
||||
}
|
||||
|
@ -144,7 +147,8 @@ async fn migrate(services: &Services) -> Result<()> {
|
|||
|
||||
assert!(
|
||||
version_match,
|
||||
"Failed asserting local database version {} is equal to known latest conduwuit database version {}",
|
||||
"Failed asserting local database version {} is equal to known latest conduwuit database \
|
||||
version {}",
|
||||
services.globals.db.database_version().await,
|
||||
DATABASE_VERSION,
|
||||
);
|
||||
|
@ -192,7 +196,8 @@ async fn migrate(services: &Services) -> Result<()> {
|
|||
let matches = patterns.matches(room_alias.alias());
|
||||
if matches.matched_any() {
|
||||
warn!(
|
||||
"Room with alias {} ({}) matches the following forbidden room name patterns: {}",
|
||||
"Room with alias {} ({}) matches the following forbidden room \
|
||||
name patterns: {}",
|
||||
room_alias,
|
||||
&room_id,
|
||||
matches
|
||||
|
@ -223,8 +228,8 @@ async fn db_lt_12(services: &Services) -> Result<()> {
|
|||
.await
|
||||
{
|
||||
let user = match UserId::parse_with_server_name(username.as_str(), &config.server_name) {
|
||||
Ok(u) => u,
|
||||
Err(e) => {
|
||||
| Ok(u) => u,
|
||||
| Err(e) => {
|
||||
warn!("Invalid username {username}: {e}");
|
||||
continue;
|
||||
},
|
||||
|
@ -240,7 +245,8 @@ async fn db_lt_12(services: &Services) -> Result<()> {
|
|||
|
||||
//content rule
|
||||
{
|
||||
let content_rule_transformation = [".m.rules.contains_user_name", ".m.rule.contains_user_name"];
|
||||
let content_rule_transformation =
|
||||
[".m.rules.contains_user_name", ".m.rule.contains_user_name"];
|
||||
|
||||
let rule = rules_list.content.get(content_rule_transformation[0]);
|
||||
if rule.is_some() {
|
||||
|
@ -301,8 +307,8 @@ async fn db_lt_13(services: &Services) -> Result<()> {
|
|||
.await
|
||||
{
|
||||
let user = match UserId::parse_with_server_name(username.as_str(), &config.server_name) {
|
||||
Ok(u) => u,
|
||||
Err(e) => {
|
||||
| Ok(u) => u,
|
||||
| Err(e) => {
|
||||
warn!("Invalid username {username}: {e}");
|
||||
continue;
|
||||
},
|
||||
|
@ -413,7 +419,9 @@ async fn retroactively_fix_bad_data_from_roomuserid_joined(services: &Services)
|
|||
.rooms
|
||||
.state_accessor
|
||||
.get_member(room_id, user_id)
|
||||
.map(|member| member.is_ok_and(|member| member.membership == MembershipState::Join))
|
||||
.map(|member| {
|
||||
member.is_ok_and(|member| member.membership == MembershipState::Join)
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.await;
|
||||
|
@ -426,7 +434,9 @@ async fn retroactively_fix_bad_data_from_roomuserid_joined(services: &Services)
|
|||
.rooms
|
||||
.state_accessor
|
||||
.get_member(room_id, user_id)
|
||||
.map(|member| member.is_ok_and(|member| member.membership == MembershipState::Join))
|
||||
.map(|member| {
|
||||
member.is_ok_and(|member| member.membership == MembershipState::Join)
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.await;
|
||||
|
@ -444,7 +454,8 @@ async fn retroactively_fix_bad_data_from_roomuserid_joined(services: &Services)
|
|||
|
||||
for room_id in &room_ids {
|
||||
debug_info!(
|
||||
"Updating joined count for room {room_id} to fix servers in room after correcting membership states"
|
||||
"Updating joined count for room {room_id} to fix servers in room after correcting \
|
||||
membership states"
|
||||
);
|
||||
|
||||
services
|
||||
|
|
|
@ -53,18 +53,22 @@ impl Data {
|
|||
}
|
||||
|
||||
pub(super) async fn set_presence(
|
||||
&self, user_id: &UserId, presence_state: &PresenceState, currently_active: Option<bool>,
|
||||
last_active_ago: Option<UInt>, status_msg: Option<String>,
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
presence_state: &PresenceState,
|
||||
currently_active: Option<bool>,
|
||||
last_active_ago: Option<UInt>,
|
||||
status_msg: Option<String>,
|
||||
) -> Result<()> {
|
||||
let last_presence = self.get_presence(user_id).await;
|
||||
let state_changed = match last_presence {
|
||||
Err(_) => true,
|
||||
Ok(ref presence) => presence.1.content.presence != *presence_state,
|
||||
| Err(_) => true,
|
||||
| Ok(ref presence) => presence.1.content.presence != *presence_state,
|
||||
};
|
||||
|
||||
let status_msg_changed = match last_presence {
|
||||
Err(_) => true,
|
||||
Ok(ref last_presence) => {
|
||||
| Err(_) => true,
|
||||
| Ok(ref last_presence) => {
|
||||
let old_msg = last_presence
|
||||
.1
|
||||
.content
|
||||
|
@ -80,18 +84,22 @@ impl Data {
|
|||
|
||||
let now = utils::millis_since_unix_epoch();
|
||||
let last_last_active_ts = match last_presence {
|
||||
Err(_) => 0,
|
||||
Ok((_, ref presence)) => now.saturating_sub(presence.content.last_active_ago.unwrap_or_default().into()),
|
||||
| Err(_) => 0,
|
||||
| Ok((_, ref presence)) =>
|
||||
now.saturating_sub(presence.content.last_active_ago.unwrap_or_default().into()),
|
||||
};
|
||||
|
||||
let last_active_ts = match last_active_ago {
|
||||
None => now,
|
||||
Some(last_active_ago) => now.saturating_sub(last_active_ago.into()),
|
||||
| None => now,
|
||||
| Some(last_active_ago) => now.saturating_sub(last_active_ago.into()),
|
||||
};
|
||||
|
||||
// TODO: tighten for state flicker?
|
||||
if !status_msg_changed && !state_changed && last_active_ts < last_last_active_ts {
|
||||
debug_warn!("presence spam {user_id:?} last_active_ts:{last_active_ts:?} < {last_last_active_ts:?}",);
|
||||
debug_warn!(
|
||||
"presence spam {user_id:?} last_active_ts:{last_active_ts:?} < \
|
||||
{last_last_active_ts:?}",
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
|
@ -138,7 +146,10 @@ impl Data {
|
|||
}
|
||||
|
||||
#[inline]
|
||||
pub(super) fn presence_since(&self, since: u64) -> impl Stream<Item = (&UserId, u64, &[u8])> + Send + '_ {
|
||||
pub(super) fn presence_since(
|
||||
&self,
|
||||
since: u64,
|
||||
) -> impl Stream<Item = (&UserId, u64, &[u8])> + Send + '_ {
|
||||
self.presenceid_presence
|
||||
.raw_stream()
|
||||
.ignore_err()
|
||||
|
|
|
@ -99,13 +99,14 @@ impl Service {
|
|||
|
||||
let last_presence = self.db.get_presence(user_id).await;
|
||||
let state_changed = match last_presence {
|
||||
Err(_) => true,
|
||||
Ok((_, ref presence)) => presence.content.presence != *new_state,
|
||||
| Err(_) => true,
|
||||
| Ok((_, ref presence)) => presence.content.presence != *new_state,
|
||||
};
|
||||
|
||||
let last_last_active_ago = match last_presence {
|
||||
Err(_) => 0_u64,
|
||||
Ok((_, ref presence)) => presence.content.last_active_ago.unwrap_or_default().into(),
|
||||
| Err(_) => 0_u64,
|
||||
| Ok((_, ref presence)) =>
|
||||
presence.content.last_active_ago.unwrap_or_default().into(),
|
||||
};
|
||||
|
||||
if !state_changed && last_last_active_ago < REFRESH_TIMEOUT {
|
||||
|
@ -113,8 +114,8 @@ impl Service {
|
|||
}
|
||||
|
||||
let status_msg = match last_presence {
|
||||
Ok((_, ref presence)) => presence.content.status_msg.clone(),
|
||||
Err(_) => Some(String::new()),
|
||||
| Ok((_, ref presence)) => presence.content.status_msg.clone(),
|
||||
| Err(_) => Some(String::new()),
|
||||
};
|
||||
|
||||
let last_active_ago = UInt::new(0);
|
||||
|
@ -125,12 +126,16 @@ impl Service {
|
|||
|
||||
/// Adds a presence event which will be saved until a new event replaces it.
|
||||
pub async fn set_presence(
|
||||
&self, user_id: &UserId, state: &PresenceState, currently_active: Option<bool>, last_active_ago: Option<UInt>,
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
state: &PresenceState,
|
||||
currently_active: Option<bool>,
|
||||
last_active_ago: Option<UInt>,
|
||||
status_msg: Option<String>,
|
||||
) -> Result<()> {
|
||||
let presence_state = match state.as_str() {
|
||||
"" => &PresenceState::Offline, // default an empty string to 'offline'
|
||||
&_ => state,
|
||||
| "" => &PresenceState::Offline, // default an empty string to 'offline'
|
||||
| &_ => state,
|
||||
};
|
||||
|
||||
self.db
|
||||
|
@ -141,8 +146,8 @@ impl Service {
|
|||
&& user_id != self.services.globals.server_user
|
||||
{
|
||||
let timeout = match presence_state {
|
||||
PresenceState::Online => self.services.server.config.presence_idle_timeout_s,
|
||||
_ => self.services.server.config.presence_offline_timeout_s,
|
||||
| PresenceState::Online => self.services.server.config.presence_idle_timeout_s,
|
||||
| _ => self.services.server.config.presence_offline_timeout_s,
|
||||
};
|
||||
|
||||
self.timer_sender
|
||||
|
@ -160,16 +165,25 @@ impl Service {
|
|||
///
|
||||
/// TODO: Why is this not used?
|
||||
#[allow(dead_code)]
|
||||
pub async fn remove_presence(&self, user_id: &UserId) { self.db.remove_presence(user_id).await }
|
||||
pub async fn remove_presence(&self, user_id: &UserId) {
|
||||
self.db.remove_presence(user_id).await;
|
||||
}
|
||||
|
||||
/// Returns the most recent presence updates that happened after the event
|
||||
/// with id `since`.
|
||||
pub fn presence_since(&self, since: u64) -> impl Stream<Item = (&UserId, u64, &[u8])> + Send + '_ {
|
||||
pub fn presence_since(
|
||||
&self,
|
||||
since: u64,
|
||||
) -> impl Stream<Item = (&UserId, u64, &[u8])> + Send + '_ {
|
||||
self.db.presence_since(since)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub async fn from_json_bytes_to_event(&self, bytes: &[u8], user_id: &UserId) -> Result<PresenceEvent> {
|
||||
pub async fn from_json_bytes_to_event(
|
||||
&self,
|
||||
bytes: &[u8],
|
||||
user_id: &UserId,
|
||||
) -> Result<PresenceEvent> {
|
||||
let presence = Presence::from_json_bytes(bytes)?;
|
||||
let event = presence
|
||||
.to_presence_event(user_id, &self.services.users)
|
||||
|
@ -192,13 +206,16 @@ impl Service {
|
|||
}
|
||||
|
||||
let new_state = match (&presence_state, last_active_ago.map(u64::from)) {
|
||||
(PresenceState::Online, Some(ago)) if ago >= self.idle_timeout => Some(PresenceState::Unavailable),
|
||||
(PresenceState::Unavailable, Some(ago)) if ago >= self.offline_timeout => Some(PresenceState::Offline),
|
||||
_ => None,
|
||||
| (PresenceState::Online, Some(ago)) if ago >= self.idle_timeout =>
|
||||
Some(PresenceState::Unavailable),
|
||||
| (PresenceState::Unavailable, Some(ago)) if ago >= self.offline_timeout =>
|
||||
Some(PresenceState::Offline),
|
||||
| _ => None,
|
||||
};
|
||||
|
||||
debug!(
|
||||
"Processed presence timer for user '{user_id}': Old state = {presence_state}, New state = {new_state:?}"
|
||||
"Processed presence timer for user '{user_id}': Old state = {presence_state}, New \
|
||||
state = {new_state:?}"
|
||||
);
|
||||
|
||||
if let Some(new_state) = new_state {
|
||||
|
|
|
@ -21,7 +21,10 @@ pub(super) struct Presence {
|
|||
impl Presence {
|
||||
#[must_use]
|
||||
pub(super) fn new(
|
||||
state: PresenceState, currently_active: bool, last_active_ts: u64, status_msg: Option<String>,
|
||||
state: PresenceState,
|
||||
currently_active: bool,
|
||||
last_active_ts: u64,
|
||||
status_msg: Option<String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
state,
|
||||
|
@ -32,11 +35,16 @@ impl Presence {
|
|||
}
|
||||
|
||||
pub(super) fn from_json_bytes(bytes: &[u8]) -> Result<Self> {
|
||||
serde_json::from_slice(bytes).map_err(|_| Error::bad_database("Invalid presence data in database"))
|
||||
serde_json::from_slice(bytes)
|
||||
.map_err(|_| Error::bad_database("Invalid presence data in database"))
|
||||
}
|
||||
|
||||
/// Creates a PresenceEvent from available data.
|
||||
pub(super) async fn to_presence_event(&self, user_id: &UserId, users: &users::Service) -> PresenceEvent {
|
||||
pub(super) async fn to_presence_event(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
users: &users::Service,
|
||||
) -> PresenceEvent {
|
||||
let now = utils::millis_since_unix_epoch();
|
||||
let last_active_ago = if self.currently_active {
|
||||
None
|
||||
|
|
|
@ -19,9 +19,12 @@ use ruma::{
|
|||
IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken,
|
||||
},
|
||||
events::{
|
||||
room::power_levels::RoomPowerLevelsEventContent, AnySyncTimelineEvent, StateEventType, TimelineEventType,
|
||||
room::power_levels::RoomPowerLevelsEventContent, AnySyncTimelineEvent, StateEventType,
|
||||
TimelineEventType,
|
||||
},
|
||||
push::{
|
||||
Action, PushConditionPowerLevelsCtx, PushConditionRoomCtx, PushFormat, Ruleset, Tweak,
|
||||
},
|
||||
push::{Action, PushConditionPowerLevelsCtx, PushConditionRoomCtx, PushFormat, Ruleset, Tweak},
|
||||
serde::Raw,
|
||||
uint, RoomId, UInt, UserId,
|
||||
};
|
||||
|
@ -55,7 +58,8 @@ impl crate::Service for Service {
|
|||
services: Services {
|
||||
globals: args.depend::<globals::Service>("globals"),
|
||||
client: args.depend::<client::Service>("client"),
|
||||
state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
|
||||
state_accessor: args
|
||||
.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
|
||||
state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"),
|
||||
users: args.depend::<users::Service>("users"),
|
||||
sending: args.depend::<sending::Service>("sending"),
|
||||
|
@ -67,23 +71,31 @@ impl crate::Service for Service {
|
|||
}
|
||||
|
||||
impl Service {
|
||||
pub async fn set_pusher(&self, sender: &UserId, pusher: &set_pusher::v3::PusherAction) -> Result {
|
||||
pub async fn set_pusher(
|
||||
&self,
|
||||
sender: &UserId,
|
||||
pusher: &set_pusher::v3::PusherAction,
|
||||
) -> Result {
|
||||
match pusher {
|
||||
set_pusher::v3::PusherAction::Post(data) => {
|
||||
| set_pusher::v3::PusherAction::Post(data) => {
|
||||
let pushkey = data.pusher.ids.pushkey.as_str();
|
||||
|
||||
if pushkey.len() > 512 {
|
||||
return Err!(Request(InvalidParam("Push key length cannot be greater than 512 bytes.")));
|
||||
return Err!(Request(InvalidParam(
|
||||
"Push key length cannot be greater than 512 bytes."
|
||||
)));
|
||||
}
|
||||
|
||||
if data.pusher.ids.app_id.as_str().len() > 64 {
|
||||
return Err!(Request(InvalidParam("App ID length cannot be greater than 64 bytes.")));
|
||||
return Err!(Request(InvalidParam(
|
||||
"App ID length cannot be greater than 64 bytes."
|
||||
)));
|
||||
}
|
||||
|
||||
let key = (sender, data.pusher.ids.pushkey.as_str());
|
||||
self.db.senderkey_pusher.put(key, Json(pusher));
|
||||
},
|
||||
set_pusher::v3::PusherAction::Delete(ids) => {
|
||||
| set_pusher::v3::PusherAction::Delete(ids) => {
|
||||
let key = (sender, ids.pushkey.as_str());
|
||||
self.db.senderkey_pusher.del(key);
|
||||
|
||||
|
@ -118,7 +130,10 @@ impl Service {
|
|||
.await
|
||||
}
|
||||
|
||||
pub fn get_pushkeys<'a>(&'a self, sender: &'a UserId) -> impl Stream<Item = &str> + Send + 'a {
|
||||
pub fn get_pushkeys<'a>(
|
||||
&'a self,
|
||||
sender: &'a UserId,
|
||||
) -> impl Stream<Item = &str> + Send + 'a {
|
||||
let prefix = (sender, Interfix);
|
||||
self.db
|
||||
.senderkey_pusher
|
||||
|
@ -160,14 +175,16 @@ impl Service {
|
|||
let response = self.services.client.pusher.execute(reqwest_request).await;
|
||||
|
||||
match response {
|
||||
Ok(mut response) => {
|
||||
| Ok(mut response) => {
|
||||
// reqwest::Response -> http::Response conversion
|
||||
|
||||
trace!("Checking response destination's IP");
|
||||
if let Some(remote_addr) = response.remote_addr() {
|
||||
if let Ok(ip) = IPAddress::parse(remote_addr.ip().to_string()) {
|
||||
if !self.services.client.valid_cidr_range(&ip) {
|
||||
return Err!(BadServerResponse("Not allowed to send requests to this IP"));
|
||||
return Err!(BadServerResponse(
|
||||
"Not allowed to send requests to this IP"
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -197,10 +214,13 @@ impl Service {
|
|||
.body(body)
|
||||
.expect("reqwest body is valid http body"),
|
||||
);
|
||||
response
|
||||
.map_err(|e| err!(BadServerResponse(warn!("Push gateway {dest} returned invalid response: {e}"))))
|
||||
response.map_err(|e| {
|
||||
err!(BadServerResponse(warn!(
|
||||
"Push gateway {dest} returned invalid response: {e}"
|
||||
)))
|
||||
})
|
||||
},
|
||||
Err(e) => {
|
||||
| Err(e) => {
|
||||
warn!("Could not send request to pusher {dest}: {e}");
|
||||
Err(e.into())
|
||||
},
|
||||
|
@ -209,7 +229,12 @@ impl Service {
|
|||
|
||||
#[tracing::instrument(skip(self, user, unread, pusher, ruleset, pdu))]
|
||||
pub async fn send_push_notice(
|
||||
&self, user: &UserId, unread: UInt, pusher: &Pusher, ruleset: Ruleset, pdu: &PduEvent,
|
||||
&self,
|
||||
user: &UserId,
|
||||
unread: UInt,
|
||||
pusher: &Pusher,
|
||||
ruleset: Ruleset,
|
||||
pdu: &PduEvent,
|
||||
) -> Result<()> {
|
||||
let mut notify = None;
|
||||
let mut tweaks = Vec::new();
|
||||
|
@ -220,8 +245,9 @@ impl Service {
|
|||
.room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")
|
||||
.await
|
||||
.and_then(|ev| {
|
||||
serde_json::from_str(ev.content.get())
|
||||
.map_err(|e| err!(Database(error!("invalid m.room.power_levels event: {e:?}"))))
|
||||
serde_json::from_str(ev.content.get()).map_err(|e| {
|
||||
err!(Database(error!("invalid m.room.power_levels event: {e:?}")))
|
||||
})
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
|
@ -230,12 +256,12 @@ impl Service {
|
|||
.await
|
||||
{
|
||||
let n = match action {
|
||||
Action::Notify => true,
|
||||
Action::SetTweak(tweak) => {
|
||||
| Action::Notify => true,
|
||||
| Action::SetTweak(tweak) => {
|
||||
tweaks.push(tweak.clone());
|
||||
continue;
|
||||
},
|
||||
_ => false,
|
||||
| _ => false,
|
||||
};
|
||||
|
||||
if notify.is_some() {
|
||||
|
@ -257,8 +283,12 @@ impl Service {
|
|||
|
||||
#[tracing::instrument(skip(self, user, ruleset, pdu), level = "debug")]
|
||||
pub async fn get_actions<'a>(
|
||||
&self, user: &UserId, ruleset: &'a Ruleset, power_levels: &RoomPowerLevelsEventContent,
|
||||
pdu: &Raw<AnySyncTimelineEvent>, room_id: &RoomId,
|
||||
&self,
|
||||
user: &UserId,
|
||||
ruleset: &'a Ruleset,
|
||||
power_levels: &RoomPowerLevelsEventContent,
|
||||
pdu: &Raw<AnySyncTimelineEvent>,
|
||||
room_id: &RoomId,
|
||||
) -> &'a [Action] {
|
||||
let power_levels = PushConditionPowerLevelsCtx {
|
||||
users: power_levels.users.clone(),
|
||||
|
@ -294,14 +324,21 @@ impl Service {
|
|||
}
|
||||
|
||||
#[tracing::instrument(skip(self, unread, pusher, tweaks, event))]
|
||||
async fn send_notice(&self, unread: UInt, pusher: &Pusher, tweaks: Vec<Tweak>, event: &PduEvent) -> Result<()> {
|
||||
async fn send_notice(
|
||||
&self,
|
||||
unread: UInt,
|
||||
pusher: &Pusher,
|
||||
tweaks: Vec<Tweak>,
|
||||
event: &PduEvent,
|
||||
) -> Result<()> {
|
||||
// TODO: email
|
||||
match &pusher.kind {
|
||||
PusherKind::Http(http) => {
|
||||
| PusherKind::Http(http) => {
|
||||
// TODO (timo): can pusher/devices have conflicting formats
|
||||
let event_id_only = http.format == Some(PushFormat::EventIdOnly);
|
||||
|
||||
let mut device = Device::new(pusher.ids.app_id.clone(), pusher.ids.pushkey.clone());
|
||||
let mut device =
|
||||
Device::new(pusher.ids.app_id.clone(), pusher.ids.pushkey.clone());
|
||||
device.data.default_payload = http.default_payload.clone();
|
||||
device.data.format.clone_from(&http.format);
|
||||
|
||||
|
@ -319,8 +356,11 @@ impl Service {
|
|||
notifi.counts = NotificationCounts::new(unread, uint!(0));
|
||||
|
||||
if event_id_only {
|
||||
self.send_request(&http.url, send_event_notification::v1::Request::new(notifi))
|
||||
.await?;
|
||||
self.send_request(
|
||||
&http.url,
|
||||
send_event_notification::v1::Request::new(notifi),
|
||||
)
|
||||
.await?;
|
||||
} else {
|
||||
if event.kind == TimelineEventType::RoomEncrypted
|
||||
|| tweaks
|
||||
|
@ -336,10 +376,12 @@ impl Service {
|
|||
notifi.content = serde_json::value::to_raw_value(&event.content).ok();
|
||||
|
||||
if event.kind == TimelineEventType::RoomMember {
|
||||
notifi.user_is_target = event.state_key.as_deref() == Some(event.sender.as_str());
|
||||
notifi.user_is_target =
|
||||
event.state_key.as_deref() == Some(event.sender.as_str());
|
||||
}
|
||||
|
||||
notifi.sender_display_name = self.services.users.displayname(&event.sender).await.ok();
|
||||
notifi.sender_display_name =
|
||||
self.services.users.displayname(&event.sender).await.ok();
|
||||
|
||||
notifi.room_name = self
|
||||
.services
|
||||
|
@ -355,15 +397,18 @@ impl Service {
|
|||
.await
|
||||
.ok();
|
||||
|
||||
self.send_request(&http.url, send_event_notification::v1::Request::new(notifi))
|
||||
.await?;
|
||||
self.send_request(
|
||||
&http.url,
|
||||
send_event_notification::v1::Request::new(notifi),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
},
|
||||
// TODO: Handle email
|
||||
//PusherKind::Email(_) => Ok(()),
|
||||
_ => Ok(()),
|
||||
| _ => Ok(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -35,17 +35,9 @@ impl super::Service {
|
|||
(self.resolve_actual_dest(server_name, true).await?, false)
|
||||
};
|
||||
|
||||
let CachedDest {
|
||||
dest,
|
||||
host,
|
||||
..
|
||||
} = result;
|
||||
let CachedDest { dest, host, .. } = result;
|
||||
|
||||
Ok(ActualDest {
|
||||
dest,
|
||||
host,
|
||||
cached,
|
||||
})
|
||||
Ok(ActualDest { dest, host, cached })
|
||||
}
|
||||
|
||||
/// Returns: `actual_destination`, host header
|
||||
|
@ -53,12 +45,16 @@ impl super::Service {
|
|||
/// Numbers in comments below refer to bullet points in linked section of
|
||||
/// specification
|
||||
#[tracing::instrument(skip_all, name = "actual")]
|
||||
pub async fn resolve_actual_dest(&self, dest: &ServerName, cache: bool) -> Result<CachedDest> {
|
||||
pub async fn resolve_actual_dest(
|
||||
&self,
|
||||
dest: &ServerName,
|
||||
cache: bool,
|
||||
) -> Result<CachedDest> {
|
||||
trace!("Finding actual destination for {dest}");
|
||||
let mut host = dest.as_str().to_owned();
|
||||
let actual_dest = match get_ip_with_port(dest.as_str()) {
|
||||
Some(host_port) => Self::actual_dest_1(host_port)?,
|
||||
None => {
|
||||
| Some(host_port) => Self::actual_dest_1(host_port)?,
|
||||
| None =>
|
||||
if let Some(pos) = dest.as_str().find(':') {
|
||||
self.actual_dest_2(dest, cache, pos).await?
|
||||
} else if let Some(delegated) = self.request_well_known(dest.as_str()).await? {
|
||||
|
@ -67,8 +63,7 @@ impl super::Service {
|
|||
self.actual_dest_4(&host, cache, overrider).await?
|
||||
} else {
|
||||
self.actual_dest_5(dest, cache).await?
|
||||
}
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
// Can't use get_ip_with_port here because we don't want to add a port
|
||||
|
@ -79,7 +74,10 @@ impl super::Service {
|
|||
FedDest::Named(addr.to_string(), FedDest::default_port())
|
||||
} else if let Some(pos) = host.find(':') {
|
||||
let (host, port) = host.split_at(pos);
|
||||
FedDest::Named(host.to_owned(), port.try_into().unwrap_or_else(|_| FedDest::default_port()))
|
||||
FedDest::Named(
|
||||
host.to_owned(),
|
||||
port.try_into().unwrap_or_else(|_| FedDest::default_port()),
|
||||
)
|
||||
} else {
|
||||
FedDest::Named(host, FedDest::default_port())
|
||||
};
|
||||
|
@ -100,20 +98,30 @@ impl super::Service {
|
|||
async fn actual_dest_2(&self, dest: &ServerName, cache: bool, pos: usize) -> Result<FedDest> {
|
||||
debug!("2: Hostname with included port");
|
||||
let (host, port) = dest.as_str().split_at(pos);
|
||||
self.conditional_query_and_cache_override(host, host, port.parse::<u16>().unwrap_or(8448), cache)
|
||||
.await?;
|
||||
self.conditional_query_and_cache_override(
|
||||
host,
|
||||
host,
|
||||
port.parse::<u16>().unwrap_or(8448),
|
||||
cache,
|
||||
)
|
||||
.await?;
|
||||
Ok(FedDest::Named(
|
||||
host.to_owned(),
|
||||
port.try_into().unwrap_or_else(|_| FedDest::default_port()),
|
||||
))
|
||||
}
|
||||
|
||||
async fn actual_dest_3(&self, host: &mut String, cache: bool, delegated: String) -> Result<FedDest> {
|
||||
async fn actual_dest_3(
|
||||
&self,
|
||||
host: &mut String,
|
||||
cache: bool,
|
||||
delegated: String,
|
||||
) -> Result<FedDest> {
|
||||
debug!("3: A .well-known file is available");
|
||||
*host = add_port_to_hostname(&delegated).uri_string();
|
||||
match get_ip_with_port(&delegated) {
|
||||
Some(host_and_port) => Self::actual_dest_3_1(host_and_port),
|
||||
None => {
|
||||
| Some(host_and_port) => Self::actual_dest_3_1(host_and_port),
|
||||
| None =>
|
||||
if let Some(pos) = delegated.find(':') {
|
||||
self.actual_dest_3_2(cache, delegated, pos).await
|
||||
} else {
|
||||
|
@ -123,8 +131,7 @@ impl super::Service {
|
|||
} else {
|
||||
self.actual_dest_3_4(cache, delegated).await
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -133,22 +140,42 @@ impl super::Service {
|
|||
Ok(host_and_port)
|
||||
}
|
||||
|
||||
async fn actual_dest_3_2(&self, cache: bool, delegated: String, pos: usize) -> Result<FedDest> {
|
||||
async fn actual_dest_3_2(
|
||||
&self,
|
||||
cache: bool,
|
||||
delegated: String,
|
||||
pos: usize,
|
||||
) -> Result<FedDest> {
|
||||
debug!("3.2: Hostname with port in .well-known file");
|
||||
let (host, port) = delegated.split_at(pos);
|
||||
self.conditional_query_and_cache_override(host, host, port.parse::<u16>().unwrap_or(8448), cache)
|
||||
.await?;
|
||||
self.conditional_query_and_cache_override(
|
||||
host,
|
||||
host,
|
||||
port.parse::<u16>().unwrap_or(8448),
|
||||
cache,
|
||||
)
|
||||
.await?;
|
||||
Ok(FedDest::Named(
|
||||
host.to_owned(),
|
||||
port.try_into().unwrap_or_else(|_| FedDest::default_port()),
|
||||
))
|
||||
}
|
||||
|
||||
async fn actual_dest_3_3(&self, cache: bool, delegated: String, overrider: FedDest) -> Result<FedDest> {
|
||||
async fn actual_dest_3_3(
|
||||
&self,
|
||||
cache: bool,
|
||||
delegated: String,
|
||||
overrider: FedDest,
|
||||
) -> Result<FedDest> {
|
||||
debug!("3.3: SRV lookup successful");
|
||||
let force_port = overrider.port();
|
||||
self.conditional_query_and_cache_override(&delegated, &overrider.hostname(), force_port.unwrap_or(8448), cache)
|
||||
.await?;
|
||||
self.conditional_query_and_cache_override(
|
||||
&delegated,
|
||||
&overrider.hostname(),
|
||||
force_port.unwrap_or(8448),
|
||||
cache,
|
||||
)
|
||||
.await?;
|
||||
if let Some(port) = force_port {
|
||||
Ok(FedDest::Named(
|
||||
delegated,
|
||||
|
@ -169,11 +196,21 @@ impl super::Service {
|
|||
Ok(add_port_to_hostname(&delegated))
|
||||
}
|
||||
|
||||
async fn actual_dest_4(&self, host: &str, cache: bool, overrider: FedDest) -> Result<FedDest> {
|
||||
async fn actual_dest_4(
|
||||
&self,
|
||||
host: &str,
|
||||
cache: bool,
|
||||
overrider: FedDest,
|
||||
) -> Result<FedDest> {
|
||||
debug!("4: No .well-known; SRV record found");
|
||||
let force_port = overrider.port();
|
||||
self.conditional_query_and_cache_override(host, &overrider.hostname(), force_port.unwrap_or(8448), cache)
|
||||
.await?;
|
||||
self.conditional_query_and_cache_override(
|
||||
host,
|
||||
&overrider.hostname(),
|
||||
force_port.unwrap_or(8448),
|
||||
cache,
|
||||
)
|
||||
.await?;
|
||||
if let Some(port) = force_port {
|
||||
let port = format!(":{port}");
|
||||
Ok(FedDest::Named(
|
||||
|
@ -245,7 +282,11 @@ impl super::Service {
|
|||
|
||||
#[inline]
|
||||
async fn conditional_query_and_cache_override(
|
||||
&self, overname: &str, hostname: &str, port: u16, cache: bool,
|
||||
&self,
|
||||
overname: &str,
|
||||
hostname: &str,
|
||||
port: u16,
|
||||
cache: bool,
|
||||
) -> Result<()> {
|
||||
if cache {
|
||||
self.query_and_cache_override(overname, hostname, port)
|
||||
|
@ -256,22 +297,24 @@ impl super::Service {
|
|||
}
|
||||
|
||||
#[tracing::instrument(skip_all, name = "ip")]
|
||||
async fn query_and_cache_override(&self, overname: &'_ str, hostname: &'_ str, port: u16) -> Result<()> {
|
||||
async fn query_and_cache_override(
|
||||
&self,
|
||||
overname: &'_ str,
|
||||
hostname: &'_ str,
|
||||
port: u16,
|
||||
) -> Result<()> {
|
||||
match self.resolver.resolver.lookup_ip(hostname.to_owned()).await {
|
||||
Err(e) => Self::handle_resolve_error(&e, hostname),
|
||||
Ok(override_ip) => {
|
||||
| Err(e) => Self::handle_resolve_error(&e, hostname),
|
||||
| Ok(override_ip) => {
|
||||
if hostname != overname {
|
||||
debug_info!("{overname:?} overriden by {hostname:?}");
|
||||
}
|
||||
|
||||
self.set_cached_override(
|
||||
overname,
|
||||
CachedOverride {
|
||||
ips: override_ip.into_iter().take(MAX_IPS).collect(),
|
||||
port,
|
||||
expire: CachedOverride::default_expire(),
|
||||
},
|
||||
);
|
||||
self.set_cached_override(overname, CachedOverride {
|
||||
ips: override_ip.into_iter().take(MAX_IPS).collect(),
|
||||
port,
|
||||
expire: CachedOverride::default_expire(),
|
||||
});
|
||||
|
||||
Ok(())
|
||||
},
|
||||
|
@ -280,14 +323,15 @@ impl super::Service {
|
|||
|
||||
#[tracing::instrument(skip_all, name = "srv")]
|
||||
async fn query_srv_record(&self, hostname: &'_ str) -> Result<Option<FedDest>> {
|
||||
let hostnames = [format!("_matrix-fed._tcp.{hostname}."), format!("_matrix._tcp.{hostname}.")];
|
||||
let hostnames =
|
||||
[format!("_matrix-fed._tcp.{hostname}."), format!("_matrix._tcp.{hostname}.")];
|
||||
|
||||
for hostname in hostnames {
|
||||
debug!("querying SRV for {hostname:?}");
|
||||
let hostname = hostname.trim_end_matches('.');
|
||||
match self.resolver.resolver.srv_lookup(hostname).await {
|
||||
Err(e) => Self::handle_resolve_error(&e, hostname)?,
|
||||
Ok(result) => {
|
||||
| Err(e) => Self::handle_resolve_error(&e, hostname)?,
|
||||
| Ok(result) =>
|
||||
return Ok(result.iter().next().map(|result| {
|
||||
FedDest::Named(
|
||||
result.target().to_string().trim_end_matches('.').to_owned(),
|
||||
|
@ -296,8 +340,7 @@ impl super::Service {
|
|||
.try_into()
|
||||
.unwrap_or_else(|_| FedDest::default_port()),
|
||||
)
|
||||
}))
|
||||
},
|
||||
})),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -308,25 +351,24 @@ impl super::Service {
|
|||
use hickory_resolver::error::ResolveErrorKind;
|
||||
|
||||
match *e.kind() {
|
||||
ResolveErrorKind::NoRecordsFound {
|
||||
..
|
||||
} => {
|
||||
| ResolveErrorKind::NoRecordsFound { .. } => {
|
||||
// Raise to debug_warn if we can find out the result wasn't from cache
|
||||
debug!(%host, "No DNS records found: {e}");
|
||||
Ok(())
|
||||
},
|
||||
ResolveErrorKind::Timeout => {
|
||||
| ResolveErrorKind::Timeout => {
|
||||
Err!(warn!(%host, "DNS {e}"))
|
||||
},
|
||||
ResolveErrorKind::NoConnections => {
|
||||
| ResolveErrorKind::NoConnections => {
|
||||
error!(
|
||||
"Your DNS server is overloaded and has ran out of connections. It is strongly recommended you \
|
||||
remediate this issue to ensure proper federation connectivity."
|
||||
"Your DNS server is overloaded and has ran out of connections. It is \
|
||||
strongly recommended you remediate this issue to ensure proper federation \
|
||||
connectivity."
|
||||
);
|
||||
|
||||
Err!(error!(%host, "DNS error: {e}"))
|
||||
},
|
||||
_ => Err!(error!(%host, "DNS error: {e}")),
|
||||
| _ => Err!(error!(%host, "DNS error: {e}")),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -349,8 +391,9 @@ impl super::Service {
|
|||
dest.is_ip_literal() || !IPAddress::is_valid(dest.host()),
|
||||
"Destination is not an IP literal."
|
||||
);
|
||||
let ip = IPAddress::parse(dest.host())
|
||||
.map_err(|e| err!(BadServerResponse(debug_error!("Failed to parse IP literal from string: {e}"))))?;
|
||||
let ip = IPAddress::parse(dest.host()).map_err(|e| {
|
||||
err!(BadServerResponse(debug_error!("Failed to parse IP literal from string: {e}")))
|
||||
})?;
|
||||
|
||||
self.validate_ip(&ip)?;
|
||||
|
||||
|
|
|
@ -46,7 +46,11 @@ impl Cache {
|
|||
}
|
||||
|
||||
impl super::Service {
|
||||
pub fn set_cached_destination(&self, name: OwnedServerName, dest: CachedDest) -> Option<CachedDest> {
|
||||
pub fn set_cached_destination(
|
||||
&self,
|
||||
name: OwnedServerName,
|
||||
dest: CachedDest,
|
||||
) -> Option<CachedDest> {
|
||||
trace!(?name, ?dest, "set cached destination");
|
||||
self.cache
|
||||
.destinations
|
||||
|
@ -65,7 +69,11 @@ impl super::Service {
|
|||
.cloned()
|
||||
}
|
||||
|
||||
pub fn set_cached_override(&self, name: &str, over: CachedOverride) -> Option<CachedOverride> {
|
||||
pub fn set_cached_override(
|
||||
&self,
|
||||
name: &str,
|
||||
over: CachedOverride,
|
||||
) -> Option<CachedOverride> {
|
||||
trace!(?name, ?over, "set cached override");
|
||||
self.cache
|
||||
.overrides
|
||||
|
@ -102,7 +110,9 @@ impl CachedDest {
|
|||
//pub fn valid(&self) -> bool { self.expire > SystemTime::now() }
|
||||
|
||||
#[must_use]
|
||||
pub(crate) fn default_expire() -> SystemTime { rand::timepoint_secs(60 * 60 * 18..60 * 60 * 36) }
|
||||
pub(crate) fn default_expire() -> SystemTime {
|
||||
rand::timepoint_secs(60 * 60 * 18..60 * 60 * 36)
|
||||
}
|
||||
}
|
||||
|
||||
impl CachedOverride {
|
||||
|
@ -113,5 +123,7 @@ impl CachedOverride {
|
|||
//pub fn valid(&self) -> bool { self.expire > SystemTime::now() }
|
||||
|
||||
#[must_use]
|
||||
pub(crate) fn default_expire() -> SystemTime { rand::timepoint_secs(60 * 60 * 6..60 * 60 * 12) }
|
||||
pub(crate) fn default_expire() -> SystemTime {
|
||||
rand::timepoint_secs(60 * 60 * 6..60 * 60 * 12)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -60,27 +60,26 @@ impl Resolver {
|
|||
opts.shuffle_dns_servers = true;
|
||||
opts.rotate = true;
|
||||
opts.ip_strategy = match config.ip_lookup_strategy {
|
||||
1 => hickory_resolver::config::LookupIpStrategy::Ipv4Only,
|
||||
2 => hickory_resolver::config::LookupIpStrategy::Ipv6Only,
|
||||
3 => hickory_resolver::config::LookupIpStrategy::Ipv4AndIpv6,
|
||||
4 => hickory_resolver::config::LookupIpStrategy::Ipv6thenIpv4,
|
||||
_ => hickory_resolver::config::LookupIpStrategy::Ipv4thenIpv6,
|
||||
| 1 => hickory_resolver::config::LookupIpStrategy::Ipv4Only,
|
||||
| 2 => hickory_resolver::config::LookupIpStrategy::Ipv6Only,
|
||||
| 3 => hickory_resolver::config::LookupIpStrategy::Ipv4AndIpv6,
|
||||
| 4 => hickory_resolver::config::LookupIpStrategy::Ipv6thenIpv4,
|
||||
| _ => hickory_resolver::config::LookupIpStrategy::Ipv4thenIpv6,
|
||||
};
|
||||
opts.authentic_data = false;
|
||||
|
||||
let resolver = Arc::new(TokioAsyncResolver::tokio(conf, opts));
|
||||
Ok(Arc::new(Self {
|
||||
resolver: resolver.clone(),
|
||||
hooked: Arc::new(Hooked {
|
||||
resolver,
|
||||
cache,
|
||||
}),
|
||||
hooked: Arc::new(Hooked { resolver, cache }),
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
impl Resolve for Resolver {
|
||||
fn resolve(&self, name: Name) -> Resolving { resolve_to_reqwest(self.resolver.clone(), name).boxed() }
|
||||
fn resolve(&self, name: Name) -> Resolving {
|
||||
resolve_to_reqwest(self.resolver.clone(), name).boxed()
|
||||
}
|
||||
}
|
||||
|
||||
impl Resolve for Hooked {
|
||||
|
|
|
@ -29,8 +29,8 @@ pub(crate) fn get_ip_with_port(dest_str: &str) -> Option<FedDest> {
|
|||
|
||||
pub(crate) fn add_port_to_hostname(dest: &str) -> FedDest {
|
||||
let (host, port) = match dest.find(':') {
|
||||
None => (dest, DEFAULT_PORT),
|
||||
Some(pos) => dest.split_at(pos),
|
||||
| None => (dest, DEFAULT_PORT),
|
||||
| Some(pos) => dest.split_at(pos),
|
||||
};
|
||||
|
||||
FedDest::Named(
|
||||
|
@ -42,23 +42,23 @@ pub(crate) fn add_port_to_hostname(dest: &str) -> FedDest {
|
|||
impl FedDest {
|
||||
pub(crate) fn https_string(&self) -> String {
|
||||
match self {
|
||||
Self::Literal(addr) => format!("https://{addr}"),
|
||||
Self::Named(host, port) => format!("https://{host}{port}"),
|
||||
| Self::Literal(addr) => format!("https://{addr}"),
|
||||
| Self::Named(host, port) => format!("https://{host}{port}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn uri_string(&self) -> String {
|
||||
match self {
|
||||
Self::Literal(addr) => addr.to_string(),
|
||||
Self::Named(host, port) => format!("{host}{port}"),
|
||||
| Self::Literal(addr) => addr.to_string(),
|
||||
| Self::Named(host, port) => format!("{host}{port}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(crate) fn hostname(&self) -> Cow<'_, str> {
|
||||
match &self {
|
||||
Self::Literal(addr) => addr.ip().to_string().into(),
|
||||
Self::Named(host, _) => host.into(),
|
||||
| Self::Literal(addr) => addr.ip().to_string().into(),
|
||||
| Self::Named(host, _) => host.into(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -66,16 +66,20 @@ impl FedDest {
|
|||
#[allow(clippy::string_slice)]
|
||||
pub(crate) fn port(&self) -> Option<u16> {
|
||||
match &self {
|
||||
Self::Literal(addr) => Some(addr.port()),
|
||||
Self::Named(_, port) => port[1..].parse().ok(),
|
||||
| Self::Literal(addr) => Some(addr.port()),
|
||||
| Self::Named(_, port) => port[1..].parse().ok(),
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
#[must_use]
|
||||
pub fn default_port() -> PortString { PortString::from(DEFAULT_PORT).expect("default port string") }
|
||||
pub fn default_port() -> PortString {
|
||||
PortString::from(DEFAULT_PORT).expect("default port string")
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for FedDest {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str(self.uri_string().as_str()) }
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.write_str(self.uri_string().as_str())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -52,7 +52,8 @@ impl crate::Service for Service {
|
|||
appservice: args.depend::<appservice::Service>("appservice"),
|
||||
globals: args.depend::<globals::Service>("globals"),
|
||||
sending: args.depend::<sending::Service>("sending"),
|
||||
state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
|
||||
state_accessor: args
|
||||
.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
@ -62,8 +63,15 @@ impl crate::Service for Service {
|
|||
|
||||
impl Service {
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId, user_id: &UserId) -> Result<()> {
|
||||
if alias == self.services.globals.admin_alias && user_id != self.services.globals.server_user {
|
||||
pub fn set_alias(
|
||||
&self,
|
||||
alias: &RoomAliasId,
|
||||
room_id: &RoomId,
|
||||
user_id: &UserId,
|
||||
) -> Result<()> {
|
||||
if alias == self.services.globals.admin_alias
|
||||
&& user_id != self.services.globals.server_user
|
||||
{
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::forbidden(),
|
||||
"Only the server user can set this alias",
|
||||
|
@ -120,7 +128,9 @@ impl Service {
|
|||
}
|
||||
|
||||
pub async fn resolve_with_servers(
|
||||
&self, room: &RoomOrAliasId, servers: Option<Vec<OwnedServerName>>,
|
||||
&self,
|
||||
room: &RoomOrAliasId,
|
||||
servers: Option<Vec<OwnedServerName>>,
|
||||
) -> Result<(OwnedRoomId, Vec<OwnedServerName>)> {
|
||||
if room.is_room_id() {
|
||||
let room_id = RoomId::parse(room).expect("valid RoomId");
|
||||
|
@ -133,14 +143,16 @@ impl Service {
|
|||
|
||||
#[tracing::instrument(skip(self), name = "resolve")]
|
||||
pub async fn resolve_alias(
|
||||
&self, room_alias: &RoomAliasId, servers: Option<Vec<OwnedServerName>>,
|
||||
&self,
|
||||
room_alias: &RoomAliasId,
|
||||
servers: Option<Vec<OwnedServerName>>,
|
||||
) -> Result<(OwnedRoomId, Vec<OwnedServerName>)> {
|
||||
let server_name = room_alias.server_name();
|
||||
let server_is_ours = self.services.globals.server_is_ours(server_name);
|
||||
let servers_contains_ours = || {
|
||||
servers
|
||||
.as_ref()
|
||||
.is_some_and(|servers| servers.contains(&self.services.globals.config.server_name))
|
||||
servers.as_ref().is_some_and(|servers| {
|
||||
servers.contains(&self.services.globals.config.server_name)
|
||||
})
|
||||
};
|
||||
|
||||
if !server_is_ours && !servers_contains_ours() {
|
||||
|
@ -150,8 +162,8 @@ impl Service {
|
|||
}
|
||||
|
||||
let room_id = match self.resolve_local_alias(room_alias).await {
|
||||
Ok(r) => Some(r),
|
||||
Err(_) => self.resolve_appservice_alias(room_alias).await?,
|
||||
| Ok(r) => Some(r),
|
||||
| Err(_) => self.resolve_appservice_alias(room_alias).await?,
|
||||
};
|
||||
|
||||
room_id.map_or_else(
|
||||
|
@ -166,7 +178,10 @@ impl Service {
|
|||
}
|
||||
|
||||
#[tracing::instrument(skip(self), level = "debug")]
|
||||
pub fn local_aliases_for_room<'a>(&'a self, room_id: &'a RoomId) -> impl Stream<Item = &RoomAliasId> + Send + 'a {
|
||||
pub fn local_aliases_for_room<'a>(
|
||||
&'a self,
|
||||
room_id: &'a RoomId,
|
||||
) -> impl Stream<Item = &RoomAliasId> + Send + 'a {
|
||||
let prefix = (room_id, Interfix);
|
||||
self.db
|
||||
.aliasid_alias
|
||||
|
@ -208,10 +223,15 @@ impl Service {
|
|||
if let Ok(content) = self
|
||||
.services
|
||||
.state_accessor
|
||||
.room_state_get_content::<RoomPowerLevelsEventContent>(&room_id, &StateEventType::RoomPowerLevels, "")
|
||||
.room_state_get_content::<RoomPowerLevelsEventContent>(
|
||||
&room_id,
|
||||
&StateEventType::RoomPowerLevels,
|
||||
"",
|
||||
)
|
||||
.await
|
||||
{
|
||||
return Ok(RoomPowerLevels::from(content).user_can_send_state(user_id, StateEventType::RoomCanonicalAlias));
|
||||
return Ok(RoomPowerLevels::from(content)
|
||||
.user_can_send_state(user_id, StateEventType::RoomCanonicalAlias));
|
||||
}
|
||||
|
||||
// If there is no power levels event, only the room creator can change
|
||||
|
@ -232,7 +252,10 @@ impl Service {
|
|||
self.db.alias_userid.get(alias.alias()).await.deserialized()
|
||||
}
|
||||
|
||||
async fn resolve_appservice_alias(&self, room_alias: &RoomAliasId) -> Result<Option<OwnedRoomId>> {
|
||||
async fn resolve_appservice_alias(
|
||||
&self,
|
||||
room_alias: &RoomAliasId,
|
||||
) -> Result<Option<OwnedRoomId>> {
|
||||
use ruma::api::appservice::query::query_room_alias;
|
||||
|
||||
for appservice in self.services.appservice.read().await.values() {
|
||||
|
@ -242,9 +265,7 @@ impl Service {
|
|||
.sending
|
||||
.send_appservice_request(
|
||||
appservice.registration.clone(),
|
||||
query_room_alias::v1::Request {
|
||||
room_alias: room_alias.to_owned(),
|
||||
},
|
||||
query_room_alias::v1::Request { room_alias: room_alias.to_owned() },
|
||||
)
|
||||
.await,
|
||||
Ok(Some(_opt_result))
|
||||
|
@ -261,19 +282,27 @@ impl Service {
|
|||
}
|
||||
|
||||
pub async fn appservice_checks(
|
||||
&self, room_alias: &RoomAliasId, appservice_info: &Option<RegistrationInfo>,
|
||||
&self,
|
||||
room_alias: &RoomAliasId,
|
||||
appservice_info: &Option<RegistrationInfo>,
|
||||
) -> Result<()> {
|
||||
if !self
|
||||
.services
|
||||
.globals
|
||||
.server_is_ours(room_alias.server_name())
|
||||
{
|
||||
return Err(Error::BadRequest(ErrorKind::InvalidParam, "Alias is from another server."));
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Alias is from another server.",
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(ref info) = appservice_info {
|
||||
if !info.aliases.is_match(room_alias.as_str()) {
|
||||
return Err(Error::BadRequest(ErrorKind::Exclusive, "Room alias is not in namespace."));
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Exclusive,
|
||||
"Room alias is not in namespace.",
|
||||
));
|
||||
}
|
||||
} else if self
|
||||
.services
|
||||
|
@ -281,7 +310,10 @@ impl Service {
|
|||
.is_exclusive_alias(room_alias)
|
||||
.await
|
||||
{
|
||||
return Err(Error::BadRequest(ErrorKind::Exclusive, "Room alias reserved by appservice."));
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Exclusive,
|
||||
"Room alias reserved by appservice.",
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
|
|
@ -6,7 +6,9 @@ use ruma::{api::federation, OwnedRoomId, OwnedServerName, RoomAliasId, ServerNam
|
|||
|
||||
#[implement(super::Service)]
|
||||
pub(super) async fn remote_resolve(
|
||||
&self, room_alias: &RoomAliasId, servers: Vec<OwnedServerName>,
|
||||
&self,
|
||||
room_alias: &RoomAliasId,
|
||||
servers: Vec<OwnedServerName>,
|
||||
) -> Result<(OwnedRoomId, Vec<OwnedServerName>)> {
|
||||
debug!(?room_alias, servers = ?servers, "resolve");
|
||||
let servers = once(room_alias.server_name())
|
||||
|
@ -17,12 +19,12 @@ pub(super) async fn remote_resolve(
|
|||
let mut resolved_room_id: Option<OwnedRoomId> = None;
|
||||
for server in servers {
|
||||
match self.remote_request(room_alias, &server).await {
|
||||
Err(e) => debug_error!("Failed to query for {room_alias:?} from {server}: {e}"),
|
||||
Ok(Response {
|
||||
room_id,
|
||||
servers,
|
||||
}) => {
|
||||
debug!("Server {server} answered with {room_id:?} for {room_alias:?} servers: {servers:?}");
|
||||
| Err(e) => debug_error!("Failed to query for {room_alias:?} from {server}: {e}"),
|
||||
| Ok(Response { room_id, servers }) => {
|
||||
debug!(
|
||||
"Server {server} answered with {room_id:?} for {room_alias:?} servers: \
|
||||
{servers:?}"
|
||||
);
|
||||
|
||||
resolved_room_id.get_or_insert(room_id);
|
||||
add_server(&mut resolved_servers, server);
|
||||
|
@ -37,16 +39,20 @@ pub(super) async fn remote_resolve(
|
|||
|
||||
resolved_room_id
|
||||
.map(|room_id| (room_id, resolved_servers))
|
||||
.ok_or_else(|| err!(Request(NotFound("No servers could assist in resolving the room alias"))))
|
||||
.ok_or_else(|| {
|
||||
err!(Request(NotFound("No servers could assist in resolving the room alias")))
|
||||
})
|
||||
}
|
||||
|
||||
#[implement(super::Service)]
|
||||
async fn remote_request(&self, room_alias: &RoomAliasId, server: &ServerName) -> Result<Response> {
|
||||
async fn remote_request(
|
||||
&self,
|
||||
room_alias: &RoomAliasId,
|
||||
server: &ServerName,
|
||||
) -> Result<Response> {
|
||||
use federation::query::get_room_information::v1::Request;
|
||||
|
||||
let request = Request {
|
||||
room_alias: room_alias.to_owned(),
|
||||
};
|
||||
let request = Request { room_alias: room_alias.to_owned() };
|
||||
|
||||
self.services
|
||||
.sending
|
||||
|
|
|
@ -19,14 +19,18 @@ impl Data {
|
|||
let db = &args.db;
|
||||
let config = &args.server.config;
|
||||
let cache_size = f64::from(config.auth_chain_cache_capacity);
|
||||
let cache_size = usize_from_f64(cache_size * config.cache_capacity_modifier).expect("valid cache size");
|
||||
let cache_size = usize_from_f64(cache_size * config.cache_capacity_modifier)
|
||||
.expect("valid cache size");
|
||||
Self {
|
||||
shorteventid_authchain: db["shorteventid_authchain"].clone(),
|
||||
auth_chain_cache: Mutex::new(LruCache::new(cache_size)),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Arc<[ShortEventId]>> {
|
||||
pub(super) async fn get_cached_eventid_authchain(
|
||||
&self,
|
||||
key: &[u64],
|
||||
) -> Result<Arc<[ShortEventId]>> {
|
||||
debug_assert!(!key.is_empty(), "auth_chain key must not be empty");
|
||||
|
||||
// Check RAM cache
|
||||
|
|
|
@ -43,7 +43,9 @@ impl crate::Service for Service {
|
|||
|
||||
impl Service {
|
||||
pub async fn event_ids_iter<'a, I>(
|
||||
&'a self, room_id: &RoomId, starting_events: I,
|
||||
&'a self,
|
||||
room_id: &RoomId,
|
||||
starting_events: I,
|
||||
) -> Result<impl Stream<Item = Arc<EventId>> + Send + '_>
|
||||
where
|
||||
I: Iterator<Item = &'a EventId> + Clone + Debug + ExactSizeIterator + Send + 'a,
|
||||
|
@ -57,7 +59,11 @@ impl Service {
|
|||
Ok(stream)
|
||||
}
|
||||
|
||||
pub async fn get_event_ids<'a, I>(&'a self, room_id: &RoomId, starting_events: I) -> Result<Vec<Arc<EventId>>>
|
||||
pub async fn get_event_ids<'a, I>(
|
||||
&'a self,
|
||||
room_id: &RoomId,
|
||||
starting_events: I,
|
||||
) -> Result<Vec<Arc<EventId>>>
|
||||
where
|
||||
I: Iterator<Item = &'a EventId> + Clone + Debug + ExactSizeIterator + Send + 'a,
|
||||
{
|
||||
|
@ -74,7 +80,11 @@ impl Service {
|
|||
}
|
||||
|
||||
#[tracing::instrument(skip_all, name = "auth_chain")]
|
||||
pub async fn get_auth_chain<'a, I>(&'a self, room_id: &RoomId, starting_events: I) -> Result<Vec<ShortEventId>>
|
||||
pub async fn get_auth_chain<'a, I>(
|
||||
&'a self,
|
||||
room_id: &RoomId,
|
||||
starting_events: I,
|
||||
) -> Result<Vec<ShortEventId>>
|
||||
where
|
||||
I: Iterator<Item = &'a EventId> + Clone + Debug + ExactSizeIterator + Send + 'a,
|
||||
{
|
||||
|
@ -110,7 +120,8 @@ impl Service {
|
|||
continue;
|
||||
}
|
||||
|
||||
let chunk_key: Vec<ShortEventId> = chunk.iter().map(|(short, _)| short).copied().collect();
|
||||
let chunk_key: Vec<ShortEventId> =
|
||||
chunk.iter().map(|(short, _)| short).copied().collect();
|
||||
if let Ok(cached) = self.get_cached_eventid_authchain(&chunk_key).await {
|
||||
trace!("Found cache entry for whole chunk");
|
||||
full_auth_chain.extend(cached.iter().copied());
|
||||
|
@ -169,7 +180,11 @@ impl Service {
|
|||
}
|
||||
|
||||
#[tracing::instrument(skip(self, room_id))]
|
||||
async fn get_auth_chain_inner(&self, room_id: &RoomId, event_id: &EventId) -> Result<HashSet<ShortEventId>> {
|
||||
async fn get_auth_chain_inner(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
event_id: &EventId,
|
||||
) -> Result<HashSet<ShortEventId>> {
|
||||
let mut todo = vec![Arc::from(event_id)];
|
||||
let mut found = HashSet::new();
|
||||
|
||||
|
@ -177,8 +192,10 @@ impl Service {
|
|||
trace!(?event_id, "processing auth event");
|
||||
|
||||
match self.services.timeline.get_pdu(&event_id).await {
|
||||
Err(e) => debug_error!(?event_id, ?e, "Could not find pdu mentioned in auth events"),
|
||||
Ok(pdu) => {
|
||||
| Err(e) => {
|
||||
debug_error!(?event_id, ?e, "Could not find pdu mentioned in auth events");
|
||||
},
|
||||
| Ok(pdu) => {
|
||||
if pdu.room_id != room_id {
|
||||
return Err!(Request(Forbidden(error!(
|
||||
?event_id,
|
||||
|
@ -196,7 +213,11 @@ impl Service {
|
|||
.await;
|
||||
|
||||
if found.insert(sauthevent) {
|
||||
trace!(?event_id, ?auth_event, "adding auth event to processing queue");
|
||||
trace!(
|
||||
?event_id,
|
||||
?auth_event,
|
||||
"adding auth event to processing queue"
|
||||
);
|
||||
todo.push(auth_event.clone());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -32,10 +32,14 @@ pub fn set_public(&self, room_id: &RoomId) { self.db.publicroomids.insert(room_i
|
|||
pub fn set_not_public(&self, room_id: &RoomId) { self.db.publicroomids.remove(room_id); }
|
||||
|
||||
#[implement(Service)]
|
||||
pub fn public_rooms(&self) -> impl Stream<Item = &RoomId> + Send { self.db.publicroomids.keys().ignore_err() }
|
||||
pub fn public_rooms(&self) -> impl Stream<Item = &RoomId> + Send {
|
||||
self.db.publicroomids.keys().ignore_err()
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
pub async fn is_public_room(&self, room_id: &RoomId) -> bool { self.visibility(room_id).await == Visibility::Public }
|
||||
pub async fn is_public_room(&self, room_id: &RoomId) -> bool {
|
||||
self.visibility(room_id).await == Visibility::Public
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
pub async fn visibility(&self, room_id: &RoomId) -> Visibility {
|
||||
|
|
|
@ -5,10 +5,14 @@ use std::{
|
|||
};
|
||||
|
||||
use conduwuit::{
|
||||
debug, debug_error, implement, info, pdu, trace, utils::math::continue_exponential_backoff_secs, warn, PduEvent,
|
||||
debug, debug_error, implement, info, pdu, trace,
|
||||
utils::math::continue_exponential_backoff_secs, warn, PduEvent,
|
||||
};
|
||||
use futures::TryFutureExt;
|
||||
use ruma::{api::federation::event::get_event, CanonicalJsonValue, EventId, RoomId, RoomVersionId, ServerName};
|
||||
use ruma::{
|
||||
api::federation::event::get_event, CanonicalJsonValue, EventId, RoomId, RoomVersionId,
|
||||
ServerName,
|
||||
};
|
||||
|
||||
/// Find the event and auth it. Once the event is validated (steps 1 - 8)
|
||||
/// it is appended to the outliers Tree.
|
||||
|
@ -21,7 +25,11 @@ use ruma::{api::federation::event::get_event, CanonicalJsonValue, EventId, RoomI
|
|||
/// d. TODO: Ask other servers over federation?
|
||||
#[implement(super::Service)]
|
||||
pub(super) async fn fetch_and_handle_outliers<'a>(
|
||||
&self, origin: &'a ServerName, events: &'a [Arc<EventId>], create_event: &'a PduEvent, room_id: &'a RoomId,
|
||||
&self,
|
||||
origin: &'a ServerName,
|
||||
events: &'a [Arc<EventId>],
|
||||
create_event: &'a PduEvent,
|
||||
room_id: &'a RoomId,
|
||||
room_version_id: &'a RoomVersionId,
|
||||
) -> Vec<(Arc<PduEvent>, Option<BTreeMap<String, CanonicalJsonValue>>)> {
|
||||
let back_off = |id| match self
|
||||
|
@ -32,10 +40,12 @@ pub(super) async fn fetch_and_handle_outliers<'a>(
|
|||
.expect("locked")
|
||||
.entry(id)
|
||||
{
|
||||
hash_map::Entry::Vacant(e) => {
|
||||
| hash_map::Entry::Vacant(e) => {
|
||||
e.insert((Instant::now(), 1));
|
||||
},
|
||||
hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1.saturating_add(1)),
|
||||
| hash_map::Entry::Occupied(mut e) => {
|
||||
*e.get_mut() = (Instant::now(), e.get().1.saturating_add(1));
|
||||
},
|
||||
};
|
||||
|
||||
let mut events_with_auth_events = Vec::with_capacity(events.len());
|
||||
|
@ -67,7 +77,12 @@ pub(super) async fn fetch_and_handle_outliers<'a>(
|
|||
// Exponential backoff
|
||||
const MIN_DURATION: u64 = 5 * 60;
|
||||
const MAX_DURATION: u64 = 60 * 60 * 24;
|
||||
if continue_exponential_backoff_secs(MIN_DURATION, MAX_DURATION, time.elapsed(), *tries) {
|
||||
if continue_exponential_backoff_secs(
|
||||
MIN_DURATION,
|
||||
MAX_DURATION,
|
||||
time.elapsed(),
|
||||
*tries,
|
||||
) {
|
||||
info!("Backing off from {next_id}");
|
||||
continue;
|
||||
}
|
||||
|
@ -86,18 +101,16 @@ pub(super) async fn fetch_and_handle_outliers<'a>(
|
|||
match self
|
||||
.services
|
||||
.sending
|
||||
.send_federation_request(
|
||||
origin,
|
||||
get_event::v1::Request {
|
||||
event_id: (*next_id).to_owned(),
|
||||
include_unredacted_content: None,
|
||||
},
|
||||
)
|
||||
.send_federation_request(origin, get_event::v1::Request {
|
||||
event_id: (*next_id).to_owned(),
|
||||
include_unredacted_content: None,
|
||||
})
|
||||
.await
|
||||
{
|
||||
Ok(res) => {
|
||||
| Ok(res) => {
|
||||
debug!("Got {next_id} over federation");
|
||||
let Ok((calculated_event_id, value)) = pdu::gen_event_id_canonical_json(&res.pdu, room_version_id)
|
||||
let Ok((calculated_event_id, value)) =
|
||||
pdu::gen_event_id_canonical_json(&res.pdu, room_version_id)
|
||||
else {
|
||||
back_off((*next_id).to_owned());
|
||||
continue;
|
||||
|
@ -105,15 +118,18 @@ pub(super) async fn fetch_and_handle_outliers<'a>(
|
|||
|
||||
if calculated_event_id != *next_id {
|
||||
warn!(
|
||||
"Server didn't return event id we requested: requested: {next_id}, we got \
|
||||
{calculated_event_id}. Event: {:?}",
|
||||
"Server didn't return event id we requested: requested: {next_id}, \
|
||||
we got {calculated_event_id}. Event: {:?}",
|
||||
&res.pdu
|
||||
);
|
||||
}
|
||||
|
||||
if let Some(auth_events) = value.get("auth_events").and_then(|c| c.as_array()) {
|
||||
if let Some(auth_events) = value.get("auth_events").and_then(|c| c.as_array())
|
||||
{
|
||||
for auth_event in auth_events {
|
||||
if let Ok(auth_event) = serde_json::from_value(auth_event.clone().into()) {
|
||||
if let Ok(auth_event) =
|
||||
serde_json::from_value(auth_event.clone().into())
|
||||
{
|
||||
let a: Arc<EventId> = auth_event;
|
||||
todo_auth_events.push(a);
|
||||
} else {
|
||||
|
@ -127,7 +143,7 @@ pub(super) async fn fetch_and_handle_outliers<'a>(
|
|||
events_in_reverse_order.push((next_id.clone(), value));
|
||||
events_all.insert(next_id);
|
||||
},
|
||||
Err(e) => {
|
||||
| Err(e) => {
|
||||
debug_error!("Failed to fetch event {next_id}: {e}");
|
||||
back_off((*next_id).to_owned());
|
||||
},
|
||||
|
@ -158,20 +174,32 @@ pub(super) async fn fetch_and_handle_outliers<'a>(
|
|||
// Exponential backoff
|
||||
const MIN_DURATION: u64 = 5 * 60;
|
||||
const MAX_DURATION: u64 = 60 * 60 * 24;
|
||||
if continue_exponential_backoff_secs(MIN_DURATION, MAX_DURATION, time.elapsed(), *tries) {
|
||||
if continue_exponential_backoff_secs(
|
||||
MIN_DURATION,
|
||||
MAX_DURATION,
|
||||
time.elapsed(),
|
||||
*tries,
|
||||
) {
|
||||
debug!("Backing off from {next_id}");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
match Box::pin(self.handle_outlier_pdu(origin, create_event, &next_id, room_id, value.clone(), true)).await
|
||||
match Box::pin(self.handle_outlier_pdu(
|
||||
origin,
|
||||
create_event,
|
||||
&next_id,
|
||||
room_id,
|
||||
value.clone(),
|
||||
true,
|
||||
))
|
||||
.await
|
||||
{
|
||||
Ok((pdu, json)) => {
|
||||
| Ok((pdu, json)) =>
|
||||
if next_id == *id {
|
||||
pdus.push((pdu, Some(json)));
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
},
|
||||
| Err(e) => {
|
||||
warn!("Authentication of event {next_id} failed: {e:?}");
|
||||
back_off(next_id.into());
|
||||
},
|
||||
|
|
|
@ -8,7 +8,8 @@ use futures::{future, FutureExt};
|
|||
use ruma::{
|
||||
int,
|
||||
state_res::{self},
|
||||
uint, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, RoomId, RoomVersionId, ServerName,
|
||||
uint, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, RoomId, RoomVersionId,
|
||||
ServerName,
|
||||
};
|
||||
|
||||
use super::check_room_id;
|
||||
|
@ -17,7 +18,11 @@ use super::check_room_id;
|
|||
#[allow(clippy::type_complexity)]
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub(super) async fn fetch_prev(
|
||||
&self, origin: &ServerName, create_event: &PduEvent, room_id: &RoomId, room_version_id: &RoomVersionId,
|
||||
&self,
|
||||
origin: &ServerName,
|
||||
create_event: &PduEvent,
|
||||
room_id: &RoomId,
|
||||
room_version_id: &RoomVersionId,
|
||||
initial_set: Vec<Arc<EventId>>,
|
||||
) -> Result<(
|
||||
Vec<Arc<EventId>>,
|
||||
|
@ -35,7 +40,13 @@ pub(super) async fn fetch_prev(
|
|||
self.services.server.check_running()?;
|
||||
|
||||
if let Some((pdu, mut json_opt)) = self
|
||||
.fetch_and_handle_outliers(origin, &[prev_event_id.clone()], create_event, room_id, room_version_id)
|
||||
.fetch_and_handle_outliers(
|
||||
origin,
|
||||
&[prev_event_id.clone()],
|
||||
create_event,
|
||||
room_id,
|
||||
room_version_id,
|
||||
)
|
||||
.boxed()
|
||||
.await
|
||||
.pop()
|
||||
|
@ -67,7 +78,8 @@ pub(super) async fn fetch_prev(
|
|||
}
|
||||
}
|
||||
|
||||
graph.insert(prev_event_id.clone(), pdu.prev_events.iter().cloned().collect());
|
||||
graph
|
||||
.insert(prev_event_id.clone(), pdu.prev_events.iter().cloned().collect());
|
||||
} else {
|
||||
// Time based check failed
|
||||
graph.insert(prev_event_id.clone(), HashSet::new());
|
||||
|
|
|
@ -6,7 +6,8 @@ use std::{
|
|||
use conduwuit::{debug, implement, warn, Err, Error, PduEvent, Result};
|
||||
use futures::FutureExt;
|
||||
use ruma::{
|
||||
api::federation::event::get_room_state_ids, events::StateEventType, EventId, RoomId, RoomVersionId, ServerName,
|
||||
api::federation::event::get_room_state_ids, events::StateEventType, EventId, RoomId,
|
||||
RoomVersionId, ServerName,
|
||||
};
|
||||
|
||||
/// Call /state_ids to find out what the state at this pdu is. We trust the
|
||||
|
@ -15,20 +16,21 @@ use ruma::{
|
|||
#[implement(super::Service)]
|
||||
#[tracing::instrument(skip(self, create_event, room_version_id))]
|
||||
pub(super) async fn fetch_state(
|
||||
&self, origin: &ServerName, create_event: &PduEvent, room_id: &RoomId, room_version_id: &RoomVersionId,
|
||||
&self,
|
||||
origin: &ServerName,
|
||||
create_event: &PduEvent,
|
||||
room_id: &RoomId,
|
||||
room_version_id: &RoomVersionId,
|
||||
event_id: &EventId,
|
||||
) -> Result<Option<HashMap<u64, Arc<EventId>>>> {
|
||||
debug!("Fetching state ids");
|
||||
let res = self
|
||||
.services
|
||||
.sending
|
||||
.send_federation_request(
|
||||
origin,
|
||||
get_room_state_ids::v1::Request {
|
||||
room_id: room_id.to_owned(),
|
||||
event_id: (*event_id).to_owned(),
|
||||
},
|
||||
)
|
||||
.send_federation_request(origin, get_room_state_ids::v1::Request {
|
||||
room_id: room_id.to_owned(),
|
||||
event_id: (*event_id).to_owned(),
|
||||
})
|
||||
.await
|
||||
.inspect_err(|e| warn!("Fetching state for event failed: {e}"))?;
|
||||
|
||||
|
@ -58,14 +60,13 @@ pub(super) async fn fetch_state(
|
|||
.await;
|
||||
|
||||
match state.entry(shortstatekey) {
|
||||
hash_map::Entry::Vacant(v) => {
|
||||
| hash_map::Entry::Vacant(v) => {
|
||||
v.insert(Arc::from(&*pdu.event_id));
|
||||
},
|
||||
hash_map::Entry::Occupied(_) => {
|
||||
| hash_map::Entry::Occupied(_) =>
|
||||
return Err(Error::bad_database(
|
||||
"State event's type and state_key combination exists multiple times.",
|
||||
))
|
||||
},
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -7,7 +7,8 @@ use std::{
|
|||
use conduwuit::{debug, err, implement, warn, Error, Result};
|
||||
use futures::{FutureExt, TryFutureExt};
|
||||
use ruma::{
|
||||
api::client::error::ErrorKind, events::StateEventType, CanonicalJsonValue, EventId, RoomId, ServerName, UserId,
|
||||
api::client::error::ErrorKind, events::StateEventType, CanonicalJsonValue, EventId, RoomId,
|
||||
ServerName, UserId,
|
||||
};
|
||||
|
||||
use super::{check_room_id, get_room_version_id};
|
||||
|
@ -43,8 +44,12 @@ use crate::rooms::timeline::RawPduId;
|
|||
#[implement(super::Service)]
|
||||
#[tracing::instrument(skip(self, origin, value, is_timeline_event), name = "pdu")]
|
||||
pub async fn handle_incoming_pdu<'a>(
|
||||
&self, origin: &'a ServerName, room_id: &'a RoomId, event_id: &'a EventId,
|
||||
value: BTreeMap<String, CanonicalJsonValue>, is_timeline_event: bool,
|
||||
&self,
|
||||
origin: &'a ServerName,
|
||||
room_id: &'a RoomId,
|
||||
event_id: &'a EventId,
|
||||
value: BTreeMap<String, CanonicalJsonValue>,
|
||||
is_timeline_event: bool,
|
||||
) -> Result<Option<RawPduId>> {
|
||||
// 1. Skip the PDU if we already have it as a timeline event
|
||||
if let Ok(pdu_id) = self.services.timeline.get_pdu_id(event_id).await {
|
||||
|
@ -144,10 +149,10 @@ pub async fn handle_incoming_pdu<'a>(
|
|||
.expect("locked")
|
||||
.entry(prev_id.into())
|
||||
{
|
||||
Entry::Vacant(e) => {
|
||||
| Entry::Vacant(e) => {
|
||||
e.insert((now, 1));
|
||||
},
|
||||
Entry::Occupied(mut e) => {
|
||||
| Entry::Occupied(mut e) => {
|
||||
*e.get_mut() = (now, e.get().1.saturating_add(1));
|
||||
},
|
||||
};
|
||||
|
|
|
@ -17,8 +17,13 @@ use super::{check_room_id, get_room_version_id, to_room_version};
|
|||
#[implement(super::Service)]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(super) async fn handle_outlier_pdu<'a>(
|
||||
&self, origin: &'a ServerName, create_event: &'a PduEvent, event_id: &'a EventId, room_id: &'a RoomId,
|
||||
mut value: CanonicalJsonObject, auth_events_known: bool,
|
||||
&self,
|
||||
origin: &'a ServerName,
|
||||
create_event: &'a PduEvent,
|
||||
event_id: &'a EventId,
|
||||
room_id: &'a RoomId,
|
||||
mut value: CanonicalJsonObject,
|
||||
auth_events_known: bool,
|
||||
) -> Result<(Arc<PduEvent>, BTreeMap<String, CanonicalJsonValue>)> {
|
||||
// 1. Remove unsigned field
|
||||
value.remove("unsigned");
|
||||
|
@ -34,8 +39,8 @@ pub(super) async fn handle_outlier_pdu<'a>(
|
|||
.verify_event(&value, Some(&room_version_id))
|
||||
.await
|
||||
{
|
||||
Ok(ruma::signatures::Verified::All) => value,
|
||||
Ok(ruma::signatures::Verified::Signatures) => {
|
||||
| Ok(ruma::signatures::Verified::All) => value,
|
||||
| Ok(ruma::signatures::Verified::Signatures) => {
|
||||
// Redact
|
||||
debug_info!("Calculated hash does not match (redaction): {event_id}");
|
||||
let Ok(obj) = ruma::canonical_json::redact(value, &room_version_id, None) else {
|
||||
|
@ -44,24 +49,26 @@ pub(super) async fn handle_outlier_pdu<'a>(
|
|||
|
||||
// Skip the PDU if it is redacted and we already have it as an outlier event
|
||||
if self.services.timeline.pdu_exists(event_id).await {
|
||||
return Err!(Request(InvalidParam("Event was redacted and we already knew about it")));
|
||||
return Err!(Request(InvalidParam(
|
||||
"Event was redacted and we already knew about it"
|
||||
)));
|
||||
}
|
||||
|
||||
obj
|
||||
},
|
||||
Err(e) => {
|
||||
| Err(e) =>
|
||||
return Err!(Request(InvalidParam(debug_error!(
|
||||
"Signature verification failed for {event_id}: {e}"
|
||||
))))
|
||||
},
|
||||
)))),
|
||||
};
|
||||
|
||||
// Now that we have checked the signature and hashes we can add the eventID and
|
||||
// convert to our PduEvent type
|
||||
val.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.as_str().to_owned()));
|
||||
let incoming_pdu =
|
||||
serde_json::from_value::<PduEvent>(serde_json::to_value(&val).expect("CanonicalJsonObj is a valid JsonValue"))
|
||||
.map_err(|_| Error::bad_database("Event is not a valid PDU."))?;
|
||||
let incoming_pdu = serde_json::from_value::<PduEvent>(
|
||||
serde_json::to_value(&val).expect("CanonicalJsonObj is a valid JsonValue"),
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Event is not a valid PDU."))?;
|
||||
|
||||
check_room_id(room_id, &incoming_pdu)?;
|
||||
|
||||
|
@ -108,10 +115,10 @@ pub(super) async fn handle_outlier_pdu<'a>(
|
|||
.clone()
|
||||
.expect("all auth events have state keys"),
|
||||
)) {
|
||||
hash_map::Entry::Vacant(v) => {
|
||||
| hash_map::Entry::Vacant(v) => {
|
||||
v.insert(auth_event);
|
||||
},
|
||||
hash_map::Entry::Occupied(_) => {
|
||||
| hash_map::Entry::Occupied(_) => {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Auth event's type and state_key combination exists multiple times.",
|
||||
|
|
|
@ -4,7 +4,9 @@ use std::{
|
|||
time::Instant,
|
||||
};
|
||||
|
||||
use conduwuit::{debug, implement, utils::math::continue_exponential_backoff_secs, Error, PduEvent, Result};
|
||||
use conduwuit::{
|
||||
debug, implement, utils::math::continue_exponential_backoff_secs, Error, PduEvent, Result,
|
||||
};
|
||||
use ruma::{api::client::error::ErrorKind, CanonicalJsonValue, EventId, RoomId, ServerName};
|
||||
|
||||
#[implement(super::Service)]
|
||||
|
@ -15,15 +17,23 @@ use ruma::{api::client::error::ErrorKind, CanonicalJsonValue, EventId, RoomId, S
|
|||
name = "prev"
|
||||
)]
|
||||
pub(super) async fn handle_prev_pdu<'a>(
|
||||
&self, origin: &'a ServerName, event_id: &'a EventId, room_id: &'a RoomId,
|
||||
eventid_info: &mut HashMap<Arc<EventId>, (Arc<PduEvent>, BTreeMap<String, CanonicalJsonValue>)>,
|
||||
create_event: &Arc<PduEvent>, first_pdu_in_room: &Arc<PduEvent>, prev_id: &EventId,
|
||||
&self,
|
||||
origin: &'a ServerName,
|
||||
event_id: &'a EventId,
|
||||
room_id: &'a RoomId,
|
||||
eventid_info: &mut HashMap<
|
||||
Arc<EventId>,
|
||||
(Arc<PduEvent>, BTreeMap<String, CanonicalJsonValue>),
|
||||
>,
|
||||
create_event: &Arc<PduEvent>,
|
||||
first_pdu_in_room: &Arc<PduEvent>,
|
||||
prev_id: &EventId,
|
||||
) -> Result {
|
||||
// Check for disabled again because it might have changed
|
||||
if self.services.metadata.is_disabled(room_id).await {
|
||||
debug!(
|
||||
"Federaton of room {room_id} is currently disabled on this server. Request by origin {origin} and event \
|
||||
ID {event_id}"
|
||||
"Federaton of room {room_id} is currently disabled on this server. Request by \
|
||||
origin {origin} and event ID {event_id}"
|
||||
);
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::forbidden(),
|
||||
|
|
|
@ -23,8 +23,8 @@ use conduwuit::{
|
|||
};
|
||||
use futures::TryFutureExt;
|
||||
use ruma::{
|
||||
events::room::create::RoomCreateEventContent, state_res::RoomVersion, EventId, OwnedEventId, OwnedRoomId, RoomId,
|
||||
RoomVersionId,
|
||||
events::room::create::RoomCreateEventContent, state_res::RoomVersion, EventId, OwnedEventId,
|
||||
OwnedRoomId, RoomId, RoomVersionId,
|
||||
};
|
||||
|
||||
use crate::{globals, rooms, sending, server_keys, Dep};
|
||||
|
@ -69,8 +69,10 @@ impl crate::Service for Service {
|
|||
pdu_metadata: args.depend::<rooms::pdu_metadata::Service>("rooms::pdu_metadata"),
|
||||
short: args.depend::<rooms::short::Service>("rooms::short"),
|
||||
state: args.depend::<rooms::state::Service>("rooms::state"),
|
||||
state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
|
||||
state_compressor: args.depend::<rooms::state_compressor::Service>("rooms::state_compressor"),
|
||||
state_accessor: args
|
||||
.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
|
||||
state_compressor: args
|
||||
.depend::<rooms::state_compressor::Service>("rooms::state_compressor"),
|
||||
timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"),
|
||||
server: args.server.clone(),
|
||||
},
|
||||
|
@ -95,7 +97,9 @@ impl crate::Service for Service {
|
|||
}
|
||||
|
||||
impl Service {
|
||||
async fn event_exists(&self, event_id: Arc<EventId>) -> bool { self.services.timeline.pdu_exists(&event_id).await }
|
||||
async fn event_exists(&self, event_id: Arc<EventId>) -> bool {
|
||||
self.services.timeline.pdu_exists(&event_id).await
|
||||
}
|
||||
|
||||
async fn event_fetch(&self, event_id: Arc<EventId>) -> Option<Arc<PduEvent>> {
|
||||
self.services
|
||||
|
|
|
@ -3,9 +3,13 @@ use ruma::{CanonicalJsonObject, CanonicalJsonValue, OwnedEventId, OwnedRoomId, R
|
|||
use serde_json::value::RawValue as RawJsonValue;
|
||||
|
||||
#[implement(super::Service)]
|
||||
pub async fn parse_incoming_pdu(&self, pdu: &RawJsonValue) -> Result<(OwnedEventId, CanonicalJsonObject, OwnedRoomId)> {
|
||||
let value = serde_json::from_str::<CanonicalJsonObject>(pdu.get())
|
||||
.map_err(|e| err!(BadServerResponse(debug_warn!("Error parsing incoming event {e:?}"))))?;
|
||||
pub async fn parse_incoming_pdu(
|
||||
&self,
|
||||
pdu: &RawJsonValue,
|
||||
) -> Result<(OwnedEventId, CanonicalJsonObject, OwnedRoomId)> {
|
||||
let value = serde_json::from_str::<CanonicalJsonObject>(pdu.get()).map_err(|e| {
|
||||
err!(BadServerResponse(debug_warn!("Error parsing incoming event {e:?}")))
|
||||
})?;
|
||||
|
||||
let room_id: OwnedRoomId = value
|
||||
.get("room_id")
|
||||
|
@ -20,8 +24,9 @@ pub async fn parse_incoming_pdu(&self, pdu: &RawJsonValue) -> Result<(OwnedEvent
|
|||
.await
|
||||
.map_err(|_| err!("Server is not in room {room_id}"))?;
|
||||
|
||||
let (event_id, value) = gen_event_id_canonical_json(pdu, &room_version_id)
|
||||
.map_err(|e| err!(Request(InvalidParam("Could not convert event to canonical json: {e}"))))?;
|
||||
let (event_id, value) = gen_event_id_canonical_json(pdu, &room_version_id).map_err(|e| {
|
||||
err!(Request(InvalidParam("Could not convert event to canonical json: {e}")))
|
||||
})?;
|
||||
|
||||
Ok((event_id, value, room_id))
|
||||
}
|
||||
|
|
|
@ -20,7 +20,10 @@ use crate::rooms::state_compressor::CompressedStateEvent;
|
|||
#[implement(super::Service)]
|
||||
#[tracing::instrument(skip_all, name = "resolve")]
|
||||
pub async fn resolve_state(
|
||||
&self, room_id: &RoomId, room_version_id: &RoomVersionId, incoming_state: HashMap<u64, Arc<EventId>>,
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
room_version_id: &RoomVersionId,
|
||||
incoming_state: HashMap<u64, Arc<EventId>>,
|
||||
) -> Result<Arc<HashSet<CompressedStateEvent>>> {
|
||||
debug!("Loading current room state ids");
|
||||
let current_sstatehash = self
|
||||
|
@ -76,10 +79,16 @@ pub async fn resolve_state(
|
|||
|
||||
let event_fetch = |event_id| self.event_fetch(event_id);
|
||||
let event_exists = |event_id| self.event_exists(event_id);
|
||||
let state = state_res::resolve(room_version_id, &fork_states, &auth_chain_sets, &event_fetch, &event_exists)
|
||||
.boxed()
|
||||
.await
|
||||
.map_err(|e| err!(Database(error!("State resolution failed: {e:?}"))))?;
|
||||
let state = state_res::resolve(
|
||||
room_version_id,
|
||||
&fork_states,
|
||||
&auth_chain_sets,
|
||||
&event_fetch,
|
||||
&event_exists,
|
||||
)
|
||||
.boxed()
|
||||
.await
|
||||
.map_err(|e| err!(Database(error!("State resolution failed: {e:?}"))))?;
|
||||
|
||||
drop(lock);
|
||||
|
||||
|
|
|
@ -21,7 +21,8 @@ use ruma::{
|
|||
// request and build the state from a known point and resolve if > 1 prev_event
|
||||
#[tracing::instrument(skip_all, name = "state")]
|
||||
pub(super) async fn state_at_incoming_degree_one(
|
||||
&self, incoming_pdu: &Arc<PduEvent>,
|
||||
&self,
|
||||
incoming_pdu: &Arc<PduEvent>,
|
||||
) -> Result<Option<HashMap<u64, Arc<EventId>>>> {
|
||||
let prev_event = &*incoming_pdu.prev_events[0];
|
||||
let Ok(prev_event_sstatehash) = self
|
||||
|
@ -70,7 +71,10 @@ pub(super) async fn state_at_incoming_degree_one(
|
|||
#[implement(super::Service)]
|
||||
#[tracing::instrument(skip_all, name = "state")]
|
||||
pub(super) async fn state_at_incoming_resolved(
|
||||
&self, incoming_pdu: &Arc<PduEvent>, room_id: &RoomId, room_version_id: &RoomVersionId,
|
||||
&self,
|
||||
incoming_pdu: &Arc<PduEvent>,
|
||||
room_id: &RoomId,
|
||||
room_version_id: &RoomVersionId,
|
||||
) -> Result<Option<HashMap<u64, Arc<EventId>>>> {
|
||||
debug!("Calculating state at event using state res");
|
||||
let mut extremity_sstatehashes = HashMap::with_capacity(incoming_pdu.prev_events.len());
|
||||
|
@ -157,10 +161,16 @@ pub(super) async fn state_at_incoming_resolved(
|
|||
|
||||
let event_fetch = |event_id| self.event_fetch(event_id);
|
||||
let event_exists = |event_id| self.event_exists(event_id);
|
||||
let result = state_res::resolve(room_version_id, &fork_states, &auth_chain_sets, &event_fetch, &event_exists)
|
||||
.boxed()
|
||||
.await
|
||||
.map_err(|e| err!(Database(warn!(?e, "State resolution on prev events failed."))));
|
||||
let result = state_res::resolve(
|
||||
room_version_id,
|
||||
&fork_states,
|
||||
&auth_chain_sets,
|
||||
&event_fetch,
|
||||
&event_exists,
|
||||
)
|
||||
.boxed()
|
||||
.await
|
||||
.map_err(|e| err!(Database(warn!(?e, "State resolution on prev events failed."))));
|
||||
|
||||
drop(lock);
|
||||
|
||||
|
|
|
@ -19,8 +19,12 @@ use crate::rooms::{state_compressor::HashSetCompressStateEvent, timeline::RawPdu
|
|||
|
||||
#[implement(super::Service)]
|
||||
pub(super) async fn upgrade_outlier_to_timeline_pdu(
|
||||
&self, incoming_pdu: Arc<PduEvent>, val: BTreeMap<String, CanonicalJsonValue>, create_event: &PduEvent,
|
||||
origin: &ServerName, room_id: &RoomId,
|
||||
&self,
|
||||
incoming_pdu: Arc<PduEvent>,
|
||||
val: BTreeMap<String, CanonicalJsonValue>,
|
||||
create_event: &PduEvent,
|
||||
origin: &ServerName,
|
||||
room_id: &RoomId,
|
||||
) -> Result<Option<RawPduId>> {
|
||||
// Skip the PDU if we already have it as a timeline event
|
||||
if let Ok(pduid) = self
|
||||
|
@ -63,7 +67,8 @@ pub(super) async fn upgrade_outlier_to_timeline_pdu(
|
|||
.await?;
|
||||
}
|
||||
|
||||
let state_at_incoming_event = state_at_incoming_event.expect("we always set this to some above");
|
||||
let state_at_incoming_event =
|
||||
state_at_incoming_event.expect("we always set this to some above");
|
||||
let room_version = to_room_version(&room_version_id);
|
||||
|
||||
debug!("Performing auth check");
|
||||
|
@ -124,24 +129,34 @@ pub(super) async fn upgrade_outlier_to_timeline_pdu(
|
|||
!auth_check
|
||||
|| incoming_pdu.kind == TimelineEventType::RoomRedaction
|
||||
&& match room_version_id {
|
||||
V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => {
|
||||
| V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => {
|
||||
if let Some(redact_id) = &incoming_pdu.redacts {
|
||||
!self
|
||||
.services
|
||||
.state_accessor
|
||||
.user_can_redact(redact_id, &incoming_pdu.sender, &incoming_pdu.room_id, true)
|
||||
.user_can_redact(
|
||||
redact_id,
|
||||
&incoming_pdu.sender,
|
||||
&incoming_pdu.room_id,
|
||||
true,
|
||||
)
|
||||
.await?
|
||||
} else {
|
||||
false
|
||||
}
|
||||
},
|
||||
_ => {
|
||||
| _ => {
|
||||
let content: RoomRedactionEventContent = incoming_pdu.get_content()?;
|
||||
if let Some(redact_id) = &content.redacts {
|
||||
!self
|
||||
.services
|
||||
.state_accessor
|
||||
.user_can_redact(redact_id, &incoming_pdu.sender, &incoming_pdu.room_id, true)
|
||||
.user_can_redact(
|
||||
redact_id,
|
||||
&incoming_pdu.sender,
|
||||
&incoming_pdu.room_id,
|
||||
true,
|
||||
)
|
||||
.await?
|
||||
} else {
|
||||
false
|
||||
|
@ -229,11 +244,7 @@ pub(super) async fn upgrade_outlier_to_timeline_pdu(
|
|||
|
||||
// Set the new room state to the resolved state
|
||||
debug!("Forcing new room state");
|
||||
let HashSetCompressStateEvent {
|
||||
shortstatehash,
|
||||
added,
|
||||
removed,
|
||||
} = self
|
||||
let HashSetCompressStateEvent { shortstatehash, added, removed } = self
|
||||
.services
|
||||
.state_compressor
|
||||
.save_state(room_id, new_room_state)
|
||||
|
|
|
@ -51,7 +51,11 @@ impl crate::Service for Service {
|
|||
#[tracing::instrument(skip(self), level = "debug")]
|
||||
#[inline]
|
||||
pub async fn lazy_load_was_sent_before(
|
||||
&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, ll_user: &UserId,
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
room_id: &RoomId,
|
||||
ll_user: &UserId,
|
||||
) -> bool {
|
||||
let key = (user_id, device_id, room_id, ll_user);
|
||||
self.db.lazyloadedids.qry(&key).await.is_ok()
|
||||
|
@ -60,7 +64,12 @@ pub async fn lazy_load_was_sent_before(
|
|||
#[implement(Service)]
|
||||
#[tracing::instrument(skip(self), level = "debug")]
|
||||
pub fn lazy_load_mark_sent(
|
||||
&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, lazy_load: HashSet<OwnedUserId>, count: PduCount,
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
room_id: &RoomId,
|
||||
lazy_load: HashSet<OwnedUserId>,
|
||||
count: PduCount,
|
||||
) {
|
||||
let key = (user_id.to_owned(), device_id.to_owned(), room_id.to_owned(), count);
|
||||
|
||||
|
@ -72,7 +81,13 @@ pub fn lazy_load_mark_sent(
|
|||
|
||||
#[implement(Service)]
|
||||
#[tracing::instrument(skip(self), level = "debug")]
|
||||
pub fn lazy_load_confirm_delivery(&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, since: PduCount) {
|
||||
pub fn lazy_load_confirm_delivery(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
room_id: &RoomId,
|
||||
since: PduCount,
|
||||
) {
|
||||
let key = (user_id.to_owned(), device_id.to_owned(), room_id.to_owned(), since);
|
||||
|
||||
let Some(user_ids) = self.lazy_load_waiting.lock().expect("locked").remove(&key) else {
|
||||
|
|
|
@ -58,7 +58,9 @@ pub async fn exists(&self, room_id: &RoomId) -> bool {
|
|||
}
|
||||
|
||||
#[implement(Service)]
|
||||
pub fn iter_ids(&self) -> impl Stream<Item = &RoomId> + Send + '_ { self.db.roomid_shortroomid.keys().ignore_err() }
|
||||
pub fn iter_ids(&self) -> impl Stream<Item = &RoomId> + Send + '_ {
|
||||
self.db.roomid_shortroomid.keys().ignore_err()
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
#[inline]
|
||||
|
@ -81,12 +83,18 @@ pub fn ban_room(&self, room_id: &RoomId, banned: bool) {
|
|||
}
|
||||
|
||||
#[implement(Service)]
|
||||
pub fn list_banned_rooms(&self) -> impl Stream<Item = &RoomId> + Send + '_ { self.db.bannedroomids.keys().ignore_err() }
|
||||
pub fn list_banned_rooms(&self) -> impl Stream<Item = &RoomId> + Send + '_ {
|
||||
self.db.bannedroomids.keys().ignore_err()
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
#[inline]
|
||||
pub async fn is_disabled(&self, room_id: &RoomId) -> bool { self.db.disabledroomids.get(room_id).await.is_ok() }
|
||||
pub async fn is_disabled(&self, room_id: &RoomId) -> bool {
|
||||
self.db.disabledroomids.get(room_id).await.is_ok()
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
#[inline]
|
||||
pub async fn is_banned(&self, room_id: &RoomId) -> bool { self.db.bannedroomids.get(room_id).await.is_ok() }
|
||||
pub async fn is_banned(&self, room_id: &RoomId) -> bool {
|
||||
self.db.bannedroomids.get(room_id).await.is_ok()
|
||||
}
|
||||
|
|
|
@ -56,26 +56,27 @@ impl Data {
|
|||
}
|
||||
|
||||
pub(super) fn get_relations<'a>(
|
||||
&'a self, user_id: &'a UserId, shortroomid: ShortRoomId, target: ShortEventId, from: PduCount, dir: Direction,
|
||||
&'a self,
|
||||
user_id: &'a UserId,
|
||||
shortroomid: ShortRoomId,
|
||||
target: ShortEventId,
|
||||
from: PduCount,
|
||||
dir: Direction,
|
||||
) -> impl Stream<Item = PdusIterItem> + Send + '_ {
|
||||
let mut current = ArrayVec::<u8, 16>::new();
|
||||
current.extend(target.to_be_bytes());
|
||||
current.extend(from.saturating_inc(dir).into_unsigned().to_be_bytes());
|
||||
let current = current.as_slice();
|
||||
match dir {
|
||||
Direction::Forward => self.tofrom_relation.raw_keys_from(current).boxed(),
|
||||
Direction::Backward => self.tofrom_relation.rev_raw_keys_from(current).boxed(),
|
||||
| Direction::Forward => self.tofrom_relation.raw_keys_from(current).boxed(),
|
||||
| Direction::Backward => self.tofrom_relation.rev_raw_keys_from(current).boxed(),
|
||||
}
|
||||
.ignore_err()
|
||||
.ready_take_while(move |key| key.starts_with(&target.to_be_bytes()))
|
||||
.map(|to_from| u64_from_u8(&to_from[8..16]))
|
||||
.map(PduCount::from_unsigned)
|
||||
.wide_filter_map(move |shorteventid| async move {
|
||||
let pdu_id: RawPduId = PduId {
|
||||
shortroomid,
|
||||
shorteventid,
|
||||
}
|
||||
.into();
|
||||
let pdu_id: RawPduId = PduId { shortroomid, shorteventid }.into();
|
||||
|
||||
let mut pdu = self.services.timeline.get_pdu_from_id(&pdu_id).await.ok()?;
|
||||
|
||||
|
@ -99,7 +100,9 @@ impl Data {
|
|||
self.referencedevents.qry(&key).await.is_ok()
|
||||
}
|
||||
|
||||
pub(super) fn mark_event_soft_failed(&self, event_id: &EventId) { self.softfailedeventids.insert(event_id, []); }
|
||||
pub(super) fn mark_event_soft_failed(&self, event_id: &EventId) {
|
||||
self.softfailedeventids.insert(event_id, []);
|
||||
}
|
||||
|
||||
pub(super) async fn is_event_soft_failed(&self, event_id: &EventId) -> bool {
|
||||
self.softfailedeventids.get(event_id).await.is_ok()
|
||||
|
|
|
@ -36,8 +36,8 @@ impl Service {
|
|||
#[tracing::instrument(skip(self, from, to), level = "debug")]
|
||||
pub fn add_relation(&self, from: PduCount, to: PduCount) {
|
||||
match (from, to) {
|
||||
(PduCount::Normal(f), PduCount::Normal(t)) => self.db.add_relation(f, t),
|
||||
_ => {
|
||||
| (PduCount::Normal(f), PduCount::Normal(t)) => self.db.add_relation(f, t),
|
||||
| _ => {
|
||||
// TODO: Relations with backfilled pdus
|
||||
},
|
||||
}
|
||||
|
@ -45,15 +45,21 @@ impl Service {
|
|||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn get_relations(
|
||||
&self, user_id: &UserId, room_id: &RoomId, target: &EventId, from: PduCount, limit: usize, max_depth: u8,
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
room_id: &RoomId,
|
||||
target: &EventId,
|
||||
from: PduCount,
|
||||
limit: usize,
|
||||
max_depth: u8,
|
||||
dir: Direction,
|
||||
) -> Vec<PdusIterItem> {
|
||||
let room_id = self.services.short.get_or_create_shortroomid(room_id).await;
|
||||
|
||||
let target = match self.services.timeline.get_pdu_count(target).await {
|
||||
Ok(PduCount::Normal(c)) => c,
|
||||
| Ok(PduCount::Normal(c)) => c,
|
||||
// TODO: Support backfilled relations
|
||||
_ => 0, // This will result in an empty iterator
|
||||
| _ => 0, // This will result in an empty iterator
|
||||
};
|
||||
|
||||
let mut pdus: Vec<_> = self
|
||||
|
@ -66,9 +72,9 @@ impl Service {
|
|||
|
||||
'limit: while let Some(stack_pdu) = stack.pop() {
|
||||
let target = match stack_pdu.0 .0 {
|
||||
PduCount::Normal(c) => c,
|
||||
| PduCount::Normal(c) => c,
|
||||
// TODO: Support backfilled relations
|
||||
PduCount::Backfilled(_) => 0, // This will result in an empty iterator
|
||||
| PduCount::Backfilled(_) => 0, // This will result in an empty iterator
|
||||
};
|
||||
|
||||
let relations: Vec<_> = self
|
||||
|
@ -106,7 +112,9 @@ impl Service {
|
|||
|
||||
#[inline]
|
||||
#[tracing::instrument(skip(self), level = "debug")]
|
||||
pub fn mark_event_soft_failed(&self, event_id: &EventId) { self.db.mark_event_soft_failed(event_id) }
|
||||
pub fn mark_event_soft_failed(&self, event_id: &EventId) {
|
||||
self.db.mark_event_soft_failed(event_id);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
#[tracing::instrument(skip(self), level = "debug")]
|
||||
|
|
|
@ -40,7 +40,12 @@ impl Data {
|
|||
}
|
||||
}
|
||||
|
||||
pub(super) async fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: &ReceiptEvent) {
|
||||
pub(super) async fn readreceipt_update(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
room_id: &RoomId,
|
||||
event: &ReceiptEvent,
|
||||
) {
|
||||
// Remove old entry
|
||||
let last_possible_key = (room_id, u64::MAX);
|
||||
self.readreceiptid_readreceipt
|
||||
|
@ -57,7 +62,9 @@ impl Data {
|
|||
}
|
||||
|
||||
pub(super) fn readreceipts_since<'a>(
|
||||
&'a self, room_id: &'a RoomId, since: u64,
|
||||
&'a self,
|
||||
room_id: &'a RoomId,
|
||||
since: u64,
|
||||
) -> impl Stream<Item = ReceiptItem<'_>> + Send + 'a {
|
||||
type Key<'a> = (&'a RoomId, u64, &'a UserId);
|
||||
type KeyVal<'a> = (Key<'a>, CanonicalJsonObject);
|
||||
|
@ -87,12 +94,20 @@ impl Data {
|
|||
self.roomuserid_lastprivatereadupdate.put(key, next_count);
|
||||
}
|
||||
|
||||
pub(super) async fn private_read_get_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<u64> {
|
||||
pub(super) async fn private_read_get_count(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
user_id: &UserId,
|
||||
) -> Result<u64> {
|
||||
let key = (room_id, user_id);
|
||||
self.roomuserid_privateread.qry(&key).await.deserialized()
|
||||
}
|
||||
|
||||
pub(super) async fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> u64 {
|
||||
pub(super) async fn last_privateread_update(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
room_id: &RoomId,
|
||||
) -> u64 {
|
||||
let key = (room_id, user_id);
|
||||
self.roomuserid_lastprivatereadupdate
|
||||
.qry(&key)
|
||||
|
|
|
@ -44,7 +44,12 @@ impl crate::Service for Service {
|
|||
|
||||
impl Service {
|
||||
/// Replaces the previous read receipt.
|
||||
pub async fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: &ReceiptEvent) {
|
||||
pub async fn readreceipt_update(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
room_id: &RoomId,
|
||||
event: &ReceiptEvent,
|
||||
) {
|
||||
self.db.readreceipt_update(user_id, room_id, event).await;
|
||||
self.services
|
||||
.sending
|
||||
|
@ -54,23 +59,21 @@ impl Service {
|
|||
}
|
||||
|
||||
/// Gets the latest private read receipt from the user in the room
|
||||
pub async fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result<Raw<AnySyncEphemeralRoomEvent>> {
|
||||
let pdu_count = self
|
||||
.private_read_get_count(room_id, user_id)
|
||||
.map_err(|e| err!(Database(warn!("No private read receipt was set in {room_id}: {e}"))));
|
||||
let shortroomid = self
|
||||
.services
|
||||
.short
|
||||
.get_shortroomid(room_id)
|
||||
.map_err(|e| err!(Database(warn!("Short room ID does not exist in database for {room_id}: {e}"))));
|
||||
pub async fn private_read_get(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
user_id: &UserId,
|
||||
) -> Result<Raw<AnySyncEphemeralRoomEvent>> {
|
||||
let pdu_count = self.private_read_get_count(room_id, user_id).map_err(|e| {
|
||||
err!(Database(warn!("No private read receipt was set in {room_id}: {e}")))
|
||||
});
|
||||
let shortroomid = self.services.short.get_shortroomid(room_id).map_err(|e| {
|
||||
err!(Database(warn!("Short room ID does not exist in database for {room_id}: {e}")))
|
||||
});
|
||||
let (pdu_count, shortroomid) = try_join!(pdu_count, shortroomid)?;
|
||||
|
||||
let shorteventid = PduCount::Normal(pdu_count);
|
||||
let pdu_id: RawPduId = PduId {
|
||||
shortroomid,
|
||||
shorteventid,
|
||||
}
|
||||
.into();
|
||||
let pdu_id: RawPduId = PduId { shortroomid, shorteventid }.into();
|
||||
|
||||
let pdu = self.services.timeline.get_pdu_from_id(&pdu_id).await?;
|
||||
|
||||
|
@ -80,21 +83,17 @@ impl Service {
|
|||
event_id,
|
||||
BTreeMap::from_iter([(
|
||||
ruma::events::receipt::ReceiptType::ReadPrivate,
|
||||
BTreeMap::from_iter([(
|
||||
user_id,
|
||||
ruma::events::receipt::Receipt {
|
||||
ts: None, // TODO: start storing the timestamp so we can return one
|
||||
thread: ruma::events::receipt::ReceiptThread::Unthreaded,
|
||||
},
|
||||
)]),
|
||||
BTreeMap::from_iter([(user_id, ruma::events::receipt::Receipt {
|
||||
ts: None, // TODO: start storing the timestamp so we can return one
|
||||
thread: ruma::events::receipt::ReceiptThread::Unthreaded,
|
||||
})]),
|
||||
)]),
|
||||
)]);
|
||||
let receipt_event_content = ReceiptEventContent(content);
|
||||
let receipt_sync_event = SyncEphemeralRoomEvent {
|
||||
content: receipt_event_content,
|
||||
};
|
||||
let receipt_sync_event = SyncEphemeralRoomEvent { content: receipt_event_content };
|
||||
|
||||
let event = serde_json::value::to_raw_value(&receipt_sync_event).expect("receipt created manually");
|
||||
let event = serde_json::value::to_raw_value(&receipt_sync_event)
|
||||
.expect("receipt created manually");
|
||||
|
||||
Ok(Raw::from_json(event))
|
||||
}
|
||||
|
@ -104,7 +103,9 @@ impl Service {
|
|||
#[inline]
|
||||
#[tracing::instrument(skip(self), level = "debug")]
|
||||
pub fn readreceipts_since<'a>(
|
||||
&'a self, room_id: &'a RoomId, since: u64,
|
||||
&'a self,
|
||||
room_id: &'a RoomId,
|
||||
since: u64,
|
||||
) -> impl Stream<Item = ReceiptItem<'_>> + Send + 'a {
|
||||
self.db.readreceipts_since(room_id, since)
|
||||
}
|
||||
|
@ -119,7 +120,11 @@ impl Service {
|
|||
/// Returns the private read marker PDU count.
|
||||
#[inline]
|
||||
#[tracing::instrument(skip(self), level = "debug")]
|
||||
pub async fn private_read_get_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<u64> {
|
||||
pub async fn private_read_get_count(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
user_id: &UserId,
|
||||
) -> Result<u64> {
|
||||
self.db.private_read_get_count(room_id, user_id).await
|
||||
}
|
||||
|
||||
|
@ -137,7 +142,9 @@ where
|
|||
{
|
||||
let mut json = BTreeMap::new();
|
||||
for value in receipts {
|
||||
let receipt = serde_json::from_str::<SyncEphemeralRoomEvent<ReceiptEventContent>>(value.json().get());
|
||||
let receipt = serde_json::from_str::<SyncEphemeralRoomEvent<ReceiptEventContent>>(
|
||||
value.json().get(),
|
||||
);
|
||||
if let Ok(value) = receipt {
|
||||
for (event, receipt) in value.content {
|
||||
json.insert(event, receipt);
|
||||
|
@ -149,9 +156,7 @@ where
|
|||
let content = ReceiptEventContent::from_iter(json);
|
||||
|
||||
Raw::from_json(
|
||||
serde_json::value::to_raw_value(&SyncEphemeralRoomEvent {
|
||||
content,
|
||||
})
|
||||
.expect("received valid json"),
|
||||
serde_json::value::to_raw_value(&SyncEphemeralRoomEvent { content })
|
||||
.expect("received valid json"),
|
||||
)
|
||||
}
|
||||
|
|
|
@ -49,18 +49,18 @@ pub struct RoomQuery<'a> {
|
|||
|
||||
type TokenId = ArrayVec<u8, TOKEN_ID_MAX_LEN>;
|
||||
|
||||
const TOKEN_ID_MAX_LEN: usize = size_of::<ShortRoomId>() + WORD_MAX_LEN + 1 + size_of::<RawPduId>();
|
||||
const TOKEN_ID_MAX_LEN: usize =
|
||||
size_of::<ShortRoomId>() + WORD_MAX_LEN + 1 + size_of::<RawPduId>();
|
||||
const WORD_MAX_LEN: usize = 50;
|
||||
|
||||
impl crate::Service for Service {
|
||||
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
|
||||
Ok(Arc::new(Self {
|
||||
db: Data {
|
||||
tokenids: args.db["tokenids"].clone(),
|
||||
},
|
||||
db: Data { tokenids: args.db["tokenids"].clone() },
|
||||
services: Services {
|
||||
short: args.depend::<rooms::short::Service>("rooms::short"),
|
||||
state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
|
||||
state_accessor: args
|
||||
.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
|
||||
timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"),
|
||||
},
|
||||
}))
|
||||
|
@ -103,7 +103,8 @@ pub fn deindex_pdu(&self, shortroomid: ShortRoomId, pdu_id: &RawPduId, message_b
|
|||
|
||||
#[implement(Service)]
|
||||
pub async fn search_pdus<'a>(
|
||||
&'a self, query: &'a RoomQuery<'a>,
|
||||
&'a self,
|
||||
query: &'a RoomQuery<'a>,
|
||||
) -> Result<(usize, impl Stream<Item = PduEvent> + Send + 'a)> {
|
||||
let pdu_ids: Vec<_> = self.search_pdu_ids(query).await?.collect().await;
|
||||
|
||||
|
@ -136,7 +137,10 @@ pub async fn search_pdus<'a>(
|
|||
// result is modeled as a stream such that callers don't have to be refactored
|
||||
// though an additional async/wrap still exists for now
|
||||
#[implement(Service)]
|
||||
pub async fn search_pdu_ids(&self, query: &RoomQuery<'_>) -> Result<impl Stream<Item = RawPduId> + Send + '_> {
|
||||
pub async fn search_pdu_ids(
|
||||
&self,
|
||||
query: &RoomQuery<'_>,
|
||||
) -> Result<impl Stream<Item = RawPduId> + Send + '_> {
|
||||
let shortroomid = self.services.short.get_shortroomid(query.room_id).await?;
|
||||
|
||||
let pdu_ids = self.search_pdu_ids_query_room(query, shortroomid).await;
|
||||
|
@ -147,7 +151,11 @@ pub async fn search_pdu_ids(&self, query: &RoomQuery<'_>) -> Result<impl Stream<
|
|||
}
|
||||
|
||||
#[implement(Service)]
|
||||
async fn search_pdu_ids_query_room(&self, query: &RoomQuery<'_>, shortroomid: ShortRoomId) -> Vec<Vec<RawPduId>> {
|
||||
async fn search_pdu_ids_query_room(
|
||||
&self,
|
||||
query: &RoomQuery<'_>,
|
||||
shortroomid: ShortRoomId,
|
||||
) -> Vec<Vec<RawPduId>> {
|
||||
tokenize(&query.criteria.search_term)
|
||||
.stream()
|
||||
.wide_then(|word| async move {
|
||||
|
@ -162,7 +170,9 @@ async fn search_pdu_ids_query_room(&self, query: &RoomQuery<'_>, shortroomid: Sh
|
|||
/// Iterate over PduId's containing a word
|
||||
#[implement(Service)]
|
||||
fn search_pdu_ids_query_words<'a>(
|
||||
&'a self, shortroomid: ShortRoomId, word: &'a str,
|
||||
&'a self,
|
||||
shortroomid: ShortRoomId,
|
||||
word: &'a str,
|
||||
) -> impl Stream<Item = RawPduId> + Send + '_ {
|
||||
self.search_pdu_ids_query_word(shortroomid, word)
|
||||
.map(move |key| -> RawPduId {
|
||||
|
@ -173,7 +183,11 @@ fn search_pdu_ids_query_words<'a>(
|
|||
|
||||
/// Iterate over raw database results for a word
|
||||
#[implement(Service)]
|
||||
fn search_pdu_ids_query_word(&self, shortroomid: ShortRoomId, word: &str) -> impl Stream<Item = Val<'_>> + Send + '_ {
|
||||
fn search_pdu_ids_query_word(
|
||||
&self,
|
||||
shortroomid: ShortRoomId,
|
||||
word: &str,
|
||||
) -> impl Stream<Item = Val<'_>> + Send + '_ {
|
||||
// rustc says const'ing this not yet stable
|
||||
let end_id: RawPduId = PduId {
|
||||
shortroomid,
|
||||
|
|
|
@ -60,7 +60,10 @@ pub async fn get_or_create_shorteventid(&self, event_id: &EventId) -> ShortEvent
|
|||
}
|
||||
|
||||
#[implement(Service)]
|
||||
pub fn multi_get_or_create_shorteventid<'a, I>(&'a self, event_ids: I) -> impl Stream<Item = ShortEventId> + Send + '_
|
||||
pub fn multi_get_or_create_shorteventid<'a, I>(
|
||||
&'a self,
|
||||
event_ids: I,
|
||||
) -> impl Stream<Item = ShortEventId> + Send + '_
|
||||
where
|
||||
I: Iterator<Item = &'a EventId> + Clone + Debug + ExactSizeIterator + Send + 'a,
|
||||
<I as Iterator>::Item: AsRef<[u8]> + Send + Sync + 'a,
|
||||
|
@ -72,8 +75,8 @@ where
|
|||
event_ids.next().map(|event_id| (event_id, result))
|
||||
})
|
||||
.map(|(event_id, result)| match result {
|
||||
Ok(ref short) => utils::u64_from_u8(short),
|
||||
Err(_) => self.create_shorteventid(event_id),
|
||||
| Ok(ref short) => utils::u64_from_u8(short),
|
||||
| Err(_) => self.create_shorteventid(event_id),
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -104,7 +107,11 @@ pub async fn get_shorteventid(&self, event_id: &EventId) -> Result<ShortEventId>
|
|||
}
|
||||
|
||||
#[implement(Service)]
|
||||
pub async fn get_or_create_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> ShortStateKey {
|
||||
pub async fn get_or_create_shortstatekey(
|
||||
&self,
|
||||
event_type: &StateEventType,
|
||||
state_key: &str,
|
||||
) -> ShortStateKey {
|
||||
const BUFSIZE: usize = size_of::<ShortStateKey>();
|
||||
|
||||
if let Ok(shortstatekey) = self.get_shortstatekey(event_type, state_key).await {
|
||||
|
@ -127,7 +134,11 @@ pub async fn get_or_create_shortstatekey(&self, event_type: &StateEventType, sta
|
|||
}
|
||||
|
||||
#[implement(Service)]
|
||||
pub async fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result<ShortStateKey> {
|
||||
pub async fn get_shortstatekey(
|
||||
&self,
|
||||
event_type: &StateEventType,
|
||||
state_key: &str,
|
||||
) -> Result<ShortStateKey> {
|
||||
let key = (event_type, state_key);
|
||||
self.db
|
||||
.statekey_shortstatekey
|
||||
|
@ -153,7 +164,10 @@ where
|
|||
}
|
||||
|
||||
#[implement(Service)]
|
||||
pub fn multi_get_eventid_from_short<'a, Id, I>(&'a self, shorteventid: I) -> impl Stream<Item = Result<Id>> + Send + 'a
|
||||
pub fn multi_get_eventid_from_short<'a, Id, I>(
|
||||
&'a self,
|
||||
shorteventid: I,
|
||||
) -> impl Stream<Item = Result<Id>> + Send + 'a
|
||||
where
|
||||
I: Iterator<Item = &'a ShortEventId> + Send + 'a,
|
||||
Id: for<'de> Deserialize<'de> + Sized + ToOwned + 'a,
|
||||
|
@ -168,7 +182,10 @@ where
|
|||
}
|
||||
|
||||
#[implement(Service)]
|
||||
pub async fn get_statekey_from_short(&self, shortstatekey: ShortStateKey) -> Result<(StateEventType, String)> {
|
||||
pub async fn get_statekey_from_short(
|
||||
&self,
|
||||
shortstatekey: ShortStateKey,
|
||||
) -> Result<(StateEventType, String)> {
|
||||
const BUFSIZE: usize = size_of::<ShortStateKey>();
|
||||
|
||||
self.db
|
||||
|
|
|
@ -125,7 +125,8 @@ enum Identifier<'a> {
|
|||
|
||||
pub struct Service {
|
||||
services: Services,
|
||||
pub roomid_spacehierarchy_cache: Mutex<LruCache<OwnedRoomId, Option<CachedSpaceHierarchySummary>>>,
|
||||
pub roomid_spacehierarchy_cache:
|
||||
Mutex<LruCache<OwnedRoomId, Option<CachedSpaceHierarchySummary>>>,
|
||||
}
|
||||
|
||||
struct Services {
|
||||
|
@ -145,11 +146,13 @@ impl crate::Service for Service {
|
|||
let cache_size = cache_size * config.cache_capacity_modifier;
|
||||
Ok(Arc::new(Self {
|
||||
services: Services {
|
||||
state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
|
||||
state_accessor: args
|
||||
.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
|
||||
state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"),
|
||||
state: args.depend::<rooms::state::Service>("rooms::state"),
|
||||
short: args.depend::<rooms::short::Service>("rooms::short"),
|
||||
event_handler: args.depend::<rooms::event_handler::Service>("rooms::event_handler"),
|
||||
event_handler: args
|
||||
.depend::<rooms::event_handler::Service>("rooms::event_handler"),
|
||||
timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"),
|
||||
sending: args.depend::<sending::Service>("sending"),
|
||||
},
|
||||
|
@ -166,28 +169,37 @@ impl Service {
|
|||
/// Errors if the room does not exist, so a check if the room exists should
|
||||
/// be done
|
||||
pub async fn get_federation_hierarchy(
|
||||
&self, room_id: &RoomId, server_name: &ServerName, suggested_only: bool,
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
server_name: &ServerName,
|
||||
suggested_only: bool,
|
||||
) -> Result<federation::space::get_hierarchy::v1::Response> {
|
||||
match self
|
||||
.get_summary_and_children_local(&room_id.to_owned(), Identifier::ServerName(server_name))
|
||||
.get_summary_and_children_local(
|
||||
&room_id.to_owned(),
|
||||
Identifier::ServerName(server_name),
|
||||
)
|
||||
.await?
|
||||
{
|
||||
Some(SummaryAccessibility::Accessible(room)) => {
|
||||
| Some(SummaryAccessibility::Accessible(room)) => {
|
||||
let mut children = Vec::new();
|
||||
let mut inaccessible_children = Vec::new();
|
||||
|
||||
for (child, _via) in get_parent_children_via(&room, suggested_only) {
|
||||
match self
|
||||
.get_summary_and_children_local(&child, Identifier::ServerName(server_name))
|
||||
.get_summary_and_children_local(
|
||||
&child,
|
||||
Identifier::ServerName(server_name),
|
||||
)
|
||||
.await?
|
||||
{
|
||||
Some(SummaryAccessibility::Accessible(summary)) => {
|
||||
| Some(SummaryAccessibility::Accessible(summary)) => {
|
||||
children.push((*summary).into());
|
||||
},
|
||||
Some(SummaryAccessibility::Inaccessible) => {
|
||||
| Some(SummaryAccessibility::Inaccessible) => {
|
||||
inaccessible_children.push(child);
|
||||
},
|
||||
None => (),
|
||||
| None => (),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -197,16 +209,18 @@ impl Service {
|
|||
inaccessible_children,
|
||||
})
|
||||
},
|
||||
Some(SummaryAccessibility::Inaccessible) => {
|
||||
Err(Error::BadRequest(ErrorKind::NotFound, "The requested room is inaccessible"))
|
||||
},
|
||||
None => Err(Error::BadRequest(ErrorKind::NotFound, "The requested room was not found")),
|
||||
| Some(SummaryAccessibility::Inaccessible) =>
|
||||
Err(Error::BadRequest(ErrorKind::NotFound, "The requested room is inaccessible")),
|
||||
| None =>
|
||||
Err(Error::BadRequest(ErrorKind::NotFound, "The requested room was not found")),
|
||||
}
|
||||
}
|
||||
|
||||
/// Gets the summary of a space using solely local information
|
||||
async fn get_summary_and_children_local(
|
||||
&self, current_room: &OwnedRoomId, identifier: Identifier<'_>,
|
||||
&self,
|
||||
current_room: &OwnedRoomId,
|
||||
identifier: Identifier<'_>,
|
||||
) -> Result<Option<SummaryAccessibility>> {
|
||||
if let Some(cached) = self
|
||||
.roomid_spacehierarchy_cache
|
||||
|
@ -241,9 +255,7 @@ impl Service {
|
|||
if let Ok(summary) = summary {
|
||||
self.roomid_spacehierarchy_cache.lock().await.insert(
|
||||
current_room.clone(),
|
||||
Some(CachedSpaceHierarchySummary {
|
||||
summary: summary.clone(),
|
||||
}),
|
||||
Some(CachedSpaceHierarchySummary { summary: summary.clone() }),
|
||||
);
|
||||
|
||||
Ok(Some(SummaryAccessibility::Accessible(Box::new(summary))))
|
||||
|
@ -258,20 +270,21 @@ impl Service {
|
|||
/// Gets the summary of a space using solely federation
|
||||
#[tracing::instrument(skip(self))]
|
||||
async fn get_summary_and_children_federation(
|
||||
&self, current_room: &OwnedRoomId, suggested_only: bool, user_id: &UserId, via: &[OwnedServerName],
|
||||
&self,
|
||||
current_room: &OwnedRoomId,
|
||||
suggested_only: bool,
|
||||
user_id: &UserId,
|
||||
via: &[OwnedServerName],
|
||||
) -> Result<Option<SummaryAccessibility>> {
|
||||
for server in via {
|
||||
debug_info!("Asking {server} for /hierarchy");
|
||||
let Ok(response) = self
|
||||
.services
|
||||
.sending
|
||||
.send_federation_request(
|
||||
server,
|
||||
federation::space::get_hierarchy::v1::Request {
|
||||
room_id: current_room.to_owned(),
|
||||
suggested_only,
|
||||
},
|
||||
)
|
||||
.send_federation_request(server, federation::space::get_hierarchy::v1::Request {
|
||||
room_id: current_room.to_owned(),
|
||||
suggested_only,
|
||||
})
|
||||
.await
|
||||
else {
|
||||
continue;
|
||||
|
@ -282,9 +295,7 @@ impl Service {
|
|||
|
||||
self.roomid_spacehierarchy_cache.lock().await.insert(
|
||||
current_room.clone(),
|
||||
Some(CachedSpaceHierarchySummary {
|
||||
summary: summary.clone(),
|
||||
}),
|
||||
Some(CachedSpaceHierarchySummary { summary: summary.clone() }),
|
||||
);
|
||||
|
||||
for child in response.children {
|
||||
|
@ -356,7 +367,11 @@ impl Service {
|
|||
/// Gets the summary of a space using either local or remote (federation)
|
||||
/// sources
|
||||
async fn get_summary_and_children_client(
|
||||
&self, current_room: &OwnedRoomId, suggested_only: bool, user_id: &UserId, via: &[OwnedServerName],
|
||||
&self,
|
||||
current_room: &OwnedRoomId,
|
||||
suggested_only: bool,
|
||||
user_id: &UserId,
|
||||
via: &[OwnedServerName],
|
||||
) -> Result<Option<SummaryAccessibility>> {
|
||||
if let Ok(Some(response)) = self
|
||||
.get_summary_and_children_local(current_room, Identifier::UserId(user_id))
|
||||
|
@ -370,7 +385,9 @@ impl Service {
|
|||
}
|
||||
|
||||
async fn get_room_summary(
|
||||
&self, current_room: &OwnedRoomId, children_state: Vec<Raw<HierarchySpaceChildEvent>>,
|
||||
&self,
|
||||
current_room: &OwnedRoomId,
|
||||
children_state: Vec<Raw<HierarchySpaceChildEvent>>,
|
||||
identifier: &Identifier<'_>,
|
||||
) -> Result<SpaceHierarchyParentSummary, Error> {
|
||||
let room_id: &RoomId = current_room;
|
||||
|
@ -388,12 +405,20 @@ impl Service {
|
|||
.allowed_room_ids(join_rule.clone());
|
||||
|
||||
if !self
|
||||
.is_accessible_child(current_room, &join_rule.clone().into(), identifier, &allowed_room_ids)
|
||||
.is_accessible_child(
|
||||
current_room,
|
||||
&join_rule.clone().into(),
|
||||
identifier,
|
||||
&allowed_room_ids,
|
||||
)
|
||||
.await
|
||||
{
|
||||
debug_info!("User is not allowed to see room {room_id}");
|
||||
// This error will be caught later
|
||||
return Err(Error::BadRequest(ErrorKind::forbidden(), "User is not allowed to see the room"));
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::forbidden(),
|
||||
"User is not allowed to see the room",
|
||||
));
|
||||
}
|
||||
|
||||
Ok(SpaceHierarchyParentSummary {
|
||||
|
@ -446,7 +471,12 @@ impl Service {
|
|||
}
|
||||
|
||||
pub async fn get_client_hierarchy(
|
||||
&self, sender_user: &UserId, room_id: &RoomId, limit: usize, short_room_ids: Vec<ShortRoomId>, max_depth: u64,
|
||||
&self,
|
||||
sender_user: &UserId,
|
||||
room_id: &RoomId,
|
||||
limit: usize,
|
||||
short_room_ids: Vec<ShortRoomId>,
|
||||
max_depth: u64,
|
||||
suggested_only: bool,
|
||||
) -> Result<client::space::get_hierarchy::v1::Response> {
|
||||
let mut parents = VecDeque::new();
|
||||
|
@ -454,27 +484,30 @@ impl Service {
|
|||
// Don't start populating the results if we have to start at a specific room.
|
||||
let mut populate_results = short_room_ids.is_empty();
|
||||
|
||||
let mut stack = vec![vec![(
|
||||
room_id.to_owned(),
|
||||
match room_id.server_name() {
|
||||
Some(server_name) => vec![server_name.into()],
|
||||
None => vec![],
|
||||
},
|
||||
)]];
|
||||
let mut stack = vec![vec![(room_id.to_owned(), match room_id.server_name() {
|
||||
| Some(server_name) => vec![server_name.into()],
|
||||
| None => vec![],
|
||||
})]];
|
||||
|
||||
let mut results = Vec::with_capacity(limit);
|
||||
|
||||
while let Some((current_room, via)) = { next_room_to_traverse(&mut stack, &mut parents) } {
|
||||
while let Some((current_room, via)) = { next_room_to_traverse(&mut stack, &mut parents) }
|
||||
{
|
||||
if results.len() >= limit {
|
||||
break;
|
||||
}
|
||||
|
||||
match (
|
||||
self.get_summary_and_children_client(¤t_room, suggested_only, sender_user, &via)
|
||||
.await?,
|
||||
self.get_summary_and_children_client(
|
||||
¤t_room,
|
||||
suggested_only,
|
||||
sender_user,
|
||||
&via,
|
||||
)
|
||||
.await?,
|
||||
current_room == room_id,
|
||||
) {
|
||||
(Some(SummaryAccessibility::Accessible(summary)), _) => {
|
||||
| (Some(SummaryAccessibility::Accessible(summary)), _) => {
|
||||
let mut children: Vec<(OwnedRoomId, Vec<OwnedServerName>)> =
|
||||
get_parent_children_via(&summary, suggested_only)
|
||||
.into_iter()
|
||||
|
@ -493,7 +526,9 @@ impl Service {
|
|||
self.services
|
||||
.short
|
||||
.get_shortroomid(room)
|
||||
.map_ok(|short| Some(&short) != short_room_ids.get(parents.len()))
|
||||
.map_ok(|short| {
|
||||
Some(&short) != short_room_ids.get(parents.len())
|
||||
})
|
||||
.unwrap_or_else(|_| false)
|
||||
})
|
||||
.map(Clone::clone)
|
||||
|
@ -525,14 +560,20 @@ impl Service {
|
|||
// Root room in the space hierarchy, we return an error
|
||||
// if this one fails.
|
||||
},
|
||||
(Some(SummaryAccessibility::Inaccessible), true) => {
|
||||
return Err(Error::BadRequest(ErrorKind::forbidden(), "The requested room is inaccessible"));
|
||||
| (Some(SummaryAccessibility::Inaccessible), true) => {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::forbidden(),
|
||||
"The requested room is inaccessible",
|
||||
));
|
||||
},
|
||||
(None, true) => {
|
||||
return Err(Error::BadRequest(ErrorKind::forbidden(), "The requested room was not found"));
|
||||
| (None, true) => {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::forbidden(),
|
||||
"The requested room was not found",
|
||||
));
|
||||
},
|
||||
// Just ignore other unavailable rooms
|
||||
(None | Some(SummaryAccessibility::Inaccessible), false) => (),
|
||||
| (None | Some(SummaryAccessibility::Inaccessible), false) => (),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -544,15 +585,19 @@ impl Service {
|
|||
let short_room_ids: Vec<_> = parents
|
||||
.iter()
|
||||
.stream()
|
||||
.filter_map(|room_id| async move { self.services.short.get_shortroomid(room_id).await.ok() })
|
||||
.filter_map(|room_id| async move {
|
||||
self.services.short.get_shortroomid(room_id).await.ok()
|
||||
})
|
||||
.collect()
|
||||
.await;
|
||||
|
||||
Some(
|
||||
PaginationToken {
|
||||
short_room_ids,
|
||||
limit: UInt::new(max_depth).expect("When sent in request it must have been valid UInt"),
|
||||
max_depth: UInt::new(max_depth).expect("When sent in request it must have been valid UInt"),
|
||||
limit: UInt::new(max_depth)
|
||||
.expect("When sent in request it must have been valid UInt"),
|
||||
max_depth: UInt::new(max_depth)
|
||||
.expect("When sent in request it must have been valid UInt"),
|
||||
suggested_only,
|
||||
}
|
||||
.to_string(),
|
||||
|
@ -566,9 +611,12 @@ impl Service {
|
|||
|
||||
/// Simply returns the stripped m.space.child events of a room
|
||||
async fn get_stripped_space_child_events(
|
||||
&self, room_id: &RoomId,
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
) -> Result<Option<Vec<Raw<HierarchySpaceChildEvent>>>, Error> {
|
||||
let Ok(current_shortstatehash) = self.services.state.get_room_shortstatehash(room_id).await else {
|
||||
let Ok(current_shortstatehash) =
|
||||
self.services.state.get_room_shortstatehash(room_id).await
|
||||
else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
|
@ -581,18 +629,17 @@ impl Service {
|
|||
|
||||
let mut children_pdus = Vec::with_capacity(state.len());
|
||||
for (key, id) in state {
|
||||
let (event_type, state_key) = self.services.short.get_statekey_from_short(key).await?;
|
||||
let (event_type, state_key) =
|
||||
self.services.short.get_statekey_from_short(key).await?;
|
||||
|
||||
if event_type != StateEventType::SpaceChild {
|
||||
continue;
|
||||
}
|
||||
|
||||
let pdu = self
|
||||
.services
|
||||
.timeline
|
||||
.get_pdu(&id)
|
||||
.await
|
||||
.map_err(|e| err!(Database("Event {id:?} in space state not found: {e:?}")))?;
|
||||
let pdu =
|
||||
self.services.timeline.get_pdu(&id).await.map_err(|e| {
|
||||
err!(Database("Event {id:?} in space state not found: {e:?}"))
|
||||
})?;
|
||||
|
||||
if let Ok(content) = pdu.get_content::<SpaceChildEventContent>() {
|
||||
if content.via.is_empty() {
|
||||
|
@ -610,11 +657,14 @@ impl Service {
|
|||
|
||||
/// With the given identifier, checks if a room is accessable
|
||||
async fn is_accessible_child(
|
||||
&self, current_room: &OwnedRoomId, join_rule: &SpaceRoomJoinRule, identifier: &Identifier<'_>,
|
||||
&self,
|
||||
current_room: &OwnedRoomId,
|
||||
join_rule: &SpaceRoomJoinRule,
|
||||
identifier: &Identifier<'_>,
|
||||
allowed_room_ids: &Vec<OwnedRoomId>,
|
||||
) -> bool {
|
||||
match identifier {
|
||||
Identifier::ServerName(server_name) => {
|
||||
| Identifier::ServerName(server_name) => {
|
||||
// Checks if ACLs allow for the server to participate
|
||||
if self
|
||||
.services
|
||||
|
@ -626,7 +676,7 @@ impl Service {
|
|||
return false;
|
||||
}
|
||||
},
|
||||
Identifier::UserId(user_id) => {
|
||||
| Identifier::UserId(user_id) => {
|
||||
if self
|
||||
.services
|
||||
.state_cache
|
||||
|
@ -642,16 +692,18 @@ impl Service {
|
|||
},
|
||||
}
|
||||
match &join_rule {
|
||||
SpaceRoomJoinRule::Public | SpaceRoomJoinRule::Knock | SpaceRoomJoinRule::KnockRestricted => true,
|
||||
SpaceRoomJoinRule::Restricted => {
|
||||
| SpaceRoomJoinRule::Public
|
||||
| SpaceRoomJoinRule::Knock
|
||||
| SpaceRoomJoinRule::KnockRestricted => true,
|
||||
| SpaceRoomJoinRule::Restricted => {
|
||||
for room in allowed_room_ids {
|
||||
match identifier {
|
||||
Identifier::UserId(user) => {
|
||||
| Identifier::UserId(user) => {
|
||||
if self.services.state_cache.is_joined(user, room).await {
|
||||
return true;
|
||||
}
|
||||
},
|
||||
Identifier::ServerName(server) => {
|
||||
| Identifier::ServerName(server) => {
|
||||
if self.services.state_cache.server_in_room(server, room).await {
|
||||
return true;
|
||||
}
|
||||
|
@ -661,7 +713,7 @@ impl Service {
|
|||
false
|
||||
},
|
||||
// Invite only, Private, or Custom join rule
|
||||
_ => false,
|
||||
| _ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -737,7 +789,8 @@ fn summary_to_chunk(summary: SpaceHierarchyParentSummary) -> SpaceHierarchyRooms
|
|||
/// Returns the children of a SpaceHierarchyParentSummary, making use of the
|
||||
/// children_state field
|
||||
fn get_parent_children_via(
|
||||
parent: &SpaceHierarchyParentSummary, suggested_only: bool,
|
||||
parent: &SpaceHierarchyParentSummary,
|
||||
suggested_only: bool,
|
||||
) -> Vec<(OwnedRoomId, Vec<OwnedServerName>)> {
|
||||
parent
|
||||
.children_state
|
||||
|
@ -755,7 +808,8 @@ fn get_parent_children_via(
|
|||
}
|
||||
|
||||
fn next_room_to_traverse(
|
||||
stack: &mut Vec<Vec<(OwnedRoomId, Vec<OwnedServerName>)>>, parents: &mut VecDeque<OwnedRoomId>,
|
||||
stack: &mut Vec<Vec<(OwnedRoomId, Vec<OwnedServerName>)>>,
|
||||
parents: &mut VecDeque<OwnedRoomId>,
|
||||
) -> Option<(OwnedRoomId, Vec<OwnedServerName>)> {
|
||||
while stack.last().is_some_and(Vec::is_empty) {
|
||||
stack.pop();
|
||||
|
|
|
@ -69,18 +69,15 @@ fn get_summary_children() {
|
|||
}
|
||||
.into();
|
||||
|
||||
assert_eq!(
|
||||
get_parent_children_via(&summary, false),
|
||||
vec![
|
||||
(owned_room_id!("!foo:example.org"), vec![owned_server_name!("example.org")]),
|
||||
(owned_room_id!("!bar:example.org"), vec![owned_server_name!("example.org")]),
|
||||
(owned_room_id!("!baz:example.org"), vec![owned_server_name!("example.org")])
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
get_parent_children_via(&summary, true),
|
||||
vec![(owned_room_id!("!bar:example.org"), vec![owned_server_name!("example.org")])]
|
||||
);
|
||||
assert_eq!(get_parent_children_via(&summary, false), vec![
|
||||
(owned_room_id!("!foo:example.org"), vec![owned_server_name!("example.org")]),
|
||||
(owned_room_id!("!bar:example.org"), vec![owned_server_name!("example.org")]),
|
||||
(owned_room_id!("!baz:example.org"), vec![owned_server_name!("example.org")])
|
||||
]);
|
||||
assert_eq!(get_parent_children_via(&summary, true), vec![(
|
||||
owned_room_id!("!bar:example.org"),
|
||||
vec![owned_server_name!("example.org")]
|
||||
)]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
@ -16,7 +16,9 @@ use conduwuit::{
|
|||
warn, PduEvent, Result,
|
||||
};
|
||||
use database::{Deserialized, Ignore, Interfix, Map};
|
||||
use futures::{future::join_all, pin_mut, FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt};
|
||||
use futures::{
|
||||
future::join_all, pin_mut, FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt,
|
||||
};
|
||||
use ruma::{
|
||||
events::{
|
||||
room::{create::RoomCreateEventContent, member::RoomMemberEventContent},
|
||||
|
@ -70,8 +72,10 @@ impl crate::Service for Service {
|
|||
short: args.depend::<rooms::short::Service>("rooms::short"),
|
||||
spaces: args.depend::<rooms::spaces::Service>("rooms::spaces"),
|
||||
state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"),
|
||||
state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
|
||||
state_compressor: args.depend::<rooms::state_compressor::Service>("rooms::state_compressor"),
|
||||
state_accessor: args
|
||||
.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
|
||||
state_compressor: args
|
||||
.depend::<rooms::state_compressor::Service>("rooms::state_compressor"),
|
||||
timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"),
|
||||
},
|
||||
db: Data {
|
||||
|
@ -100,7 +104,8 @@ impl Service {
|
|||
shortstatehash: u64,
|
||||
statediffnew: Arc<HashSet<CompressedStateEvent>>,
|
||||
_statediffremoved: Arc<HashSet<CompressedStateEvent>>,
|
||||
state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex
|
||||
state_lock: &RoomMutexGuard, /* Take mutex guard to make sure users get the room state
|
||||
* mutex */
|
||||
) -> Result {
|
||||
let event_ids = statediffnew
|
||||
.iter()
|
||||
|
@ -120,8 +125,9 @@ impl Service {
|
|||
};
|
||||
|
||||
match pdu.kind {
|
||||
TimelineEventType::RoomMember => {
|
||||
let Some(user_id) = pdu.state_key.as_ref().map(UserId::parse).flat_ok() else {
|
||||
| TimelineEventType::RoomMember => {
|
||||
let Some(user_id) = pdu.state_key.as_ref().map(UserId::parse).flat_ok()
|
||||
else {
|
||||
continue;
|
||||
};
|
||||
|
||||
|
@ -131,10 +137,18 @@ impl Service {
|
|||
|
||||
self.services
|
||||
.state_cache
|
||||
.update_membership(room_id, &user_id, membership_event, &pdu.sender, None, None, false)
|
||||
.update_membership(
|
||||
room_id,
|
||||
&user_id,
|
||||
membership_event,
|
||||
&pdu.sender,
|
||||
None,
|
||||
None,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
},
|
||||
TimelineEventType::SpaceChild => {
|
||||
| TimelineEventType::SpaceChild => {
|
||||
self.services
|
||||
.spaces
|
||||
.roomid_spacehierarchy_cache
|
||||
|
@ -142,7 +156,7 @@ impl Service {
|
|||
.await
|
||||
.remove(&pdu.room_id);
|
||||
},
|
||||
_ => continue,
|
||||
| _ => continue,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -159,7 +173,10 @@ impl Service {
|
|||
/// to `stateid_pduid` and adds the incoming event to `eventid_statehash`.
|
||||
#[tracing::instrument(skip(self, state_ids_compressed), level = "debug")]
|
||||
pub async fn set_event_state(
|
||||
&self, event_id: &EventId, room_id: &RoomId, state_ids_compressed: Arc<HashSet<CompressedStateEvent>>,
|
||||
&self,
|
||||
event_id: &EventId,
|
||||
room_id: &RoomId,
|
||||
state_ids_compressed: Arc<HashSet<CompressedStateEvent>>,
|
||||
) -> Result<ShortStateHash> {
|
||||
const KEY_LEN: usize = size_of::<ShortEventId>();
|
||||
const VAL_LEN: usize = size_of::<ShortStateHash>();
|
||||
|
@ -190,22 +207,23 @@ impl Service {
|
|||
Vec::new()
|
||||
};
|
||||
|
||||
let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() {
|
||||
let statediffnew: HashSet<_> = state_ids_compressed
|
||||
.difference(&parent_stateinfo.full_state)
|
||||
.copied()
|
||||
.collect();
|
||||
let (statediffnew, statediffremoved) =
|
||||
if let Some(parent_stateinfo) = states_parents.last() {
|
||||
let statediffnew: HashSet<_> = state_ids_compressed
|
||||
.difference(&parent_stateinfo.full_state)
|
||||
.copied()
|
||||
.collect();
|
||||
|
||||
let statediffremoved: HashSet<_> = parent_stateinfo
|
||||
.full_state
|
||||
.difference(&state_ids_compressed)
|
||||
.copied()
|
||||
.collect();
|
||||
let statediffremoved: HashSet<_> = parent_stateinfo
|
||||
.full_state
|
||||
.difference(&state_ids_compressed)
|
||||
.copied()
|
||||
.collect();
|
||||
|
||||
(Arc::new(statediffnew), Arc::new(statediffremoved))
|
||||
} else {
|
||||
(state_ids_compressed, Arc::new(HashSet::new()))
|
||||
};
|
||||
(Arc::new(statediffnew), Arc::new(statediffremoved))
|
||||
} else {
|
||||
(state_ids_compressed, Arc::new(HashSet::new()))
|
||||
};
|
||||
self.services.state_compressor.save_state_from_diff(
|
||||
shortstatehash,
|
||||
statediffnew,
|
||||
|
@ -338,7 +356,8 @@ impl Service {
|
|||
&self,
|
||||
room_id: &RoomId,
|
||||
shortstatehash: u64,
|
||||
_mutex_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex
|
||||
_mutex_lock: &RoomMutexGuard, /* Take mutex guard to make sure users get the room
|
||||
* state mutex */
|
||||
) {
|
||||
const BUFSIZE: usize = size_of::<u64>();
|
||||
|
||||
|
@ -366,7 +385,10 @@ impl Service {
|
|||
.deserialized()
|
||||
}
|
||||
|
||||
pub fn get_forward_extremities<'a>(&'a self, room_id: &'a RoomId) -> impl Stream<Item = &EventId> + Send + '_ {
|
||||
pub fn get_forward_extremities<'a>(
|
||||
&'a self,
|
||||
room_id: &'a RoomId,
|
||||
) -> impl Stream<Item = &EventId> + Send + '_ {
|
||||
let prefix = (room_id, Interfix);
|
||||
|
||||
self.db
|
||||
|
@ -380,7 +402,8 @@ impl Service {
|
|||
&self,
|
||||
room_id: &RoomId,
|
||||
event_ids: Vec<OwnedEventId>,
|
||||
_state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex
|
||||
_state_lock: &RoomMutexGuard, /* Take mutex guard to make sure users get the room
|
||||
* state mutex */
|
||||
) {
|
||||
let prefix = (room_id, Interfix);
|
||||
self.db
|
||||
|
@ -399,26 +422,33 @@ impl Service {
|
|||
/// This fetches auth events from the current state.
|
||||
#[tracing::instrument(skip(self, content), level = "debug")]
|
||||
pub async fn get_auth_events(
|
||||
&self, room_id: &RoomId, kind: &TimelineEventType, sender: &UserId, state_key: Option<&str>,
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
kind: &TimelineEventType,
|
||||
sender: &UserId,
|
||||
state_key: Option<&str>,
|
||||
content: &serde_json::value::RawValue,
|
||||
) -> Result<StateMap<Arc<PduEvent>>> {
|
||||
let Ok(shortstatehash) = self.get_room_shortstatehash(room_id).await else {
|
||||
return Ok(HashMap::new());
|
||||
};
|
||||
|
||||
let mut sauthevents: HashMap<_, _> = state_res::auth_types_for_event(kind, sender, state_key, content)?
|
||||
.iter()
|
||||
.stream()
|
||||
.broad_filter_map(|(event_type, state_key)| {
|
||||
self.services
|
||||
.short
|
||||
.get_shortstatekey(event_type, state_key)
|
||||
.map_ok(move |ssk| (ssk, (event_type, state_key)))
|
||||
.map(Result::ok)
|
||||
})
|
||||
.map(|(ssk, (event_type, state_key))| (ssk, (event_type.to_owned(), state_key.to_owned())))
|
||||
.collect()
|
||||
.await;
|
||||
let mut sauthevents: HashMap<_, _> =
|
||||
state_res::auth_types_for_event(kind, sender, state_key, content)?
|
||||
.iter()
|
||||
.stream()
|
||||
.broad_filter_map(|(event_type, state_key)| {
|
||||
self.services
|
||||
.short
|
||||
.get_shortstatekey(event_type, state_key)
|
||||
.map_ok(move |ssk| (ssk, (event_type, state_key)))
|
||||
.map(Result::ok)
|
||||
})
|
||||
.map(|(ssk, (event_type, state_key))| {
|
||||
(ssk, (event_type.to_owned(), state_key.to_owned()))
|
||||
})
|
||||
.collect()
|
||||
.await;
|
||||
|
||||
let (state_keys, event_ids): (Vec<_>, Vec<_>) = self
|
||||
.services
|
||||
|
|
|
@ -39,14 +39,16 @@ impl Data {
|
|||
services: Services {
|
||||
short: args.depend::<rooms::short::Service>("rooms::short"),
|
||||
state: args.depend::<rooms::state::Service>("rooms::state"),
|
||||
state_compressor: args.depend::<rooms::state_compressor::Service>("rooms::state_compressor"),
|
||||
state_compressor: args
|
||||
.depend::<rooms::state_compressor::Service>("rooms::state_compressor"),
|
||||
timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn state_full(
|
||||
&self, shortstatehash: ShortStateHash,
|
||||
&self,
|
||||
shortstatehash: ShortStateHash,
|
||||
) -> Result<HashMap<(StateEventType, String), PduEvent>> {
|
||||
let state = self
|
||||
.state_full_pdus(shortstatehash)
|
||||
|
@ -58,7 +60,10 @@ impl Data {
|
|||
Ok(state)
|
||||
}
|
||||
|
||||
pub(super) async fn state_full_pdus(&self, shortstatehash: ShortStateHash) -> Result<Vec<PduEvent>> {
|
||||
pub(super) async fn state_full_pdus(
|
||||
&self,
|
||||
shortstatehash: ShortStateHash,
|
||||
) -> Result<Vec<PduEvent>> {
|
||||
let short_ids = self.state_full_shortids(shortstatehash).await?;
|
||||
|
||||
let full_pdus = self
|
||||
|
@ -66,16 +71,19 @@ impl Data {
|
|||
.short
|
||||
.multi_get_eventid_from_short(short_ids.iter().map(ref_at!(1)))
|
||||
.ready_filter_map(Result::ok)
|
||||
.broad_filter_map(
|
||||
|event_id: OwnedEventId| async move { self.services.timeline.get_pdu(&event_id).await.ok() },
|
||||
)
|
||||
.broad_filter_map(|event_id: OwnedEventId| async move {
|
||||
self.services.timeline.get_pdu(&event_id).await.ok()
|
||||
})
|
||||
.collect()
|
||||
.await;
|
||||
|
||||
Ok(full_pdus)
|
||||
}
|
||||
|
||||
pub(super) async fn state_full_ids<Id>(&self, shortstatehash: ShortStateHash) -> Result<HashMap<ShortStateKey, Id>>
|
||||
pub(super) async fn state_full_ids<Id>(
|
||||
&self,
|
||||
shortstatehash: ShortStateHash,
|
||||
) -> Result<HashMap<ShortStateKey, Id>>
|
||||
where
|
||||
Id: for<'de> Deserialize<'de> + Send + Sized + ToOwned,
|
||||
<Id as ToOwned>::Owned: Borrow<EventId>,
|
||||
|
@ -96,7 +104,8 @@ impl Data {
|
|||
}
|
||||
|
||||
pub(super) async fn state_full_shortids(
|
||||
&self, shortstatehash: ShortStateHash,
|
||||
&self,
|
||||
shortstatehash: ShortStateHash,
|
||||
) -> Result<Vec<(ShortStateKey, ShortEventId)>> {
|
||||
let shortids = self
|
||||
.services
|
||||
|
@ -118,7 +127,10 @@ impl Data {
|
|||
/// Returns a single EventId from `room_id` with key
|
||||
/// (`event_type`,`state_key`).
|
||||
pub(super) async fn state_get_id<Id>(
|
||||
&self, shortstatehash: ShortStateHash, event_type: &StateEventType, state_key: &str,
|
||||
&self,
|
||||
shortstatehash: ShortStateHash,
|
||||
event_type: &StateEventType,
|
||||
state_key: &str,
|
||||
) -> Result<Id>
|
||||
where
|
||||
Id: for<'de> Deserialize<'de> + Sized + ToOwned,
|
||||
|
@ -155,10 +167,15 @@ impl Data {
|
|||
|
||||
/// Returns a single PDU from `room_id` with key (`event_type`,`state_key`).
|
||||
pub(super) async fn state_get(
|
||||
&self, shortstatehash: ShortStateHash, event_type: &StateEventType, state_key: &str,
|
||||
&self,
|
||||
shortstatehash: ShortStateHash,
|
||||
event_type: &StateEventType,
|
||||
state_key: &str,
|
||||
) -> Result<PduEvent> {
|
||||
self.state_get_id(shortstatehash, event_type, state_key)
|
||||
.and_then(|event_id: OwnedEventId| async move { self.services.timeline.get_pdu(&event_id).await })
|
||||
.and_then(|event_id: OwnedEventId| async move {
|
||||
self.services.timeline.get_pdu(&event_id).await
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
|
@ -179,7 +196,8 @@ impl Data {
|
|||
|
||||
/// Returns the full room state.
|
||||
pub(super) async fn room_state_full(
|
||||
&self, room_id: &RoomId,
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
) -> Result<HashMap<(StateEventType, String), PduEvent>> {
|
||||
self.services
|
||||
.state
|
||||
|
@ -203,7 +221,10 @@ impl Data {
|
|||
/// Returns a single EventId from `room_id` with key
|
||||
/// (`event_type`,`state_key`).
|
||||
pub(super) async fn room_state_get_id<Id>(
|
||||
&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
event_type: &StateEventType,
|
||||
state_key: &str,
|
||||
) -> Result<Id>
|
||||
where
|
||||
Id: for<'de> Deserialize<'de> + Sized + ToOwned,
|
||||
|
@ -218,7 +239,10 @@ impl Data {
|
|||
|
||||
/// Returns a single PDU from `room_id` with key (`event_type`,`state_key`).
|
||||
pub(super) async fn room_state_get(
|
||||
&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
event_type: &StateEventType,
|
||||
state_key: &str,
|
||||
) -> Result<PduEvent> {
|
||||
self.services
|
||||
.state
|
||||
|
|
|
@ -34,8 +34,8 @@ use ruma::{
|
|||
},
|
||||
room::RoomType,
|
||||
space::SpaceRoomJoinRule,
|
||||
EventEncryptionAlgorithm, EventId, JsOption, OwnedRoomAliasId, OwnedRoomId, OwnedServerName, OwnedUserId, RoomId,
|
||||
ServerName, UserId,
|
||||
EventEncryptionAlgorithm, EventId, JsOption, OwnedRoomAliasId, OwnedRoomId, OwnedServerName,
|
||||
OwnedUserId, RoomId, ServerName, UserId,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
|
||||
|
@ -75,8 +75,12 @@ impl crate::Service for Service {
|
|||
timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"),
|
||||
},
|
||||
db: Data::new(&args),
|
||||
server_visibility_cache: StdMutex::new(LruCache::new(usize_from_f64(server_visibility_cache_capacity)?)),
|
||||
user_visibility_cache: StdMutex::new(LruCache::new(usize_from_f64(user_visibility_cache_capacity)?)),
|
||||
server_visibility_cache: StdMutex::new(LruCache::new(usize_from_f64(
|
||||
server_visibility_cache_capacity,
|
||||
)?)),
|
||||
user_visibility_cache: StdMutex::new(LruCache::new(usize_from_f64(
|
||||
user_visibility_cache_capacity,
|
||||
)?)),
|
||||
}))
|
||||
}
|
||||
|
||||
|
@ -102,7 +106,10 @@ impl Service {
|
|||
/// Builds a StateMap by iterating over all keys that start
|
||||
/// with state_hash, this gives the full state for the given state_hash.
|
||||
#[tracing::instrument(skip(self), level = "debug")]
|
||||
pub async fn state_full_ids<Id>(&self, shortstatehash: ShortStateHash) -> Result<HashMap<ShortStateKey, Id>>
|
||||
pub async fn state_full_ids<Id>(
|
||||
&self,
|
||||
shortstatehash: ShortStateHash,
|
||||
) -> Result<HashMap<ShortStateKey, Id>>
|
||||
where
|
||||
Id: for<'de> Deserialize<'de> + Send + Sized + ToOwned,
|
||||
<Id as ToOwned>::Owned: Borrow<EventId>,
|
||||
|
@ -112,13 +119,15 @@ impl Service {
|
|||
|
||||
#[inline]
|
||||
pub async fn state_full_shortids(
|
||||
&self, shortstatehash: ShortStateHash,
|
||||
&self,
|
||||
shortstatehash: ShortStateHash,
|
||||
) -> Result<Vec<(ShortStateKey, ShortEventId)>> {
|
||||
self.db.state_full_shortids(shortstatehash).await
|
||||
}
|
||||
|
||||
pub async fn state_full(
|
||||
&self, shortstatehash: ShortStateHash,
|
||||
&self,
|
||||
shortstatehash: ShortStateHash,
|
||||
) -> Result<HashMap<(StateEventType, String), PduEvent>> {
|
||||
self.db.state_full(shortstatehash).await
|
||||
}
|
||||
|
@ -127,7 +136,10 @@ impl Service {
|
|||
/// `state_key`).
|
||||
#[tracing::instrument(skip(self), level = "debug")]
|
||||
pub async fn state_get_id<Id>(
|
||||
&self, shortstatehash: ShortStateHash, event_type: &StateEventType, state_key: &str,
|
||||
&self,
|
||||
shortstatehash: ShortStateHash,
|
||||
event_type: &StateEventType,
|
||||
state_key: &str,
|
||||
) -> Result<Id>
|
||||
where
|
||||
Id: for<'de> Deserialize<'de> + Sized + ToOwned,
|
||||
|
@ -142,7 +154,10 @@ impl Service {
|
|||
/// `state_key`).
|
||||
#[inline]
|
||||
pub async fn state_get(
|
||||
&self, shortstatehash: ShortStateHash, event_type: &StateEventType, state_key: &str,
|
||||
&self,
|
||||
shortstatehash: ShortStateHash,
|
||||
event_type: &StateEventType,
|
||||
state_key: &str,
|
||||
) -> Result<PduEvent> {
|
||||
self.db
|
||||
.state_get(shortstatehash, event_type, state_key)
|
||||
|
@ -151,7 +166,10 @@ impl Service {
|
|||
|
||||
/// Returns a single PDU from `room_id` with key (`event_type`,`state_key`).
|
||||
pub async fn state_get_content<T>(
|
||||
&self, shortstatehash: ShortStateHash, event_type: &StateEventType, state_key: &str,
|
||||
&self,
|
||||
shortstatehash: ShortStateHash,
|
||||
event_type: &StateEventType,
|
||||
state_key: &str,
|
||||
) -> Result<T>
|
||||
where
|
||||
T: for<'de> Deserialize<'de>,
|
||||
|
@ -162,7 +180,11 @@ impl Service {
|
|||
}
|
||||
|
||||
/// Get membership for given user in state
|
||||
async fn user_membership(&self, shortstatehash: ShortStateHash, user_id: &UserId) -> MembershipState {
|
||||
async fn user_membership(
|
||||
&self,
|
||||
shortstatehash: ShortStateHash,
|
||||
user_id: &UserId,
|
||||
) -> MembershipState {
|
||||
self.state_get_content(shortstatehash, &StateEventType::RoomMember, user_id.as_str())
|
||||
.await
|
||||
.map_or(MembershipState::Leave, |c: RoomMemberEventContent| c.membership)
|
||||
|
@ -185,7 +207,12 @@ impl Service {
|
|||
/// Whether a server is allowed to see an event through federation, based on
|
||||
/// the room's history_visibility at that event's state.
|
||||
#[tracing::instrument(skip_all, level = "trace")]
|
||||
pub async fn server_can_see_event(&self, origin: &ServerName, room_id: &RoomId, event_id: &EventId) -> bool {
|
||||
pub async fn server_can_see_event(
|
||||
&self,
|
||||
origin: &ServerName,
|
||||
room_id: &RoomId,
|
||||
event_id: &EventId,
|
||||
) -> bool {
|
||||
let Ok(shortstatehash) = self.pdu_shortstatehash(event_id).await else {
|
||||
return true;
|
||||
};
|
||||
|
@ -213,20 +240,20 @@ impl Service {
|
|||
.ready_filter(|member| member.server_name() == origin);
|
||||
|
||||
let visibility = match history_visibility {
|
||||
HistoryVisibility::WorldReadable | HistoryVisibility::Shared => true,
|
||||
HistoryVisibility::Invited => {
|
||||
| HistoryVisibility::WorldReadable | HistoryVisibility::Shared => true,
|
||||
| HistoryVisibility::Invited => {
|
||||
// Allow if any member on requesting server was AT LEAST invited, else deny
|
||||
current_server_members
|
||||
.any(|member| self.user_was_invited(shortstatehash, member))
|
||||
.await
|
||||
},
|
||||
HistoryVisibility::Joined => {
|
||||
| HistoryVisibility::Joined => {
|
||||
// Allow if any member on requested server was joined, else deny
|
||||
current_server_members
|
||||
.any(|member| self.user_was_joined(shortstatehash, member))
|
||||
.await
|
||||
},
|
||||
_ => {
|
||||
| _ => {
|
||||
error!("Unknown history visibility {history_visibility}");
|
||||
false
|
||||
},
|
||||
|
@ -243,7 +270,12 @@ impl Service {
|
|||
/// Whether a user is allowed to see an event, based on
|
||||
/// the room's history_visibility at that event's state.
|
||||
#[tracing::instrument(skip_all, level = "trace")]
|
||||
pub async fn user_can_see_event(&self, user_id: &UserId, room_id: &RoomId, event_id: &EventId) -> bool {
|
||||
pub async fn user_can_see_event(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
room_id: &RoomId,
|
||||
event_id: &EventId,
|
||||
) -> bool {
|
||||
let Ok(shortstatehash) = self.pdu_shortstatehash(event_id).await else {
|
||||
return true;
|
||||
};
|
||||
|
@ -267,17 +299,17 @@ impl Service {
|
|||
});
|
||||
|
||||
let visibility = match history_visibility {
|
||||
HistoryVisibility::WorldReadable => true,
|
||||
HistoryVisibility::Shared => currently_member,
|
||||
HistoryVisibility::Invited => {
|
||||
| HistoryVisibility::WorldReadable => true,
|
||||
| HistoryVisibility::Shared => currently_member,
|
||||
| HistoryVisibility::Invited => {
|
||||
// Allow if any member on requesting server was AT LEAST invited, else deny
|
||||
self.user_was_invited(shortstatehash, user_id).await
|
||||
},
|
||||
HistoryVisibility::Joined => {
|
||||
| HistoryVisibility::Joined => {
|
||||
// Allow if any member on requested server was joined, else deny
|
||||
self.user_was_joined(shortstatehash, user_id).await
|
||||
},
|
||||
_ => {
|
||||
| _ => {
|
||||
error!("Unknown history visibility {history_visibility}");
|
||||
false
|
||||
},
|
||||
|
@ -307,9 +339,10 @@ impl Service {
|
|||
});
|
||||
|
||||
match history_visibility {
|
||||
HistoryVisibility::Invited => self.services.state_cache.is_invited(user_id, room_id).await,
|
||||
HistoryVisibility::WorldReadable => true,
|
||||
_ => false,
|
||||
| HistoryVisibility::Invited =>
|
||||
self.services.state_cache.is_invited(user_id, room_id).await,
|
||||
| HistoryVisibility::WorldReadable => true,
|
||||
| _ => false,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -320,7 +353,10 @@ impl Service {
|
|||
|
||||
/// Returns the full room state.
|
||||
#[tracing::instrument(skip(self), level = "debug")]
|
||||
pub async fn room_state_full(&self, room_id: &RoomId) -> Result<HashMap<(StateEventType, String), PduEvent>> {
|
||||
pub async fn room_state_full(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
) -> Result<HashMap<(StateEventType, String), PduEvent>> {
|
||||
self.db.room_state_full(room_id).await
|
||||
}
|
||||
|
||||
|
@ -334,7 +370,10 @@ impl Service {
|
|||
/// `state_key`).
|
||||
#[tracing::instrument(skip(self), level = "debug")]
|
||||
pub async fn room_state_get_id<Id>(
|
||||
&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
event_type: &StateEventType,
|
||||
state_key: &str,
|
||||
) -> Result<Id>
|
||||
where
|
||||
Id: for<'de> Deserialize<'de> + Sized + ToOwned,
|
||||
|
@ -349,14 +388,20 @@ impl Service {
|
|||
/// `state_key`).
|
||||
#[tracing::instrument(skip(self), level = "debug")]
|
||||
pub async fn room_state_get(
|
||||
&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
event_type: &StateEventType,
|
||||
state_key: &str,
|
||||
) -> Result<PduEvent> {
|
||||
self.db.room_state_get(room_id, event_type, state_key).await
|
||||
}
|
||||
|
||||
/// Returns a single PDU from `room_id` with key (`event_type`,`state_key`).
|
||||
pub async fn room_state_get_content<T>(
|
||||
&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
event_type: &StateEventType,
|
||||
state_key: &str,
|
||||
) -> Result<T>
|
||||
where
|
||||
T: for<'de> Deserialize<'de>,
|
||||
|
@ -381,18 +426,29 @@ impl Service {
|
|||
JsOption::from_option(content)
|
||||
}
|
||||
|
||||
pub async fn get_member(&self, room_id: &RoomId, user_id: &UserId) -> Result<RoomMemberEventContent> {
|
||||
pub async fn get_member(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
user_id: &UserId,
|
||||
) -> Result<RoomMemberEventContent> {
|
||||
self.room_state_get_content(room_id, &StateEventType::RoomMember, user_id.as_str())
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn user_can_invite(
|
||||
&self, room_id: &RoomId, sender: &UserId, target_user: &UserId, state_lock: &RoomMutexGuard,
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
sender: &UserId,
|
||||
target_user: &UserId,
|
||||
state_lock: &RoomMutexGuard,
|
||||
) -> bool {
|
||||
self.services
|
||||
.timeline
|
||||
.create_hash_and_sign_event(
|
||||
PduBuilder::state(target_user.into(), &RoomMemberEventContent::new(MembershipState::Invite)),
|
||||
PduBuilder::state(
|
||||
target_user.into(),
|
||||
&RoomMemberEventContent::new(MembershipState::Invite),
|
||||
),
|
||||
sender,
|
||||
room_id,
|
||||
state_lock,
|
||||
|
@ -405,7 +461,9 @@ impl Service {
|
|||
pub async fn is_world_readable(&self, room_id: &RoomId) -> bool {
|
||||
self.room_state_get_content(room_id, &StateEventType::RoomHistoryVisibility, "")
|
||||
.await
|
||||
.map(|c: RoomHistoryVisibilityEventContent| c.history_visibility == HistoryVisibility::WorldReadable)
|
||||
.map(|c: RoomHistoryVisibilityEventContent| {
|
||||
c.history_visibility == HistoryVisibility::WorldReadable
|
||||
})
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
|
@ -439,7 +497,11 @@ impl Service {
|
|||
/// If federation is true, it allows redaction events from any user of the
|
||||
/// same server as the original event sender
|
||||
pub async fn user_can_redact(
|
||||
&self, redacts: &EventId, sender: &UserId, room_id: &RoomId, federation: bool,
|
||||
&self,
|
||||
redacts: &EventId,
|
||||
sender: &UserId,
|
||||
room_id: &RoomId,
|
||||
federation: bool,
|
||||
) -> Result<bool> {
|
||||
let redacting_event = self.services.timeline.get_pdu(redacts).await;
|
||||
|
||||
|
@ -451,7 +513,11 @@ impl Service {
|
|||
}
|
||||
|
||||
if let Ok(pl_event_content) = self
|
||||
.room_state_get_content::<RoomPowerLevelsEventContent>(room_id, &StateEventType::RoomPowerLevels, "")
|
||||
.room_state_get_content::<RoomPowerLevelsEventContent>(
|
||||
room_id,
|
||||
&StateEventType::RoomPowerLevels,
|
||||
"",
|
||||
)
|
||||
.await
|
||||
{
|
||||
let pl_event: RoomPowerLevels = pl_event_content.into();
|
||||
|
@ -485,10 +551,15 @@ impl Service {
|
|||
}
|
||||
|
||||
/// Returns the join rule (`SpaceRoomJoinRule`) for a given room
|
||||
pub async fn get_join_rule(&self, room_id: &RoomId) -> Result<(SpaceRoomJoinRule, Vec<OwnedRoomId>)> {
|
||||
pub async fn get_join_rule(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
) -> Result<(SpaceRoomJoinRule, Vec<OwnedRoomId>)> {
|
||||
self.room_state_get_content(room_id, &StateEventType::RoomJoinRules, "")
|
||||
.await
|
||||
.map(|c: RoomJoinRulesEventContent| (c.join_rule.clone().into(), self.allowed_room_ids(c.join_rule)))
|
||||
.map(|c: RoomJoinRulesEventContent| {
|
||||
(c.join_rule.clone().into(), self.allowed_room_ids(c.join_rule))
|
||||
})
|
||||
.or_else(|_| Ok((SpaceRoomJoinRule::Invite, vec![])))
|
||||
}
|
||||
|
||||
|
@ -497,10 +568,7 @@ impl Service {
|
|||
let mut room_ids = Vec::with_capacity(1);
|
||||
if let JoinRule::Restricted(r) | JoinRule::KnockRestricted(r) = join_rule {
|
||||
for rule in r.allow {
|
||||
if let AllowRule::RoomMembership(RoomMembership {
|
||||
room_id: membership,
|
||||
}) = rule
|
||||
{
|
||||
if let AllowRule::RoomMembership(RoomMembership { room_id: membership }) = rule {
|
||||
room_ids.push(membership.clone());
|
||||
}
|
||||
}
|
||||
|
@ -520,7 +588,10 @@ impl Service {
|
|||
|
||||
/// Gets the room's encryption algorithm if `m.room.encryption` state event
|
||||
/// is found
|
||||
pub async fn get_room_encryption(&self, room_id: &RoomId) -> Result<EventEncryptionAlgorithm> {
|
||||
pub async fn get_room_encryption(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
) -> Result<EventEncryptionAlgorithm> {
|
||||
self.room_state_get_content(room_id, &StateEventType::RoomEncryption, "")
|
||||
.await
|
||||
.map(|content: RoomEncryptionEventContent| content.algorithm)
|
||||
|
|
|
@ -20,7 +20,8 @@ use ruma::{
|
|||
member::{MembershipState, RoomMemberEventContent},
|
||||
power_levels::RoomPowerLevelsEventContent,
|
||||
},
|
||||
AnyStrippedStateEvent, AnySyncStateEvent, GlobalAccountDataEventType, RoomAccountDataEventType, StateEventType,
|
||||
AnyStrippedStateEvent, AnySyncStateEvent, GlobalAccountDataEventType,
|
||||
RoomAccountDataEventType, StateEventType,
|
||||
},
|
||||
int,
|
||||
serde::Raw,
|
||||
|
@ -68,7 +69,8 @@ impl crate::Service for Service {
|
|||
services: Services {
|
||||
account_data: args.depend::<account_data::Service>("account_data"),
|
||||
globals: args.depend::<globals::Service>("globals"),
|
||||
state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
|
||||
state_accessor: args
|
||||
.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
|
||||
users: args.depend::<users::Service>("users"),
|
||||
},
|
||||
db: Data {
|
||||
|
@ -96,8 +98,13 @@ impl Service {
|
|||
#[tracing::instrument(skip(self, last_state))]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn update_membership(
|
||||
&self, room_id: &RoomId, user_id: &UserId, membership_event: RoomMemberEventContent, sender: &UserId,
|
||||
last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>, invite_via: Option<Vec<OwnedServerName>>,
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
user_id: &UserId,
|
||||
membership_event: RoomMemberEventContent,
|
||||
sender: &UserId,
|
||||
last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
|
||||
invite_via: Option<Vec<OwnedServerName>>,
|
||||
update_joined_count: bool,
|
||||
) -> Result<()> {
|
||||
let membership = membership_event.membership;
|
||||
|
@ -138,7 +145,7 @@ impl Service {
|
|||
}
|
||||
|
||||
match &membership {
|
||||
MembershipState::Join => {
|
||||
| MembershipState::Join => {
|
||||
// Check if the user never joined this room
|
||||
if !self.once_joined(user_id, room_id).await {
|
||||
// Add the user ID to the join list then
|
||||
|
@ -181,12 +188,21 @@ impl Service {
|
|||
if let Ok(tag_event) = self
|
||||
.services
|
||||
.account_data
|
||||
.get_room(&predecessor.room_id, user_id, RoomAccountDataEventType::Tag)
|
||||
.get_room(
|
||||
&predecessor.room_id,
|
||||
user_id,
|
||||
RoomAccountDataEventType::Tag,
|
||||
)
|
||||
.await
|
||||
{
|
||||
self.services
|
||||
.account_data
|
||||
.update(Some(room_id), user_id, RoomAccountDataEventType::Tag, &tag_event)
|
||||
.update(
|
||||
Some(room_id),
|
||||
user_id,
|
||||
RoomAccountDataEventType::Tag,
|
||||
&tag_event,
|
||||
)
|
||||
.await
|
||||
.ok();
|
||||
};
|
||||
|
@ -195,7 +211,10 @@ impl Service {
|
|||
if let Ok(mut direct_event) = self
|
||||
.services
|
||||
.account_data
|
||||
.get_global::<DirectEvent>(user_id, GlobalAccountDataEventType::Direct)
|
||||
.get_global::<DirectEvent>(
|
||||
user_id,
|
||||
GlobalAccountDataEventType::Direct,
|
||||
)
|
||||
.await
|
||||
{
|
||||
let mut room_ids_updated = false;
|
||||
|
@ -213,7 +232,8 @@ impl Service {
|
|||
None,
|
||||
user_id,
|
||||
GlobalAccountDataEventType::Direct.to_string().into(),
|
||||
&serde_json::to_value(&direct_event).expect("to json always works"),
|
||||
&serde_json::to_value(&direct_event)
|
||||
.expect("to json always works"),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
@ -223,7 +243,7 @@ impl Service {
|
|||
|
||||
self.mark_as_joined(user_id, room_id);
|
||||
},
|
||||
MembershipState::Invite => {
|
||||
| MembershipState::Invite => {
|
||||
// We want to know if the sender is ignored by the receiver
|
||||
if self.services.users.user_is_ignored(sender, user_id).await {
|
||||
return Ok(());
|
||||
|
@ -232,10 +252,10 @@ impl Service {
|
|||
self.mark_as_invited(user_id, room_id, last_state, invite_via)
|
||||
.await;
|
||||
},
|
||||
MembershipState::Leave | MembershipState::Ban => {
|
||||
| MembershipState::Leave | MembershipState::Ban => {
|
||||
self.mark_as_left(user_id, room_id);
|
||||
},
|
||||
_ => {},
|
||||
| _ => {},
|
||||
}
|
||||
|
||||
if update_joined_count {
|
||||
|
@ -246,7 +266,11 @@ impl Service {
|
|||
}
|
||||
|
||||
#[tracing::instrument(skip(self, room_id, appservice), level = "debug")]
|
||||
pub async fn appservice_in_room(&self, room_id: &RoomId, appservice: &RegistrationInfo) -> bool {
|
||||
pub async fn appservice_in_room(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
appservice: &RegistrationInfo,
|
||||
) -> bool {
|
||||
if let Some(cached) = self
|
||||
.appservice_in_room_cache
|
||||
.read()
|
||||
|
@ -347,7 +371,10 @@ impl Service {
|
|||
|
||||
/// Returns an iterator of all servers participating in this room.
|
||||
#[tracing::instrument(skip(self), level = "debug")]
|
||||
pub fn room_servers<'a>(&'a self, room_id: &'a RoomId) -> impl Stream<Item = &ServerName> + Send + 'a {
|
||||
pub fn room_servers<'a>(
|
||||
&'a self,
|
||||
room_id: &'a RoomId,
|
||||
) -> impl Stream<Item = &ServerName> + Send + 'a {
|
||||
let prefix = (room_id, Interfix);
|
||||
self.db
|
||||
.roomserverids
|
||||
|
@ -357,7 +384,11 @@ impl Service {
|
|||
}
|
||||
|
||||
#[tracing::instrument(skip(self), level = "debug")]
|
||||
pub async fn server_in_room<'a>(&'a self, server: &'a ServerName, room_id: &'a RoomId) -> bool {
|
||||
pub async fn server_in_room<'a>(
|
||||
&'a self,
|
||||
server: &'a ServerName,
|
||||
room_id: &'a RoomId,
|
||||
) -> bool {
|
||||
let key = (server, room_id);
|
||||
self.db.serverroomids.qry(&key).await.is_ok()
|
||||
}
|
||||
|
@ -365,7 +396,10 @@ impl Service {
|
|||
/// Returns an iterator of all rooms a server participates in (as far as we
|
||||
/// know).
|
||||
#[tracing::instrument(skip(self), level = "debug")]
|
||||
pub fn server_rooms<'a>(&'a self, server: &'a ServerName) -> impl Stream<Item = &RoomId> + Send + 'a {
|
||||
pub fn server_rooms<'a>(
|
||||
&'a self,
|
||||
server: &'a ServerName,
|
||||
) -> impl Stream<Item = &RoomId> + Send + 'a {
|
||||
let prefix = (server, Interfix);
|
||||
self.db
|
||||
.serverroomids
|
||||
|
@ -393,7 +427,9 @@ impl Service {
|
|||
|
||||
/// List the rooms common between two users
|
||||
pub fn get_shared_rooms<'a>(
|
||||
&'a self, user_a: &'a UserId, user_b: &'a UserId,
|
||||
&'a self,
|
||||
user_a: &'a UserId,
|
||||
user_b: &'a UserId,
|
||||
) -> impl Stream<Item = &RoomId> + Send + 'a {
|
||||
use conduwuit::utils::set;
|
||||
|
||||
|
@ -404,7 +440,10 @@ impl Service {
|
|||
|
||||
/// Returns an iterator of all joined members of a room.
|
||||
#[tracing::instrument(skip(self), level = "debug")]
|
||||
pub fn room_members<'a>(&'a self, room_id: &'a RoomId) -> impl Stream<Item = &UserId> + Send + 'a {
|
||||
pub fn room_members<'a>(
|
||||
&'a self,
|
||||
room_id: &'a RoomId,
|
||||
) -> impl Stream<Item = &UserId> + Send + 'a {
|
||||
let prefix = (room_id, Interfix);
|
||||
self.db
|
||||
.roomuserid_joined
|
||||
|
@ -422,7 +461,10 @@ impl Service {
|
|||
#[tracing::instrument(skip(self), level = "debug")]
|
||||
/// Returns an iterator of all our local users in the room, even if they're
|
||||
/// deactivated/guests
|
||||
pub fn local_users_in_room<'a>(&'a self, room_id: &'a RoomId) -> impl Stream<Item = &UserId> + Send + 'a {
|
||||
pub fn local_users_in_room<'a>(
|
||||
&'a self,
|
||||
room_id: &'a RoomId,
|
||||
) -> impl Stream<Item = &UserId> + Send + 'a {
|
||||
self.room_members(room_id)
|
||||
.ready_filter(|user| self.services.globals.user_is_local(user))
|
||||
}
|
||||
|
@ -430,7 +472,10 @@ impl Service {
|
|||
#[tracing::instrument(skip(self), level = "debug")]
|
||||
/// Returns an iterator of all our local joined users in a room who are
|
||||
/// active (not deactivated, not guest)
|
||||
pub fn active_local_users_in_room<'a>(&'a self, room_id: &'a RoomId) -> impl Stream<Item = &UserId> + Send + 'a {
|
||||
pub fn active_local_users_in_room<'a>(
|
||||
&'a self,
|
||||
room_id: &'a RoomId,
|
||||
) -> impl Stream<Item = &UserId> + Send + 'a {
|
||||
self.local_users_in_room(room_id)
|
||||
.filter(|user| self.services.users.is_active(user))
|
||||
}
|
||||
|
@ -447,7 +492,10 @@ impl Service {
|
|||
|
||||
/// Returns an iterator over all User IDs who ever joined a room.
|
||||
#[tracing::instrument(skip(self), level = "debug")]
|
||||
pub fn room_useroncejoined<'a>(&'a self, room_id: &'a RoomId) -> impl Stream<Item = &UserId> + Send + 'a {
|
||||
pub fn room_useroncejoined<'a>(
|
||||
&'a self,
|
||||
room_id: &'a RoomId,
|
||||
) -> impl Stream<Item = &UserId> + Send + 'a {
|
||||
let prefix = (room_id, Interfix);
|
||||
self.db
|
||||
.roomuseroncejoinedids
|
||||
|
@ -458,7 +506,10 @@ impl Service {
|
|||
|
||||
/// Returns an iterator over all invited members of a room.
|
||||
#[tracing::instrument(skip(self), level = "debug")]
|
||||
pub fn room_members_invited<'a>(&'a self, room_id: &'a RoomId) -> impl Stream<Item = &UserId> + Send + 'a {
|
||||
pub fn room_members_invited<'a>(
|
||||
&'a self,
|
||||
room_id: &'a RoomId,
|
||||
) -> impl Stream<Item = &UserId> + Send + 'a {
|
||||
let prefix = (room_id, Interfix);
|
||||
self.db
|
||||
.roomuserid_invitecount
|
||||
|
@ -485,7 +536,10 @@ impl Service {
|
|||
|
||||
/// Returns an iterator over all rooms this user joined.
|
||||
#[tracing::instrument(skip(self), level = "debug")]
|
||||
pub fn rooms_joined<'a>(&'a self, user_id: &'a UserId) -> impl Stream<Item = &RoomId> + Send + 'a {
|
||||
pub fn rooms_joined<'a>(
|
||||
&'a self,
|
||||
user_id: &'a UserId,
|
||||
) -> impl Stream<Item = &RoomId> + Send + 'a {
|
||||
self.db
|
||||
.userroomid_joined
|
||||
.keys_raw_prefix(user_id)
|
||||
|
@ -495,7 +549,10 @@ impl Service {
|
|||
|
||||
/// Returns an iterator over all rooms a user was invited to.
|
||||
#[tracing::instrument(skip(self), level = "debug")]
|
||||
pub fn rooms_invited<'a>(&'a self, user_id: &'a UserId) -> impl Stream<Item = StrippedStateEventItem> + Send + 'a {
|
||||
pub fn rooms_invited<'a>(
|
||||
&'a self,
|
||||
user_id: &'a UserId,
|
||||
) -> impl Stream<Item = StrippedStateEventItem> + Send + 'a {
|
||||
type KeyVal<'a> = (Key<'a>, Raw<Vec<AnyStrippedStateEvent>>);
|
||||
type Key<'a> = (&'a UserId, &'a RoomId);
|
||||
|
||||
|
@ -510,30 +567,45 @@ impl Service {
|
|||
}
|
||||
|
||||
#[tracing::instrument(skip(self), level = "debug")]
|
||||
pub async fn invite_state(&self, user_id: &UserId, room_id: &RoomId) -> Result<Vec<Raw<AnyStrippedStateEvent>>> {
|
||||
pub async fn invite_state(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
room_id: &RoomId,
|
||||
) -> Result<Vec<Raw<AnyStrippedStateEvent>>> {
|
||||
let key = (user_id, room_id);
|
||||
self.db
|
||||
.userroomid_invitestate
|
||||
.qry(&key)
|
||||
.await
|
||||
.deserialized()
|
||||
.and_then(|val: Raw<Vec<AnyStrippedStateEvent>>| val.deserialize_as().map_err(Into::into))
|
||||
.and_then(|val: Raw<Vec<AnyStrippedStateEvent>>| {
|
||||
val.deserialize_as().map_err(Into::into)
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self), level = "debug")]
|
||||
pub async fn left_state(&self, user_id: &UserId, room_id: &RoomId) -> Result<Vec<Raw<AnyStrippedStateEvent>>> {
|
||||
pub async fn left_state(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
room_id: &RoomId,
|
||||
) -> Result<Vec<Raw<AnyStrippedStateEvent>>> {
|
||||
let key = (user_id, room_id);
|
||||
self.db
|
||||
.userroomid_leftstate
|
||||
.qry(&key)
|
||||
.await
|
||||
.deserialized()
|
||||
.and_then(|val: Raw<Vec<AnyStrippedStateEvent>>| val.deserialize_as().map_err(Into::into))
|
||||
.and_then(|val: Raw<Vec<AnyStrippedStateEvent>>| {
|
||||
val.deserialize_as().map_err(Into::into)
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns an iterator over all rooms a user left.
|
||||
#[tracing::instrument(skip(self), level = "debug")]
|
||||
pub fn rooms_left<'a>(&'a self, user_id: &'a UserId) -> impl Stream<Item = SyncStateEventItem> + Send + 'a {
|
||||
pub fn rooms_left<'a>(
|
||||
&'a self,
|
||||
user_id: &'a UserId,
|
||||
) -> impl Stream<Item = SyncStateEventItem> + Send + 'a {
|
||||
type KeyVal<'a> = (Key<'a>, Raw<Vec<Raw<AnySyncStateEvent>>>);
|
||||
type Key<'a> = (&'a UserId, &'a RoomId);
|
||||
|
||||
|
@ -571,7 +643,11 @@ impl Service {
|
|||
self.db.userroomid_leftstate.qry(&key).await.is_ok()
|
||||
}
|
||||
|
||||
pub async fn user_membership(&self, user_id: &UserId, room_id: &RoomId) -> Option<MembershipState> {
|
||||
pub async fn user_membership(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
room_id: &RoomId,
|
||||
) -> Option<MembershipState> {
|
||||
let states = join4(
|
||||
self.is_joined(user_id, room_id),
|
||||
self.is_left(user_id, room_id),
|
||||
|
@ -581,16 +657,19 @@ impl Service {
|
|||
.await;
|
||||
|
||||
match states {
|
||||
(true, ..) => Some(MembershipState::Join),
|
||||
(_, true, ..) => Some(MembershipState::Leave),
|
||||
(_, _, true, ..) => Some(MembershipState::Invite),
|
||||
(false, false, false, true) => Some(MembershipState::Ban),
|
||||
_ => None,
|
||||
| (true, ..) => Some(MembershipState::Join),
|
||||
| (_, true, ..) => Some(MembershipState::Leave),
|
||||
| (_, _, true, ..) => Some(MembershipState::Invite),
|
||||
| (false, false, false, true) => Some(MembershipState::Ban),
|
||||
| _ => None,
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self), level = "debug")]
|
||||
pub fn servers_invite_via<'a>(&'a self, room_id: &'a RoomId) -> impl Stream<Item = &ServerName> + Send + 'a {
|
||||
pub fn servers_invite_via<'a>(
|
||||
&'a self,
|
||||
room_id: &'a RoomId,
|
||||
) -> impl Stream<Item = &ServerName> + Send + 'a {
|
||||
type KeyVal<'a> = (Ignore, Vec<&'a ServerName>);
|
||||
|
||||
self.db
|
||||
|
@ -711,7 +790,10 @@ impl Service {
|
|||
}
|
||||
|
||||
pub async fn mark_as_invited(
|
||||
&self, user_id: &UserId, room_id: &RoomId, last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
room_id: &RoomId,
|
||||
last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
|
||||
invite_via: Option<Vec<OwnedServerName>>,
|
||||
) {
|
||||
let roomuser_id = (room_id, user_id);
|
||||
|
|
|
@ -69,7 +69,8 @@ pub(crate) type CompressedStateEvent = [u8; 2 * size_of::<ShortId>()];
|
|||
impl crate::Service for Service {
|
||||
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
|
||||
let config = &args.server.config;
|
||||
let cache_capacity = f64::from(config.stateinfo_cache_capacity) * config.cache_capacity_modifier;
|
||||
let cache_capacity =
|
||||
f64::from(config.stateinfo_cache_capacity) * config.cache_capacity_modifier;
|
||||
Ok(Arc::new(Self {
|
||||
stateinfo_cache: LruCache::new(usize_from_f64(cache_capacity)?).into(),
|
||||
db: Data {
|
||||
|
@ -85,17 +86,16 @@ impl crate::Service for Service {
|
|||
fn memory_usage(&self, out: &mut dyn Write) -> Result {
|
||||
let (cache_len, ents) = {
|
||||
let cache = self.stateinfo_cache.lock().expect("locked");
|
||||
let ents = cache
|
||||
.iter()
|
||||
.map(at!(1))
|
||||
.flat_map(|vec| vec.iter())
|
||||
.fold(HashMap::new(), |mut ents, ssi| {
|
||||
let ents = cache.iter().map(at!(1)).flat_map(|vec| vec.iter()).fold(
|
||||
HashMap::new(),
|
||||
|mut ents, ssi| {
|
||||
for cs in &[&ssi.added, &ssi.removed, &ssi.full_state] {
|
||||
ents.insert(Arc::as_ptr(cs), compressed_state_size(cs));
|
||||
}
|
||||
|
||||
ents
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
(cache.len(), ents)
|
||||
};
|
||||
|
@ -117,7 +117,10 @@ impl crate::Service for Service {
|
|||
impl Service {
|
||||
/// Returns a stack with info on shortstatehash, full state, added diff and
|
||||
/// removed diff for the selected shortstatehash and each parent layer.
|
||||
pub async fn load_shortstatehash_info(&self, shortstatehash: ShortStateHash) -> Result<ShortStateInfoVec> {
|
||||
pub async fn load_shortstatehash_info(
|
||||
&self,
|
||||
shortstatehash: ShortStateHash,
|
||||
) -> Result<ShortStateInfoVec> {
|
||||
if let Some(r) = self
|
||||
.stateinfo_cache
|
||||
.lock()
|
||||
|
@ -143,12 +146,11 @@ impl Service {
|
|||
Ok(stack)
|
||||
}
|
||||
|
||||
async fn new_shortstatehash_info(&self, shortstatehash: ShortStateHash) -> Result<ShortStateInfoVec> {
|
||||
let StateDiff {
|
||||
parent,
|
||||
added,
|
||||
removed,
|
||||
} = self.get_statediff(shortstatehash).await?;
|
||||
async fn new_shortstatehash_info(
|
||||
&self,
|
||||
shortstatehash: ShortStateHash,
|
||||
) -> Result<ShortStateInfoVec> {
|
||||
let StateDiff { parent, added, removed } = self.get_statediff(shortstatehash).await?;
|
||||
|
||||
let Some(parent) = parent else {
|
||||
return Ok(vec![ShortStateInfo {
|
||||
|
@ -180,9 +182,17 @@ impl Service {
|
|||
Ok(stack)
|
||||
}
|
||||
|
||||
pub fn compress_state_events<'a, I>(&'a self, state: I) -> impl Stream<Item = CompressedStateEvent> + Send + 'a
|
||||
pub fn compress_state_events<'a, I>(
|
||||
&'a self,
|
||||
state: I,
|
||||
) -> impl Stream<Item = CompressedStateEvent> + Send + 'a
|
||||
where
|
||||
I: Iterator<Item = (&'a ShortStateKey, &'a EventId)> + Clone + Debug + ExactSizeIterator + Send + 'a,
|
||||
I: Iterator<Item = (&'a ShortStateKey, &'a EventId)>
|
||||
+ Clone
|
||||
+ Debug
|
||||
+ ExactSizeIterator
|
||||
+ Send
|
||||
+ 'a,
|
||||
{
|
||||
let event_ids = state.clone().map(at!(1));
|
||||
|
||||
|
@ -195,10 +205,16 @@ impl Service {
|
|||
.stream()
|
||||
.map(at!(0))
|
||||
.zip(short_event_ids)
|
||||
.map(|(shortstatekey, shorteventid)| compress_state_event(*shortstatekey, shorteventid))
|
||||
.map(|(shortstatekey, shorteventid)| {
|
||||
compress_state_event(*shortstatekey, shorteventid)
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn compress_state_event(&self, shortstatekey: ShortStateKey, event_id: &EventId) -> CompressedStateEvent {
|
||||
pub async fn compress_state_event(
|
||||
&self,
|
||||
shortstatekey: ShortStateKey,
|
||||
event_id: &EventId,
|
||||
) -> CompressedStateEvent {
|
||||
let shorteventid = self
|
||||
.services
|
||||
.short
|
||||
|
@ -227,8 +243,11 @@ impl Service {
|
|||
/// * `parent_states` - A stack with info on shortstatehash, full state,
|
||||
/// added diff and removed diff for each parent layer
|
||||
pub fn save_state_from_diff(
|
||||
&self, shortstatehash: ShortStateHash, statediffnew: Arc<HashSet<CompressedStateEvent>>,
|
||||
statediffremoved: Arc<HashSet<CompressedStateEvent>>, diff_to_sibling: usize,
|
||||
&self,
|
||||
shortstatehash: ShortStateHash,
|
||||
statediffnew: Arc<HashSet<CompressedStateEvent>>,
|
||||
statediffremoved: Arc<HashSet<CompressedStateEvent>>,
|
||||
diff_to_sibling: usize,
|
||||
mut parent_states: ParentStatesVec,
|
||||
) -> Result {
|
||||
let statediffnew_len = statediffnew.len();
|
||||
|
@ -274,14 +293,11 @@ impl Service {
|
|||
|
||||
if parent_states.is_empty() {
|
||||
// There is no parent layer, create a new state
|
||||
self.save_statediff(
|
||||
shortstatehash,
|
||||
&StateDiff {
|
||||
parent: None,
|
||||
added: statediffnew,
|
||||
removed: statediffremoved,
|
||||
},
|
||||
);
|
||||
self.save_statediff(shortstatehash, &StateDiff {
|
||||
parent: None,
|
||||
added: statediffnew,
|
||||
removed: statediffremoved,
|
||||
});
|
||||
|
||||
return Ok(());
|
||||
};
|
||||
|
@ -327,14 +343,11 @@ impl Service {
|
|||
)?;
|
||||
} else {
|
||||
// Diff small enough, we add diff as layer on top of parent
|
||||
self.save_statediff(
|
||||
shortstatehash,
|
||||
&StateDiff {
|
||||
parent: Some(parent.shortstatehash),
|
||||
added: statediffnew,
|
||||
removed: statediffremoved,
|
||||
},
|
||||
);
|
||||
self.save_statediff(shortstatehash, &StateDiff {
|
||||
parent: Some(parent.shortstatehash),
|
||||
added: statediffnew,
|
||||
removed: statediffremoved,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
@ -344,7 +357,9 @@ impl Service {
|
|||
/// room state
|
||||
#[tracing::instrument(skip(self, new_state_ids_compressed), level = "debug")]
|
||||
pub async fn save_state(
|
||||
&self, room_id: &RoomId, new_state_ids_compressed: Arc<HashSet<CompressedStateEvent>>,
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
new_state_ids_compressed: Arc<HashSet<CompressedStateEvent>>,
|
||||
) -> Result<HashSetCompressStateEvent> {
|
||||
let previous_shortstatehash = self
|
||||
.services
|
||||
|
@ -353,7 +368,8 @@ impl Service {
|
|||
.await
|
||||
.ok();
|
||||
|
||||
let state_hash = utils::calculate_hash(new_state_ids_compressed.iter().map(|bytes| &bytes[..]));
|
||||
let state_hash =
|
||||
utils::calculate_hash(new_state_ids_compressed.iter().map(|bytes| &bytes[..]));
|
||||
|
||||
let (new_shortstatehash, already_existed) = self
|
||||
.services
|
||||
|
@ -374,22 +390,23 @@ impl Service {
|
|||
ShortStateInfoVec::new()
|
||||
};
|
||||
|
||||
let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() {
|
||||
let statediffnew: HashSet<_> = new_state_ids_compressed
|
||||
.difference(&parent_stateinfo.full_state)
|
||||
.copied()
|
||||
.collect();
|
||||
let (statediffnew, statediffremoved) =
|
||||
if let Some(parent_stateinfo) = states_parents.last() {
|
||||
let statediffnew: HashSet<_> = new_state_ids_compressed
|
||||
.difference(&parent_stateinfo.full_state)
|
||||
.copied()
|
||||
.collect();
|
||||
|
||||
let statediffremoved: HashSet<_> = parent_stateinfo
|
||||
.full_state
|
||||
.difference(&new_state_ids_compressed)
|
||||
.copied()
|
||||
.collect();
|
||||
let statediffremoved: HashSet<_> = parent_stateinfo
|
||||
.full_state
|
||||
.difference(&new_state_ids_compressed)
|
||||
.copied()
|
||||
.collect();
|
||||
|
||||
(Arc::new(statediffnew), Arc::new(statediffremoved))
|
||||
} else {
|
||||
(new_state_ids_compressed, Arc::new(HashSet::new()))
|
||||
};
|
||||
(Arc::new(statediffnew), Arc::new(statediffremoved))
|
||||
} else {
|
||||
(new_state_ids_compressed, Arc::new(HashSet::new()))
|
||||
};
|
||||
|
||||
if !already_existed {
|
||||
self.save_state_from_diff(
|
||||
|
@ -418,7 +435,9 @@ impl Service {
|
|||
.shortstatehash_statediff
|
||||
.aqry::<BUFSIZE, _>(&shortstatehash)
|
||||
.await
|
||||
.map_err(|e| err!(Database("Failed to find StateDiff from short {shortstatehash:?}: {e}")))?;
|
||||
.map_err(|e| {
|
||||
err!(Database("Failed to find StateDiff from short {shortstatehash:?}: {e}"))
|
||||
})?;
|
||||
|
||||
let parent = utils::u64_from_bytes(&value[0..size_of::<u64>()])
|
||||
.ok()
|
||||
|
@ -484,7 +503,10 @@ impl Service {
|
|||
|
||||
#[inline]
|
||||
#[must_use]
|
||||
fn compress_state_event(shortstatekey: ShortStateKey, shorteventid: ShortEventId) -> CompressedStateEvent {
|
||||
fn compress_state_event(
|
||||
shortstatekey: ShortStateKey,
|
||||
shorteventid: ShortEventId,
|
||||
) -> CompressedStateEvent {
|
||||
const SIZE: usize = size_of::<CompressedStateEvent>();
|
||||
|
||||
let mut v = ArrayVec::<u8, SIZE>::new();
|
||||
|
@ -497,7 +519,9 @@ fn compress_state_event(shortstatekey: ShortStateKey, shorteventid: ShortEventId
|
|||
|
||||
#[inline]
|
||||
#[must_use]
|
||||
pub fn parse_compressed_state_event(compressed_event: CompressedStateEvent) -> (ShortStateKey, ShortEventId) {
|
||||
pub fn parse_compressed_state_event(
|
||||
compressed_event: CompressedStateEvent,
|
||||
) -> (ShortStateKey, ShortEventId) {
|
||||
use utils::u64_from_u8;
|
||||
|
||||
let shortstatekey = u64_from_u8(&compressed_event[0..size_of::<ShortStateKey>()]);
|
||||
|
|
|
@ -11,8 +11,8 @@ use conduwuit::{
|
|||
use database::{Deserialized, Map};
|
||||
use futures::{Stream, StreamExt};
|
||||
use ruma::{
|
||||
api::client::threads::get_threads::v1::IncludeThreads, events::relation::BundledThread, uint, CanonicalJsonValue,
|
||||
EventId, OwnedUserId, RoomId, UserId,
|
||||
api::client::threads::get_threads::v1::IncludeThreads, events::relation::BundledThread, uint,
|
||||
CanonicalJsonValue, EventId, OwnedUserId, RoomId, UserId,
|
||||
};
|
||||
use serde_json::json;
|
||||
|
||||
|
@ -55,7 +55,9 @@ impl Service {
|
|||
.timeline
|
||||
.get_pdu_id(root_event_id)
|
||||
.await
|
||||
.map_err(|e| err!(Request(InvalidParam("Invalid event_id in thread message: {e:?}"))))?;
|
||||
.map_err(|e| {
|
||||
err!(Request(InvalidParam("Invalid event_id in thread message: {e:?}")))
|
||||
})?;
|
||||
|
||||
let root_pdu = self
|
||||
.services
|
||||
|
@ -79,8 +81,9 @@ impl Service {
|
|||
.get("m.relations")
|
||||
.and_then(|r| r.as_object())
|
||||
.and_then(|r| r.get("m.thread"))
|
||||
.and_then(|relations| serde_json::from_value::<BundledThread>(relations.clone().into()).ok())
|
||||
{
|
||||
.and_then(|relations| {
|
||||
serde_json::from_value::<BundledThread>(relations.clone().into()).ok()
|
||||
}) {
|
||||
// Thread already existed
|
||||
relations.count = relations.count.saturating_add(uint!(1));
|
||||
relations.latest_event = pdu.to_message_like_event();
|
||||
|
@ -129,7 +132,11 @@ impl Service {
|
|||
}
|
||||
|
||||
pub async fn threads_until<'a>(
|
||||
&'a self, user_id: &'a UserId, room_id: &'a RoomId, shorteventid: PduCount, _inc: &'a IncludeThreads,
|
||||
&'a self,
|
||||
user_id: &'a UserId,
|
||||
room_id: &'a RoomId,
|
||||
shorteventid: PduCount,
|
||||
_inc: &'a IncludeThreads,
|
||||
) -> Result<impl Stream<Item = (PduCount, PduEvent)> + Send + 'a> {
|
||||
let shortroomid: ShortRoomId = self.services.short.get_shortroomid(room_id).await?;
|
||||
|
||||
|
@ -160,7 +167,11 @@ impl Service {
|
|||
Ok(stream)
|
||||
}
|
||||
|
||||
pub(super) fn update_participants(&self, root_id: &RawPduId, participants: &[OwnedUserId]) -> Result {
|
||||
pub(super) fn update_participants(
|
||||
&self,
|
||||
root_id: &RawPduId,
|
||||
participants: &[OwnedUserId],
|
||||
) -> Result {
|
||||
let users = participants
|
||||
.iter()
|
||||
.map(|user| user.as_bytes())
|
||||
|
|
|
@ -13,7 +13,9 @@ use conduwuit::{
|
|||
};
|
||||
use database::{Database, Deserialized, Json, KeyVal, Map};
|
||||
use futures::{future::select_ok, FutureExt, Stream, StreamExt};
|
||||
use ruma::{api::Direction, CanonicalJsonObject, EventId, OwnedRoomId, OwnedUserId, RoomId, UserId};
|
||||
use ruma::{
|
||||
api::Direction, CanonicalJsonObject, EventId, OwnedRoomId, OwnedUserId, RoomId, UserId,
|
||||
};
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use super::{PduId, RawPduId};
|
||||
|
@ -54,15 +56,19 @@ impl Data {
|
|||
}
|
||||
}
|
||||
|
||||
pub(super) async fn last_timeline_count(&self, sender_user: Option<&UserId>, room_id: &RoomId) -> Result<PduCount> {
|
||||
pub(super) async fn last_timeline_count(
|
||||
&self,
|
||||
sender_user: Option<&UserId>,
|
||||
room_id: &RoomId,
|
||||
) -> Result<PduCount> {
|
||||
match self
|
||||
.lasttimelinecount_cache
|
||||
.lock()
|
||||
.await
|
||||
.entry(room_id.into())
|
||||
{
|
||||
hash_map::Entry::Occupied(o) => Ok(*o.get()),
|
||||
hash_map::Entry::Vacant(v) => Ok(self
|
||||
| hash_map::Entry::Occupied(o) => Ok(*o.get()),
|
||||
| hash_map::Entry::Vacant(v) => Ok(self
|
||||
.pdus_rev(sender_user, room_id, PduCount::max())
|
||||
.await?
|
||||
.next()
|
||||
|
@ -93,7 +99,10 @@ impl Data {
|
|||
}
|
||||
|
||||
/// Returns the json of a pdu.
|
||||
pub(super) async fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result<CanonicalJsonObject> {
|
||||
pub(super) async fn get_non_outlier_pdu_json(
|
||||
&self,
|
||||
event_id: &EventId,
|
||||
) -> Result<CanonicalJsonObject> {
|
||||
let pduid = self.get_pdu_id(event_id).await?;
|
||||
|
||||
self.pduid_pdu.get(&pduid).await.deserialized()
|
||||
|
@ -160,12 +169,19 @@ impl Data {
|
|||
}
|
||||
|
||||
/// Returns the pdu as a `BTreeMap<String, CanonicalJsonValue>`.
|
||||
pub(super) async fn get_pdu_json_from_id(&self, pdu_id: &RawPduId) -> Result<CanonicalJsonObject> {
|
||||
pub(super) async fn get_pdu_json_from_id(
|
||||
&self,
|
||||
pdu_id: &RawPduId,
|
||||
) -> Result<CanonicalJsonObject> {
|
||||
self.pduid_pdu.get(pdu_id).await.deserialized()
|
||||
}
|
||||
|
||||
pub(super) async fn append_pdu(
|
||||
&self, pdu_id: &RawPduId, pdu: &PduEvent, json: &CanonicalJsonObject, count: PduCount,
|
||||
&self,
|
||||
pdu_id: &RawPduId,
|
||||
pdu: &PduEvent,
|
||||
json: &CanonicalJsonObject,
|
||||
count: PduCount,
|
||||
) {
|
||||
debug_assert!(matches!(count, PduCount::Normal(_)), "PduCount not Normal");
|
||||
|
||||
|
@ -179,7 +195,12 @@ impl Data {
|
|||
self.eventid_outlierpdu.remove(pdu.event_id.as_bytes());
|
||||
}
|
||||
|
||||
pub(super) fn prepend_backfill_pdu(&self, pdu_id: &RawPduId, event_id: &EventId, json: &CanonicalJsonObject) {
|
||||
pub(super) fn prepend_backfill_pdu(
|
||||
&self,
|
||||
pdu_id: &RawPduId,
|
||||
event_id: &EventId,
|
||||
json: &CanonicalJsonObject,
|
||||
) {
|
||||
self.pduid_pdu.raw_put(pdu_id, Json(json));
|
||||
self.eventid_pduid.insert(event_id, pdu_id);
|
||||
self.eventid_outlierpdu.remove(event_id);
|
||||
|
@ -187,7 +208,10 @@ impl Data {
|
|||
|
||||
/// Removes a pdu and creates a new one with the same id.
|
||||
pub(super) async fn replace_pdu(
|
||||
&self, pdu_id: &RawPduId, pdu_json: &CanonicalJsonObject, _pdu: &PduEvent,
|
||||
&self,
|
||||
pdu_id: &RawPduId,
|
||||
pdu_json: &CanonicalJsonObject,
|
||||
_pdu: &PduEvent,
|
||||
) -> Result {
|
||||
if self.pduid_pdu.get(pdu_id).await.is_not_found() {
|
||||
return Err!(Request(NotFound("PDU does not exist.")));
|
||||
|
@ -202,7 +226,10 @@ impl Data {
|
|||
/// happened before the event with id `until` in reverse-chronological
|
||||
/// order.
|
||||
pub(super) async fn pdus_rev<'a>(
|
||||
&'a self, user_id: Option<&'a UserId>, room_id: &'a RoomId, until: PduCount,
|
||||
&'a self,
|
||||
user_id: Option<&'a UserId>,
|
||||
room_id: &'a RoomId,
|
||||
until: PduCount,
|
||||
) -> Result<impl Stream<Item = PdusIterItem> + Send + 'a> {
|
||||
let current = self
|
||||
.count_to_id(room_id, until, Direction::Backward)
|
||||
|
@ -219,7 +246,10 @@ impl Data {
|
|||
}
|
||||
|
||||
pub(super) async fn pdus<'a>(
|
||||
&'a self, user_id: Option<&'a UserId>, room_id: &'a RoomId, from: PduCount,
|
||||
&'a self,
|
||||
user_id: Option<&'a UserId>,
|
||||
room_id: &'a RoomId,
|
||||
from: PduCount,
|
||||
) -> Result<impl Stream<Item = PdusIterItem> + Send + Unpin + 'a> {
|
||||
let current = self.count_to_id(room_id, from, Direction::Forward).await?;
|
||||
let prefix = current.shortroomid();
|
||||
|
@ -236,8 +266,8 @@ impl Data {
|
|||
fn each_pdu((pdu_id, pdu): KeyVal<'_>, user_id: Option<&UserId>) -> PdusIterItem {
|
||||
let pdu_id: RawPduId = pdu_id.into();
|
||||
|
||||
let mut pdu =
|
||||
serde_json::from_slice::<PduEvent>(pdu).expect("PduEvent in pduid_pdu database column is invalid JSON");
|
||||
let mut pdu = serde_json::from_slice::<PduEvent>(pdu)
|
||||
.expect("PduEvent in pduid_pdu database column is invalid JSON");
|
||||
|
||||
if Some(pdu.sender.borrow()) != user_id {
|
||||
pdu.remove_transaction_id().log_err().ok();
|
||||
|
@ -249,7 +279,10 @@ impl Data {
|
|||
}
|
||||
|
||||
pub(super) fn increment_notification_counts(
|
||||
&self, room_id: &RoomId, notifies: Vec<OwnedUserId>, highlights: Vec<OwnedUserId>,
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
notifies: Vec<OwnedUserId>,
|
||||
highlights: Vec<OwnedUserId>,
|
||||
) {
|
||||
let _cork = self.db.cork();
|
||||
|
||||
|
@ -268,7 +301,12 @@ impl Data {
|
|||
}
|
||||
}
|
||||
|
||||
async fn count_to_id(&self, room_id: &RoomId, shorteventid: PduCount, dir: Direction) -> Result<RawPduId> {
|
||||
async fn count_to_id(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
shorteventid: PduCount,
|
||||
dir: Direction,
|
||||
) -> Result<RawPduId> {
|
||||
let shortroomid: ShortRoomId = self
|
||||
.services
|
||||
.short
|
||||
|
|
|
@ -15,7 +15,9 @@ use conduwuit::{
|
|||
validated, warn, Err, Error, Result, Server,
|
||||
};
|
||||
pub use conduwuit::{PduId, RawPduId};
|
||||
use futures::{future, future::ready, Future, FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt};
|
||||
use futures::{
|
||||
future, future::ready, Future, FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt,
|
||||
};
|
||||
use ruma::{
|
||||
api::federation,
|
||||
canonical_json::to_canonical_value,
|
||||
|
@ -32,8 +34,8 @@ use ruma::{
|
|||
},
|
||||
push::{Action, Ruleset, Tweak},
|
||||
state_res::{self, Event, RoomVersion},
|
||||
uint, user_id, CanonicalJsonObject, CanonicalJsonValue, EventId, OwnedEventId, OwnedRoomId, OwnedServerName,
|
||||
RoomId, RoomVersionId, ServerName, UserId,
|
||||
uint, user_id, CanonicalJsonObject, CanonicalJsonValue, EventId, OwnedEventId, OwnedRoomId,
|
||||
OwnedServerName, RoomId, RoomVersionId, ServerName, UserId,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use serde_json::value::{to_raw_value, RawValue as RawJsonValue};
|
||||
|
@ -116,7 +118,8 @@ impl crate::Service for Service {
|
|||
short: args.depend::<rooms::short::Service>("rooms::short"),
|
||||
state: args.depend::<rooms::state::Service>("rooms::state"),
|
||||
state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"),
|
||||
state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
|
||||
state_accessor: args
|
||||
.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
|
||||
pdu_metadata: args.depend::<rooms::pdu_metadata::Service>("rooms::pdu_metadata"),
|
||||
read_receipt: args.depend::<rooms::read_receipt::Service>("rooms::read_receipt"),
|
||||
sending: args.depend::<sending::Service>("sending"),
|
||||
|
@ -127,7 +130,8 @@ impl crate::Service for Service {
|
|||
threads: args.depend::<rooms::threads::Service>("rooms::threads"),
|
||||
search: args.depend::<rooms::search::Service>("rooms::search"),
|
||||
spaces: args.depend::<rooms::spaces::Service>("rooms::spaces"),
|
||||
event_handler: args.depend::<rooms::event_handler::Service>("rooms::event_handler"),
|
||||
event_handler: args
|
||||
.depend::<rooms::event_handler::Service>("rooms::event_handler"),
|
||||
},
|
||||
db: Data::new(&args),
|
||||
mutex_insert: RoomMutexMap::new(),
|
||||
|
@ -185,12 +189,18 @@ impl Service {
|
|||
}
|
||||
|
||||
#[tracing::instrument(skip(self), level = "debug")]
|
||||
pub async fn last_timeline_count(&self, sender_user: Option<&UserId>, room_id: &RoomId) -> Result<PduCount> {
|
||||
pub async fn last_timeline_count(
|
||||
&self,
|
||||
sender_user: Option<&UserId>,
|
||||
room_id: &RoomId,
|
||||
) -> Result<PduCount> {
|
||||
self.db.last_timeline_count(sender_user, room_id).await
|
||||
}
|
||||
|
||||
/// Returns the `count` of this pdu's id.
|
||||
pub async fn get_pdu_count(&self, event_id: &EventId) -> Result<PduCount> { self.db.get_pdu_count(event_id).await }
|
||||
pub async fn get_pdu_count(&self, event_id: &EventId) -> Result<PduCount> {
|
||||
self.db.get_pdu_count(event_id).await
|
||||
}
|
||||
|
||||
// TODO Is this the same as the function above?
|
||||
/*
|
||||
|
@ -222,13 +232,18 @@ impl Service {
|
|||
|
||||
/// Returns the json of a pdu.
|
||||
#[inline]
|
||||
pub async fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result<CanonicalJsonObject> {
|
||||
pub async fn get_non_outlier_pdu_json(
|
||||
&self,
|
||||
event_id: &EventId,
|
||||
) -> Result<CanonicalJsonObject> {
|
||||
self.db.get_non_outlier_pdu_json(event_id).await
|
||||
}
|
||||
|
||||
/// Returns the pdu's id.
|
||||
#[inline]
|
||||
pub async fn get_pdu_id(&self, event_id: &EventId) -> Result<RawPduId> { self.db.get_pdu_id(event_id).await }
|
||||
pub async fn get_pdu_id(&self, event_id: &EventId) -> Result<RawPduId> {
|
||||
self.db.get_pdu_id(event_id).await
|
||||
}
|
||||
|
||||
/// Returns the pdu.
|
||||
///
|
||||
|
@ -241,19 +256,26 @@ impl Service {
|
|||
/// Returns the pdu.
|
||||
///
|
||||
/// Checks the `eventid_outlierpdu` Tree if not found in the timeline.
|
||||
pub async fn get_pdu(&self, event_id: &EventId) -> Result<PduEvent> { self.db.get_pdu(event_id).await }
|
||||
pub async fn get_pdu(&self, event_id: &EventId) -> Result<PduEvent> {
|
||||
self.db.get_pdu(event_id).await
|
||||
}
|
||||
|
||||
/// Checks if pdu exists
|
||||
///
|
||||
/// Checks the `eventid_outlierpdu` Tree if not found in the timeline.
|
||||
pub fn pdu_exists<'a>(&'a self, event_id: &'a EventId) -> impl Future<Output = bool> + Send + 'a {
|
||||
pub fn pdu_exists<'a>(
|
||||
&'a self,
|
||||
event_id: &'a EventId,
|
||||
) -> impl Future<Output = bool> + Send + 'a {
|
||||
self.db.pdu_exists(event_id)
|
||||
}
|
||||
|
||||
/// Returns the pdu.
|
||||
///
|
||||
/// This does __NOT__ check the outliers `Tree`.
|
||||
pub async fn get_pdu_from_id(&self, pdu_id: &RawPduId) -> Result<PduEvent> { self.db.get_pdu_from_id(pdu_id).await }
|
||||
pub async fn get_pdu_from_id(&self, pdu_id: &RawPduId) -> Result<PduEvent> {
|
||||
self.db.get_pdu_from_id(pdu_id).await
|
||||
}
|
||||
|
||||
/// Returns the pdu as a `BTreeMap<String, CanonicalJsonValue>`.
|
||||
pub async fn get_pdu_json_from_id(&self, pdu_id: &RawPduId) -> Result<CanonicalJsonObject> {
|
||||
|
@ -262,7 +284,12 @@ impl Service {
|
|||
|
||||
/// Removes a pdu and creates a new one with the same id.
|
||||
#[tracing::instrument(skip(self), level = "debug")]
|
||||
pub async fn replace_pdu(&self, pdu_id: &RawPduId, pdu_json: &CanonicalJsonObject, pdu: &PduEvent) -> Result<()> {
|
||||
pub async fn replace_pdu(
|
||||
&self,
|
||||
pdu_id: &RawPduId,
|
||||
pdu_json: &CanonicalJsonObject,
|
||||
pdu: &PduEvent,
|
||||
) -> Result<()> {
|
||||
self.db.replace_pdu(pdu_id, pdu_json, pdu).await
|
||||
}
|
||||
|
||||
|
@ -278,7 +305,8 @@ impl Service {
|
|||
pdu: &PduEvent,
|
||||
mut pdu_json: CanonicalJsonObject,
|
||||
leaves: Vec<OwnedEventId>,
|
||||
state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex
|
||||
state_lock: &RoomMutexGuard, /* Take mutex guard to make sure users get the room state
|
||||
* mutex */
|
||||
) -> Result<RawPduId> {
|
||||
// Coalesce database writes for the remainder of this scope.
|
||||
let _cork = self.db.db.cork_and_flush();
|
||||
|
@ -313,10 +341,16 @@ impl Service {
|
|||
unsigned.insert(
|
||||
"prev_content".to_owned(),
|
||||
CanonicalJsonValue::Object(
|
||||
utils::to_canonical_object(prev_state.content.clone()).map_err(|e| {
|
||||
error!("Failed to convert prev_state to canonical JSON: {e}");
|
||||
Error::bad_database("Failed to convert prev_state to canonical JSON.")
|
||||
})?,
|
||||
utils::to_canonical_object(prev_state.content.clone()).map_err(
|
||||
|e| {
|
||||
error!(
|
||||
"Failed to convert prev_state to canonical JSON: {e}"
|
||||
);
|
||||
Error::bad_database(
|
||||
"Failed to convert prev_state to canonical JSON.",
|
||||
)
|
||||
},
|
||||
)?,
|
||||
),
|
||||
);
|
||||
unsigned.insert(
|
||||
|
@ -357,11 +391,7 @@ impl Service {
|
|||
.reset_notification_counts(&pdu.sender, &pdu.room_id);
|
||||
|
||||
let count2 = PduCount::Normal(self.services.globals.next_count().unwrap());
|
||||
let pdu_id: RawPduId = PduId {
|
||||
shortroomid,
|
||||
shorteventid: count2,
|
||||
}
|
||||
.into();
|
||||
let pdu_id: RawPduId = PduId { shortroomid, shorteventid: count2 }.into();
|
||||
|
||||
// Insert pdu
|
||||
self.db.append_pdu(&pdu_id, pdu, &pdu_json, count2).await;
|
||||
|
@ -408,7 +438,10 @@ impl Service {
|
|||
.account_data
|
||||
.get_global(user, GlobalAccountDataEventType::PushRules)
|
||||
.await
|
||||
.map_or_else(|_| Ruleset::server_default(user), |ev: PushRulesEvent| ev.content.global);
|
||||
.map_or_else(
|
||||
|_| Ruleset::server_default(user),
|
||||
|ev: PushRulesEvent| ev.content.global,
|
||||
);
|
||||
|
||||
let mut highlight = false;
|
||||
let mut notify = false;
|
||||
|
@ -420,11 +453,11 @@ impl Service {
|
|||
.await
|
||||
{
|
||||
match action {
|
||||
Action::Notify => notify = true,
|
||||
Action::SetTweak(Tweak::Highlight(true)) => {
|
||||
| Action::Notify => notify = true,
|
||||
| Action::SetTweak(Tweak::Highlight(true)) => {
|
||||
highlight = true;
|
||||
},
|
||||
_ => {},
|
||||
| _ => {},
|
||||
};
|
||||
|
||||
// Break early if both conditions are true
|
||||
|
@ -457,12 +490,12 @@ impl Service {
|
|||
.increment_notification_counts(&pdu.room_id, notifies, highlights);
|
||||
|
||||
match pdu.kind {
|
||||
TimelineEventType::RoomRedaction => {
|
||||
| TimelineEventType::RoomRedaction => {
|
||||
use RoomVersionId::*;
|
||||
|
||||
let room_version_id = self.services.state.get_room_version(&pdu.room_id).await?;
|
||||
match room_version_id {
|
||||
V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => {
|
||||
| V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => {
|
||||
if let Some(redact_id) = &pdu.redacts {
|
||||
if self
|
||||
.services
|
||||
|
@ -474,7 +507,7 @@ impl Service {
|
|||
}
|
||||
}
|
||||
},
|
||||
_ => {
|
||||
| _ => {
|
||||
let content: RoomRedactionEventContent = pdu.get_content()?;
|
||||
if let Some(redact_id) = &content.redacts {
|
||||
if self
|
||||
|
@ -489,7 +522,7 @@ impl Service {
|
|||
},
|
||||
};
|
||||
},
|
||||
TimelineEventType::SpaceChild => {
|
||||
| TimelineEventType::SpaceChild =>
|
||||
if let Some(_state_key) = &pdu.state_key {
|
||||
self.services
|
||||
.spaces
|
||||
|
@ -497,18 +530,18 @@ impl Service {
|
|||
.lock()
|
||||
.await
|
||||
.remove(&pdu.room_id);
|
||||
}
|
||||
},
|
||||
TimelineEventType::RoomMember => {
|
||||
},
|
||||
| TimelineEventType::RoomMember => {
|
||||
if let Some(state_key) = &pdu.state_key {
|
||||
// if the state_key fails
|
||||
let target_user_id =
|
||||
UserId::parse(state_key.clone()).expect("This state_key was previously validated");
|
||||
let target_user_id = UserId::parse(state_key.clone())
|
||||
.expect("This state_key was previously validated");
|
||||
|
||||
let content: RoomMemberEventContent = pdu.get_content()?;
|
||||
let invite_state = match content.membership {
|
||||
MembershipState::Invite => self.services.state.summary_stripped(pdu).await.into(),
|
||||
_ => None,
|
||||
| MembershipState::Invite =>
|
||||
self.services.state.summary_stripped(pdu).await.into(),
|
||||
| _ => None,
|
||||
};
|
||||
|
||||
// Update our membership info, we do this here incase a user is invited
|
||||
|
@ -527,7 +560,7 @@ impl Service {
|
|||
.await?;
|
||||
}
|
||||
},
|
||||
TimelineEventType::RoomMessage => {
|
||||
| TimelineEventType::RoomMessage => {
|
||||
let content: ExtractBody = pdu.get_content()?;
|
||||
if let Some(body) = content.body {
|
||||
self.services.search.index_pdu(shortroomid, &pdu_id, &body);
|
||||
|
@ -539,7 +572,7 @@ impl Service {
|
|||
}
|
||||
}
|
||||
},
|
||||
_ => {},
|
||||
| _ => {},
|
||||
}
|
||||
|
||||
if let Ok(content) = pdu.get_content::<ExtractRelatesToEventId>() {
|
||||
|
@ -552,24 +585,23 @@ impl Service {
|
|||
|
||||
if let Ok(content) = pdu.get_content::<ExtractRelatesTo>() {
|
||||
match content.relates_to {
|
||||
Relation::Reply {
|
||||
in_reply_to,
|
||||
} => {
|
||||
| Relation::Reply { in_reply_to } => {
|
||||
// We need to do it again here, because replies don't have
|
||||
// event_id as a top level field
|
||||
if let Ok(related_pducount) = self.get_pdu_count(&in_reply_to.event_id).await {
|
||||
if let Ok(related_pducount) = self.get_pdu_count(&in_reply_to.event_id).await
|
||||
{
|
||||
self.services
|
||||
.pdu_metadata
|
||||
.add_relation(count2, related_pducount);
|
||||
}
|
||||
},
|
||||
Relation::Thread(thread) => {
|
||||
| Relation::Thread(thread) => {
|
||||
self.services
|
||||
.threads
|
||||
.add_to_thread(&thread.event_id, pdu)
|
||||
.await?;
|
||||
},
|
||||
_ => {}, // TODO: Aggregate other types
|
||||
| _ => {}, // TODO: Aggregate other types
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -637,7 +669,8 @@ impl Service {
|
|||
pdu_builder: PduBuilder,
|
||||
sender: &UserId,
|
||||
room_id: &RoomId,
|
||||
_mutex_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex
|
||||
_mutex_lock: &RoomMutexGuard, /* Take mutex guard to make sure users get the room
|
||||
* state mutex */
|
||||
) -> Result<(PduEvent, CanonicalJsonObject)> {
|
||||
let PduBuilder {
|
||||
event_type,
|
||||
|
@ -707,7 +740,8 @@ impl Service {
|
|||
unsigned.insert("prev_content".to_owned(), prev_pdu.get_content_as_value());
|
||||
unsigned.insert(
|
||||
"prev_sender".to_owned(),
|
||||
serde_json::to_value(&prev_pdu.sender).expect("UserId::to_value always works"),
|
||||
serde_json::to_value(&prev_pdu.sender)
|
||||
.expect("UserId::to_value always works"),
|
||||
);
|
||||
unsigned.insert(
|
||||
"replaces_state".to_owned(),
|
||||
|
@ -744,9 +778,7 @@ impl Service {
|
|||
} else {
|
||||
Some(to_raw_value(&unsigned).expect("to_raw_value always works"))
|
||||
},
|
||||
hashes: EventHash {
|
||||
sha256: "aaa".to_owned(),
|
||||
},
|
||||
hashes: EventHash { sha256: "aaa".to_owned() },
|
||||
signatures: None,
|
||||
};
|
||||
|
||||
|
@ -769,13 +801,14 @@ impl Service {
|
|||
}
|
||||
|
||||
// Hash and sign
|
||||
let mut pdu_json = utils::to_canonical_object(&pdu)
|
||||
.map_err(|e| err!(Request(BadJson(warn!("Failed to convert PDU to canonical JSON: {e}")))))?;
|
||||
let mut pdu_json = utils::to_canonical_object(&pdu).map_err(|e| {
|
||||
err!(Request(BadJson(warn!("Failed to convert PDU to canonical JSON: {e}"))))
|
||||
})?;
|
||||
|
||||
// room v3 and above removed the "event_id" field from remote PDU format
|
||||
match room_version_id {
|
||||
RoomVersionId::V1 | RoomVersionId::V2 => {},
|
||||
_ => {
|
||||
| RoomVersionId::V1 | RoomVersionId::V2 => {},
|
||||
| _ => {
|
||||
pdu_json.remove("event_id");
|
||||
},
|
||||
};
|
||||
|
@ -783,7 +816,8 @@ impl Service {
|
|||
// Add origin because synapse likes that (and it's required in the spec)
|
||||
pdu_json.insert(
|
||||
"origin".to_owned(),
|
||||
to_canonical_value(self.services.globals.server_name()).expect("server name is a valid CanonicalJsonValue"),
|
||||
to_canonical_value(self.services.globals.server_name())
|
||||
.expect("server name is a valid CanonicalJsonValue"),
|
||||
);
|
||||
|
||||
if let Err(e) = self
|
||||
|
@ -792,17 +826,18 @@ impl Service {
|
|||
.hash_and_sign_event(&mut pdu_json, &room_version_id)
|
||||
{
|
||||
return match e {
|
||||
Error::Signatures(ruma::signatures::Error::PduSize) => {
|
||||
| Error::Signatures(ruma::signatures::Error::PduSize) => {
|
||||
Err!(Request(TooLarge("Message/PDU is too long (exceeds 65535 bytes)")))
|
||||
},
|
||||
_ => Err!(Request(Unknown(warn!("Signing event failed: {e}")))),
|
||||
| _ => Err!(Request(Unknown(warn!("Signing event failed: {e}")))),
|
||||
};
|
||||
}
|
||||
|
||||
// Generate event id
|
||||
pdu.event_id = EventId::parse_arc(format!(
|
||||
"${}",
|
||||
ruma::signatures::reference_hash(&pdu_json, &room_version_id).expect("ruma can calculate reference hashes")
|
||||
ruma::signatures::reference_hash(&pdu_json, &room_version_id)
|
||||
.expect("ruma can calculate reference hashes")
|
||||
))
|
||||
.expect("ruma's reference hashes are valid event ids");
|
||||
|
||||
|
@ -830,7 +865,8 @@ impl Service {
|
|||
pdu_builder: PduBuilder,
|
||||
sender: &UserId,
|
||||
room_id: &RoomId,
|
||||
state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex
|
||||
state_lock: &RoomMutexGuard, /* Take mutex guard to make sure users get the room state
|
||||
* mutex */
|
||||
) -> Result<Arc<EventId>> {
|
||||
let (pdu, pdu_json) = self
|
||||
.create_hash_and_sign_event(pdu_builder, sender, room_id, state_lock)
|
||||
|
@ -844,7 +880,7 @@ impl Service {
|
|||
if pdu.kind == TimelineEventType::RoomRedaction {
|
||||
use RoomVersionId::*;
|
||||
match self.services.state.get_room_version(&pdu.room_id).await? {
|
||||
V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => {
|
||||
| V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => {
|
||||
if let Some(redact_id) = &pdu.redacts {
|
||||
if !self
|
||||
.services
|
||||
|
@ -856,7 +892,7 @@ impl Service {
|
|||
}
|
||||
};
|
||||
},
|
||||
_ => {
|
||||
| _ => {
|
||||
let content: RoomRedactionEventContent = pdu.get_content()?;
|
||||
if let Some(redact_id) = &content.redacts {
|
||||
if !self
|
||||
|
@ -937,7 +973,8 @@ impl Service {
|
|||
new_room_leaves: Vec<OwnedEventId>,
|
||||
state_ids_compressed: Arc<HashSet<CompressedStateEvent>>,
|
||||
soft_fail: bool,
|
||||
state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex
|
||||
state_lock: &RoomMutexGuard, /* Take mutex guard to make sure users get the room state
|
||||
* mutex */
|
||||
) -> Result<Option<RawPduId>> {
|
||||
// We append to state before appending the pdu, so we don't have a moment in
|
||||
// time with the pdu without it's state. This is okay because append_pdu can't
|
||||
|
@ -971,7 +1008,9 @@ impl Service {
|
|||
/// items.
|
||||
#[inline]
|
||||
pub fn all_pdus<'a>(
|
||||
&'a self, user_id: &'a UserId, room_id: &'a RoomId,
|
||||
&'a self,
|
||||
user_id: &'a UserId,
|
||||
room_id: &'a RoomId,
|
||||
) -> impl Stream<Item = PdusIterItem> + Send + Unpin + 'a {
|
||||
self.pdus(Some(user_id), room_id, None)
|
||||
.map_ok(|stream| stream.map(Ok))
|
||||
|
@ -983,7 +1022,10 @@ impl Service {
|
|||
/// Reverse iteration starting at from.
|
||||
#[tracing::instrument(skip(self), level = "debug")]
|
||||
pub async fn pdus_rev<'a>(
|
||||
&'a self, user_id: Option<&'a UserId>, room_id: &'a RoomId, until: Option<PduCount>,
|
||||
&'a self,
|
||||
user_id: Option<&'a UserId>,
|
||||
room_id: &'a RoomId,
|
||||
until: Option<PduCount>,
|
||||
) -> Result<impl Stream<Item = PdusIterItem> + Send + 'a> {
|
||||
self.db
|
||||
.pdus_rev(user_id, room_id, until.unwrap_or_else(PduCount::max))
|
||||
|
@ -993,7 +1035,10 @@ impl Service {
|
|||
/// Forward iteration starting at from.
|
||||
#[tracing::instrument(skip(self), level = "debug")]
|
||||
pub async fn pdus<'a>(
|
||||
&'a self, user_id: Option<&'a UserId>, room_id: &'a RoomId, from: Option<PduCount>,
|
||||
&'a self,
|
||||
user_id: Option<&'a UserId>,
|
||||
room_id: &'a RoomId,
|
||||
from: Option<PduCount>,
|
||||
) -> Result<impl Stream<Item = PdusIterItem> + Send + 'a> {
|
||||
self.db
|
||||
.pdus(user_id, room_id, from.unwrap_or_else(PduCount::min))
|
||||
|
@ -1002,17 +1047,21 @@ impl Service {
|
|||
|
||||
/// Replace a PDU with the redacted form.
|
||||
#[tracing::instrument(skip(self, reason))]
|
||||
pub async fn redact_pdu(&self, event_id: &EventId, reason: &PduEvent, shortroomid: ShortRoomId) -> Result {
|
||||
pub async fn redact_pdu(
|
||||
&self,
|
||||
event_id: &EventId,
|
||||
reason: &PduEvent,
|
||||
shortroomid: ShortRoomId,
|
||||
) -> Result {
|
||||
// TODO: Don't reserialize, keep original json
|
||||
let Ok(pdu_id) = self.get_pdu_id(event_id).await else {
|
||||
// If event does not exist, just noop
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let mut pdu = self
|
||||
.get_pdu_from_id(&pdu_id)
|
||||
.await
|
||||
.map_err(|e| err!(Database(error!(?pdu_id, ?event_id, ?e, "PDU ID points to invalid PDU."))))?;
|
||||
let mut pdu = self.get_pdu_from_id(&pdu_id).await.map_err(|e| {
|
||||
err!(Database(error!(?pdu_id, ?event_id, ?e, "PDU ID points to invalid PDU.")))
|
||||
})?;
|
||||
|
||||
if let Ok(content) = pdu.get_content::<ExtractBody>() {
|
||||
if let Some(body) = content.body {
|
||||
|
@ -1026,8 +1075,9 @@ impl Service {
|
|||
|
||||
pdu.redact(&room_version_id, reason)?;
|
||||
|
||||
let obj = utils::to_canonical_object(&pdu)
|
||||
.map_err(|e| err!(Database(error!(?event_id, ?e, "Failed to convert PDU to canonical JSON"))))?;
|
||||
let obj = utils::to_canonical_object(&pdu).map_err(|e| {
|
||||
err!(Database(error!(?event_id, ?e, "Failed to convert PDU to canonical JSON")))
|
||||
})?;
|
||||
|
||||
self.replace_pdu(&pdu_id, &obj, &pdu).await
|
||||
}
|
||||
|
@ -1069,7 +1119,9 @@ impl Service {
|
|||
.unwrap_or_default();
|
||||
|
||||
let room_mods = power_levels.users.iter().filter_map(|(user_id, level)| {
|
||||
if level > &power_levels.users_default && !self.services.globals.user_is_local(user_id) {
|
||||
if level > &power_levels.users_default
|
||||
&& !self.services.globals.user_is_local(user_id)
|
||||
{
|
||||
Some(user_id.server_name())
|
||||
} else {
|
||||
None
|
||||
|
@ -1124,7 +1176,7 @@ impl Service {
|
|||
)
|
||||
.await;
|
||||
match response {
|
||||
Ok(response) => {
|
||||
| Ok(response) => {
|
||||
for pdu in response.pdus {
|
||||
if let Err(e) = self.backfill_pdu(backfill_server, pdu).boxed().await {
|
||||
debug_warn!("Failed to add backfilled pdu in room {room_id}: {e}");
|
||||
|
@ -1132,7 +1184,7 @@ impl Service {
|
|||
}
|
||||
return Ok(());
|
||||
},
|
||||
Err(e) => {
|
||||
| Err(e) => {
|
||||
warn!("{backfill_server} failed to provide backfill for room {room_id}: {e}");
|
||||
},
|
||||
}
|
||||
|
@ -1144,7 +1196,8 @@ impl Service {
|
|||
|
||||
#[tracing::instrument(skip(self, pdu))]
|
||||
pub async fn backfill_pdu(&self, origin: &ServerName, pdu: Box<RawJsonValue>) -> Result<()> {
|
||||
let (event_id, value, room_id) = self.services.event_handler.parse_incoming_pdu(&pdu).await?;
|
||||
let (event_id, value, room_id) =
|
||||
self.services.event_handler.parse_incoming_pdu(&pdu).await?;
|
||||
|
||||
// Lock so we cannot backfill the same pdu twice at the same time
|
||||
let mutex_lock = self
|
||||
|
@ -1210,10 +1263,10 @@ impl Service {
|
|||
#[tracing::instrument(skip_all, level = "debug")]
|
||||
async fn check_pdu_for_admin_room(&self, pdu: &PduEvent, sender: &UserId) -> Result<()> {
|
||||
match pdu.event_type() {
|
||||
TimelineEventType::RoomEncryption => {
|
||||
| TimelineEventType::RoomEncryption => {
|
||||
return Err!(Request(Forbidden(error!("Encryption not supported in admins room."))));
|
||||
},
|
||||
TimelineEventType::RoomMember => {
|
||||
| TimelineEventType::RoomMember => {
|
||||
let target = pdu
|
||||
.state_key()
|
||||
.filter(|v| v.starts_with('@'))
|
||||
|
@ -1223,9 +1276,11 @@ async fn check_pdu_for_admin_room(&self, pdu: &PduEvent, sender: &UserId) -> Res
|
|||
|
||||
let content: RoomMemberEventContent = pdu.get_content()?;
|
||||
match content.membership {
|
||||
MembershipState::Leave => {
|
||||
| MembershipState::Leave => {
|
||||
if target == server_user {
|
||||
return Err!(Request(Forbidden(error!("Server user cannot leave the admins room."))));
|
||||
return Err!(Request(Forbidden(error!(
|
||||
"Server user cannot leave the admins room."
|
||||
))));
|
||||
}
|
||||
|
||||
let count = self
|
||||
|
@ -1239,13 +1294,17 @@ async fn check_pdu_for_admin_room(&self, pdu: &PduEvent, sender: &UserId) -> Res
|
|||
.await;
|
||||
|
||||
if count < 2 {
|
||||
return Err!(Request(Forbidden(error!("Last admin cannot leave the admins room."))));
|
||||
return Err!(Request(Forbidden(error!(
|
||||
"Last admin cannot leave the admins room."
|
||||
))));
|
||||
}
|
||||
},
|
||||
|
||||
MembershipState::Ban if pdu.state_key().is_some() => {
|
||||
| MembershipState::Ban if pdu.state_key().is_some() => {
|
||||
if target == server_user {
|
||||
return Err!(Request(Forbidden(error!("Server cannot be banned from admins room."))));
|
||||
return Err!(Request(Forbidden(error!(
|
||||
"Server cannot be banned from admins room."
|
||||
))));
|
||||
}
|
||||
|
||||
let count = self
|
||||
|
@ -1259,13 +1318,15 @@ async fn check_pdu_for_admin_room(&self, pdu: &PduEvent, sender: &UserId) -> Res
|
|||
.await;
|
||||
|
||||
if count < 2 {
|
||||
return Err!(Request(Forbidden(error!("Last admin cannot be banned from admins room."))));
|
||||
return Err!(Request(Forbidden(error!(
|
||||
"Last admin cannot be banned from admins room."
|
||||
))));
|
||||
}
|
||||
},
|
||||
_ => {},
|
||||
| _ => {},
|
||||
};
|
||||
},
|
||||
_ => {},
|
||||
| _ => {},
|
||||
};
|
||||
|
||||
Ok(())
|
||||
|
|
|
@ -52,7 +52,12 @@ impl crate::Service for Service {
|
|||
impl Service {
|
||||
/// Sets a user as typing until the timeout timestamp is reached or
|
||||
/// roomtyping_remove is called.
|
||||
pub async fn typing_add(&self, user_id: &UserId, room_id: &RoomId, timeout: u64) -> Result<()> {
|
||||
pub async fn typing_add(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
room_id: &RoomId,
|
||||
timeout: u64,
|
||||
) -> Result<()> {
|
||||
debug_info!("typing started {user_id:?} in {room_id:?} timeout:{timeout:?}");
|
||||
// update clients
|
||||
self.typing
|
||||
|
@ -177,15 +182,15 @@ impl Service {
|
|||
|
||||
/// Returns a new typing EDU.
|
||||
pub async fn typings_all(
|
||||
&self, room_id: &RoomId, sender_user: &UserId,
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
sender_user: &UserId,
|
||||
) -> Result<SyncEphemeralRoomEvent<ruma::events::typing::TypingEventContent>> {
|
||||
let room_typing_indicators = self.typing.read().await.get(room_id).cloned();
|
||||
|
||||
let Some(typing_indicators) = room_typing_indicators else {
|
||||
return Ok(SyncEphemeralRoomEvent {
|
||||
content: ruma::events::typing::TypingEventContent {
|
||||
user_ids: Vec::new(),
|
||||
},
|
||||
content: ruma::events::typing::TypingEventContent { user_ids: Vec::new() },
|
||||
});
|
||||
};
|
||||
|
||||
|
@ -204,13 +209,16 @@ impl Service {
|
|||
.await;
|
||||
|
||||
Ok(SyncEphemeralRoomEvent {
|
||||
content: ruma::events::typing::TypingEventContent {
|
||||
user_ids,
|
||||
},
|
||||
content: ruma::events::typing::TypingEventContent { user_ids },
|
||||
})
|
||||
}
|
||||
|
||||
async fn federation_send(&self, room_id: &RoomId, user_id: &UserId, typing: bool) -> Result<()> {
|
||||
async fn federation_send(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
user_id: &UserId,
|
||||
typing: bool,
|
||||
) -> Result<()> {
|
||||
debug_assert!(
|
||||
self.services.globals.user_is_local(user_id),
|
||||
"tried to broadcast typing status of remote user",
|
||||
|
|
|
@ -92,7 +92,12 @@ pub async fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -
|
|||
}
|
||||
|
||||
#[implement(Service)]
|
||||
pub async fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64, shortstatehash: ShortStateHash) {
|
||||
pub async fn associate_token_shortstatehash(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
token: u64,
|
||||
shortstatehash: ShortStateHash,
|
||||
) {
|
||||
let shortroomid = self
|
||||
.services
|
||||
.short
|
||||
|
@ -108,7 +113,11 @@ pub async fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64,
|
|||
}
|
||||
|
||||
#[implement(Service)]
|
||||
pub async fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<ShortStateHash> {
|
||||
pub async fn get_token_shortstatehash(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
token: u64,
|
||||
) -> Result<ShortStateHash> {
|
||||
let shortroomid = self.services.short.get_shortroomid(room_id).await?;
|
||||
|
||||
let key: &[u64] = &[shortroomid, token];
|
||||
|
|
|
@ -3,14 +3,18 @@ use std::{fmt::Debug, mem};
|
|||
use bytes::BytesMut;
|
||||
use conduwuit::{debug_error, err, trace, utils, warn, Err, Result};
|
||||
use reqwest::Client;
|
||||
use ruma::api::{appservice::Registration, IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken};
|
||||
use ruma::api::{
|
||||
appservice::Registration, IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken,
|
||||
};
|
||||
|
||||
/// Sends a request to an appservice
|
||||
///
|
||||
/// Only returns Ok(None) if there is no url specified in the appservice
|
||||
/// registration file
|
||||
pub(crate) async fn send_request<T>(
|
||||
client: &Client, registration: Registration, request: T,
|
||||
client: &Client,
|
||||
registration: Registration,
|
||||
request: T,
|
||||
) -> Result<Option<T::IncomingResponse>>
|
||||
where
|
||||
T: OutgoingRequest + Debug + Send,
|
||||
|
@ -25,17 +29,17 @@ where
|
|||
|
||||
let hs_token = registration.hs_token.as_str();
|
||||
let mut http_request = request
|
||||
.try_into_http_request::<BytesMut>(&dest, SendAccessToken::IfRequired(hs_token), &VERSIONS)
|
||||
.try_into_http_request::<BytesMut>(
|
||||
&dest,
|
||||
SendAccessToken::IfRequired(hs_token),
|
||||
&VERSIONS,
|
||||
)
|
||||
.map_err(|e| err!(BadServerResponse(warn!("Failed to find destination {dest}: {e}"))))?
|
||||
.map(BytesMut::freeze);
|
||||
|
||||
let mut parts = http_request.uri().clone().into_parts();
|
||||
let old_path_and_query = parts.path_and_query.unwrap().as_str().to_owned();
|
||||
let symbol = if old_path_and_query.contains('?') {
|
||||
"&"
|
||||
} else {
|
||||
"?"
|
||||
};
|
||||
let symbol = if old_path_and_query.contains('?') { "&" } else { "?" };
|
||||
|
||||
parts.path_and_query = Some(
|
||||
(old_path_and_query + symbol + "access_token=" + hs_token)
|
||||
|
|
|
@ -43,7 +43,9 @@ impl Data {
|
|||
}
|
||||
}
|
||||
|
||||
pub(super) fn delete_active_request(&self, key: &[u8]) { self.servercurrentevent_data.remove(key); }
|
||||
pub(super) fn delete_active_request(&self, key: &[u8]) {
|
||||
self.servercurrentevent_data.remove(key);
|
||||
}
|
||||
|
||||
pub(super) async fn delete_all_active_requests_for(&self, destination: &Destination) {
|
||||
let prefix = destination.get_prefix();
|
||||
|
@ -76,11 +78,7 @@ impl Data {
|
|||
events
|
||||
.filter(|(key, _)| !key.is_empty())
|
||||
.for_each(|(key, val)| {
|
||||
let val = if let SendingEvent::Edu(val) = &val {
|
||||
&**val
|
||||
} else {
|
||||
&[]
|
||||
};
|
||||
let val = if let SendingEvent::Edu(val) = &val { &**val } else { &[] };
|
||||
|
||||
self.servercurrentevent_data.insert(key, val);
|
||||
self.servernameevent_data.remove(key);
|
||||
|
@ -93,21 +91,26 @@ impl Data {
|
|||
.raw_stream()
|
||||
.ignore_err()
|
||||
.map(|(key, val)| {
|
||||
let (dest, event) = parse_servercurrentevent(key, val).expect("invalid servercurrentevent");
|
||||
let (dest, event) =
|
||||
parse_servercurrentevent(key, val).expect("invalid servercurrentevent");
|
||||
|
||||
(key.to_vec(), event, dest)
|
||||
})
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn active_requests_for(&self, destination: &Destination) -> impl Stream<Item = SendingItem> + Send + '_ {
|
||||
pub fn active_requests_for(
|
||||
&self,
|
||||
destination: &Destination,
|
||||
) -> impl Stream<Item = SendingItem> + Send + '_ {
|
||||
let prefix = destination.get_prefix();
|
||||
self.servercurrentevent_data
|
||||
.raw_stream_from(&prefix)
|
||||
.ignore_err()
|
||||
.ready_take_while(move |(key, _)| key.starts_with(&prefix))
|
||||
.map(|(key, val)| {
|
||||
let (_, event) = parse_servercurrentevent(key, val).expect("invalid servercurrentevent");
|
||||
let (_, event) =
|
||||
parse_servercurrentevent(key, val).expect("invalid servercurrentevent");
|
||||
|
||||
(key.to_vec(), event)
|
||||
})
|
||||
|
@ -150,14 +153,18 @@ impl Data {
|
|||
keys
|
||||
}
|
||||
|
||||
pub fn queued_requests(&self, destination: &Destination) -> impl Stream<Item = QueueItem> + Send + '_ {
|
||||
pub fn queued_requests(
|
||||
&self,
|
||||
destination: &Destination,
|
||||
) -> impl Stream<Item = QueueItem> + Send + '_ {
|
||||
let prefix = destination.get_prefix();
|
||||
self.servernameevent_data
|
||||
.raw_stream_from(&prefix)
|
||||
.ignore_err()
|
||||
.ready_take_while(move |(key, _)| key.starts_with(&prefix))
|
||||
.map(|(key, val)| {
|
||||
let (_, event) = parse_servercurrentevent(key, val).expect("invalid servercurrentevent");
|
||||
let (_, event) =
|
||||
parse_servercurrentevent(key, val).expect("invalid servercurrentevent");
|
||||
|
||||
(key.to_vec(), event)
|
||||
})
|
||||
|
@ -186,8 +193,9 @@ fn parse_servercurrentevent(key: &[u8], value: &[u8]) -> Result<(Destination, Se
|
|||
.next()
|
||||
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
|
||||
|
||||
let server = utils::string_from_bytes(server)
|
||||
.map_err(|_| Error::bad_database("Invalid server bytes in server_currenttransaction"))?;
|
||||
let server = utils::string_from_bytes(server).map_err(|_| {
|
||||
Error::bad_database("Invalid server bytes in server_currenttransaction")
|
||||
})?;
|
||||
|
||||
(
|
||||
Destination::Appservice(server),
|
||||
|
@ -203,8 +211,8 @@ fn parse_servercurrentevent(key: &[u8], value: &[u8]) -> Result<(Destination, Se
|
|||
let user = parts.next().expect("splitn always returns one element");
|
||||
let user_string = utils::string_from_bytes(user)
|
||||
.map_err(|_| Error::bad_database("Invalid user string in servercurrentevent"))?;
|
||||
let user_id =
|
||||
UserId::parse(user_string).map_err(|_| Error::bad_database("Invalid user id in servercurrentevent"))?;
|
||||
let user_id = UserId::parse(user_string)
|
||||
.map_err(|_| Error::bad_database("Invalid user id in servercurrentevent"))?;
|
||||
|
||||
let pushkey = parts
|
||||
.next()
|
||||
|
@ -233,14 +241,14 @@ fn parse_servercurrentevent(key: &[u8], value: &[u8]) -> Result<(Destination, Se
|
|||
.next()
|
||||
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
|
||||
|
||||
let server = utils::string_from_bytes(server)
|
||||
.map_err(|_| Error::bad_database("Invalid server bytes in server_currenttransaction"))?;
|
||||
let server = utils::string_from_bytes(server).map_err(|_| {
|
||||
Error::bad_database("Invalid server bytes in server_currenttransaction")
|
||||
})?;
|
||||
|
||||
(
|
||||
Destination::Normal(
|
||||
ServerName::parse(server)
|
||||
.map_err(|_| Error::bad_database("Invalid server string in server_currenttransaction"))?,
|
||||
),
|
||||
Destination::Normal(ServerName::parse(server).map_err(|_| {
|
||||
Error::bad_database("Invalid server string in server_currenttransaction")
|
||||
})?),
|
||||
if value.is_empty() {
|
||||
SendingEvent::Pdu(event.into())
|
||||
} else {
|
||||
|
|
|
@ -14,7 +14,7 @@ pub enum Destination {
|
|||
#[must_use]
|
||||
pub(super) fn get_prefix(&self) -> Vec<u8> {
|
||||
match self {
|
||||
Self::Normal(server) => {
|
||||
| Self::Normal(server) => {
|
||||
let len = server.as_bytes().len().saturating_add(1);
|
||||
|
||||
let mut p = Vec::with_capacity(len);
|
||||
|
@ -22,7 +22,7 @@ pub(super) fn get_prefix(&self) -> Vec<u8> {
|
|||
p.push(0xFF);
|
||||
p
|
||||
},
|
||||
Self::Appservice(server) => {
|
||||
| Self::Appservice(server) => {
|
||||
let sigil = b"+";
|
||||
let len = sigil.len().saturating_add(server.len()).saturating_add(1);
|
||||
|
||||
|
@ -32,7 +32,7 @@ pub(super) fn get_prefix(&self) -> Vec<u8> {
|
|||
p.push(0xFF);
|
||||
p
|
||||
},
|
||||
Self::Push(user, pushkey) => {
|
||||
| Self::Push(user, pushkey) => {
|
||||
let sigil = b"$";
|
||||
let len = sigil
|
||||
.len()
|
||||
|
|
|
@ -25,8 +25,8 @@ pub use self::{
|
|||
sender::{EDU_LIMIT, PDU_LIMIT},
|
||||
};
|
||||
use crate::{
|
||||
account_data, client, globals, presence, pusher, resolver, rooms, rooms::timeline::RawPduId, server_keys, users,
|
||||
Dep,
|
||||
account_data, client, globals, presence, pusher, resolver, rooms, rooms::timeline::RawPduId,
|
||||
server_keys, users, Dep,
|
||||
};
|
||||
|
||||
pub struct Service {
|
||||
|
@ -156,18 +156,16 @@ impl Service {
|
|||
{
|
||||
let _cork = self.db.db.cork();
|
||||
let requests = servers
|
||||
.map(|server| (Destination::Normal(server.into()), SendingEvent::Pdu(pdu_id.to_owned())))
|
||||
.map(|server| {
|
||||
(Destination::Normal(server.into()), SendingEvent::Pdu(pdu_id.to_owned()))
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.await;
|
||||
|
||||
let keys = self.db.queue_requests(requests.iter().map(|(o, e)| (e, o)));
|
||||
|
||||
for ((dest, event), queue_id) in requests.into_iter().zip(keys) {
|
||||
self.dispatch(Msg {
|
||||
dest,
|
||||
event,
|
||||
queue_id,
|
||||
})?;
|
||||
self.dispatch(Msg { dest, event, queue_id })?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
@ -204,18 +202,16 @@ impl Service {
|
|||
{
|
||||
let _cork = self.db.db.cork();
|
||||
let requests = servers
|
||||
.map(|server| (Destination::Normal(server.to_owned()), SendingEvent::Edu(serialized.clone())))
|
||||
.map(|server| {
|
||||
(Destination::Normal(server.to_owned()), SendingEvent::Edu(serialized.clone()))
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.await;
|
||||
|
||||
let keys = self.db.queue_requests(requests.iter().map(|(o, e)| (e, o)));
|
||||
|
||||
for ((dest, event), queue_id) in requests.into_iter().zip(keys) {
|
||||
self.dispatch(Msg {
|
||||
dest,
|
||||
event,
|
||||
queue_id,
|
||||
})?;
|
||||
self.dispatch(Msg { dest, event, queue_id })?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
@ -253,7 +249,11 @@ impl Service {
|
|||
|
||||
/// Sends a request to a federation server
|
||||
#[tracing::instrument(skip_all, name = "request")]
|
||||
pub async fn send_federation_request<T>(&self, dest: &ServerName, request: T) -> Result<T::IncomingResponse>
|
||||
pub async fn send_federation_request<T>(
|
||||
&self,
|
||||
dest: &ServerName,
|
||||
request: T,
|
||||
) -> Result<T::IncomingResponse>
|
||||
where
|
||||
T: OutgoingRequest + Debug + Send,
|
||||
{
|
||||
|
@ -263,7 +263,11 @@ impl Service {
|
|||
|
||||
/// Like send_federation_request() but with a very large timeout
|
||||
#[tracing::instrument(skip_all, name = "synapse")]
|
||||
pub async fn send_synapse_request<T>(&self, dest: &ServerName, request: T) -> Result<T::IncomingResponse>
|
||||
pub async fn send_synapse_request<T>(
|
||||
&self,
|
||||
dest: &ServerName,
|
||||
request: T,
|
||||
) -> Result<T::IncomingResponse>
|
||||
where
|
||||
T: OutgoingRequest + Debug + Send,
|
||||
{
|
||||
|
@ -276,7 +280,9 @@ impl Service {
|
|||
/// Only returns None if there is no url specified in the appservice
|
||||
/// registration file
|
||||
pub async fn send_appservice_request<T>(
|
||||
&self, registration: Registration, request: T,
|
||||
&self,
|
||||
registration: Registration,
|
||||
request: T,
|
||||
) -> Result<Option<T::IncomingResponse>>
|
||||
where
|
||||
T: OutgoingRequest + Debug + Send,
|
||||
|
@ -291,24 +297,30 @@ impl Service {
|
|||
/// key
|
||||
#[tracing::instrument(skip(self), level = "debug")]
|
||||
pub async fn cleanup_events(
|
||||
&self, appservice_id: Option<&str>, user_id: Option<&UserId>, push_key: Option<&str>,
|
||||
&self,
|
||||
appservice_id: Option<&str>,
|
||||
user_id: Option<&UserId>,
|
||||
push_key: Option<&str>,
|
||||
) -> Result {
|
||||
match (appservice_id, user_id, push_key) {
|
||||
(None, Some(user_id), Some(push_key)) => {
|
||||
| (None, Some(user_id), Some(push_key)) => {
|
||||
self.db
|
||||
.delete_all_requests_for(&Destination::Push(user_id.to_owned(), push_key.to_owned()))
|
||||
.delete_all_requests_for(&Destination::Push(
|
||||
user_id.to_owned(),
|
||||
push_key.to_owned(),
|
||||
))
|
||||
.await;
|
||||
|
||||
Ok(())
|
||||
},
|
||||
(Some(appservice_id), None, None) => {
|
||||
| (Some(appservice_id), None, None) => {
|
||||
self.db
|
||||
.delete_all_requests_for(&Destination::Appservice(appservice_id.to_owned()))
|
||||
.await;
|
||||
|
||||
Ok(())
|
||||
},
|
||||
_ => {
|
||||
| _ => {
|
||||
debug_warn!("cleanup_events called with too many or too few arguments");
|
||||
Ok(())
|
||||
},
|
||||
|
|
|
@ -2,16 +2,16 @@ use std::mem;
|
|||
|
||||
use bytes::Bytes;
|
||||
use conduwuit::{
|
||||
debug, debug_error, debug_warn, err, error::inspect_debug_log, implement, trace, utils::string::EMPTY, Err, Error,
|
||||
Result,
|
||||
debug, debug_error, debug_warn, err, error::inspect_debug_log, implement, trace,
|
||||
utils::string::EMPTY, Err, Error, Result,
|
||||
};
|
||||
use http::{header::AUTHORIZATION, HeaderValue};
|
||||
use ipaddress::IPAddress;
|
||||
use reqwest::{Client, Method, Request, Response, Url};
|
||||
use ruma::{
|
||||
api::{
|
||||
client::error::Error as RumaError, EndpointError, IncomingResponse, MatrixVersion, OutgoingRequest,
|
||||
SendAccessToken,
|
||||
client::error::Error as RumaError, EndpointError, IncomingResponse, MatrixVersion,
|
||||
OutgoingRequest, SendAccessToken,
|
||||
},
|
||||
serde::Base64,
|
||||
server_util::authorization::XMatrix,
|
||||
|
@ -25,7 +25,12 @@ use crate::{
|
|||
|
||||
impl super::Service {
|
||||
#[tracing::instrument(skip_all, level = "debug")]
|
||||
pub async fn send<T>(&self, client: &Client, dest: &ServerName, request: T) -> Result<T::IncomingResponse>
|
||||
pub async fn send<T>(
|
||||
&self,
|
||||
client: &Client,
|
||||
dest: &ServerName,
|
||||
request: T,
|
||||
) -> Result<T::IncomingResponse>
|
||||
where
|
||||
T: OutgoingRequest + Send,
|
||||
{
|
||||
|
@ -39,7 +44,9 @@ impl super::Service {
|
|||
.forbidden_remote_server_names
|
||||
.contains(dest)
|
||||
{
|
||||
return Err!(Request(Forbidden(debug_warn!("Federation with {dest} is not allowed."))));
|
||||
return Err!(Request(Forbidden(debug_warn!(
|
||||
"Federation with {dest} is not allowed."
|
||||
))));
|
||||
}
|
||||
|
||||
let actual = self.services.resolver.get_actual_dest(dest).await?;
|
||||
|
@ -49,7 +56,11 @@ impl super::Service {
|
|||
}
|
||||
|
||||
async fn execute<T>(
|
||||
&self, dest: &ServerName, actual: &ActualDest, request: Request, client: &Client,
|
||||
&self,
|
||||
dest: &ServerName,
|
||||
actual: &ActualDest,
|
||||
request: Request,
|
||||
client: &Client,
|
||||
) -> Result<T::IncomingResponse>
|
||||
where
|
||||
T: OutgoingRequest + Send,
|
||||
|
@ -59,8 +70,18 @@ impl super::Service {
|
|||
|
||||
debug!(?method, ?url, "Sending request");
|
||||
match client.execute(request).await {
|
||||
Ok(response) => handle_response::<T>(&self.services.resolver, dest, actual, &method, &url, response).await,
|
||||
Err(error) => Err(handle_error(actual, &method, &url, error).expect_err("always returns error")),
|
||||
| Ok(response) =>
|
||||
handle_response::<T>(
|
||||
&self.services.resolver,
|
||||
dest,
|
||||
actual,
|
||||
&method,
|
||||
&url,
|
||||
response,
|
||||
)
|
||||
.await,
|
||||
| Err(error) =>
|
||||
Err(handle_error(actual, &method, &url, error).expect_err("always returns error")),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -86,7 +107,11 @@ impl super::Service {
|
|||
}
|
||||
|
||||
async fn handle_response<T>(
|
||||
resolver: &resolver::Service, dest: &ServerName, actual: &ActualDest, method: &Method, url: &Url,
|
||||
resolver: &resolver::Service,
|
||||
dest: &ServerName,
|
||||
actual: &ActualDest,
|
||||
method: &Method,
|
||||
url: &Url,
|
||||
response: Response,
|
||||
) -> Result<T::IncomingResponse>
|
||||
where
|
||||
|
@ -96,21 +121,22 @@ where
|
|||
let result = T::IncomingResponse::try_from_http_response(response);
|
||||
|
||||
if result.is_ok() && !actual.cached {
|
||||
resolver.set_cached_destination(
|
||||
dest.to_owned(),
|
||||
CachedDest {
|
||||
dest: actual.dest.clone(),
|
||||
host: actual.host.clone(),
|
||||
expire: CachedDest::default_expire(),
|
||||
},
|
||||
);
|
||||
resolver.set_cached_destination(dest.to_owned(), CachedDest {
|
||||
dest: actual.dest.clone(),
|
||||
host: actual.host.clone(),
|
||||
expire: CachedDest::default_expire(),
|
||||
});
|
||||
}
|
||||
|
||||
result.map_err(|e| err!(BadServerResponse("Server returned bad 200 response: {e:?}")))
|
||||
}
|
||||
|
||||
async fn into_http_response(
|
||||
dest: &ServerName, actual: &ActualDest, method: &Method, url: &Url, mut response: Response,
|
||||
dest: &ServerName,
|
||||
actual: &ActualDest,
|
||||
method: &Method,
|
||||
url: &Url,
|
||||
mut response: Response,
|
||||
) -> Result<http::Response<Bytes>> {
|
||||
let status = response.status();
|
||||
trace!(
|
||||
|
@ -146,13 +172,21 @@ async fn into_http_response(
|
|||
|
||||
debug!("Got {status:?} for {method} {url}");
|
||||
if !status.is_success() {
|
||||
return Err(Error::Federation(dest.to_owned(), RumaError::from_http_response(http_response)));
|
||||
return Err(Error::Federation(
|
||||
dest.to_owned(),
|
||||
RumaError::from_http_response(http_response),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(http_response)
|
||||
}
|
||||
|
||||
fn handle_error(actual: &ActualDest, method: &Method, url: &Url, mut e: reqwest::Error) -> Result {
|
||||
fn handle_error(
|
||||
actual: &ActualDest,
|
||||
method: &Method,
|
||||
url: &Url,
|
||||
mut e: reqwest::Error,
|
||||
) -> Result {
|
||||
if e.is_timeout() || e.is_connect() {
|
||||
e = e.without_url();
|
||||
debug_warn!("{e:?}");
|
||||
|
@ -186,7 +220,8 @@ fn sign_request(&self, http_request: &mut http::Request<Vec<u8>>, dest: &ServerN
|
|||
.expect("http::Request missing path_and_query");
|
||||
|
||||
let mut req: Object = if !body.is_empty() {
|
||||
let content: CanonicalJsonValue = serde_json::from_slice(body).expect("failed to serialize body");
|
||||
let content: CanonicalJsonValue =
|
||||
serde_json::from_slice(body).expect("failed to serialize body");
|
||||
|
||||
let authorization: [Member; 5] = [
|
||||
("content".into(), content),
|
||||
|
|
|
@ -24,15 +24,19 @@ use ruma::{
|
|||
appservice::event::push_events::v1::Edu as RumaEdu,
|
||||
federation::transactions::{
|
||||
edu::{
|
||||
DeviceListUpdateContent, Edu, PresenceContent, PresenceUpdate, ReceiptContent, ReceiptData, ReceiptMap,
|
||||
DeviceListUpdateContent, Edu, PresenceContent, PresenceUpdate, ReceiptContent,
|
||||
ReceiptData, ReceiptMap,
|
||||
},
|
||||
send_transaction_message,
|
||||
},
|
||||
},
|
||||
device_id,
|
||||
events::{push_rules::PushRulesEvent, receipt::ReceiptType, AnySyncEphemeralRoomEvent, GlobalAccountDataEventType},
|
||||
push, uint, CanonicalJsonObject, MilliSecondsSinceUnixEpoch, OwnedRoomId, OwnedServerName, OwnedUserId, RoomId,
|
||||
RoomVersionId, ServerName, UInt,
|
||||
events::{
|
||||
push_rules::PushRulesEvent, receipt::ReceiptType, AnySyncEphemeralRoomEvent,
|
||||
GlobalAccountDataEventType,
|
||||
},
|
||||
push, uint, CanonicalJsonObject, MilliSecondsSinceUnixEpoch, OwnedRoomId, OwnedServerName,
|
||||
OwnedUserId, RoomId, RoomVersionId, ServerName, UInt,
|
||||
};
|
||||
use serde_json::value::{to_raw_value, RawValue as RawJsonValue};
|
||||
|
||||
|
@ -86,11 +90,14 @@ impl Service {
|
|||
}
|
||||
|
||||
async fn handle_response<'a>(
|
||||
&'a self, response: SendingResult, futures: &mut SendingFutures<'a>, statuses: &mut CurTransactionStatus,
|
||||
&'a self,
|
||||
response: SendingResult,
|
||||
futures: &mut SendingFutures<'a>,
|
||||
statuses: &mut CurTransactionStatus,
|
||||
) {
|
||||
match response {
|
||||
Ok(dest) => self.handle_response_ok(&dest, futures, statuses).await,
|
||||
Err((dest, e)) => Self::handle_response_err(dest, statuses, &e),
|
||||
| Ok(dest) => self.handle_response_ok(&dest, futures, statuses).await,
|
||||
| Err((dest, e)) => Self::handle_response_err(dest, statuses, &e),
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -98,16 +105,22 @@ impl Service {
|
|||
debug!(dest = ?dest, "{e:?}");
|
||||
statuses.entry(dest).and_modify(|e| {
|
||||
*e = match e {
|
||||
TransactionStatus::Running => TransactionStatus::Failed(1, Instant::now()),
|
||||
TransactionStatus::Retrying(ref n) => TransactionStatus::Failed(n.saturating_add(1), Instant::now()),
|
||||
TransactionStatus::Failed(..) => panic!("Request that was not even running failed?!"),
|
||||
| TransactionStatus::Running => TransactionStatus::Failed(1, Instant::now()),
|
||||
| TransactionStatus::Retrying(ref n) =>
|
||||
TransactionStatus::Failed(n.saturating_add(1), Instant::now()),
|
||||
| TransactionStatus::Failed(..) => {
|
||||
panic!("Request that was not even running failed?!")
|
||||
},
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#[allow(clippy::needless_pass_by_ref_mut)]
|
||||
async fn handle_response_ok<'a>(
|
||||
&'a self, dest: &Destination, futures: &mut SendingFutures<'a>, statuses: &mut CurTransactionStatus,
|
||||
&'a self,
|
||||
dest: &Destination,
|
||||
futures: &mut SendingFutures<'a>,
|
||||
statuses: &mut CurTransactionStatus,
|
||||
) {
|
||||
let _cork = self.db.db.cork();
|
||||
self.db.delete_all_active_requests_for(dest).await;
|
||||
|
@ -133,7 +146,10 @@ impl Service {
|
|||
|
||||
#[allow(clippy::needless_pass_by_ref_mut)]
|
||||
async fn handle_request<'a>(
|
||||
&'a self, msg: Msg, futures: &mut SendingFutures<'a>, statuses: &mut CurTransactionStatus,
|
||||
&'a self,
|
||||
msg: Msg,
|
||||
futures: &mut SendingFutures<'a>,
|
||||
statuses: &mut CurTransactionStatus,
|
||||
) {
|
||||
let iv = vec![(msg.queue_id, msg.event)];
|
||||
if let Ok(Some(events)) = self.select_events(&msg.dest, iv, statuses).await {
|
||||
|
@ -168,8 +184,13 @@ impl Service {
|
|||
}
|
||||
|
||||
#[allow(clippy::needless_pass_by_ref_mut)]
|
||||
async fn initial_requests<'a>(&'a self, futures: &mut SendingFutures<'a>, statuses: &mut CurTransactionStatus) {
|
||||
let keep = usize::try_from(self.server.config.startup_netburst_keep).unwrap_or(usize::MAX);
|
||||
async fn initial_requests<'a>(
|
||||
&'a self,
|
||||
futures: &mut SendingFutures<'a>,
|
||||
statuses: &mut CurTransactionStatus,
|
||||
) {
|
||||
let keep =
|
||||
usize::try_from(self.server.config.startup_netburst_keep).unwrap_or(usize::MAX);
|
||||
let mut txns = HashMap::<Destination, Vec<SendingEvent>>::new();
|
||||
let mut active = self.db.active_requests().boxed();
|
||||
|
||||
|
@ -240,7 +261,11 @@ impl Service {
|
|||
}
|
||||
|
||||
#[tracing::instrument(skip_all, level = "debug")]
|
||||
fn select_events_current(&self, dest: Destination, statuses: &mut CurTransactionStatus) -> Result<(bool, bool)> {
|
||||
fn select_events_current(
|
||||
&self,
|
||||
dest: Destination,
|
||||
statuses: &mut CurTransactionStatus,
|
||||
) -> Result<(bool, bool)> {
|
||||
let (mut allow, mut retry) = (true, false);
|
||||
statuses
|
||||
.entry(dest.clone()) // TODO: can we avoid cloning?
|
||||
|
@ -278,7 +303,8 @@ impl Service {
|
|||
let events_len = AtomicUsize::default();
|
||||
let max_edu_count = AtomicU64::new(since);
|
||||
|
||||
let device_changes = self.select_edus_device_changes(server_name, batch, &max_edu_count, &events_len);
|
||||
let device_changes =
|
||||
self.select_edus_device_changes(server_name, batch, &max_edu_count, &events_len);
|
||||
|
||||
let receipts: OptionFuture<_> = self
|
||||
.server
|
||||
|
@ -305,7 +331,11 @@ impl Service {
|
|||
|
||||
/// Look for presence
|
||||
async fn select_edus_device_changes(
|
||||
&self, server_name: &ServerName, since: (u64, u64), max_edu_count: &AtomicU64, events_len: &AtomicUsize,
|
||||
&self,
|
||||
server_name: &ServerName,
|
||||
since: (u64, u64),
|
||||
max_edu_count: &AtomicU64,
|
||||
events_len: &AtomicUsize,
|
||||
) -> Vec<Vec<u8>> {
|
||||
let mut events = Vec::new();
|
||||
let server_rooms = self.services.state_cache.server_rooms(server_name);
|
||||
|
@ -342,7 +372,8 @@ impl Service {
|
|||
keys: None,
|
||||
});
|
||||
|
||||
let edu = serde_json::to_vec(&edu).expect("failed to serialize device list update to JSON");
|
||||
let edu = serde_json::to_vec(&edu)
|
||||
.expect("failed to serialize device list update to JSON");
|
||||
|
||||
events.push(edu);
|
||||
if events_len.fetch_add(1, Ordering::Relaxed) >= SELECT_EDU_LIMIT - 1 {
|
||||
|
@ -356,7 +387,10 @@ impl Service {
|
|||
|
||||
/// Look for read receipts in this room
|
||||
async fn select_edus_receipts(
|
||||
&self, server_name: &ServerName, since: (u64, u64), max_edu_count: &AtomicU64,
|
||||
&self,
|
||||
server_name: &ServerName,
|
||||
since: (u64, u64),
|
||||
max_edu_count: &AtomicU64,
|
||||
) -> Option<Vec<u8>> {
|
||||
let server_rooms = self.services.state_cache.server_rooms(server_name);
|
||||
|
||||
|
@ -377,19 +411,21 @@ impl Service {
|
|||
return None;
|
||||
}
|
||||
|
||||
let receipt_content = Edu::Receipt(ReceiptContent {
|
||||
receipts,
|
||||
});
|
||||
let receipt_content = Edu::Receipt(ReceiptContent { receipts });
|
||||
|
||||
let receipt_content =
|
||||
serde_json::to_vec(&receipt_content).expect("Failed to serialize Receipt EDU to JSON vec");
|
||||
let receipt_content = serde_json::to_vec(&receipt_content)
|
||||
.expect("Failed to serialize Receipt EDU to JSON vec");
|
||||
|
||||
Some(receipt_content)
|
||||
}
|
||||
|
||||
/// Look for read receipts in this room
|
||||
async fn select_edus_receipts_room(
|
||||
&self, room_id: &RoomId, since: (u64, u64), max_edu_count: &AtomicU64, num: &mut usize,
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
since: (u64, u64),
|
||||
max_edu_count: &AtomicU64,
|
||||
num: &mut usize,
|
||||
) -> ReceiptMap {
|
||||
let receipts = self
|
||||
.services
|
||||
|
@ -444,14 +480,15 @@ impl Service {
|
|||
}
|
||||
}
|
||||
|
||||
ReceiptMap {
|
||||
read,
|
||||
}
|
||||
ReceiptMap { read }
|
||||
}
|
||||
|
||||
/// Look for presence
|
||||
async fn select_edus_presence(
|
||||
&self, server_name: &ServerName, since: (u64, u64), max_edu_count: &AtomicU64,
|
||||
&self,
|
||||
server_name: &ServerName,
|
||||
since: (u64, u64),
|
||||
max_edu_count: &AtomicU64,
|
||||
) -> Option<Vec<u8>> {
|
||||
let presence_since = self.services.presence.presence_since(since.0);
|
||||
|
||||
|
@ -511,7 +548,8 @@ impl Service {
|
|||
push: presence_updates.into_values().collect(),
|
||||
});
|
||||
|
||||
let presence_content = serde_json::to_vec(&presence_content).expect("failed to serialize Presence EDU to JSON");
|
||||
let presence_content = serde_json::to_vec(&presence_content)
|
||||
.expect("failed to serialize Presence EDU to JSON");
|
||||
|
||||
Some(presence_content)
|
||||
}
|
||||
|
@ -519,21 +557,28 @@ impl Service {
|
|||
async fn send_events(&self, dest: Destination, events: Vec<SendingEvent>) -> SendingResult {
|
||||
//debug_assert!(!events.is_empty(), "sending empty transaction");
|
||||
match dest {
|
||||
Destination::Normal(ref server) => self.send_events_dest_normal(&dest, server, events).await,
|
||||
Destination::Appservice(ref id) => self.send_events_dest_appservice(&dest, id, events).await,
|
||||
Destination::Push(ref userid, ref pushkey) => {
|
||||
| Destination::Normal(ref server) =>
|
||||
self.send_events_dest_normal(&dest, server, events).await,
|
||||
| Destination::Appservice(ref id) =>
|
||||
self.send_events_dest_appservice(&dest, id, events).await,
|
||||
| Destination::Push(ref userid, ref pushkey) =>
|
||||
self.send_events_dest_push(&dest, userid, pushkey, events)
|
||||
.await
|
||||
},
|
||||
.await,
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, dest, events), name = "appservice")]
|
||||
async fn send_events_dest_appservice(
|
||||
&self, dest: &Destination, id: &str, events: Vec<SendingEvent>,
|
||||
&self,
|
||||
dest: &Destination,
|
||||
id: &str,
|
||||
events: Vec<SendingEvent>,
|
||||
) -> SendingResult {
|
||||
let Some(appservice) = self.services.appservice.get_registration(id).await else {
|
||||
return Err((dest.clone(), err!(Database(warn!(?id, "Missing appservice registration")))));
|
||||
return Err((
|
||||
dest.clone(),
|
||||
err!(Database(warn!(?id, "Missing appservice registration"))),
|
||||
));
|
||||
};
|
||||
|
||||
let mut pdu_jsons = Vec::with_capacity(
|
||||
|
@ -550,12 +595,12 @@ impl Service {
|
|||
);
|
||||
for event in &events {
|
||||
match event {
|
||||
SendingEvent::Pdu(pdu_id) => {
|
||||
| SendingEvent::Pdu(pdu_id) => {
|
||||
if let Ok(pdu) = self.services.timeline.get_pdu_from_id(pdu_id).await {
|
||||
pdu_jsons.push(pdu.to_room_event());
|
||||
}
|
||||
},
|
||||
SendingEvent::Edu(edu) => {
|
||||
| SendingEvent::Edu(edu) => {
|
||||
if appservice
|
||||
.receive_ephemeral
|
||||
.is_some_and(|receive_edus| receive_edus)
|
||||
|
@ -565,14 +610,14 @@ impl Service {
|
|||
}
|
||||
}
|
||||
},
|
||||
SendingEvent::Flush => {}, // flush only; no new content
|
||||
| SendingEvent::Flush => {}, // flush only; no new content
|
||||
}
|
||||
}
|
||||
|
||||
let txn_hash = calculate_hash(events.iter().filter_map(|e| match e {
|
||||
SendingEvent::Edu(b) => Some(&**b),
|
||||
SendingEvent::Pdu(b) => Some(b.as_ref()),
|
||||
SendingEvent::Flush => None,
|
||||
| SendingEvent::Edu(b) => Some(&**b),
|
||||
| SendingEvent::Pdu(b) => Some(b.as_ref()),
|
||||
| SendingEvent::Flush => None,
|
||||
}));
|
||||
|
||||
let txn_id = &*general_purpose::URL_SAFE_NO_PAD.encode(txn_hash);
|
||||
|
@ -592,28 +637,35 @@ impl Service {
|
|||
)
|
||||
.await
|
||||
{
|
||||
Ok(_) => Ok(dest.clone()),
|
||||
Err(e) => Err((dest.clone(), e)),
|
||||
| Ok(_) => Ok(dest.clone()),
|
||||
| Err(e) => Err((dest.clone(), e)),
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, dest, events), name = "push")]
|
||||
async fn send_events_dest_push(
|
||||
&self, dest: &Destination, userid: &OwnedUserId, pushkey: &str, events: Vec<SendingEvent>,
|
||||
&self,
|
||||
dest: &Destination,
|
||||
userid: &OwnedUserId,
|
||||
pushkey: &str,
|
||||
events: Vec<SendingEvent>,
|
||||
) -> SendingResult {
|
||||
let Ok(pusher) = self.services.pusher.get_pusher(userid, pushkey).await else {
|
||||
return Err((dest.clone(), err!(Database(error!(?userid, ?pushkey, "Missing pusher")))));
|
||||
return Err((
|
||||
dest.clone(),
|
||||
err!(Database(error!(?userid, ?pushkey, "Missing pusher"))),
|
||||
));
|
||||
};
|
||||
|
||||
let mut pdus = Vec::new();
|
||||
for event in &events {
|
||||
match event {
|
||||
SendingEvent::Pdu(pdu_id) => {
|
||||
| SendingEvent::Pdu(pdu_id) => {
|
||||
if let Ok(pdu) = self.services.timeline.get_pdu_from_id(pdu_id).await {
|
||||
pdus.push(pdu);
|
||||
}
|
||||
},
|
||||
SendingEvent::Edu(_) | SendingEvent::Flush => {
|
||||
| SendingEvent::Edu(_) | SendingEvent::Flush => {
|
||||
// Push gateways don't need EDUs (?) and flush only;
|
||||
// no new content
|
||||
},
|
||||
|
@ -657,7 +709,10 @@ impl Service {
|
|||
|
||||
#[tracing::instrument(skip(self, dest, events), name = "", level = "debug")]
|
||||
async fn send_events_dest_normal(
|
||||
&self, dest: &Destination, server: &OwnedServerName, events: Vec<SendingEvent>,
|
||||
&self,
|
||||
dest: &Destination,
|
||||
server: &OwnedServerName,
|
||||
events: Vec<SendingEvent>,
|
||||
) -> SendingResult {
|
||||
let mut pdu_jsons = Vec::with_capacity(
|
||||
events
|
||||
|
@ -675,17 +730,16 @@ impl Service {
|
|||
for event in &events {
|
||||
match event {
|
||||
// TODO: check room version and remove event_id if needed
|
||||
SendingEvent::Pdu(pdu_id) => {
|
||||
| SendingEvent::Pdu(pdu_id) => {
|
||||
if let Ok(pdu) = self.services.timeline.get_pdu_json_from_id(pdu_id).await {
|
||||
pdu_jsons.push(self.convert_to_outgoing_federation_event(pdu).await);
|
||||
}
|
||||
},
|
||||
SendingEvent::Edu(edu) => {
|
||||
| SendingEvent::Edu(edu) =>
|
||||
if let Ok(raw) = serde_json::from_slice(edu) {
|
||||
edu_jsons.push(raw);
|
||||
}
|
||||
},
|
||||
SendingEvent::Flush => {}, // flush only; no new content
|
||||
},
|
||||
| SendingEvent::Flush => {}, // flush only; no new content
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -693,9 +747,9 @@ impl Service {
|
|||
// transaction");
|
||||
|
||||
let txn_hash = calculate_hash(events.iter().filter_map(|e| match e {
|
||||
SendingEvent::Edu(b) => Some(&**b),
|
||||
SendingEvent::Pdu(b) => Some(b.as_ref()),
|
||||
SendingEvent::Flush => None,
|
||||
| SendingEvent::Edu(b) => Some(&**b),
|
||||
| SendingEvent::Pdu(b) => Some(b.as_ref()),
|
||||
| SendingEvent::Flush => None,
|
||||
}));
|
||||
|
||||
let txn_id = &*general_purpose::URL_SAFE_NO_PAD.encode(txn_hash);
|
||||
|
@ -725,7 +779,10 @@ impl Service {
|
|||
}
|
||||
|
||||
/// This does not return a full `Pdu` it is only to satisfy ruma's types.
|
||||
pub async fn convert_to_outgoing_federation_event(&self, mut pdu_json: CanonicalJsonObject) -> Box<RawJsonValue> {
|
||||
pub async fn convert_to_outgoing_federation_event(
|
||||
&self,
|
||||
mut pdu_json: CanonicalJsonObject,
|
||||
) -> Box<RawJsonValue> {
|
||||
if let Some(unsigned) = pdu_json
|
||||
.get_mut("unsigned")
|
||||
.and_then(|val| val.as_object_mut())
|
||||
|
@ -739,11 +796,11 @@ impl Service {
|
|||
.and_then(|val| RoomId::parse(val.as_str()?).ok())
|
||||
{
|
||||
match self.services.state.get_room_version(&room_id).await {
|
||||
Ok(room_version_id) => match room_version_id {
|
||||
RoomVersionId::V1 | RoomVersionId::V2 => {},
|
||||
_ => _ = pdu_json.remove("event_id"),
|
||||
| Ok(room_version_id) => match room_version_id {
|
||||
| RoomVersionId::V1 | RoomVersionId::V2 => {},
|
||||
| _ => _ = pdu_json.remove("event_id"),
|
||||
},
|
||||
Err(_) => _ = pdu_json.remove("event_id"),
|
||||
| Err(_) => _ = pdu_json.remove("event_id"),
|
||||
}
|
||||
} else {
|
||||
pdu_json.remove("event_id");
|
||||
|
|
|
@ -4,11 +4,13 @@ use std::{
|
|||
time::Duration,
|
||||
};
|
||||
|
||||
use conduwuit::{debug, debug_error, debug_warn, error, implement, info, result::FlatOk, trace, warn};
|
||||
use conduwuit::{
|
||||
debug, debug_error, debug_warn, error, implement, info, result::FlatOk, trace, warn,
|
||||
};
|
||||
use futures::{stream::FuturesUnordered, StreamExt};
|
||||
use ruma::{
|
||||
api::federation::discovery::ServerSigningKeys, serde::Raw, CanonicalJsonObject, OwnedServerName,
|
||||
OwnedServerSigningKeyId, ServerName, ServerSigningKeyId,
|
||||
api::federation::discovery::ServerSigningKeys, serde::Raw, CanonicalJsonObject,
|
||||
OwnedServerName, OwnedServerSigningKeyId, ServerName, ServerSigningKeyId,
|
||||
};
|
||||
use serde_json::value::RawValue as RawJsonValue;
|
||||
use tokio::time::{timeout_at, Instant};
|
||||
|
@ -79,7 +81,9 @@ where
|
|||
return;
|
||||
}
|
||||
|
||||
warn!("missing {missing_keys} keys for {missing_servers} servers from all notaries first");
|
||||
warn!(
|
||||
"missing {missing_keys} keys for {missing_servers} servers from all notaries first"
|
||||
);
|
||||
}
|
||||
|
||||
if !notary_only {
|
||||
|
@ -101,13 +105,15 @@ where
|
|||
return;
|
||||
}
|
||||
|
||||
debug_warn!("still missing {missing_keys} keys for {missing_servers} servers from all notaries.");
|
||||
debug_warn!(
|
||||
"still missing {missing_keys} keys for {missing_servers} servers from all notaries."
|
||||
);
|
||||
}
|
||||
|
||||
if missing_keys > 0 {
|
||||
warn!(
|
||||
"did not obtain {missing_keys} keys for {missing_servers} servers out of {requested_keys} total keys for \
|
||||
{requested_servers} total servers."
|
||||
"did not obtain {missing_keys} keys for {missing_servers} servers out of \
|
||||
{requested_keys} total keys for {requested_servers} total servers."
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -162,12 +168,15 @@ where
|
|||
|
||||
#[implement(super::Service)]
|
||||
async fn acquire_origin(
|
||||
&self, origin: OwnedServerName, mut key_ids: Vec<OwnedServerSigningKeyId>, timeout: Instant,
|
||||
&self,
|
||||
origin: OwnedServerName,
|
||||
mut key_ids: Vec<OwnedServerSigningKeyId>,
|
||||
timeout: Instant,
|
||||
) -> (OwnedServerName, Vec<OwnedServerSigningKeyId>) {
|
||||
match timeout_at(timeout, self.server_request(&origin)).await {
|
||||
Err(e) => debug_warn!(?origin, "timed out: {e}"),
|
||||
Ok(Err(e)) => debug_error!(?origin, "{e}"),
|
||||
Ok(Ok(server_keys)) => {
|
||||
| Err(e) => debug_warn!(?origin, "timed out: {e}"),
|
||||
| Ok(Err(e)) => debug_error!(?origin, "{e}"),
|
||||
| Ok(Ok(server_keys)) => {
|
||||
trace!(
|
||||
%origin,
|
||||
?key_ids,
|
||||
|
@ -192,19 +201,21 @@ where
|
|||
for notary in self.services.globals.trusted_servers() {
|
||||
let missing_keys = keys_count(&missing);
|
||||
let missing_servers = missing.len();
|
||||
debug!("Asking notary {notary} for {missing_keys} missing keys from {missing_servers} servers");
|
||||
debug!(
|
||||
"Asking notary {notary} for {missing_keys} missing keys from {missing_servers} \
|
||||
servers"
|
||||
);
|
||||
|
||||
let batch = missing
|
||||
.iter()
|
||||
.map(|(server, keys)| (server.borrow(), keys.iter().map(Borrow::borrow)));
|
||||
|
||||
match self.batch_notary_request(notary, batch).await {
|
||||
Err(e) => error!("Failed to contact notary {notary:?}: {e}"),
|
||||
Ok(results) => {
|
||||
| Err(e) => error!("Failed to contact notary {notary:?}: {e}"),
|
||||
| Ok(results) =>
|
||||
for server_keys in results {
|
||||
self.acquire_notary_result(&mut missing, server_keys).await;
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -224,4 +235,6 @@ async fn acquire_notary_result(&self, missing: &mut Batch, server_keys: ServerSi
|
|||
}
|
||||
}
|
||||
|
||||
fn keys_count(batch: &Batch) -> usize { batch.iter().flat_map(|(_, key_ids)| key_ids.iter()).count() }
|
||||
fn keys_count(batch: &Batch) -> usize {
|
||||
batch.iter().flat_map(|(_, key_ids)| key_ids.iter()).count()
|
||||
}
|
||||
|
|
|
@ -1,17 +1,25 @@
|
|||
use std::borrow::Borrow;
|
||||
|
||||
use conduwuit::{implement, Err, Result};
|
||||
use ruma::{api::federation::discovery::VerifyKey, CanonicalJsonObject, RoomVersionId, ServerName, ServerSigningKeyId};
|
||||
use ruma::{
|
||||
api::federation::discovery::VerifyKey, CanonicalJsonObject, RoomVersionId, ServerName,
|
||||
ServerSigningKeyId,
|
||||
};
|
||||
|
||||
use super::{extract_key, PubKeyMap, PubKeys};
|
||||
|
||||
#[implement(super::Service)]
|
||||
pub async fn get_event_keys(&self, object: &CanonicalJsonObject, version: &RoomVersionId) -> Result<PubKeyMap> {
|
||||
pub async fn get_event_keys(
|
||||
&self,
|
||||
object: &CanonicalJsonObject,
|
||||
version: &RoomVersionId,
|
||||
) -> Result<PubKeyMap> {
|
||||
use ruma::signatures::required_keys;
|
||||
|
||||
let required = match required_keys(object, version) {
|
||||
Ok(required) => required,
|
||||
Err(e) => return Err!(BadServerResponse("Failed to determine keys required to verify: {e}")),
|
||||
| Ok(required) => required,
|
||||
| Err(e) =>
|
||||
return Err!(BadServerResponse("Failed to determine keys required to verify: {e}")),
|
||||
};
|
||||
|
||||
let batch = required
|
||||
|
@ -52,7 +60,11 @@ where
|
|||
}
|
||||
|
||||
#[implement(super::Service)]
|
||||
pub async fn get_verify_key(&self, origin: &ServerName, key_id: &ServerSigningKeyId) -> Result<VerifyKey> {
|
||||
pub async fn get_verify_key(
|
||||
&self,
|
||||
origin: &ServerName,
|
||||
key_id: &ServerSigningKeyId,
|
||||
) -> Result<VerifyKey> {
|
||||
let notary_first = self.services.server.config.query_trusted_key_servers_first;
|
||||
let notary_only = self.services.server.config.only_query_trusted_key_servers;
|
||||
|
||||
|
@ -86,7 +98,11 @@ pub async fn get_verify_key(&self, origin: &ServerName, key_id: &ServerSigningKe
|
|||
}
|
||||
|
||||
#[implement(super::Service)]
|
||||
async fn get_verify_key_from_notaries(&self, origin: &ServerName, key_id: &ServerSigningKeyId) -> Result<VerifyKey> {
|
||||
async fn get_verify_key_from_notaries(
|
||||
&self,
|
||||
origin: &ServerName,
|
||||
key_id: &ServerSigningKeyId,
|
||||
) -> Result<VerifyKey> {
|
||||
for notary in self.services.globals.trusted_servers() {
|
||||
if let Ok(server_keys) = self.notary_request(notary, origin).await {
|
||||
for server_key in server_keys.clone() {
|
||||
|
@ -105,7 +121,11 @@ async fn get_verify_key_from_notaries(&self, origin: &ServerName, key_id: &Serve
|
|||
}
|
||||
|
||||
#[implement(super::Service)]
|
||||
async fn get_verify_key_from_origin(&self, origin: &ServerName, key_id: &ServerSigningKeyId) -> Result<VerifyKey> {
|
||||
async fn get_verify_key_from_origin(
|
||||
&self,
|
||||
origin: &ServerName,
|
||||
key_id: &ServerSigningKeyId,
|
||||
) -> Result<VerifyKey> {
|
||||
if let Ok(server_key) = self.server_request(origin).await {
|
||||
self.add_signing_keys(server_key.clone()).await;
|
||||
if let Some(result) = extract_key(server_key, key_id) {
|
||||
|
|
|
@ -39,14 +39,15 @@ fn load(db: &Arc<Database>) -> Result<Box<Ed25519KeyPair>> {
|
|||
create(db)
|
||||
})?;
|
||||
|
||||
let key =
|
||||
Ed25519KeyPair::from_der(&key, version).map_err(|e| err!("Failed to load ed25519 keypair from der: {e:?}"))?;
|
||||
let key = Ed25519KeyPair::from_der(&key, version)
|
||||
.map_err(|e| err!("Failed to load ed25519 keypair from der: {e:?}"))?;
|
||||
|
||||
Ok(Box::new(key))
|
||||
}
|
||||
|
||||
fn create(db: &Arc<Database>) -> Result<(String, Vec<u8>)> {
|
||||
let keypair = Ed25519KeyPair::generate().map_err(|e| err!("Failed to generate new ed25519 keypair: {e:?}"))?;
|
||||
let keypair = Ed25519KeyPair::generate()
|
||||
.map_err(|e| err!("Failed to generate new ed25519 keypair: {e:?}"))?;
|
||||
|
||||
let id = utils::rand::string(8);
|
||||
debug_info!("Generated new Ed25519 keypair: {id:?}");
|
||||
|
|
|
@ -18,8 +18,8 @@ use ruma::{
|
|||
api::federation::discovery::{ServerSigningKeys, VerifyKey},
|
||||
serde::Raw,
|
||||
signatures::{Ed25519KeyPair, PublicKeyMap, PublicKeySet},
|
||||
CanonicalJsonObject, MilliSecondsSinceUnixEpoch, OwnedServerSigningKeyId, RoomVersionId, ServerName,
|
||||
ServerSigningKeyId,
|
||||
CanonicalJsonObject, MilliSecondsSinceUnixEpoch, OwnedServerSigningKeyId, RoomVersionId,
|
||||
ServerName, ServerSigningKeyId,
|
||||
};
|
||||
use serde_json::value::RawValue as RawJsonValue;
|
||||
|
||||
|
@ -113,7 +113,11 @@ async fn add_signing_keys(&self, new_keys: ServerSigningKeys) {
|
|||
}
|
||||
|
||||
#[implement(Service)]
|
||||
pub async fn required_keys_exist(&self, object: &CanonicalJsonObject, version: &RoomVersionId) -> bool {
|
||||
pub async fn required_keys_exist(
|
||||
&self,
|
||||
object: &CanonicalJsonObject,
|
||||
version: &RoomVersionId,
|
||||
) -> bool {
|
||||
use ruma::signatures::required_keys;
|
||||
|
||||
let Ok(required_keys) = required_keys(object, version) else {
|
||||
|
@ -179,7 +183,8 @@ pub async fn signing_keys_for(&self, origin: &ServerName) -> Result<ServerSignin
|
|||
|
||||
#[implement(Service)]
|
||||
fn minimum_valid_ts(&self) -> MilliSecondsSinceUnixEpoch {
|
||||
let timepoint = timepoint_from_now(self.minimum_valid).expect("SystemTime should not overflow");
|
||||
let timepoint =
|
||||
timepoint_from_now(self.minimum_valid).expect("SystemTime should not overflow");
|
||||
MilliSecondsSinceUnixEpoch::from_system_time(timepoint).expect("UInt should not overflow")
|
||||
}
|
||||
|
||||
|
|
|
@ -12,7 +12,9 @@ use ruma::{
|
|||
|
||||
#[implement(super::Service)]
|
||||
pub(super) async fn batch_notary_request<'a, S, K>(
|
||||
&self, notary: &ServerName, batch: S,
|
||||
&self,
|
||||
notary: &ServerName,
|
||||
batch: S,
|
||||
) -> Result<Vec<ServerSigningKeys>>
|
||||
where
|
||||
S: Iterator<Item = (&'a ServerName, K)> + Send,
|
||||
|
@ -74,7 +76,9 @@ where
|
|||
|
||||
#[implement(super::Service)]
|
||||
pub async fn notary_request(
|
||||
&self, notary: &ServerName, target: &ServerName,
|
||||
&self,
|
||||
notary: &ServerName,
|
||||
target: &ServerName,
|
||||
) -> Result<impl Iterator<Item = ServerSigningKeys> + Clone + Debug + Send> {
|
||||
use get_remote_server_keys::v2::Request;
|
||||
|
||||
|
|
|
@ -10,7 +10,11 @@ pub fn sign_json(&self, object: &mut CanonicalJsonObject) -> Result {
|
|||
}
|
||||
|
||||
#[implement(super::Service)]
|
||||
pub fn hash_and_sign_event(&self, object: &mut CanonicalJsonObject, room_version: &RoomVersionId) -> Result {
|
||||
pub fn hash_and_sign_event(
|
||||
&self,
|
||||
object: &mut CanonicalJsonObject,
|
||||
room_version: &RoomVersionId,
|
||||
) -> Result {
|
||||
use ruma::signatures::hash_and_sign_event;
|
||||
|
||||
let server_name = self.services.globals.server_name().as_str();
|
||||
|
|
|
@ -1,14 +1,20 @@
|
|||
use conduwuit::{implement, pdu::gen_event_id_canonical_json, Err, Result};
|
||||
use ruma::{signatures::Verified, CanonicalJsonObject, CanonicalJsonValue, OwnedEventId, RoomVersionId};
|
||||
use ruma::{
|
||||
signatures::Verified, CanonicalJsonObject, CanonicalJsonValue, OwnedEventId, RoomVersionId,
|
||||
};
|
||||
use serde_json::value::RawValue as RawJsonValue;
|
||||
|
||||
#[implement(super::Service)]
|
||||
pub async fn validate_and_add_event_id(
|
||||
&self, pdu: &RawJsonValue, room_version: &RoomVersionId,
|
||||
&self,
|
||||
pdu: &RawJsonValue,
|
||||
room_version: &RoomVersionId,
|
||||
) -> Result<(OwnedEventId, CanonicalJsonObject)> {
|
||||
let (event_id, mut value) = gen_event_id_canonical_json(pdu, room_version)?;
|
||||
if let Err(e) = self.verify_event(&value, Some(room_version)).await {
|
||||
return Err!(BadServerResponse(debug_error!("Event {event_id} failed verification: {e:?}")));
|
||||
return Err!(BadServerResponse(debug_error!(
|
||||
"Event {event_id} failed verification: {e:?}"
|
||||
)));
|
||||
}
|
||||
|
||||
value.insert("event_id".into(), CanonicalJsonValue::String(event_id.as_str().into()));
|
||||
|
@ -18,7 +24,9 @@ pub async fn validate_and_add_event_id(
|
|||
|
||||
#[implement(super::Service)]
|
||||
pub async fn validate_and_add_event_id_no_fetch(
|
||||
&self, pdu: &RawJsonValue, room_version: &RoomVersionId,
|
||||
&self,
|
||||
pdu: &RawJsonValue,
|
||||
room_version: &RoomVersionId,
|
||||
) -> Result<(OwnedEventId, CanonicalJsonObject)> {
|
||||
let (event_id, mut value) = gen_event_id_canonical_json(pdu, room_version)?;
|
||||
if !self.required_keys_exist(&value, room_version).await {
|
||||
|
@ -28,7 +36,9 @@ pub async fn validate_and_add_event_id_no_fetch(
|
|||
}
|
||||
|
||||
if let Err(e) = self.verify_event(&value, Some(room_version)).await {
|
||||
return Err!(BadServerResponse(debug_error!("Event {event_id} failed verification: {e:?}")));
|
||||
return Err!(BadServerResponse(debug_error!(
|
||||
"Event {event_id} failed verification: {e:?}"
|
||||
)));
|
||||
}
|
||||
|
||||
value.insert("event_id".into(), CanonicalJsonValue::String(event_id.as_str().into()));
|
||||
|
@ -38,7 +48,9 @@ pub async fn validate_and_add_event_id_no_fetch(
|
|||
|
||||
#[implement(super::Service)]
|
||||
pub async fn verify_event(
|
||||
&self, event: &CanonicalJsonObject, room_version: Option<&RoomVersionId>,
|
||||
&self,
|
||||
event: &CanonicalJsonObject,
|
||||
room_version: Option<&RoomVersionId>,
|
||||
) -> Result<Verified> {
|
||||
let room_version = room_version.unwrap_or(&RoomVersionId::V11);
|
||||
let keys = self.get_event_keys(event, room_version).await?;
|
||||
|
@ -46,7 +58,11 @@ pub async fn verify_event(
|
|||
}
|
||||
|
||||
#[implement(super::Service)]
|
||||
pub async fn verify_json(&self, event: &CanonicalJsonObject, room_version: Option<&RoomVersionId>) -> Result {
|
||||
pub async fn verify_json(
|
||||
&self,
|
||||
event: &CanonicalJsonObject,
|
||||
room_version: Option<&RoomVersionId>,
|
||||
) -> Result {
|
||||
let room_version = room_version.unwrap_or(&RoomVersionId::V11);
|
||||
let keys = self.get_event_keys(event, room_version).await?;
|
||||
ruma::signatures::verify_json(&keys, event.clone()).map_err(Into::into)
|
||||
|
|
|
@ -114,7 +114,9 @@ impl<'a> Args<'a> {
|
|||
/// Create a reference immediately to a service when constructing another
|
||||
/// Service. The other service must be constructed.
|
||||
#[inline]
|
||||
pub(crate) fn require<T: Service>(&'a self, name: &str) -> Arc<T> { require::<T>(self.service, name) }
|
||||
pub(crate) fn require<T: Service>(&'a self, name: &str) -> Arc<T> {
|
||||
require::<T>(self.service, name)
|
||||
}
|
||||
}
|
||||
|
||||
/// Reference a Service by name. Panics if the Service does not exist or was
|
||||
|
|
|
@ -47,7 +47,8 @@ struct Services {
|
|||
struct SlidingSyncCache {
|
||||
lists: BTreeMap<String, SyncRequestList>,
|
||||
subscriptions: BTreeMap<OwnedRoomId, sync_events::v4::RoomSubscription>,
|
||||
known_rooms: BTreeMap<String, BTreeMap<OwnedRoomId, u64>>, // For every room, the roomsince number
|
||||
known_rooms: BTreeMap<String, BTreeMap<OwnedRoomId, u64>>, /* For every room, the
|
||||
* roomsince number */
|
||||
extensions: ExtensionsConfig,
|
||||
}
|
||||
|
||||
|
@ -85,14 +86,24 @@ impl crate::Service for Service {
|
|||
}
|
||||
|
||||
impl Service {
|
||||
pub fn remembered(&self, user_id: OwnedUserId, device_id: OwnedDeviceId, conn_id: String) -> bool {
|
||||
pub fn remembered(
|
||||
&self,
|
||||
user_id: OwnedUserId,
|
||||
device_id: OwnedDeviceId,
|
||||
conn_id: String,
|
||||
) -> bool {
|
||||
self.connections
|
||||
.lock()
|
||||
.unwrap()
|
||||
.contains_key(&(user_id, device_id, conn_id))
|
||||
}
|
||||
|
||||
pub fn forget_sync_request_connection(&self, user_id: OwnedUserId, device_id: OwnedDeviceId, conn_id: String) {
|
||||
pub fn forget_sync_request_connection(
|
||||
&self,
|
||||
user_id: OwnedUserId,
|
||||
device_id: OwnedDeviceId,
|
||||
conn_id: String,
|
||||
) {
|
||||
self.connections
|
||||
.lock()
|
||||
.expect("locked")
|
||||
|
@ -100,25 +111,26 @@ impl Service {
|
|||
}
|
||||
|
||||
pub fn update_sync_request_with_cache(
|
||||
&self, user_id: OwnedUserId, device_id: OwnedDeviceId, request: &mut sync_events::v4::Request,
|
||||
&self,
|
||||
user_id: OwnedUserId,
|
||||
device_id: OwnedDeviceId,
|
||||
request: &mut sync_events::v4::Request,
|
||||
) -> BTreeMap<String, BTreeMap<OwnedRoomId, u64>> {
|
||||
let Some(conn_id) = request.conn_id.clone() else {
|
||||
return BTreeMap::new();
|
||||
};
|
||||
|
||||
let mut cache = self.connections.lock().expect("locked");
|
||||
let cached = Arc::clone(
|
||||
cache
|
||||
.entry((user_id, device_id, conn_id))
|
||||
.or_insert_with(|| {
|
||||
Arc::new(Mutex::new(SlidingSyncCache {
|
||||
lists: BTreeMap::new(),
|
||||
subscriptions: BTreeMap::new(),
|
||||
known_rooms: BTreeMap::new(),
|
||||
extensions: ExtensionsConfig::default(),
|
||||
}))
|
||||
}),
|
||||
);
|
||||
let cached = Arc::clone(cache.entry((user_id, device_id, conn_id)).or_insert_with(
|
||||
|| {
|
||||
Arc::new(Mutex::new(SlidingSyncCache {
|
||||
lists: BTreeMap::new(),
|
||||
subscriptions: BTreeMap::new(),
|
||||
known_rooms: BTreeMap::new(),
|
||||
extensions: ExtensionsConfig::default(),
|
||||
}))
|
||||
},
|
||||
));
|
||||
let cached = &mut cached.lock().expect("locked");
|
||||
drop(cache);
|
||||
|
||||
|
@ -141,13 +153,15 @@ impl Service {
|
|||
.clone()
|
||||
.or_else(|| cached_list.include_old_rooms.clone());
|
||||
match (&mut list.filters, cached_list.filters.clone()) {
|
||||
(Some(list_filters), Some(cached_filters)) => {
|
||||
| (Some(list_filters), Some(cached_filters)) => {
|
||||
list_filters.is_dm = list_filters.is_dm.or(cached_filters.is_dm);
|
||||
if list_filters.spaces.is_empty() {
|
||||
list_filters.spaces = cached_filters.spaces;
|
||||
}
|
||||
list_filters.is_encrypted = list_filters.is_encrypted.or(cached_filters.is_encrypted);
|
||||
list_filters.is_invite = list_filters.is_invite.or(cached_filters.is_invite);
|
||||
list_filters.is_encrypted =
|
||||
list_filters.is_encrypted.or(cached_filters.is_encrypted);
|
||||
list_filters.is_invite =
|
||||
list_filters.is_invite.or(cached_filters.is_invite);
|
||||
if list_filters.room_types.is_empty() {
|
||||
list_filters.room_types = cached_filters.room_types;
|
||||
}
|
||||
|
@ -165,9 +179,9 @@ impl Service {
|
|||
list_filters.not_tags = cached_filters.not_tags;
|
||||
}
|
||||
},
|
||||
(_, Some(cached_filters)) => list.filters = Some(cached_filters),
|
||||
(Some(list_filters), _) => list.filters = Some(list_filters.clone()),
|
||||
(..) => {},
|
||||
| (_, Some(cached_filters)) => list.filters = Some(cached_filters),
|
||||
| (Some(list_filters), _) => list.filters = Some(list_filters.clone()),
|
||||
| (..) => {},
|
||||
}
|
||||
if list.bump_event_types.is_empty() {
|
||||
list.bump_event_types
|
||||
|
@ -220,22 +234,23 @@ impl Service {
|
|||
}
|
||||
|
||||
pub fn update_sync_subscriptions(
|
||||
&self, user_id: OwnedUserId, device_id: OwnedDeviceId, conn_id: String,
|
||||
&self,
|
||||
user_id: OwnedUserId,
|
||||
device_id: OwnedDeviceId,
|
||||
conn_id: String,
|
||||
subscriptions: BTreeMap<OwnedRoomId, sync_events::v4::RoomSubscription>,
|
||||
) {
|
||||
let mut cache = self.connections.lock().expect("locked");
|
||||
let cached = Arc::clone(
|
||||
cache
|
||||
.entry((user_id, device_id, conn_id))
|
||||
.or_insert_with(|| {
|
||||
Arc::new(Mutex::new(SlidingSyncCache {
|
||||
lists: BTreeMap::new(),
|
||||
subscriptions: BTreeMap::new(),
|
||||
known_rooms: BTreeMap::new(),
|
||||
extensions: ExtensionsConfig::default(),
|
||||
}))
|
||||
}),
|
||||
);
|
||||
let cached = Arc::clone(cache.entry((user_id, device_id, conn_id)).or_insert_with(
|
||||
|| {
|
||||
Arc::new(Mutex::new(SlidingSyncCache {
|
||||
lists: BTreeMap::new(),
|
||||
subscriptions: BTreeMap::new(),
|
||||
known_rooms: BTreeMap::new(),
|
||||
extensions: ExtensionsConfig::default(),
|
||||
}))
|
||||
},
|
||||
));
|
||||
let cached = &mut cached.lock().expect("locked");
|
||||
drop(cache);
|
||||
|
||||
|
@ -243,22 +258,25 @@ impl Service {
|
|||
}
|
||||
|
||||
pub fn update_sync_known_rooms(
|
||||
&self, user_id: OwnedUserId, device_id: OwnedDeviceId, conn_id: String, list_id: String,
|
||||
new_cached_rooms: BTreeSet<OwnedRoomId>, globalsince: u64,
|
||||
&self,
|
||||
user_id: OwnedUserId,
|
||||
device_id: OwnedDeviceId,
|
||||
conn_id: String,
|
||||
list_id: String,
|
||||
new_cached_rooms: BTreeSet<OwnedRoomId>,
|
||||
globalsince: u64,
|
||||
) {
|
||||
let mut cache = self.connections.lock().expect("locked");
|
||||
let cached = Arc::clone(
|
||||
cache
|
||||
.entry((user_id, device_id, conn_id))
|
||||
.or_insert_with(|| {
|
||||
Arc::new(Mutex::new(SlidingSyncCache {
|
||||
lists: BTreeMap::new(),
|
||||
subscriptions: BTreeMap::new(),
|
||||
known_rooms: BTreeMap::new(),
|
||||
extensions: ExtensionsConfig::default(),
|
||||
}))
|
||||
}),
|
||||
);
|
||||
let cached = Arc::clone(cache.entry((user_id, device_id, conn_id)).or_insert_with(
|
||||
|| {
|
||||
Arc::new(Mutex::new(SlidingSyncCache {
|
||||
lists: BTreeMap::new(),
|
||||
subscriptions: BTreeMap::new(),
|
||||
known_rooms: BTreeMap::new(),
|
||||
extensions: ExtensionsConfig::default(),
|
||||
}))
|
||||
},
|
||||
));
|
||||
let cached = &mut cached.lock().expect("locked");
|
||||
drop(cache);
|
||||
|
||||
|
|
|
@ -25,7 +25,13 @@ impl crate::Service for Service {
|
|||
}
|
||||
|
||||
#[implement(Service)]
|
||||
pub fn add_txnid(&self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, data: &[u8]) {
|
||||
pub fn add_txnid(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: Option<&DeviceId>,
|
||||
txn_id: &TransactionId,
|
||||
data: &[u8],
|
||||
) {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default());
|
||||
|
@ -38,7 +44,10 @@ pub fn add_txnid(&self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id:
|
|||
// If there's no entry, this is a new transaction
|
||||
#[implement(Service)]
|
||||
pub async fn existing_txnid(
|
||||
&self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId,
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: Option<&DeviceId>,
|
||||
txn_id: &TransactionId,
|
||||
) -> Result<Handle<'_>> {
|
||||
let key = (user_id, device_id, txn_id);
|
||||
self.db.userdevicetxnid_response.qry(&key).await
|
||||
|
|
|
@ -58,7 +58,13 @@ impl crate::Service for Service {
|
|||
|
||||
/// Creates a new Uiaa session. Make sure the session token is unique.
|
||||
#[implement(Service)]
|
||||
pub fn create(&self, user_id: &UserId, device_id: &DeviceId, uiaainfo: &UiaaInfo, json_body: &CanonicalJsonValue) {
|
||||
pub fn create(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
uiaainfo: &UiaaInfo,
|
||||
json_body: &CanonicalJsonValue,
|
||||
) {
|
||||
// TODO: better session error handling (why is uiaainfo.session optional in
|
||||
// ruma?)
|
||||
self.set_uiaa_request(
|
||||
|
@ -78,7 +84,11 @@ pub fn create(&self, user_id: &UserId, device_id: &DeviceId, uiaainfo: &UiaaInfo
|
|||
|
||||
#[implement(Service)]
|
||||
pub async fn try_auth(
|
||||
&self, user_id: &UserId, device_id: &DeviceId, auth: &AuthData, uiaainfo: &UiaaInfo,
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
auth: &AuthData,
|
||||
uiaainfo: &UiaaInfo,
|
||||
) -> Result<(bool, UiaaInfo)> {
|
||||
let mut uiaainfo = if let Some(session) = auth.session() {
|
||||
self.get_uiaa_session(user_id, device_id, session).await?
|
||||
|
@ -92,7 +102,7 @@ pub async fn try_auth(
|
|||
|
||||
match auth {
|
||||
// Find out what the user completed
|
||||
AuthData::Password(Password {
|
||||
| AuthData::Password(Password {
|
||||
identifier,
|
||||
password,
|
||||
#[cfg(feature = "element_hacks")]
|
||||
|
@ -105,17 +115,26 @@ pub async fn try_auth(
|
|||
} else if let Some(username) = user {
|
||||
username
|
||||
} else {
|
||||
return Err(Error::BadRequest(ErrorKind::Unrecognized, "Identifier type not recognized."));
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Unrecognized,
|
||||
"Identifier type not recognized.",
|
||||
));
|
||||
};
|
||||
|
||||
#[cfg(not(feature = "element_hacks"))]
|
||||
let Some(UserIdentifier::UserIdOrLocalpart(username)) = identifier
|
||||
else {
|
||||
return Err(Error::BadRequest(ErrorKind::Unrecognized, "Identifier type not recognized."));
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Unrecognized,
|
||||
"Identifier type not recognized.",
|
||||
));
|
||||
};
|
||||
|
||||
let user_id = UserId::parse_with_server_name(username.clone(), self.services.globals.server_name())
|
||||
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "User ID is invalid."))?;
|
||||
let user_id = UserId::parse_with_server_name(
|
||||
username.clone(),
|
||||
self.services.globals.server_name(),
|
||||
)
|
||||
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "User ID is invalid."))?;
|
||||
|
||||
// Check if password is correct
|
||||
if let Ok(hash) = self.services.users.password_hash(&user_id).await {
|
||||
|
@ -132,7 +151,7 @@ pub async fn try_auth(
|
|||
// Password was correct! Let's add it to `completed`
|
||||
uiaainfo.completed.push(AuthType::Password);
|
||||
},
|
||||
AuthData::RegistrationToken(t) => {
|
||||
| AuthData::RegistrationToken(t) => {
|
||||
if self
|
||||
.services
|
||||
.globals
|
||||
|
@ -149,10 +168,10 @@ pub async fn try_auth(
|
|||
return Ok((false, uiaainfo));
|
||||
}
|
||||
},
|
||||
AuthData::Dummy(_) => {
|
||||
| AuthData::Dummy(_) => {
|
||||
uiaainfo.completed.push(AuthType::Dummy);
|
||||
},
|
||||
k => error!("type not supported: {:?}", k),
|
||||
| k => error!("type not supported: {:?}", k),
|
||||
}
|
||||
|
||||
// Check if a flow now succeeds
|
||||
|
@ -190,7 +209,13 @@ pub async fn try_auth(
|
|||
}
|
||||
|
||||
#[implement(Service)]
|
||||
fn set_uiaa_request(&self, user_id: &UserId, device_id: &DeviceId, session: &str, request: &CanonicalJsonValue) {
|
||||
fn set_uiaa_request(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
session: &str,
|
||||
request: &CanonicalJsonValue,
|
||||
) {
|
||||
let key = (user_id.to_owned(), device_id.to_owned(), session.to_owned());
|
||||
self.userdevicesessionid_uiaarequest
|
||||
.write()
|
||||
|
@ -200,7 +225,10 @@ fn set_uiaa_request(&self, user_id: &UserId, device_id: &DeviceId, session: &str
|
|||
|
||||
#[implement(Service)]
|
||||
pub fn get_uiaa_request(
|
||||
&self, user_id: &UserId, device_id: Option<&DeviceId>, session: &str,
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: Option<&DeviceId>,
|
||||
session: &str,
|
||||
) -> Option<CanonicalJsonValue> {
|
||||
let key = (
|
||||
user_id.to_owned(),
|
||||
|
@ -216,7 +244,13 @@ pub fn get_uiaa_request(
|
|||
}
|
||||
|
||||
#[implement(Service)]
|
||||
fn update_uiaa_session(&self, user_id: &UserId, device_id: &DeviceId, session: &str, uiaainfo: Option<&UiaaInfo>) {
|
||||
fn update_uiaa_session(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
session: &str,
|
||||
uiaainfo: Option<&UiaaInfo>,
|
||||
) {
|
||||
let key = (user_id, device_id, session);
|
||||
|
||||
if let Some(uiaainfo) = uiaainfo {
|
||||
|
@ -229,7 +263,12 @@ fn update_uiaa_session(&self, user_id: &UserId, device_id: &DeviceId, session: &
|
|||
}
|
||||
|
||||
#[implement(Service)]
|
||||
async fn get_uiaa_session(&self, user_id: &UserId, device_id: &DeviceId, session: &str) -> Result<UiaaInfo> {
|
||||
async fn get_uiaa_session(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
session: &str,
|
||||
) -> Result<UiaaInfo> {
|
||||
let key = (user_id, device_id, session);
|
||||
self.db
|
||||
.userdevicesessionid_uiaainfo
|
||||
|
|
|
@ -119,7 +119,8 @@ impl Service {
|
|||
self.services
|
||||
.admin
|
||||
.send_message(RoomMessageEventContent::text_markdown(format!(
|
||||
"### the following is a message from the conduwuit puppy\n\nit was sent on `{}`:\n\n@room: {}",
|
||||
"### the following is a message from the conduwuit puppy\n\nit was sent on \
|
||||
`{}`:\n\n@room: {}",
|
||||
update.date, update.message
|
||||
)))
|
||||
.await
|
||||
|
@ -127,7 +128,9 @@ impl Service {
|
|||
}
|
||||
|
||||
#[inline]
|
||||
pub fn update_check_for_updates_id(&self, id: u64) { self.db.raw_put(LAST_CHECK_FOR_UPDATES_COUNT, id); }
|
||||
pub fn update_check_for_updates_id(&self, id: u64) {
|
||||
self.db.raw_put(LAST_CHECK_FOR_UPDATES_COUNT, id);
|
||||
}
|
||||
|
||||
pub async fn last_check_for_updates_id(&self) -> u64 {
|
||||
self.db
|
||||
|
|
|
@ -10,10 +10,12 @@ use futures::{FutureExt, Stream, StreamExt, TryFutureExt};
|
|||
use ruma::{
|
||||
api::client::{device::Device, error::ErrorKind, filter::FilterDefinition},
|
||||
encryption::{CrossSigningKey, DeviceKeys, OneTimeKey},
|
||||
events::{ignored_user_list::IgnoredUserListEvent, AnyToDeviceEvent, GlobalAccountDataEventType},
|
||||
events::{
|
||||
ignored_user_list::IgnoredUserListEvent, AnyToDeviceEvent, GlobalAccountDataEventType,
|
||||
},
|
||||
serde::Raw,
|
||||
DeviceId, KeyId, MilliSecondsSinceUnixEpoch, OneTimeKeyAlgorithm, OneTimeKeyId, OneTimeKeyName, OwnedDeviceId,
|
||||
OwnedKeyId, OwnedMxcUri, OwnedUserId, RoomId, UInt, UserId,
|
||||
DeviceId, KeyId, MilliSecondsSinceUnixEpoch, OneTimeKeyAlgorithm, OneTimeKeyId,
|
||||
OneTimeKeyName, OwnedDeviceId, OwnedKeyId, OwnedMxcUri, OwnedUserId, RoomId, UInt, UserId,
|
||||
};
|
||||
use serde_json::json;
|
||||
|
||||
|
@ -65,7 +67,8 @@ impl crate::Service for Service {
|
|||
account_data: args.depend::<account_data::Service>("account_data"),
|
||||
admin: args.depend::<admin::Service>("admin"),
|
||||
globals: args.depend::<globals::Service>("globals"),
|
||||
state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
|
||||
state_accessor: args
|
||||
.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
|
||||
state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"),
|
||||
},
|
||||
db: Data {
|
||||
|
@ -114,7 +117,9 @@ impl Service {
|
|||
|
||||
/// Check if a user is an admin
|
||||
#[inline]
|
||||
pub async fn is_admin(&self, user_id: &UserId) -> bool { self.services.admin.user_is_admin(user_id).await }
|
||||
pub async fn is_admin(&self, user_id: &UserId) -> bool {
|
||||
self.services.admin.user_is_admin(user_id).await
|
||||
}
|
||||
|
||||
/// Create a new user account on this homeserver.
|
||||
#[inline]
|
||||
|
@ -141,7 +146,9 @@ impl Service {
|
|||
|
||||
/// Check if a user has an account on this homeserver.
|
||||
#[inline]
|
||||
pub async fn exists(&self, user_id: &UserId) -> bool { self.db.userid_password.get(user_id).await.is_ok() }
|
||||
pub async fn exists(&self, user_id: &UserId) -> bool {
|
||||
self.db.userid_password.get(user_id).await.is_ok()
|
||||
}
|
||||
|
||||
/// Check if account is deactivated
|
||||
pub async fn is_deactivated(&self, user_id: &UserId) -> Result<bool> {
|
||||
|
@ -154,7 +161,9 @@ impl Service {
|
|||
}
|
||||
|
||||
/// Check if account is active, infallible
|
||||
pub async fn is_active(&self, user_id: &UserId) -> bool { !self.is_deactivated(user_id).await.unwrap_or(true) }
|
||||
pub async fn is_active(&self, user_id: &UserId) -> bool {
|
||||
!self.is_deactivated(user_id).await.unwrap_or(true)
|
||||
}
|
||||
|
||||
/// Check if account is active, infallible
|
||||
pub async fn is_active_local(&self, user_id: &UserId) -> bool {
|
||||
|
@ -173,10 +182,14 @@ impl Service {
|
|||
/// Returns an iterator over all users on this homeserver (offered for
|
||||
/// compatibility)
|
||||
#[allow(clippy::iter_without_into_iter, clippy::iter_not_returning_iterator)]
|
||||
pub fn iter(&self) -> impl Stream<Item = OwnedUserId> + Send + '_ { self.stream().map(ToOwned::to_owned) }
|
||||
pub fn iter(&self) -> impl Stream<Item = OwnedUserId> + Send + '_ {
|
||||
self.stream().map(ToOwned::to_owned)
|
||||
}
|
||||
|
||||
/// Returns an iterator over all users on this homeserver.
|
||||
pub fn stream(&self) -> impl Stream<Item = &UserId> + Send { self.db.userid_password.keys().ignore_err() }
|
||||
pub fn stream(&self) -> impl Stream<Item = &UserId> + Send {
|
||||
self.db.userid_password.keys().ignore_err()
|
||||
}
|
||||
|
||||
/// Returns a list of local users as list of usernames.
|
||||
///
|
||||
|
@ -200,7 +213,9 @@ impl Service {
|
|||
password
|
||||
.map(utils::hash::password)
|
||||
.transpose()
|
||||
.map_err(|e| err!(Request(InvalidParam("Password does not meet the requirements: {e}"))))?
|
||||
.map_err(|e| {
|
||||
err!(Request(InvalidParam("Password does not meet the requirements: {e}")))
|
||||
})?
|
||||
.map_or_else(
|
||||
|| self.db.userid_password.insert(user_id, b""),
|
||||
|hash| self.db.userid_password.insert(user_id, hash),
|
||||
|
@ -254,13 +269,19 @@ impl Service {
|
|||
|
||||
/// Adds a new device to a user.
|
||||
pub async fn create_device(
|
||||
&self, user_id: &UserId, device_id: &DeviceId, token: &str, initial_device_display_name: Option<String>,
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
token: &str,
|
||||
initial_device_display_name: Option<String>,
|
||||
client_ip: Option<String>,
|
||||
) -> Result<()> {
|
||||
// This method should never be called for nonexistent users. We shouldn't assert
|
||||
// though...
|
||||
if !self.exists(user_id).await {
|
||||
return Err!(Request(InvalidParam(error!("Called create_device for non-existent {user_id}"))));
|
||||
return Err!(Request(InvalidParam(error!(
|
||||
"Called create_device for non-existent {user_id}"
|
||||
))));
|
||||
}
|
||||
|
||||
let key = (user_id, device_id);
|
||||
|
@ -304,7 +325,10 @@ impl Service {
|
|||
}
|
||||
|
||||
/// Returns an iterator over all device ids of this user.
|
||||
pub fn all_device_ids<'a>(&'a self, user_id: &'a UserId) -> impl Stream<Item = &DeviceId> + Send + 'a {
|
||||
pub fn all_device_ids<'a>(
|
||||
&'a self,
|
||||
user_id: &'a UserId,
|
||||
) -> impl Stream<Item = &DeviceId> + Send + 'a {
|
||||
let prefix = (user_id, Interfix);
|
||||
self.db
|
||||
.userdeviceid_metadata
|
||||
|
@ -319,7 +343,12 @@ impl Service {
|
|||
}
|
||||
|
||||
/// Replaces the access token of one device.
|
||||
pub async fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> {
|
||||
pub async fn set_token(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
token: &str,
|
||||
) -> Result<()> {
|
||||
let key = (user_id, device_id);
|
||||
// should not be None, but we shouldn't assert either lol...
|
||||
if self.db.userdeviceid_metadata.qry(&key).await.is_err() {
|
||||
|
@ -344,7 +373,10 @@ impl Service {
|
|||
}
|
||||
|
||||
pub async fn add_one_time_key(
|
||||
&self, user_id: &UserId, device_id: &DeviceId, one_time_key_key: &KeyId<OneTimeKeyAlgorithm, OneTimeKeyName>,
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
one_time_key_key: &KeyId<OneTimeKeyAlgorithm, OneTimeKeyName>,
|
||||
one_time_key_value: &Raw<OneTimeKey>,
|
||||
) -> Result {
|
||||
// All devices have metadata
|
||||
|
@ -391,7 +423,10 @@ impl Service {
|
|||
}
|
||||
|
||||
pub async fn take_one_time_key(
|
||||
&self, user_id: &UserId, device_id: &DeviceId, key_algorithm: &OneTimeKeyAlgorithm,
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
key_algorithm: &OneTimeKeyAlgorithm,
|
||||
) -> Result<(OwnedKeyId<OneTimeKeyAlgorithm, OneTimeKeyName>, Raw<OneTimeKey>)> {
|
||||
let count = self.services.globals.next_count()?.to_be_bytes();
|
||||
self.db.userid_lastonetimekeyupdate.insert(user_id, count);
|
||||
|
@ -435,7 +470,9 @@ impl Service {
|
|||
}
|
||||
|
||||
pub async fn count_one_time_keys(
|
||||
&self, user_id: &UserId, device_id: &DeviceId,
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
) -> BTreeMap<OneTimeKeyAlgorithm, UInt> {
|
||||
type KeyVal<'a> = ((Ignore, Ignore, &'a Unquoted), Ignore);
|
||||
|
||||
|
@ -462,7 +499,12 @@ impl Service {
|
|||
algorithm_counts
|
||||
}
|
||||
|
||||
pub async fn add_device_keys(&self, user_id: &UserId, device_id: &DeviceId, device_keys: &Raw<DeviceKeys>) {
|
||||
pub async fn add_device_keys(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
device_keys: &Raw<DeviceKeys>,
|
||||
) {
|
||||
let key = (user_id, device_id);
|
||||
|
||||
self.db.keyid_key.put(key, Json(device_keys));
|
||||
|
@ -470,8 +512,12 @@ impl Service {
|
|||
}
|
||||
|
||||
pub async fn add_cross_signing_keys(
|
||||
&self, user_id: &UserId, master_key: &Raw<CrossSigningKey>, self_signing_key: &Option<Raw<CrossSigningKey>>,
|
||||
user_signing_key: &Option<Raw<CrossSigningKey>>, notify: bool,
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
master_key: &Raw<CrossSigningKey>,
|
||||
self_signing_key: &Option<Raw<CrossSigningKey>>,
|
||||
user_signing_key: &Option<Raw<CrossSigningKey>>,
|
||||
notify: bool,
|
||||
) -> Result<()> {
|
||||
// TODO: Check signatures
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
|
@ -495,9 +541,10 @@ impl Service {
|
|||
.keys
|
||||
.into_values();
|
||||
|
||||
let self_signing_key_id = self_signing_key_ids
|
||||
.next()
|
||||
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Self signing key contained no key."))?;
|
||||
let self_signing_key_id = self_signing_key_ids.next().ok_or(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Self signing key contained no key.",
|
||||
))?;
|
||||
|
||||
if self_signing_key_ids.next().is_some() {
|
||||
return Err(Error::BadRequest(
|
||||
|
@ -531,7 +578,9 @@ impl Service {
|
|||
.ok_or(err!(Request(InvalidParam("User signing key contained no key."))))?;
|
||||
|
||||
if user_signing_key_ids.next().is_some() {
|
||||
return Err!(Request(InvalidParam("User signing key contained more than one key.")));
|
||||
return Err!(Request(InvalidParam(
|
||||
"User signing key contained more than one key."
|
||||
)));
|
||||
}
|
||||
|
||||
let mut user_signing_key_key = prefix;
|
||||
|
@ -554,7 +603,11 @@ impl Service {
|
|||
}
|
||||
|
||||
pub async fn sign_key(
|
||||
&self, target_id: &UserId, key_id: &str, signature: (String, String), sender_id: &UserId,
|
||||
&self,
|
||||
target_id: &UserId,
|
||||
key_id: &str,
|
||||
signature: (String, String),
|
||||
sender_id: &UserId,
|
||||
) -> Result<()> {
|
||||
let key = (target_id, key_id);
|
||||
|
||||
|
@ -590,7 +643,10 @@ impl Service {
|
|||
|
||||
#[inline]
|
||||
pub fn keys_changed<'a>(
|
||||
&'a self, user_id: &'a UserId, from: u64, to: Option<u64>,
|
||||
&'a self,
|
||||
user_id: &'a UserId,
|
||||
from: u64,
|
||||
to: Option<u64>,
|
||||
) -> impl Stream<Item = &UserId> + Send + 'a {
|
||||
self.keys_changed_user_or_room(user_id.as_str(), from, to)
|
||||
.map(|(user_id, ..)| user_id)
|
||||
|
@ -598,13 +654,19 @@ impl Service {
|
|||
|
||||
#[inline]
|
||||
pub fn room_keys_changed<'a>(
|
||||
&'a self, room_id: &'a RoomId, from: u64, to: Option<u64>,
|
||||
&'a self,
|
||||
room_id: &'a RoomId,
|
||||
from: u64,
|
||||
to: Option<u64>,
|
||||
) -> impl Stream<Item = (&UserId, u64)> + Send + 'a {
|
||||
self.keys_changed_user_or_room(room_id.as_str(), from, to)
|
||||
}
|
||||
|
||||
fn keys_changed_user_or_room<'a>(
|
||||
&'a self, user_or_room_id: &'a str, from: u64, to: Option<u64>,
|
||||
&'a self,
|
||||
user_or_room_id: &'a str,
|
||||
from: u64,
|
||||
to: Option<u64>,
|
||||
) -> impl Stream<Item = (&UserId, u64)> + Send + 'a {
|
||||
type KeyVal<'a> = ((&'a str, u64), &'a UserId);
|
||||
|
||||
|
@ -614,7 +676,9 @@ impl Service {
|
|||
.keychangeid_userid
|
||||
.stream_from(&start)
|
||||
.ignore_err()
|
||||
.ready_take_while(move |((prefix, count), _): &KeyVal<'_>| *prefix == user_or_room_id && *count <= to)
|
||||
.ready_take_while(move |((prefix, count), _): &KeyVal<'_>| {
|
||||
*prefix == user_or_room_id && *count <= to
|
||||
})
|
||||
.map(|((_, count), user_id): KeyVal<'_>| (user_id, count))
|
||||
}
|
||||
|
||||
|
@ -636,13 +700,21 @@ impl Service {
|
|||
self.db.keychangeid_userid.put_raw(key, user_id);
|
||||
}
|
||||
|
||||
pub async fn get_device_keys<'a>(&'a self, user_id: &'a UserId, device_id: &DeviceId) -> Result<Raw<DeviceKeys>> {
|
||||
pub async fn get_device_keys<'a>(
|
||||
&'a self,
|
||||
user_id: &'a UserId,
|
||||
device_id: &DeviceId,
|
||||
) -> Result<Raw<DeviceKeys>> {
|
||||
let key_id = (user_id, device_id);
|
||||
self.db.keyid_key.qry(&key_id).await.deserialized()
|
||||
}
|
||||
|
||||
pub async fn get_key<F>(
|
||||
&self, key_id: &[u8], sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &F,
|
||||
&self,
|
||||
key_id: &[u8],
|
||||
sender_user: Option<&UserId>,
|
||||
user_id: &UserId,
|
||||
allowed_signatures: &F,
|
||||
) -> Result<Raw<CrossSigningKey>>
|
||||
where
|
||||
F: Fn(&UserId) -> bool + Send + Sync,
|
||||
|
@ -655,7 +727,10 @@ impl Service {
|
|||
}
|
||||
|
||||
pub async fn get_master_key<F>(
|
||||
&self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &F,
|
||||
&self,
|
||||
sender_user: Option<&UserId>,
|
||||
user_id: &UserId,
|
||||
allowed_signatures: &F,
|
||||
) -> Result<Raw<CrossSigningKey>>
|
||||
where
|
||||
F: Fn(&UserId) -> bool + Send + Sync,
|
||||
|
@ -667,7 +742,10 @@ impl Service {
|
|||
}
|
||||
|
||||
pub async fn get_self_signing_key<F>(
|
||||
&self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &F,
|
||||
&self,
|
||||
sender_user: Option<&UserId>,
|
||||
user_id: &UserId,
|
||||
allowed_signatures: &F,
|
||||
) -> Result<Raw<CrossSigningKey>>
|
||||
where
|
||||
F: Fn(&UserId) -> bool + Send + Sync,
|
||||
|
@ -688,7 +766,11 @@ impl Service {
|
|||
}
|
||||
|
||||
pub async fn add_to_device_event(
|
||||
&self, sender: &UserId, target_user_id: &UserId, target_device_id: &DeviceId, event_type: &str,
|
||||
&self,
|
||||
sender: &UserId,
|
||||
target_user_id: &UserId,
|
||||
target_device_id: &DeviceId,
|
||||
event_type: &str,
|
||||
content: serde_json::Value,
|
||||
) {
|
||||
let count = self.services.globals.next_count().unwrap();
|
||||
|
@ -705,7 +787,9 @@ impl Service {
|
|||
}
|
||||
|
||||
pub fn get_to_device_events<'a>(
|
||||
&'a self, user_id: &'a UserId, device_id: &'a DeviceId,
|
||||
&'a self,
|
||||
user_id: &'a UserId,
|
||||
device_id: &'a DeviceId,
|
||||
) -> impl Stream<Item = Raw<AnyToDeviceEvent>> + Send + 'a {
|
||||
let prefix = (user_id, device_id, Interfix);
|
||||
self.db
|
||||
|
@ -715,7 +799,12 @@ impl Service {
|
|||
.map(|(_, val): (Ignore, Raw<AnyToDeviceEvent>)| val)
|
||||
}
|
||||
|
||||
pub async fn remove_to_device_events(&self, user_id: &UserId, device_id: &DeviceId, until: u64) {
|
||||
pub async fn remove_to_device_events(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
until: u64,
|
||||
) {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xFF);
|
||||
prefix.extend_from_slice(device_id.as_bytes());
|
||||
|
@ -742,7 +831,12 @@ impl Service {
|
|||
.await;
|
||||
}
|
||||
|
||||
pub async fn update_device_metadata(&self, user_id: &UserId, device_id: &DeviceId, device: &Device) -> Result<()> {
|
||||
pub async fn update_device_metadata(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
device: &Device,
|
||||
) -> Result<()> {
|
||||
increment(&self.db.userid_devicelistversion, user_id.as_bytes());
|
||||
|
||||
let key = (user_id, device_id);
|
||||
|
@ -752,7 +846,11 @@ impl Service {
|
|||
}
|
||||
|
||||
/// Get device metadata.
|
||||
pub async fn get_device_metadata(&self, user_id: &UserId, device_id: &DeviceId) -> Result<Device> {
|
||||
pub async fn get_device_metadata(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
) -> Result<Device> {
|
||||
self.db
|
||||
.userdeviceid_metadata
|
||||
.qry(&(user_id, device_id))
|
||||
|
@ -768,7 +866,10 @@ impl Service {
|
|||
.deserialized()
|
||||
}
|
||||
|
||||
pub fn all_devices_metadata<'a>(&'a self, user_id: &'a UserId) -> impl Stream<Item = Device> + Send + 'a {
|
||||
pub fn all_devices_metadata<'a>(
|
||||
&'a self,
|
||||
user_id: &'a UserId,
|
||||
) -> impl Stream<Item = Device> + Send + 'a {
|
||||
let key = (user_id, Interfix);
|
||||
self.db
|
||||
.userdeviceid_metadata
|
||||
|
@ -787,7 +888,11 @@ impl Service {
|
|||
filter_id
|
||||
}
|
||||
|
||||
pub async fn get_filter(&self, user_id: &UserId, filter_id: &str) -> Result<FilterDefinition> {
|
||||
pub async fn get_filter(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
filter_id: &str,
|
||||
) -> Result<FilterDefinition> {
|
||||
let key = (user_id, filter_id);
|
||||
self.db.userfilterid_filter.qry(&key).await.deserialized()
|
||||
}
|
||||
|
@ -817,11 +922,10 @@ impl Service {
|
|||
};
|
||||
|
||||
let (expires_at_bytes, user_bytes) = value.split_at(0_u64.to_be_bytes().len());
|
||||
let expires_at = u64::from_be_bytes(
|
||||
expires_at_bytes
|
||||
.try_into()
|
||||
.map_err(|e| err!(Database("expires_at in openid_userid is invalid u64. {e}")))?,
|
||||
);
|
||||
let expires_at =
|
||||
u64::from_be_bytes(expires_at_bytes.try_into().map_err(|e| {
|
||||
err!(Database("expires_at in openid_userid is invalid u64. {e}"))
|
||||
})?);
|
||||
|
||||
if expires_at < utils::millis_since_unix_epoch() {
|
||||
debug_warn!("OpenID token is expired, removing");
|
||||
|
@ -833,11 +937,16 @@ impl Service {
|
|||
let user_string = utils::string_from_bytes(user_bytes)
|
||||
.map_err(|e| err!(Database("User ID in openid_userid is invalid unicode. {e}")))?;
|
||||
|
||||
UserId::parse(user_string).map_err(|e| err!(Database("User ID in openid_userid is invalid. {e}")))
|
||||
UserId::parse(user_string)
|
||||
.map_err(|e| err!(Database("User ID in openid_userid is invalid. {e}")))
|
||||
}
|
||||
|
||||
/// Gets a specific user profile key
|
||||
pub async fn profile_key(&self, user_id: &UserId, profile_key: &str) -> Result<serde_json::Value> {
|
||||
pub async fn profile_key(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
profile_key: &str,
|
||||
) -> Result<serde_json::Value> {
|
||||
let key = (user_id, profile_key);
|
||||
self.db
|
||||
.useridprofilekey_value
|
||||
|
@ -848,7 +957,8 @@ impl Service {
|
|||
|
||||
/// Gets all the user's profile keys and values in an iterator
|
||||
pub fn all_profile_keys<'a>(
|
||||
&'a self, user_id: &'a UserId,
|
||||
&'a self,
|
||||
user_id: &'a UserId,
|
||||
) -> impl Stream<Item = (String, serde_json::Value)> + 'a + Send {
|
||||
type KeyVal = ((Ignore, String), serde_json::Value);
|
||||
|
||||
|
@ -861,7 +971,12 @@ impl Service {
|
|||
}
|
||||
|
||||
/// Sets a new profile key value, removes the key if value is None
|
||||
pub fn set_profile_key(&self, user_id: &UserId, profile_key: &str, profile_key_value: Option<serde_json::Value>) {
|
||||
pub fn set_profile_key(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
profile_key: &str,
|
||||
profile_key_value: Option<serde_json::Value>,
|
||||
) {
|
||||
// TODO: insert to the stable MSC4175 key when it's stable
|
||||
let key = (user_id, profile_key);
|
||||
|
||||
|
@ -901,7 +1016,10 @@ impl Service {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn parse_master_key(user_id: &UserId, master_key: &Raw<CrossSigningKey>) -> Result<(Vec<u8>, CrossSigningKey)> {
|
||||
pub fn parse_master_key(
|
||||
user_id: &UserId,
|
||||
master_key: &Raw<CrossSigningKey>,
|
||||
) -> Result<(Vec<u8>, CrossSigningKey)> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xFF);
|
||||
|
||||
|
@ -925,7 +1043,10 @@ pub fn parse_master_key(user_id: &UserId, master_key: &Raw<CrossSigningKey>) ->
|
|||
|
||||
/// Ensure that a user only sees signatures from themselves and the target user
|
||||
fn clean_signatures<F>(
|
||||
mut cross_signing_key: serde_json::Value, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &F,
|
||||
mut cross_signing_key: serde_json::Value,
|
||||
sender_user: Option<&UserId>,
|
||||
user_id: &UserId,
|
||||
allowed_signatures: &F,
|
||||
) -> Result<serde_json::Value>
|
||||
where
|
||||
F: Fn(&UserId) -> bool + Send + Sync,
|
||||
|
@ -937,9 +1058,11 @@ where
|
|||
// Don't allocate for the full size of the current signatures, but require
|
||||
// at most one resize if nothing is dropped
|
||||
let new_capacity = signatures.len() / 2;
|
||||
for (user, signature) in mem::replace(signatures, serde_json::Map::with_capacity(new_capacity)) {
|
||||
let sid =
|
||||
<&UserId>::try_from(user.as_str()).map_err(|_| Error::bad_database("Invalid user ID in database."))?;
|
||||
for (user, signature) in
|
||||
mem::replace(signatures, serde_json::Map::with_capacity(new_capacity))
|
||||
{
|
||||
let sid = <&UserId>::try_from(user.as_str())
|
||||
.map_err(|_| Error::bad_database("Invalid user ID in database."))?;
|
||||
if sender_user == Some(user_id) || sid == user_id || allowed_signatures(sid) {
|
||||
signatures.insert(user, signature);
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue