add rustfmt.toml, format entire codebase

Signed-off-by: strawberry <strawberry@puppygock.gay>
This commit is contained in:
strawberry 2024-03-05 19:48:54 -05:00 committed by June
parent 9fd521f041
commit f419c64aca
144 changed files with 25573 additions and 31053 deletions

27
rustfmt.toml Normal file
View file

@ -0,0 +1,27 @@
edition = "2021"
condense_wildcard_suffixes = true
format_code_in_doc_comments = true
format_macro_bodies = true
format_macro_matchers = true
format_strings = true
hex_literal_case = "Upper"
max_width = 120
tab_spaces = 4
array_width = 80
comment_width = 80
wrap_comments = true
fn_params_layout = "Compressed"
fn_call_width = 80
fn_single_line = true
hard_tabs = true
match_block_trailing_comma = true
imports_granularity = "Crate"
normalize_comments = false
reorder_impl_items = true
reorder_imports = true
group_imports = "StdExternalCrate"
newline_style = "Unix"
use_field_init_shorthand = true
use_small_heuristics = "Off"
use_try_shorthand = true

View file

@ -1,114 +1,94 @@
use crate::{services, utils, Error, Result};
use bytes::BytesMut;
use ruma::api::{
appservice::Registration, IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken,
};
use std::{fmt::Debug, mem, time::Duration}; use std::{fmt::Debug, mem, time::Duration};
use bytes::BytesMut;
use ruma::api::{appservice::Registration, IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken};
use tracing::warn; use tracing::warn;
use crate::{services, utils, Error, Result};
/// Sends a request to an appservice /// Sends a request to an appservice
/// ///
/// Only returns None if there is no url specified in the appservice registration file /// Only returns None if there is no url specified in the appservice
pub(crate) async fn send_request<T>( /// registration file
registration: Registration, pub(crate) async fn send_request<T>(registration: Registration, request: T) -> Option<Result<T::IncomingResponse>>
request: T,
) -> Option<Result<T::IncomingResponse>>
where where
T: OutgoingRequest + Debug, T: OutgoingRequest + Debug,
{ {
if let Some(destination) = registration.url { if let Some(destination) = registration.url {
let hs_token = registration.hs_token.as_str(); let hs_token = registration.hs_token.as_str();
let mut http_request = request let mut http_request = request
.try_into_http_request::<BytesMut>( .try_into_http_request::<BytesMut>(
&destination, &destination,
SendAccessToken::IfRequired(hs_token), SendAccessToken::IfRequired(hs_token),
&[MatrixVersion::V1_0], &[MatrixVersion::V1_0],
) )
.map_err(|e| { .map_err(|e| {
warn!("Failed to find destination {}: {}", destination, e); warn!("Failed to find destination {}: {}", destination, e);
Error::BadServerResponse("Invalid destination") Error::BadServerResponse("Invalid destination")
}) })
.unwrap() .unwrap()
.map(bytes::BytesMut::freeze); .map(bytes::BytesMut::freeze);
let mut parts = http_request.uri().clone().into_parts(); let mut parts = http_request.uri().clone().into_parts();
let old_path_and_query = parts.path_and_query.unwrap().as_str().to_owned(); let old_path_and_query = parts.path_and_query.unwrap().as_str().to_owned();
let symbol = if old_path_and_query.contains('?') { let symbol = if old_path_and_query.contains('?') {
"&" "&"
} else { } else {
"?" "?"
}; };
parts.path_and_query = Some( parts.path_and_query = Some((old_path_and_query + symbol + "access_token=" + hs_token).parse().unwrap());
(old_path_and_query + symbol + "access_token=" + hs_token) *http_request.uri_mut() = parts.try_into().expect("our manipulation is always valid");
.parse()
.unwrap(),
);
*http_request.uri_mut() = parts.try_into().expect("our manipulation is always valid");
let mut reqwest_request = reqwest::Request::try_from(http_request) let mut reqwest_request =
.expect("all http requests are valid reqwest requests"); reqwest::Request::try_from(http_request).expect("all http requests are valid reqwest requests");
*reqwest_request.timeout_mut() = Some(Duration::from_secs(120)); *reqwest_request.timeout_mut() = Some(Duration::from_secs(120));
let url = reqwest_request.url().clone(); let url = reqwest_request.url().clone();
let mut response = match services() let mut response = match services().globals.default_client().execute(reqwest_request).await {
.globals Ok(r) => r,
.default_client() Err(e) => {
.execute(reqwest_request) warn!(
.await "Could not send request to appservice {} at {}: {}",
{ registration.id, destination, e
Ok(r) => r, );
Err(e) => { return Some(Err(e.into()));
warn!( },
"Could not send request to appservice {} at {}: {}", };
registration.id, destination, e
);
return Some(Err(e.into()));
}
};
// reqwest::Response -> http::Response conversion // reqwest::Response -> http::Response conversion
let status = response.status(); let status = response.status();
let mut http_response_builder = http::Response::builder() let mut http_response_builder = http::Response::builder().status(status).version(response.version());
.status(status) mem::swap(
.version(response.version()); response.headers_mut(),
mem::swap( http_response_builder.headers_mut().expect("http::response::Builder is usable"),
response.headers_mut(), );
http_response_builder
.headers_mut()
.expect("http::response::Builder is usable"),
);
let body = response.bytes().await.unwrap_or_else(|e| { let body = response.bytes().await.unwrap_or_else(|e| {
warn!("server error: {}", e); warn!("server error: {}", e);
Vec::new().into() Vec::new().into()
}); // TODO: handle timeout }); // TODO: handle timeout
if !status.is_success() { if !status.is_success() {
warn!( warn!(
"Appservice returned bad response {} {}\n{}\n{:?}", "Appservice returned bad response {} {}\n{}\n{:?}",
destination, destination,
status, status,
url, url,
utils::string_from_bytes(&body) utils::string_from_bytes(&body)
); );
} }
let response = T::IncomingResponse::try_from_http_response( let response = T::IncomingResponse::try_from_http_response(
http_response_builder http_response_builder.body(body).expect("reqwest body is valid http body"),
.body(body) );
.expect("reqwest body is valid http body"), Some(response.map_err(|_| {
); warn!("Appservice returned invalid response bytes {}\n{}", destination, url);
Some(response.map_err(|_| { Error::BadServerResponse("Server returned bad response.")
warn!( }))
"Appservice returned invalid response bytes {}\n{}", } else {
destination, url None
); }
Error::BadServerResponse("Server returned bad response.")
}))
} else {
None
}
} }

View file

@ -1,21 +1,21 @@
use super::{DEVICE_ID_LENGTH, SESSION_ID_LENGTH, TOKEN_LENGTH}; use register::RegistrationKind;
use crate::{api::client_server, services, utils, Error, Result, Ruma};
use ruma::{ use ruma::{
api::client::{ api::client::{
account::{ account::{
change_password, deactivate, get_3pids, get_username_availability, register, change_password, deactivate, get_3pids, get_username_availability, register,
request_3pid_management_token_via_email, request_3pid_management_token_via_msisdn, request_3pid_management_token_via_email, request_3pid_management_token_via_msisdn, whoami,
whoami, ThirdPartyIdRemovalStatus, ThirdPartyIdRemovalStatus,
}, },
error::ErrorKind, error::ErrorKind,
uiaa::{AuthFlow, AuthType, UiaaInfo}, uiaa::{AuthFlow, AuthType, UiaaInfo},
}, },
events::{room::message::RoomMessageEventContent, GlobalAccountDataEventType}, events::{room::message::RoomMessageEventContent, GlobalAccountDataEventType},
push, UserId, push, UserId,
}; };
use tracing::{info, warn}; use tracing::{info, warn};
use register::RegistrationKind; use super::{DEVICE_ID_LENGTH, SESSION_ID_LENGTH, TOKEN_LENGTH};
use crate::{api::client_server, services, utils, Error, Result, Ruma};
const RANDOM_USER_ID_LENGTH: usize = 10; const RANDOM_USER_ID_LENGTH: usize = 10;
@ -28,303 +28,266 @@ const RANDOM_USER_ID_LENGTH: usize = 10;
/// - The server name of the user id matches this server /// - The server name of the user id matches this server
/// - No user or appservice on this server already claimed this username /// - No user or appservice on this server already claimed this username
/// ///
/// Note: This will not reserve the username, so the username might become invalid when trying to register /// Note: This will not reserve the username, so the username might become
/// invalid when trying to register
pub async fn get_register_available_route( pub async fn get_register_available_route(
body: Ruma<get_username_availability::v3::Request>, body: Ruma<get_username_availability::v3::Request>,
) -> Result<get_username_availability::v3::Response> { ) -> Result<get_username_availability::v3::Response> {
// Validate user id // Validate user id
let user_id = UserId::parse_with_server_name( let user_id = UserId::parse_with_server_name(body.username.to_lowercase(), services().globals.server_name())
body.username.to_lowercase(), .ok()
services().globals.server_name(), .filter(|user_id| !user_id.is_historical() && user_id.server_name() == services().globals.server_name())
) .ok_or(Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?;
.ok()
.filter(|user_id| {
!user_id.is_historical() && user_id.server_name() == services().globals.server_name()
})
.ok_or(Error::BadRequest(
ErrorKind::InvalidUsername,
"Username is invalid.",
))?;
// Check if username is creative enough // Check if username is creative enough
if services().users.exists(&user_id)? { if services().users.exists(&user_id)? {
return Err(Error::BadRequest( return Err(Error::BadRequest(ErrorKind::UserInUse, "Desired user ID is already taken."));
ErrorKind::UserInUse, }
"Desired user ID is already taken.",
));
}
if services() if services().globals.forbidden_usernames().is_match(user_id.localpart()) {
.globals return Err(Error::BadRequest(ErrorKind::Unknown, "Username is forbidden."));
.forbidden_usernames() }
.is_match(user_id.localpart())
{
return Err(Error::BadRequest(
ErrorKind::Unknown,
"Username is forbidden.",
));
}
// TODO add check for appservice namespaces // TODO add check for appservice namespaces
// If no if check is true we have an username that's available to be used. // If no if check is true we have an username that's available to be used.
Ok(get_username_availability::v3::Response { available: true }) Ok(get_username_availability::v3::Response {
available: true,
})
} }
/// # `POST /_matrix/client/v3/register` /// # `POST /_matrix/client/v3/register`
/// ///
/// Register an account on this homeserver. /// Register an account on this homeserver.
/// ///
/// You can use [`GET /_matrix/client/v3/register/available`](fn.get_register_available_route.html) /// You can use [`GET
/// to check if the user id is valid and available. /// /_matrix/client/v3/register/available`](fn.get_register_available_route.
/// html) to check if the user id is valid and available.
/// ///
/// - Only works if registration is enabled /// - Only works if registration is enabled
/// - If type is guest: ignores all parameters except initial_device_display_name /// - If type is guest: ignores all parameters except
/// initial_device_display_name
/// - If sender is not appservice: Requires UIAA (but we only use a dummy stage) /// - If sender is not appservice: Requires UIAA (but we only use a dummy stage)
/// - If type is not guest and no username is given: Always fails after UIAA check /// - If type is not guest and no username is given: Always fails after UIAA
/// check
/// - Creates a new account and populates it with default account data /// - Creates a new account and populates it with default account data
/// - If `inhibit_login` is false: Creates a device and returns device id and access_token /// - If `inhibit_login` is false: Creates a device and returns device id and
/// access_token
pub async fn register_route(body: Ruma<register::v3::Request>) -> Result<register::v3::Response> { pub async fn register_route(body: Ruma<register::v3::Request>) -> Result<register::v3::Response> {
if !services().globals.allow_registration() && !body.from_appservice { if !services().globals.allow_registration() && !body.from_appservice {
info!("Registration disabled and request not from known appservice, rejecting registration attempt for username {:?}", body.username); info!(
return Err(Error::BadRequest( "Registration disabled and request not from known appservice, rejecting registration attempt for username \
ErrorKind::Forbidden, {:?}",
"Registration has been disabled.", body.username
)); );
} return Err(Error::BadRequest(ErrorKind::Forbidden, "Registration has been disabled."));
}
let is_guest = body.kind == RegistrationKind::Guest; let is_guest = body.kind == RegistrationKind::Guest;
if is_guest if is_guest
&& (!services().globals.allow_guest_registration() && (!services().globals.allow_guest_registration()
|| (services().globals.allow_registration() || (services().globals.allow_registration() && services().globals.config.registration_token.is_some()))
&& services().globals.config.registration_token.is_some())) {
{ info!(
info!("Guest registration disabled / registration enabled with token configured, rejecting guest registration, initial device name: {:?}", body.initial_device_display_name); "Guest registration disabled / registration enabled with token configured, rejecting guest registration, \
return Err(Error::BadRequest( initial device name: {:?}",
ErrorKind::GuestAccessForbidden, body.initial_device_display_name
"Guest registration is disabled.", );
)); return Err(Error::BadRequest(
} ErrorKind::GuestAccessForbidden,
"Guest registration is disabled.",
));
}
// forbid guests from registering if there is not a real admin user yet. give generic user error. // forbid guests from registering if there is not a real admin user yet. give
if is_guest && services().users.count()? < 2 { // generic user error.
warn!("Guest account attempted to register before a real admin user has been registered, rejecting registration. Guest's initial device name: {:?}", body.initial_device_display_name); if is_guest && services().users.count()? < 2 {
return Err(Error::BadRequest( warn!(
ErrorKind::Forbidden, "Guest account attempted to register before a real admin user has been registered, rejecting \
"Registration temporarily disabled.", registration. Guest's initial device name: {:?}",
)); body.initial_device_display_name
} );
return Err(Error::BadRequest(ErrorKind::Forbidden, "Registration temporarily disabled."));
}
let user_id = match (&body.username, is_guest) { let user_id = match (&body.username, is_guest) {
(Some(username), false) => { (Some(username), false) => {
let proposed_user_id = UserId::parse_with_server_name( let proposed_user_id =
username.to_lowercase(), UserId::parse_with_server_name(username.to_lowercase(), services().globals.server_name())
services().globals.server_name(), .ok()
) .filter(|user_id| {
.ok() !user_id.is_historical() && user_id.server_name() == services().globals.server_name()
.filter(|user_id| { })
!user_id.is_historical() .ok_or(Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?;
&& user_id.server_name() == services().globals.server_name()
})
.ok_or(Error::BadRequest(
ErrorKind::InvalidUsername,
"Username is invalid.",
))?;
if services().users.exists(&proposed_user_id)? { if services().users.exists(&proposed_user_id)? {
return Err(Error::BadRequest( return Err(Error::BadRequest(ErrorKind::UserInUse, "Desired user ID is already taken."));
ErrorKind::UserInUse, }
"Desired user ID is already taken.",
));
}
if services() if services().globals.forbidden_usernames().is_match(proposed_user_id.localpart()) {
.globals return Err(Error::BadRequest(ErrorKind::Unknown, "Username is forbidden."));
.forbidden_usernames() }
.is_match(proposed_user_id.localpart())
{
return Err(Error::BadRequest(
ErrorKind::Unknown,
"Username is forbidden.",
));
}
proposed_user_id proposed_user_id
} },
_ => loop { _ => loop {
let proposed_user_id = UserId::parse_with_server_name( let proposed_user_id = UserId::parse_with_server_name(
utils::random_string(RANDOM_USER_ID_LENGTH).to_lowercase(), utils::random_string(RANDOM_USER_ID_LENGTH).to_lowercase(),
services().globals.server_name(), services().globals.server_name(),
) )
.unwrap(); .unwrap();
if !services().users.exists(&proposed_user_id)? { if !services().users.exists(&proposed_user_id)? {
break proposed_user_id; break proposed_user_id;
} }
}, },
}; };
// UIAA // UIAA
let mut uiaainfo; let mut uiaainfo;
let skip_auth; let skip_auth;
if services().globals.config.registration_token.is_some() { if services().globals.config.registration_token.is_some() {
// Registration token required // Registration token required
uiaainfo = UiaaInfo { uiaainfo = UiaaInfo {
flows: vec![AuthFlow { flows: vec![AuthFlow {
stages: vec![AuthType::RegistrationToken], stages: vec![AuthType::RegistrationToken],
}], }],
completed: Vec::new(), completed: Vec::new(),
params: Box::default(), params: Box::default(),
session: None, session: None,
auth_error: None, auth_error: None,
}; };
skip_auth = body.from_appservice; skip_auth = body.from_appservice;
} else { } else {
// No registration token necessary, but clients must still go through the flow // No registration token necessary, but clients must still go through the flow
uiaainfo = UiaaInfo { uiaainfo = UiaaInfo {
flows: vec![AuthFlow { flows: vec![AuthFlow {
stages: vec![AuthType::Dummy], stages: vec![AuthType::Dummy],
}], }],
completed: Vec::new(), completed: Vec::new(),
params: Box::default(), params: Box::default(),
session: None, session: None,
auth_error: None, auth_error: None,
}; };
skip_auth = body.from_appservice || is_guest; skip_auth = body.from_appservice || is_guest;
} }
if !skip_auth { if !skip_auth {
if let Some(auth) = &body.auth { if let Some(auth) = &body.auth {
let (worked, uiaainfo) = services().uiaa.try_auth( let (worked, uiaainfo) = services().uiaa.try_auth(
&UserId::parse_with_server_name("", services().globals.server_name()) &UserId::parse_with_server_name("", services().globals.server_name()).expect("we know this is valid"),
.expect("we know this is valid"), "".into(),
"".into(), auth,
auth, &uiaainfo,
&uiaainfo, )?;
)?; if !worked {
if !worked { return Err(Error::Uiaa(uiaainfo));
return Err(Error::Uiaa(uiaainfo)); }
} // Success!
// Success! } else if let Some(json) = body.json_body {
} else if let Some(json) = body.json_body { uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); services().uiaa.create(
services().uiaa.create( &UserId::parse_with_server_name("", services().globals.server_name()).expect("we know this is valid"),
&UserId::parse_with_server_name("", services().globals.server_name()) "".into(),
.expect("we know this is valid"), &uiaainfo,
"".into(), &json,
&uiaainfo, )?;
&json, return Err(Error::Uiaa(uiaainfo));
)?; } else {
return Err(Error::Uiaa(uiaainfo)); return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
} else { }
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); }
}
}
let password = if is_guest { let password = if is_guest {
None None
} else { } else {
body.password.as_deref() body.password.as_deref()
}; };
// Create user // Create user
services().users.create(&user_id, password)?; services().users.create(&user_id, password)?;
// Default to pretty displayname // Default to pretty displayname
let mut displayname = user_id.localpart().to_owned(); let mut displayname = user_id.localpart().to_owned();
// If `new_user_displayname_suffix` is set, registration will push whatever content is set to the user's display name with a space before it // If `new_user_displayname_suffix` is set, registration will push whatever
if !services().globals.new_user_displayname_suffix().is_empty() { // content is set to the user's display name with a space before it
displayname.push_str(&(" ".to_owned() + services().globals.new_user_displayname_suffix())); if !services().globals.new_user_displayname_suffix().is_empty() {
} displayname.push_str(&(" ".to_owned() + services().globals.new_user_displayname_suffix()));
}
services() services().users.set_displayname(&user_id, Some(displayname.clone())).await?;
.users
.set_displayname(&user_id, Some(displayname.clone()))
.await?;
// Initial account data // Initial account data
services().account_data.update( services().account_data.update(
None, None,
&user_id, &user_id,
GlobalAccountDataEventType::PushRules.to_string().into(), GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(ruma::events::push_rules::PushRulesEvent { &serde_json::to_value(ruma::events::push_rules::PushRulesEvent {
content: ruma::events::push_rules::PushRulesEventContent { content: ruma::events::push_rules::PushRulesEventContent {
global: push::Ruleset::server_default(&user_id), global: push::Ruleset::server_default(&user_id),
}, },
}) })
.expect("to json always works"), .expect("to json always works"),
)?; )?;
// Inhibit login does not work for guests // Inhibit login does not work for guests
if !is_guest && body.inhibit_login { if !is_guest && body.inhibit_login {
return Ok(register::v3::Response { return Ok(register::v3::Response {
access_token: None, access_token: None,
user_id, user_id,
device_id: None, device_id: None,
refresh_token: None, refresh_token: None,
expires_in: None, expires_in: None,
}); });
} }
// Generate new device id if the user didn't specify one // Generate new device id if the user didn't specify one
let device_id = if is_guest { let device_id = if is_guest {
None None
} else { } else {
body.device_id.clone() body.device_id.clone()
} }
.unwrap_or_else(|| utils::random_string(DEVICE_ID_LENGTH).into()); .unwrap_or_else(|| utils::random_string(DEVICE_ID_LENGTH).into());
// Generate new token for the device // Generate new token for the device
let token = utils::random_string(TOKEN_LENGTH); let token = utils::random_string(TOKEN_LENGTH);
// Create device for this account // Create device for this account
services().users.create_device( services().users.create_device(&user_id, &device_id, &token, body.initial_device_display_name.clone())?;
&user_id,
&device_id,
&token,
body.initial_device_display_name.clone(),
)?;
info!("New user \"{}\" registered on this server.", user_id); info!("New user \"{}\" registered on this server.", user_id);
// log in conduit admin channel if a non-guest user registered // log in conduit admin channel if a non-guest user registered
if !body.from_appservice && !is_guest { if !body.from_appservice && !is_guest {
services() services().admin.send_message(RoomMessageEventContent::notice_plain(format!(
.admin "New user \"{user_id}\" registered on this server."
.send_message(RoomMessageEventContent::notice_plain(format!( )));
"New user \"{user_id}\" registered on this server." }
)));
}
// log in conduit admin channel if a guest registered // log in conduit admin channel if a guest registered
if !body.from_appservice && is_guest { if !body.from_appservice && is_guest {
services() services().admin.send_message(RoomMessageEventContent::notice_plain(format!(
.admin "Guest user \"{user_id}\" with device display name `{:?}` registered on this server.",
.send_message(RoomMessageEventContent::notice_plain(format!( body.initial_device_display_name
"Guest user \"{user_id}\" with device display name `{:?}` registered on this server.", )));
body.initial_device_display_name }
)));
}
// If this is the first real user, grant them admin privileges except for guest users // If this is the first real user, grant them admin privileges except for guest
// Note: the server user, @conduit:servername, is generated first // users Note: the server user, @conduit:servername, is generated first
if services().users.count()? == 2 && !is_guest { if services().users.count()? == 2 && !is_guest {
services() services().admin.make_user_admin(&user_id, displayname).await?;
.admin
.make_user_admin(&user_id, displayname)
.await?;
warn!("Granting {} admin privileges as the first user", user_id); warn!("Granting {} admin privileges as the first user", user_id);
} }
Ok(register::v3::Response { Ok(register::v3::Response {
access_token: Some(token), access_token: Some(token),
user_id, user_id,
device_id: Some(device_id), device_id: Some(device_id),
refresh_token: None, refresh_token: None,
expires_in: None, expires_in: None,
}) })
} }
/// # `POST /_matrix/client/r0/account/password` /// # `POST /_matrix/client/r0/account/password`
@ -333,73 +296,65 @@ pub async fn register_route(body: Ruma<register::v3::Request>) -> Result<registe
/// ///
/// - Requires UIAA to verify user password /// - Requires UIAA to verify user password
/// - Changes the password of the sender user /// - Changes the password of the sender user
/// - The password hash is calculated using argon2 with 32 character salt, the plain password is /// - The password hash is calculated using argon2 with 32 character salt, the
/// plain password is
/// not saved /// not saved
/// ///
/// If logout_devices is true it does the following for each device except the sender device: /// If logout_devices is true it does the following for each device except the
/// sender device:
/// - Invalidates access token /// - Invalidates access token
/// - Deletes device metadata (device id, device display name, last seen ip, last seen ts) /// - Deletes device metadata (device id, device display name, last seen ip,
/// last seen ts)
/// - Forgets to-device events /// - Forgets to-device events
/// - Triggers device list updates /// - Triggers device list updates
pub async fn change_password_route( pub async fn change_password_route(body: Ruma<change_password::v3::Request>) -> Result<change_password::v3::Response> {
body: Ruma<change_password::v3::Request>, let sender_user = body.sender_user.as_ref().expect("user is authenticated");
) -> Result<change_password::v3::Response> { let sender_device = body.sender_device.as_ref().expect("user is authenticated");
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
let mut uiaainfo = UiaaInfo { let mut uiaainfo = UiaaInfo {
flows: vec![AuthFlow { flows: vec![AuthFlow {
stages: vec![AuthType::Password], stages: vec![AuthType::Password],
}], }],
completed: Vec::new(), completed: Vec::new(),
params: Box::default(), params: Box::default(),
session: None, session: None,
auth_error: None, auth_error: None,
}; };
if let Some(auth) = &body.auth { if let Some(auth) = &body.auth {
let (worked, uiaainfo) = let (worked, uiaainfo) = services().uiaa.try_auth(sender_user, sender_device, auth, &uiaainfo)?;
services() if !worked {
.uiaa return Err(Error::Uiaa(uiaainfo));
.try_auth(sender_user, sender_device, auth, &uiaainfo)?; }
if !worked { // Success!
return Err(Error::Uiaa(uiaainfo)); } else if let Some(json) = body.json_body {
} uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
// Success! services().uiaa.create(sender_user, sender_device, &uiaainfo, &json)?;
} else if let Some(json) = body.json_body { return Err(Error::Uiaa(uiaainfo));
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); } else {
services() return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
.uiaa }
.create(sender_user, sender_device, &uiaainfo, &json)?;
return Err(Error::Uiaa(uiaainfo));
} else {
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
}
services() services().users.set_password(sender_user, Some(&body.new_password))?;
.users
.set_password(sender_user, Some(&body.new_password))?;
if body.logout_devices { if body.logout_devices {
// Logout all devices except the current one // Logout all devices except the current one
for id in services() for id in services()
.users .users
.all_device_ids(sender_user) .all_device_ids(sender_user)
.filter_map(std::result::Result::ok) .filter_map(std::result::Result::ok)
.filter(|id| id != sender_device) .filter(|id| id != sender_device)
{ {
services().users.remove_device(sender_user, &id)?; services().users.remove_device(sender_user, &id)?;
} }
} }
info!("User {} changed their password.", sender_user); info!("User {} changed their password.", sender_user);
services() services().admin.send_message(RoomMessageEventContent::notice_plain(format!(
.admin "User {sender_user} changed their password."
.send_message(RoomMessageEventContent::notice_plain(format!( )));
"User {sender_user} changed their password."
)));
Ok(change_password::v3::Response {}) Ok(change_password::v3::Response {})
} }
/// # `GET _matrix/client/r0/account/whoami` /// # `GET _matrix/client/r0/account/whoami`
@ -408,14 +363,14 @@ pub async fn change_password_route(
/// ///
/// Note: Also works for Application Services /// Note: Also works for Application Services
pub async fn whoami_route(body: Ruma<whoami::v3::Request>) -> Result<whoami::v3::Response> { pub async fn whoami_route(body: Ruma<whoami::v3::Request>) -> Result<whoami::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let device_id = body.sender_device.clone(); let device_id = body.sender_device.clone();
Ok(whoami::v3::Response { Ok(whoami::v3::Response {
user_id: sender_user.clone(), user_id: sender_user.clone(),
device_id, device_id,
is_guest: services().users.is_deactivated(sender_user)? && !body.from_appservice, is_guest: services().users.is_deactivated(sender_user)? && !body.from_appservice,
}) })
} }
/// # `POST /_matrix/client/r0/account/deactivate` /// # `POST /_matrix/client/r0/account/deactivate`
@ -424,61 +379,53 @@ pub async fn whoami_route(body: Ruma<whoami::v3::Request>) -> Result<whoami::v3:
/// ///
/// - Leaves all rooms and rejects all invitations /// - Leaves all rooms and rejects all invitations
/// - Invalidates all access tokens /// - Invalidates all access tokens
/// - Deletes all device metadata (device id, device display name, last seen ip, last seen ts) /// - Deletes all device metadata (device id, device display name, last seen ip,
/// last seen ts)
/// - Forgets all to-device events /// - Forgets all to-device events
/// - Triggers device list updates /// - Triggers device list updates
/// - Removes ability to log in again /// - Removes ability to log in again
pub async fn deactivate_route( pub async fn deactivate_route(body: Ruma<deactivate::v3::Request>) -> Result<deactivate::v3::Response> {
body: Ruma<deactivate::v3::Request>, let sender_user = body.sender_user.as_ref().expect("user is authenticated");
) -> Result<deactivate::v3::Response> { let sender_device = body.sender_device.as_ref().expect("user is authenticated");
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
let mut uiaainfo = UiaaInfo { let mut uiaainfo = UiaaInfo {
flows: vec![AuthFlow { flows: vec![AuthFlow {
stages: vec![AuthType::Password], stages: vec![AuthType::Password],
}], }],
completed: Vec::new(), completed: Vec::new(),
params: Box::default(), params: Box::default(),
session: None, session: None,
auth_error: None, auth_error: None,
}; };
if let Some(auth) = &body.auth { if let Some(auth) = &body.auth {
let (worked, uiaainfo) = let (worked, uiaainfo) = services().uiaa.try_auth(sender_user, sender_device, auth, &uiaainfo)?;
services() if !worked {
.uiaa return Err(Error::Uiaa(uiaainfo));
.try_auth(sender_user, sender_device, auth, &uiaainfo)?; }
if !worked { // Success!
return Err(Error::Uiaa(uiaainfo)); } else if let Some(json) = body.json_body {
} uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
// Success! services().uiaa.create(sender_user, sender_device, &uiaainfo, &json)?;
} else if let Some(json) = body.json_body { return Err(Error::Uiaa(uiaainfo));
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); } else {
services() return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
.uiaa }
.create(sender_user, sender_device, &uiaainfo, &json)?;
return Err(Error::Uiaa(uiaainfo));
} else {
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
}
// Make the user leave all rooms before deactivation // Make the user leave all rooms before deactivation
client_server::leave_all_rooms(sender_user).await?; client_server::leave_all_rooms(sender_user).await?;
// Remove devices and mark account as deactivated // Remove devices and mark account as deactivated
services().users.deactivate_account(sender_user)?; services().users.deactivate_account(sender_user)?;
info!("User {} deactivated their account.", sender_user); info!("User {} deactivated their account.", sender_user);
services() services().admin.send_message(RoomMessageEventContent::notice_plain(format!(
.admin "User {sender_user} deactivated their account."
.send_message(RoomMessageEventContent::notice_plain(format!( )));
"User {sender_user} deactivated their account."
)));
Ok(deactivate::v3::Response { Ok(deactivate::v3::Response {
id_server_unbind_result: ThirdPartyIdRemovalStatus::NoSupport, id_server_unbind_result: ThirdPartyIdRemovalStatus::NoSupport,
}) })
} }
/// # `GET _matrix/client/v3/account/3pid` /// # `GET _matrix/client/v3/account/3pid`
@ -486,38 +433,40 @@ pub async fn deactivate_route(
/// Get a list of third party identifiers associated with this account. /// Get a list of third party identifiers associated with this account.
/// ///
/// - Currently always returns empty list /// - Currently always returns empty list
pub async fn third_party_route( pub async fn third_party_route(body: Ruma<get_3pids::v3::Request>) -> Result<get_3pids::v3::Response> {
body: Ruma<get_3pids::v3::Request>, let _sender_user = body.sender_user.as_ref().expect("user is authenticated");
) -> Result<get_3pids::v3::Response> {
let _sender_user = body.sender_user.as_ref().expect("user is authenticated");
Ok(get_3pids::v3::Response::new(Vec::new())) Ok(get_3pids::v3::Response::new(Vec::new()))
} }
/// # `POST /_matrix/client/v3/account/3pid/email/requestToken` /// # `POST /_matrix/client/v3/account/3pid/email/requestToken`
/// ///
/// "This API should be used to request validation tokens when adding an email address to an account" /// "This API should be used to request validation tokens when adding an email
/// address to an account"
/// ///
/// - 403 signals that The homeserver does not allow the third party identifier as a contact option. /// - 403 signals that The homeserver does not allow the third party identifier
/// as a contact option.
pub async fn request_3pid_management_token_via_email_route( pub async fn request_3pid_management_token_via_email_route(
_body: Ruma<request_3pid_management_token_via_email::v3::Request>, _body: Ruma<request_3pid_management_token_via_email::v3::Request>,
) -> Result<request_3pid_management_token_via_email::v3::Response> { ) -> Result<request_3pid_management_token_via_email::v3::Response> {
Err(Error::BadRequest( Err(Error::BadRequest(
ErrorKind::ThreepidDenied, ErrorKind::ThreepidDenied,
"Third party identifier is not allowed", "Third party identifier is not allowed",
)) ))
} }
/// # `POST /_matrix/client/v3/account/3pid/msisdn/requestToken` /// # `POST /_matrix/client/v3/account/3pid/msisdn/requestToken`
/// ///
/// "This API should be used to request validation tokens when adding an phone number to an account" /// "This API should be used to request validation tokens when adding an phone
/// number to an account"
/// ///
/// - 403 signals that The homeserver does not allow the third party identifier as a contact option. /// - 403 signals that The homeserver does not allow the third party identifier
/// as a contact option.
pub async fn request_3pid_management_token_via_msisdn_route( pub async fn request_3pid_management_token_via_msisdn_route(
_body: Ruma<request_3pid_management_token_via_msisdn::v3::Request>, _body: Ruma<request_3pid_management_token_via_msisdn::v3::Request>,
) -> Result<request_3pid_management_token_via_msisdn::v3::Response> { ) -> Result<request_3pid_management_token_via_msisdn::v3::Response> {
Err(Error::BadRequest( Err(Error::BadRequest(
ErrorKind::ThreepidDenied, ErrorKind::ThreepidDenied,
"Third party identifier is not allowed", "Third party identifier is not allowed",
)) ))
} }

View file

@ -1,64 +1,43 @@
use crate::{services, Error, Result, Ruma};
use rand::seq::SliceRandom; use rand::seq::SliceRandom;
use regex::Regex; use regex::Regex;
use ruma::{ use ruma::{
api::{ api::{
appservice, appservice,
client::{ client::{
alias::{create_alias, delete_alias, get_alias}, alias::{create_alias, delete_alias, get_alias},
error::ErrorKind, error::ErrorKind,
}, },
federation, federation,
}, },
OwnedRoomAliasId, OwnedServerName, OwnedRoomAliasId, OwnedServerName,
}; };
use crate::{services, Error, Result, Ruma};
/// # `PUT /_matrix/client/v3/directory/room/{roomAlias}` /// # `PUT /_matrix/client/v3/directory/room/{roomAlias}`
/// ///
/// Creates a new room alias on this server. /// Creates a new room alias on this server.
pub async fn create_alias_route( pub async fn create_alias_route(body: Ruma<create_alias::v3::Request>) -> Result<create_alias::v3::Response> {
body: Ruma<create_alias::v3::Request>, if body.room_alias.server_name() != services().globals.server_name() {
) -> Result<create_alias::v3::Response> { return Err(Error::BadRequest(ErrorKind::InvalidParam, "Alias is from another server."));
if body.room_alias.server_name() != services().globals.server_name() { }
return Err(Error::BadRequest(
ErrorKind::InvalidParam,
"Alias is from another server.",
));
}
if services() if services().globals.forbidden_room_names().is_match(body.room_alias.alias()) {
.globals return Err(Error::BadRequest(ErrorKind::Unknown, "Room alias is forbidden."));
.forbidden_room_names() }
.is_match(body.room_alias.alias())
{
return Err(Error::BadRequest(
ErrorKind::Unknown,
"Room alias is forbidden.",
));
}
if services() if services().rooms.alias.resolve_local_alias(&body.room_alias)?.is_some() {
.rooms return Err(Error::Conflict("Alias already exists."));
.alias }
.resolve_local_alias(&body.room_alias)?
.is_some()
{
return Err(Error::Conflict("Alias already exists."));
}
if services() if services().rooms.alias.set_alias(&body.room_alias, &body.room_id).is_err() {
.rooms return Err(Error::BadRequest(
.alias ErrorKind::InvalidParam,
.set_alias(&body.room_alias, &body.room_id) "Invalid room alias. Alias must be in the form of '#localpart:server_name'",
.is_err() ));
{ };
return Err(Error::BadRequest(
ErrorKind::InvalidParam,
"Invalid room alias. Alias must be in the form of '#localpart:server_name'",
));
};
Ok(create_alias::v3::Response::new()) Ok(create_alias::v3::Response::new())
} }
/// # `DELETE /_matrix/client/v3/directory/room/{roomAlias}` /// # `DELETE /_matrix/client/v3/directory/room/{roomAlias}`
@ -67,183 +46,137 @@ pub async fn create_alias_route(
/// ///
/// - TODO: additional access control checks /// - TODO: additional access control checks
/// - TODO: Update canonical alias event /// - TODO: Update canonical alias event
pub async fn delete_alias_route( pub async fn delete_alias_route(body: Ruma<delete_alias::v3::Request>) -> Result<delete_alias::v3::Response> {
body: Ruma<delete_alias::v3::Request>, if body.room_alias.server_name() != services().globals.server_name() {
) -> Result<delete_alias::v3::Response> { return Err(Error::BadRequest(ErrorKind::InvalidParam, "Alias is from another server."));
if body.room_alias.server_name() != services().globals.server_name() { }
return Err(Error::BadRequest(
ErrorKind::InvalidParam,
"Alias is from another server.",
));
}
if services() if services().rooms.alias.resolve_local_alias(&body.room_alias)?.is_none() {
.rooms return Err(Error::BadRequest(ErrorKind::NotFound, "Alias does not exist."));
.alias }
.resolve_local_alias(&body.room_alias)?
.is_none()
{
return Err(Error::BadRequest(
ErrorKind::NotFound,
"Alias does not exist.",
));
}
if services() if services().rooms.alias.remove_alias(&body.room_alias).is_err() {
.rooms return Err(Error::BadRequest(
.alias ErrorKind::InvalidParam,
.remove_alias(&body.room_alias) "Invalid room alias. Alias must be in the form of '#localpart:server_name'",
.is_err() ));
{ };
return Err(Error::BadRequest(
ErrorKind::InvalidParam,
"Invalid room alias. Alias must be in the form of '#localpart:server_name'",
));
};
// TODO: update alt_aliases? // TODO: update alt_aliases?
Ok(delete_alias::v3::Response::new()) Ok(delete_alias::v3::Response::new())
} }
/// # `GET /_matrix/client/v3/directory/room/{roomAlias}` /// # `GET /_matrix/client/v3/directory/room/{roomAlias}`
/// ///
/// Resolve an alias locally or over federation. /// Resolve an alias locally or over federation.
pub async fn get_alias_route( pub async fn get_alias_route(body: Ruma<get_alias::v3::Request>) -> Result<get_alias::v3::Response> {
body: Ruma<get_alias::v3::Request>, get_alias_helper(body.body.room_alias).await
) -> Result<get_alias::v3::Response> {
get_alias_helper(body.body.room_alias).await
} }
pub(crate) async fn get_alias_helper( pub(crate) async fn get_alias_helper(room_alias: OwnedRoomAliasId) -> Result<get_alias::v3::Response> {
room_alias: OwnedRoomAliasId, if room_alias.server_name() != services().globals.server_name() {
) -> Result<get_alias::v3::Response> { let response = services()
if room_alias.server_name() != services().globals.server_name() { .sending
let response = services() .send_federation_request(
.sending room_alias.server_name(),
.send_federation_request( federation::query::get_room_information::v1::Request {
room_alias.server_name(), room_alias: room_alias.clone(),
federation::query::get_room_information::v1::Request { },
room_alias: room_alias.clone(), )
}, .await?;
)
.await?;
let room_id = response.room_id; let room_id = response.room_id;
let mut servers = response.servers; let mut servers = response.servers;
// find active servers in room state cache to suggest // find active servers in room state cache to suggest
for extra_servers in services() for extra_servers in services().rooms.state_cache.room_servers(&room_id).filter_map(std::result::Result::ok) {
.rooms servers.push(extra_servers);
.state_cache }
.room_servers(&room_id)
.filter_map(std::result::Result::ok)
{
servers.push(extra_servers);
}
// insert our server as the very first choice if in list // insert our server as the very first choice if in list
if let Some(server_index) = servers if let Some(server_index) =
.clone() servers.clone().into_iter().position(|server| server == services().globals.server_name())
.into_iter() {
.position(|server| server == services().globals.server_name()) servers.remove(server_index);
{ servers.insert(0, services().globals.server_name().to_owned());
servers.remove(server_index); }
servers.insert(0, services().globals.server_name().to_owned());
}
servers.sort_unstable(); servers.sort_unstable();
servers.dedup(); servers.dedup();
// shuffle list of servers randomly after sort and dedupe // shuffle list of servers randomly after sort and dedupe
servers.shuffle(&mut rand::thread_rng()); servers.shuffle(&mut rand::thread_rng());
return Ok(get_alias::v3::Response::new(room_id, servers)); return Ok(get_alias::v3::Response::new(room_id, servers));
} }
let mut room_id = None; let mut room_id = None;
match services().rooms.alias.resolve_local_alias(&room_alias)? { match services().rooms.alias.resolve_local_alias(&room_alias)? {
Some(r) => room_id = Some(r), Some(r) => room_id = Some(r),
None => { None => {
for (_id, registration) in services().appservice.all()? { for (_id, registration) in services().appservice.all()? {
let aliases = registration let aliases = registration
.namespaces .namespaces
.aliases .aliases
.iter() .iter()
.filter_map(|alias| Regex::new(alias.regex.as_str()).ok()) .filter_map(|alias| Regex::new(alias.regex.as_str()).ok())
.collect::<Vec<_>>(); .collect::<Vec<_>>();
if aliases if aliases.iter().any(|aliases| aliases.is_match(room_alias.as_str()))
.iter() && if let Some(opt_result) = services()
.any(|aliases| aliases.is_match(room_alias.as_str())) .sending
&& if let Some(opt_result) = services() .send_appservice_request(
.sending registration,
.send_appservice_request( appservice::query::query_room_alias::v1::Request {
registration, room_alias: room_alias.clone(),
appservice::query::query_room_alias::v1::Request { },
room_alias: room_alias.clone(), )
}, .await
) {
.await opt_result.is_ok()
{ } else {
opt_result.is_ok() false
} else { } {
false room_id = Some(
} services()
{ .rooms
room_id = Some( .alias
services() .resolve_local_alias(&room_alias)?
.rooms .ok_or_else(|| Error::bad_config("Appservice lied to us. Room does not exist."))?,
.alias );
.resolve_local_alias(&room_alias)? break;
.ok_or_else(|| { }
Error::bad_config("Appservice lied to us. Room does not exist.") }
})?, },
); };
break;
}
}
}
};
let room_id = match room_id { let room_id = match room_id {
Some(room_id) => room_id, Some(room_id) => room_id,
None => { None => return Err(Error::BadRequest(ErrorKind::NotFound, "Room with alias not found.")),
return Err(Error::BadRequest( };
ErrorKind::NotFound,
"Room with alias not found.",
))
}
};
let mut servers: Vec<OwnedServerName> = Vec::new(); let mut servers: Vec<OwnedServerName> = Vec::new();
// find active servers in room state cache to suggest // find active servers in room state cache to suggest
for extra_servers in services() for extra_servers in services().rooms.state_cache.room_servers(&room_id).filter_map(std::result::Result::ok) {
.rooms servers.push(extra_servers);
.state_cache }
.room_servers(&room_id)
.filter_map(std::result::Result::ok)
{
servers.push(extra_servers);
}
// insert our server as the very first choice if in list // insert our server as the very first choice if in list
if let Some(server_index) = servers if let Some(server_index) =
.clone() servers.clone().into_iter().position(|server| server == services().globals.server_name())
.into_iter() {
.position(|server| server == services().globals.server_name()) servers.remove(server_index);
{ servers.insert(0, services().globals.server_name().to_owned());
servers.remove(server_index); }
servers.insert(0, services().globals.server_name().to_owned());
}
servers.sort_unstable(); servers.sort_unstable();
servers.dedup(); servers.dedup();
// shuffle list of servers randomly after sort and dedupe // shuffle list of servers randomly after sort and dedupe
servers.shuffle(&mut rand::thread_rng()); servers.shuffle(&mut rand::thread_rng());
Ok(get_alias::v3::Response::new(room_id, servers)) Ok(get_alias::v3::Response::new(room_id, servers))
} }

View file

@ -1,362 +1,275 @@
use crate::{services, Error, Result, Ruma};
use ruma::api::client::{ use ruma::api::client::{
backup::{ backup::{
add_backup_keys, add_backup_keys_for_room, add_backup_keys_for_session, add_backup_keys, add_backup_keys_for_room, add_backup_keys_for_session, create_backup_version,
create_backup_version, delete_backup_keys, delete_backup_keys_for_room, delete_backup_keys, delete_backup_keys_for_room, delete_backup_keys_for_session, delete_backup_version,
delete_backup_keys_for_session, delete_backup_version, get_backup_info, get_backup_keys, get_backup_info, get_backup_keys, get_backup_keys_for_room, get_backup_keys_for_session,
get_backup_keys_for_room, get_backup_keys_for_session, get_latest_backup_info, get_latest_backup_info, update_backup_version,
update_backup_version, },
}, error::ErrorKind,
error::ErrorKind,
}; };
use crate::{services, Error, Result, Ruma};
/// # `POST /_matrix/client/r0/room_keys/version` /// # `POST /_matrix/client/r0/room_keys/version`
/// ///
/// Creates a new backup. /// Creates a new backup.
pub async fn create_backup_version_route( pub async fn create_backup_version_route(
body: Ruma<create_backup_version::v3::Request>, body: Ruma<create_backup_version::v3::Request>,
) -> Result<create_backup_version::v3::Response> { ) -> Result<create_backup_version::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let version = services() let version = services().key_backups.create_backup(sender_user, &body.algorithm)?;
.key_backups
.create_backup(sender_user, &body.algorithm)?;
Ok(create_backup_version::v3::Response { version }) Ok(create_backup_version::v3::Response {
version,
})
} }
/// # `PUT /_matrix/client/r0/room_keys/version/{version}` /// # `PUT /_matrix/client/r0/room_keys/version/{version}`
/// ///
/// Update information about an existing backup. Only `auth_data` can be modified. /// Update information about an existing backup. Only `auth_data` can be
/// modified.
pub async fn update_backup_version_route( pub async fn update_backup_version_route(
body: Ruma<update_backup_version::v3::Request>, body: Ruma<update_backup_version::v3::Request>,
) -> Result<update_backup_version::v3::Response> { ) -> Result<update_backup_version::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
services() services().key_backups.update_backup(sender_user, &body.version, &body.algorithm)?;
.key_backups
.update_backup(sender_user, &body.version, &body.algorithm)?;
Ok(update_backup_version::v3::Response {}) Ok(update_backup_version::v3::Response {})
} }
/// # `GET /_matrix/client/r0/room_keys/version` /// # `GET /_matrix/client/r0/room_keys/version`
/// ///
/// Get information about the latest backup version. /// Get information about the latest backup version.
pub async fn get_latest_backup_info_route( pub async fn get_latest_backup_info_route(
body: Ruma<get_latest_backup_info::v3::Request>, body: Ruma<get_latest_backup_info::v3::Request>,
) -> Result<get_latest_backup_info::v3::Response> { ) -> Result<get_latest_backup_info::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let (version, algorithm) = services() let (version, algorithm) = services()
.key_backups .key_backups
.get_latest_backup(sender_user)? .get_latest_backup(sender_user)?
.ok_or(Error::BadRequest( .ok_or(Error::BadRequest(ErrorKind::NotFound, "Key backup does not exist."))?;
ErrorKind::NotFound,
"Key backup does not exist.",
))?;
Ok(get_latest_backup_info::v3::Response { Ok(get_latest_backup_info::v3::Response {
algorithm, algorithm,
count: (services().key_backups.count_keys(sender_user, &version)? as u32).into(), count: (services().key_backups.count_keys(sender_user, &version)? as u32).into(),
etag: services().key_backups.get_etag(sender_user, &version)?, etag: services().key_backups.get_etag(sender_user, &version)?,
version, version,
}) })
} }
/// # `GET /_matrix/client/r0/room_keys/version` /// # `GET /_matrix/client/r0/room_keys/version`
/// ///
/// Get information about an existing backup. /// Get information about an existing backup.
pub async fn get_backup_info_route( pub async fn get_backup_info_route(body: Ruma<get_backup_info::v3::Request>) -> Result<get_backup_info::v3::Response> {
body: Ruma<get_backup_info::v3::Request>, let sender_user = body.sender_user.as_ref().expect("user is authenticated");
) -> Result<get_backup_info::v3::Response> { let algorithm = services()
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); .key_backups
let algorithm = services() .get_backup(sender_user, &body.version)?
.key_backups .ok_or(Error::BadRequest(ErrorKind::NotFound, "Key backup does not exist."))?;
.get_backup(sender_user, &body.version)?
.ok_or(Error::BadRequest(
ErrorKind::NotFound,
"Key backup does not exist.",
))?;
Ok(get_backup_info::v3::Response { Ok(get_backup_info::v3::Response {
algorithm, algorithm,
count: (services() count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(),
.key_backups etag: services().key_backups.get_etag(sender_user, &body.version)?,
.count_keys(sender_user, &body.version)? as u32) version: body.version.clone(),
.into(), })
etag: services()
.key_backups
.get_etag(sender_user, &body.version)?,
version: body.version.clone(),
})
} }
/// # `DELETE /_matrix/client/r0/room_keys/version/{version}` /// # `DELETE /_matrix/client/r0/room_keys/version/{version}`
/// ///
/// Delete an existing key backup. /// Delete an existing key backup.
/// ///
/// - Deletes both information about the backup, as well as all key data related to the backup /// - Deletes both information about the backup, as well as all key data related
/// to the backup
pub async fn delete_backup_version_route( pub async fn delete_backup_version_route(
body: Ruma<delete_backup_version::v3::Request>, body: Ruma<delete_backup_version::v3::Request>,
) -> Result<delete_backup_version::v3::Response> { ) -> Result<delete_backup_version::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
services() services().key_backups.delete_backup(sender_user, &body.version)?;
.key_backups
.delete_backup(sender_user, &body.version)?;
Ok(delete_backup_version::v3::Response {}) Ok(delete_backup_version::v3::Response {})
} }
/// # `PUT /_matrix/client/r0/room_keys/keys` /// # `PUT /_matrix/client/r0/room_keys/keys`
/// ///
/// Add the received backup keys to the database. /// Add the received backup keys to the database.
/// ///
/// - Only manipulating the most recently created version of the backup is allowed /// - Only manipulating the most recently created version of the backup is
/// allowed
/// - Adds the keys to the backup /// - Adds the keys to the backup
/// - Returns the new number of keys in this backup and the etag /// - Returns the new number of keys in this backup and the etag
pub async fn add_backup_keys_route( pub async fn add_backup_keys_route(body: Ruma<add_backup_keys::v3::Request>) -> Result<add_backup_keys::v3::Response> {
body: Ruma<add_backup_keys::v3::Request>, let sender_user = body.sender_user.as_ref().expect("user is authenticated");
) -> Result<add_backup_keys::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if Some(&body.version) if Some(&body.version) != services().key_backups.get_latest_backup_version(sender_user)?.as_ref() {
!= services() return Err(Error::BadRequest(
.key_backups ErrorKind::InvalidParam,
.get_latest_backup_version(sender_user)? "You may only manipulate the most recently created version of the backup.",
.as_ref() ));
{ }
return Err(Error::BadRequest(
ErrorKind::InvalidParam,
"You may only manipulate the most recently created version of the backup.",
));
}
for (room_id, room) in &body.rooms { for (room_id, room) in &body.rooms {
for (session_id, key_data) in &room.sessions { for (session_id, key_data) in &room.sessions {
services().key_backups.add_key( services().key_backups.add_key(sender_user, &body.version, room_id, session_id, key_data)?;
sender_user, }
&body.version, }
room_id,
session_id,
key_data,
)?;
}
}
Ok(add_backup_keys::v3::Response { Ok(add_backup_keys::v3::Response {
count: (services() count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(),
.key_backups etag: services().key_backups.get_etag(sender_user, &body.version)?,
.count_keys(sender_user, &body.version)? as u32) })
.into(),
etag: services()
.key_backups
.get_etag(sender_user, &body.version)?,
})
} }
/// # `PUT /_matrix/client/r0/room_keys/keys/{roomId}` /// # `PUT /_matrix/client/r0/room_keys/keys/{roomId}`
/// ///
/// Add the received backup keys to the database. /// Add the received backup keys to the database.
/// ///
/// - Only manipulating the most recently created version of the backup is allowed /// - Only manipulating the most recently created version of the backup is
/// allowed
/// - Adds the keys to the backup /// - Adds the keys to the backup
/// - Returns the new number of keys in this backup and the etag /// - Returns the new number of keys in this backup and the etag
pub async fn add_backup_keys_for_room_route( pub async fn add_backup_keys_for_room_route(
body: Ruma<add_backup_keys_for_room::v3::Request>, body: Ruma<add_backup_keys_for_room::v3::Request>,
) -> Result<add_backup_keys_for_room::v3::Response> { ) -> Result<add_backup_keys_for_room::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if Some(&body.version) if Some(&body.version) != services().key_backups.get_latest_backup_version(sender_user)?.as_ref() {
!= services() return Err(Error::BadRequest(
.key_backups ErrorKind::InvalidParam,
.get_latest_backup_version(sender_user)? "You may only manipulate the most recently created version of the backup.",
.as_ref() ));
{ }
return Err(Error::BadRequest(
ErrorKind::InvalidParam,
"You may only manipulate the most recently created version of the backup.",
));
}
for (session_id, key_data) in &body.sessions { for (session_id, key_data) in &body.sessions {
services().key_backups.add_key( services().key_backups.add_key(sender_user, &body.version, &body.room_id, session_id, key_data)?;
sender_user, }
&body.version,
&body.room_id,
session_id,
key_data,
)?;
}
Ok(add_backup_keys_for_room::v3::Response { Ok(add_backup_keys_for_room::v3::Response {
count: (services() count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(),
.key_backups etag: services().key_backups.get_etag(sender_user, &body.version)?,
.count_keys(sender_user, &body.version)? as u32) })
.into(),
etag: services()
.key_backups
.get_etag(sender_user, &body.version)?,
})
} }
/// # `PUT /_matrix/client/r0/room_keys/keys/{roomId}/{sessionId}` /// # `PUT /_matrix/client/r0/room_keys/keys/{roomId}/{sessionId}`
/// ///
/// Add the received backup key to the database. /// Add the received backup key to the database.
/// ///
/// - Only manipulating the most recently created version of the backup is allowed /// - Only manipulating the most recently created version of the backup is
/// allowed
/// - Adds the keys to the backup /// - Adds the keys to the backup
/// - Returns the new number of keys in this backup and the etag /// - Returns the new number of keys in this backup and the etag
pub async fn add_backup_keys_for_session_route( pub async fn add_backup_keys_for_session_route(
body: Ruma<add_backup_keys_for_session::v3::Request>, body: Ruma<add_backup_keys_for_session::v3::Request>,
) -> Result<add_backup_keys_for_session::v3::Response> { ) -> Result<add_backup_keys_for_session::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if Some(&body.version) if Some(&body.version) != services().key_backups.get_latest_backup_version(sender_user)?.as_ref() {
!= services() return Err(Error::BadRequest(
.key_backups ErrorKind::InvalidParam,
.get_latest_backup_version(sender_user)? "You may only manipulate the most recently created version of the backup.",
.as_ref() ));
{ }
return Err(Error::BadRequest(
ErrorKind::InvalidParam,
"You may only manipulate the most recently created version of the backup.",
));
}
services().key_backups.add_key( services().key_backups.add_key(sender_user, &body.version, &body.room_id, &body.session_id, &body.session_data)?;
sender_user,
&body.version,
&body.room_id,
&body.session_id,
&body.session_data,
)?;
Ok(add_backup_keys_for_session::v3::Response { Ok(add_backup_keys_for_session::v3::Response {
count: (services() count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(),
.key_backups etag: services().key_backups.get_etag(sender_user, &body.version)?,
.count_keys(sender_user, &body.version)? as u32) })
.into(),
etag: services()
.key_backups
.get_etag(sender_user, &body.version)?,
})
} }
/// # `GET /_matrix/client/r0/room_keys/keys` /// # `GET /_matrix/client/r0/room_keys/keys`
/// ///
/// Retrieves all keys from the backup. /// Retrieves all keys from the backup.
pub async fn get_backup_keys_route( pub async fn get_backup_keys_route(body: Ruma<get_backup_keys::v3::Request>) -> Result<get_backup_keys::v3::Response> {
body: Ruma<get_backup_keys::v3::Request>, let sender_user = body.sender_user.as_ref().expect("user is authenticated");
) -> Result<get_backup_keys::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let rooms = services().key_backups.get_all(sender_user, &body.version)?; let rooms = services().key_backups.get_all(sender_user, &body.version)?;
Ok(get_backup_keys::v3::Response { rooms }) Ok(get_backup_keys::v3::Response {
rooms,
})
} }
/// # `GET /_matrix/client/r0/room_keys/keys/{roomId}` /// # `GET /_matrix/client/r0/room_keys/keys/{roomId}`
/// ///
/// Retrieves all keys from the backup for a given room. /// Retrieves all keys from the backup for a given room.
pub async fn get_backup_keys_for_room_route( pub async fn get_backup_keys_for_room_route(
body: Ruma<get_backup_keys_for_room::v3::Request>, body: Ruma<get_backup_keys_for_room::v3::Request>,
) -> Result<get_backup_keys_for_room::v3::Response> { ) -> Result<get_backup_keys_for_room::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sessions = services() let sessions = services().key_backups.get_room(sender_user, &body.version, &body.room_id)?;
.key_backups
.get_room(sender_user, &body.version, &body.room_id)?;
Ok(get_backup_keys_for_room::v3::Response { sessions }) Ok(get_backup_keys_for_room::v3::Response {
sessions,
})
} }
/// # `GET /_matrix/client/r0/room_keys/keys/{roomId}/{sessionId}` /// # `GET /_matrix/client/r0/room_keys/keys/{roomId}/{sessionId}`
/// ///
/// Retrieves a key from the backup. /// Retrieves a key from the backup.
pub async fn get_backup_keys_for_session_route( pub async fn get_backup_keys_for_session_route(
body: Ruma<get_backup_keys_for_session::v3::Request>, body: Ruma<get_backup_keys_for_session::v3::Request>,
) -> Result<get_backup_keys_for_session::v3::Response> { ) -> Result<get_backup_keys_for_session::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let key_data = services() let key_data =
.key_backups services().key_backups.get_session(sender_user, &body.version, &body.room_id, &body.session_id)?.ok_or(
.get_session(sender_user, &body.version, &body.room_id, &body.session_id)? Error::BadRequest(ErrorKind::NotFound, "Backup key not found for this user's session."),
.ok_or(Error::BadRequest( )?;
ErrorKind::NotFound,
"Backup key not found for this user's session.",
))?;
Ok(get_backup_keys_for_session::v3::Response { key_data }) Ok(get_backup_keys_for_session::v3::Response {
key_data,
})
} }
/// # `DELETE /_matrix/client/r0/room_keys/keys` /// # `DELETE /_matrix/client/r0/room_keys/keys`
/// ///
/// Delete the keys from the backup. /// Delete the keys from the backup.
pub async fn delete_backup_keys_route( pub async fn delete_backup_keys_route(
body: Ruma<delete_backup_keys::v3::Request>, body: Ruma<delete_backup_keys::v3::Request>,
) -> Result<delete_backup_keys::v3::Response> { ) -> Result<delete_backup_keys::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
services() services().key_backups.delete_all_keys(sender_user, &body.version)?;
.key_backups
.delete_all_keys(sender_user, &body.version)?;
Ok(delete_backup_keys::v3::Response { Ok(delete_backup_keys::v3::Response {
count: (services() count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(),
.key_backups etag: services().key_backups.get_etag(sender_user, &body.version)?,
.count_keys(sender_user, &body.version)? as u32) })
.into(),
etag: services()
.key_backups
.get_etag(sender_user, &body.version)?,
})
} }
/// # `DELETE /_matrix/client/r0/room_keys/keys/{roomId}` /// # `DELETE /_matrix/client/r0/room_keys/keys/{roomId}`
/// ///
/// Delete the keys from the backup for a given room. /// Delete the keys from the backup for a given room.
pub async fn delete_backup_keys_for_room_route( pub async fn delete_backup_keys_for_room_route(
body: Ruma<delete_backup_keys_for_room::v3::Request>, body: Ruma<delete_backup_keys_for_room::v3::Request>,
) -> Result<delete_backup_keys_for_room::v3::Response> { ) -> Result<delete_backup_keys_for_room::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
services() services().key_backups.delete_room_keys(sender_user, &body.version, &body.room_id)?;
.key_backups
.delete_room_keys(sender_user, &body.version, &body.room_id)?;
Ok(delete_backup_keys_for_room::v3::Response { Ok(delete_backup_keys_for_room::v3::Response {
count: (services() count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(),
.key_backups etag: services().key_backups.get_etag(sender_user, &body.version)?,
.count_keys(sender_user, &body.version)? as u32) })
.into(),
etag: services()
.key_backups
.get_etag(sender_user, &body.version)?,
})
} }
/// # `DELETE /_matrix/client/r0/room_keys/keys/{roomId}/{sessionId}` /// # `DELETE /_matrix/client/r0/room_keys/keys/{roomId}/{sessionId}`
/// ///
/// Delete a key from the backup. /// Delete a key from the backup.
pub async fn delete_backup_keys_for_session_route( pub async fn delete_backup_keys_for_session_route(
body: Ruma<delete_backup_keys_for_session::v3::Request>, body: Ruma<delete_backup_keys_for_session::v3::Request>,
) -> Result<delete_backup_keys_for_session::v3::Response> { ) -> Result<delete_backup_keys_for_session::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
services().key_backups.delete_room_key( services().key_backups.delete_room_key(sender_user, &body.version, &body.room_id, &body.session_id)?;
sender_user,
&body.version,
&body.room_id,
&body.session_id,
)?;
Ok(delete_backup_keys_for_session::v3::Response { Ok(delete_backup_keys_for_session::v3::Response {
count: (services() count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(),
.key_backups etag: services().key_backups.get_etag(sender_user, &body.version)?,
.count_keys(sender_user, &body.version)? as u32) })
.into(),
etag: services()
.key_backups
.get_etag(sender_user, &body.version)?,
})
} }

View file

@ -1,28 +1,33 @@
use crate::{services, Result, Ruma};
use ruma::api::client::discovery::get_capabilities::{
self, Capabilities, RoomVersionStability, RoomVersionsCapability,
};
use std::collections::BTreeMap; use std::collections::BTreeMap;
use ruma::api::client::discovery::get_capabilities::{
self, Capabilities, RoomVersionStability, RoomVersionsCapability,
};
use crate::{services, Result, Ruma};
/// # `GET /_matrix/client/r0/capabilities` /// # `GET /_matrix/client/r0/capabilities`
/// ///
/// Get information on the supported feature set and other relevent capabilities of this server. /// Get information on the supported feature set and other relevent capabilities
/// of this server.
pub async fn get_capabilities_route( pub async fn get_capabilities_route(
_body: Ruma<get_capabilities::v3::Request>, _body: Ruma<get_capabilities::v3::Request>,
) -> Result<get_capabilities::v3::Response> { ) -> Result<get_capabilities::v3::Response> {
let mut available = BTreeMap::new(); let mut available = BTreeMap::new();
for room_version in &services().globals.unstable_room_versions { for room_version in &services().globals.unstable_room_versions {
available.insert(room_version.clone(), RoomVersionStability::Unstable); available.insert(room_version.clone(), RoomVersionStability::Unstable);
} }
for room_version in &services().globals.stable_room_versions { for room_version in &services().globals.stable_room_versions {
available.insert(room_version.clone(), RoomVersionStability::Stable); available.insert(room_version.clone(), RoomVersionStability::Stable);
} }
let mut capabilities = Capabilities::new(); let mut capabilities = Capabilities::new();
capabilities.room_versions = RoomVersionsCapability { capabilities.room_versions = RoomVersionsCapability {
default: services().globals.default_room_version(), default: services().globals.default_room_version(),
available, available,
}; };
Ok(get_capabilities::v3::Response { capabilities }) Ok(get_capabilities::v3::Response {
capabilities,
})
} }

View file

@ -1,116 +1,118 @@
use crate::{services, Error, Result, Ruma};
use ruma::{ use ruma::{
api::client::{ api::client::{
config::{ config::{get_global_account_data, get_room_account_data, set_global_account_data, set_room_account_data},
get_global_account_data, get_room_account_data, set_global_account_data, error::ErrorKind,
set_room_account_data, },
}, events::{AnyGlobalAccountDataEventContent, AnyRoomAccountDataEventContent},
error::ErrorKind, serde::Raw,
},
events::{AnyGlobalAccountDataEventContent, AnyRoomAccountDataEventContent},
serde::Raw,
}; };
use serde::Deserialize; use serde::Deserialize;
use serde_json::{json, value::RawValue as RawJsonValue}; use serde_json::{json, value::RawValue as RawJsonValue};
use crate::{services, Error, Result, Ruma};
/// # `PUT /_matrix/client/r0/user/{userId}/account_data/{type}` /// # `PUT /_matrix/client/r0/user/{userId}/account_data/{type}`
/// ///
/// Sets some account data for the sender user. /// Sets some account data for the sender user.
pub async fn set_global_account_data_route( pub async fn set_global_account_data_route(
body: Ruma<set_global_account_data::v3::Request>, body: Ruma<set_global_account_data::v3::Request>,
) -> Result<set_global_account_data::v3::Response> { ) -> Result<set_global_account_data::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let data: serde_json::Value = serde_json::from_str(body.data.json().get()) let data: serde_json::Value = serde_json::from_str(body.data.json().get())
.map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Data is invalid."))?; .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Data is invalid."))?;
let event_type = body.event_type.to_string(); let event_type = body.event_type.to_string();
services().account_data.update( services().account_data.update(
None, None,
sender_user, sender_user,
event_type.clone().into(), event_type.clone().into(),
&json!({ &json!({
"type": event_type, "type": event_type,
"content": data, "content": data,
}), }),
)?; )?;
Ok(set_global_account_data::v3::Response {}) Ok(set_global_account_data::v3::Response {})
} }
/// # `PUT /_matrix/client/r0/user/{userId}/rooms/{roomId}/account_data/{type}` /// # `PUT /_matrix/client/r0/user/{userId}/rooms/{roomId}/account_data/{type}`
/// ///
/// Sets some room account data for the sender user. /// Sets some room account data for the sender user.
pub async fn set_room_account_data_route( pub async fn set_room_account_data_route(
body: Ruma<set_room_account_data::v3::Request>, body: Ruma<set_room_account_data::v3::Request>,
) -> Result<set_room_account_data::v3::Response> { ) -> Result<set_room_account_data::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let data: serde_json::Value = serde_json::from_str(body.data.json().get()) let data: serde_json::Value = serde_json::from_str(body.data.json().get())
.map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Data is invalid."))?; .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Data is invalid."))?;
let event_type = body.event_type.to_string(); let event_type = body.event_type.to_string();
services().account_data.update( services().account_data.update(
Some(&body.room_id), Some(&body.room_id),
sender_user, sender_user,
event_type.clone().into(), event_type.clone().into(),
&json!({ &json!({
"type": event_type, "type": event_type,
"content": data, "content": data,
}), }),
)?; )?;
Ok(set_room_account_data::v3::Response {}) Ok(set_room_account_data::v3::Response {})
} }
/// # `GET /_matrix/client/r0/user/{userId}/account_data/{type}` /// # `GET /_matrix/client/r0/user/{userId}/account_data/{type}`
/// ///
/// Gets some account data for the sender user. /// Gets some account data for the sender user.
pub async fn get_global_account_data_route( pub async fn get_global_account_data_route(
body: Ruma<get_global_account_data::v3::Request>, body: Ruma<get_global_account_data::v3::Request>,
) -> Result<get_global_account_data::v3::Response> { ) -> Result<get_global_account_data::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let event: Box<RawJsonValue> = services() let event: Box<RawJsonValue> = services()
.account_data .account_data
.get(None, sender_user, body.event_type.to_string().into())? .get(None, sender_user, body.event_type.to_string().into())?
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Data not found."))?; .ok_or(Error::BadRequest(ErrorKind::NotFound, "Data not found."))?;
let account_data = serde_json::from_str::<ExtractGlobalEventContent>(event.get()) let account_data = serde_json::from_str::<ExtractGlobalEventContent>(event.get())
.map_err(|_| Error::bad_database("Invalid account data event in db."))? .map_err(|_| Error::bad_database("Invalid account data event in db."))?
.content; .content;
Ok(get_global_account_data::v3::Response { account_data }) Ok(get_global_account_data::v3::Response {
account_data,
})
} }
/// # `GET /_matrix/client/r0/user/{userId}/rooms/{roomId}/account_data/{type}` /// # `GET /_matrix/client/r0/user/{userId}/rooms/{roomId}/account_data/{type}`
/// ///
/// Gets some room account data for the sender user. /// Gets some room account data for the sender user.
pub async fn get_room_account_data_route( pub async fn get_room_account_data_route(
body: Ruma<get_room_account_data::v3::Request>, body: Ruma<get_room_account_data::v3::Request>,
) -> Result<get_room_account_data::v3::Response> { ) -> Result<get_room_account_data::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let event: Box<RawJsonValue> = services() let event: Box<RawJsonValue> = services()
.account_data .account_data
.get(Some(&body.room_id), sender_user, body.event_type.clone())? .get(Some(&body.room_id), sender_user, body.event_type.clone())?
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Data not found."))?; .ok_or(Error::BadRequest(ErrorKind::NotFound, "Data not found."))?;
let account_data = serde_json::from_str::<ExtractRoomEventContent>(event.get()) let account_data = serde_json::from_str::<ExtractRoomEventContent>(event.get())
.map_err(|_| Error::bad_database("Invalid account data event in db."))? .map_err(|_| Error::bad_database("Invalid account data event in db."))?
.content; .content;
Ok(get_room_account_data::v3::Response { account_data }) Ok(get_room_account_data::v3::Response {
account_data,
})
} }
#[derive(Deserialize)] #[derive(Deserialize)]
struct ExtractRoomEventContent { struct ExtractRoomEventContent {
content: Raw<AnyRoomAccountDataEventContent>, content: Raw<AnyRoomAccountDataEventContent>,
} }
#[derive(Deserialize)] #[derive(Deserialize)]
struct ExtractGlobalEventContent { struct ExtractGlobalEventContent {
content: Raw<AnyGlobalAccountDataEventContent>, content: Raw<AnyGlobalAccountDataEventContent>,
} }

View file

@ -1,209 +1,177 @@
use crate::{services, Error, Result, Ruma};
use ruma::{
api::client::{context::get_context, error::ErrorKind, filter::LazyLoadOptions},
events::StateEventType,
};
use std::collections::HashSet; use std::collections::HashSet;
use ruma::{
api::client::{context::get_context, error::ErrorKind, filter::LazyLoadOptions},
events::StateEventType,
};
use tracing::error; use tracing::error;
use crate::{services, Error, Result, Ruma};
/// # `GET /_matrix/client/r0/rooms/{roomId}/context` /// # `GET /_matrix/client/r0/rooms/{roomId}/context`
/// ///
/// Allows loading room history around an event. /// Allows loading room history around an event.
/// ///
/// - Only works if the user is joined (TODO: always allow, but only show events if the user was /// - Only works if the user is joined (TODO: always allow, but only show events
/// if the user was
/// joined, depending on history_visibility) /// joined, depending on history_visibility)
pub async fn get_context_route( pub async fn get_context_route(body: Ruma<get_context::v3::Request>) -> Result<get_context::v3::Response> {
body: Ruma<get_context::v3::Request>, let sender_user = body.sender_user.as_ref().expect("user is authenticated");
) -> Result<get_context::v3::Response> { let sender_device = body.sender_device.as_ref().expect("user is authenticated");
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
let (lazy_load_enabled, lazy_load_send_redundant) = match &body.filter.lazy_load_options { let (lazy_load_enabled, lazy_load_send_redundant) = match &body.filter.lazy_load_options {
LazyLoadOptions::Enabled { LazyLoadOptions::Enabled {
include_redundant_members, include_redundant_members,
} => (true, *include_redundant_members), } => (true, *include_redundant_members),
LazyLoadOptions::Disabled => (false, false), LazyLoadOptions::Disabled => (false, false),
}; };
let mut lazy_loaded = HashSet::new(); let mut lazy_loaded = HashSet::new();
let base_token = services() let base_token = services()
.rooms .rooms
.timeline .timeline
.get_pdu_count(&body.event_id)? .get_pdu_count(&body.event_id)?
.ok_or(Error::BadRequest( .ok_or(Error::BadRequest(ErrorKind::NotFound, "Base event id not found."))?;
ErrorKind::NotFound,
"Base event id not found.",
))?;
let base_event = let base_event = services()
services() .rooms
.rooms .timeline
.timeline .get_pdu(&body.event_id)?
.get_pdu(&body.event_id)? .ok_or(Error::BadRequest(ErrorKind::NotFound, "Base event not found."))?;
.ok_or(Error::BadRequest(
ErrorKind::NotFound,
"Base event not found.",
))?;
let room_id = base_event.room_id.clone(); let room_id = base_event.room_id.clone();
if !services() if !services().rooms.state_accessor.user_can_see_event(sender_user, &room_id, &body.event_id)? {
.rooms return Err(Error::BadRequest(
.state_accessor ErrorKind::Forbidden,
.user_can_see_event(sender_user, &room_id, &body.event_id)? "You don't have permission to view this event.",
{ ));
return Err(Error::BadRequest( }
ErrorKind::Forbidden,
"You don't have permission to view this event.",
));
}
if !services().rooms.lazy_loading.lazy_load_was_sent_before( if !services().rooms.lazy_loading.lazy_load_was_sent_before(
sender_user, sender_user,
sender_device, sender_device,
&room_id, &room_id,
&base_event.sender, &base_event.sender,
)? || lazy_load_send_redundant )? || lazy_load_send_redundant
{ {
lazy_loaded.insert(base_event.sender.as_str().to_owned()); lazy_loaded.insert(base_event.sender.as_str().to_owned());
} }
// Use limit with maximum 100 // Use limit with maximum 100
let limit = u64::from(body.limit).min(100) as usize; let limit = u64::from(body.limit).min(100) as usize;
let base_event = base_event.to_room_event(); let base_event = base_event.to_room_event();
let events_before: Vec<_> = services() let events_before: Vec<_> = services()
.rooms .rooms
.timeline .timeline
.pdus_until(sender_user, &room_id, base_token)? .pdus_until(sender_user, &room_id, base_token)?
.take(limit / 2) .take(limit / 2)
.filter_map(std::result::Result::ok) // Remove buggy events .filter_map(std::result::Result::ok) // Remove buggy events
.filter(|(_, pdu)| { .filter(|(_, pdu)| {
services() services()
.rooms .rooms
.state_accessor .state_accessor
.user_can_see_event(sender_user, &room_id, &pdu.event_id) .user_can_see_event(sender_user, &room_id, &pdu.event_id)
.unwrap_or(false) .unwrap_or(false)
}) })
.collect(); .collect();
for (_, event) in &events_before { for (_, event) in &events_before {
if !services().rooms.lazy_loading.lazy_load_was_sent_before( if !services().rooms.lazy_loading.lazy_load_was_sent_before(
sender_user, sender_user,
sender_device, sender_device,
&room_id, &room_id,
&event.sender, &event.sender,
)? || lazy_load_send_redundant )? || lazy_load_send_redundant
{ {
lazy_loaded.insert(event.sender.as_str().to_owned()); lazy_loaded.insert(event.sender.as_str().to_owned());
} }
} }
let start_token = events_before let start_token =
.last() events_before.last().map(|(count, _)| count.stringify()).unwrap_or_else(|| base_token.stringify());
.map(|(count, _)| count.stringify())
.unwrap_or_else(|| base_token.stringify());
let events_before: Vec<_> = events_before let events_before: Vec<_> = events_before.into_iter().map(|(_, pdu)| pdu.to_room_event()).collect();
.into_iter()
.map(|(_, pdu)| pdu.to_room_event())
.collect();
let events_after: Vec<_> = services() let events_after: Vec<_> = services()
.rooms .rooms
.timeline .timeline
.pdus_after(sender_user, &room_id, base_token)? .pdus_after(sender_user, &room_id, base_token)?
.take(limit / 2) .take(limit / 2)
.filter_map(std::result::Result::ok) // Remove buggy events .filter_map(std::result::Result::ok) // Remove buggy events
.filter(|(_, pdu)| { .filter(|(_, pdu)| {
services() services()
.rooms .rooms
.state_accessor .state_accessor
.user_can_see_event(sender_user, &room_id, &pdu.event_id) .user_can_see_event(sender_user, &room_id, &pdu.event_id)
.unwrap_or(false) .unwrap_or(false)
}) })
.collect(); .collect();
for (_, event) in &events_after { for (_, event) in &events_after {
if !services().rooms.lazy_loading.lazy_load_was_sent_before( if !services().rooms.lazy_loading.lazy_load_was_sent_before(
sender_user, sender_user,
sender_device, sender_device,
&room_id, &room_id,
&event.sender, &event.sender,
)? || lazy_load_send_redundant )? || lazy_load_send_redundant
{ {
lazy_loaded.insert(event.sender.as_str().to_owned()); lazy_loaded.insert(event.sender.as_str().to_owned());
} }
} }
let shortstatehash = match services().rooms.state_accessor.pdu_shortstatehash( let shortstatehash = match services()
events_after .rooms
.last() .state_accessor
.map_or(&*body.event_id, |(_, e)| &*e.event_id), .pdu_shortstatehash(events_after.last().map_or(&*body.event_id, |(_, e)| &*e.event_id))?
)? { {
Some(s) => s, Some(s) => s,
None => services() None => services().rooms.state.get_room_shortstatehash(&room_id)?.expect("All rooms have state"),
.rooms };
.state
.get_room_shortstatehash(&room_id)?
.expect("All rooms have state"),
};
let state_ids = services() let state_ids = services().rooms.state_accessor.state_full_ids(shortstatehash).await?;
.rooms
.state_accessor
.state_full_ids(shortstatehash)
.await?;
let end_token = events_after let end_token = events_after.last().map(|(count, _)| count.stringify()).unwrap_or_else(|| base_token.stringify());
.last()
.map(|(count, _)| count.stringify())
.unwrap_or_else(|| base_token.stringify());
let events_after: Vec<_> = events_after let events_after: Vec<_> = events_after.into_iter().map(|(_, pdu)| pdu.to_room_event()).collect();
.into_iter()
.map(|(_, pdu)| pdu.to_room_event())
.collect();
let mut state = Vec::new(); let mut state = Vec::new();
for (shortstatekey, id) in state_ids { for (shortstatekey, id) in state_ids {
let (event_type, state_key) = services() let (event_type, state_key) = services().rooms.short.get_statekey_from_short(shortstatekey)?;
.rooms
.short
.get_statekey_from_short(shortstatekey)?;
if event_type != StateEventType::RoomMember { if event_type != StateEventType::RoomMember {
let pdu = match services().rooms.timeline.get_pdu(&id)? { let pdu = match services().rooms.timeline.get_pdu(&id)? {
Some(pdu) => pdu, Some(pdu) => pdu,
None => { None => {
error!("Pdu in state not found: {}", id); error!("Pdu in state not found: {}", id);
continue; continue;
} },
}; };
state.push(pdu.to_state_event()); state.push(pdu.to_state_event());
} else if !lazy_load_enabled || lazy_loaded.contains(&state_key) { } else if !lazy_load_enabled || lazy_loaded.contains(&state_key) {
let pdu = match services().rooms.timeline.get_pdu(&id)? { let pdu = match services().rooms.timeline.get_pdu(&id)? {
Some(pdu) => pdu, Some(pdu) => pdu,
None => { None => {
error!("Pdu in state not found: {}", id); error!("Pdu in state not found: {}", id);
continue; continue;
} },
}; };
state.push(pdu.to_state_event()); state.push(pdu.to_state_event());
} }
} }
let resp = get_context::v3::Response { let resp = get_context::v3::Response {
start: Some(start_token), start: Some(start_token),
end: Some(end_token), end: Some(end_token),
events_before, events_before,
event: Some(base_event), event: Some(base_event),
events_after, events_after,
state, state,
}; };
Ok(resp) Ok(resp)
} }

View file

@ -1,65 +1,61 @@
use crate::{services, utils, Error, Result, Ruma};
use ruma::api::client::{ use ruma::api::client::{
device::{self, delete_device, delete_devices, get_device, get_devices, update_device}, device::{self, delete_device, delete_devices, get_device, get_devices, update_device},
error::ErrorKind, error::ErrorKind,
uiaa::{AuthFlow, AuthType, UiaaInfo}, uiaa::{AuthFlow, AuthType, UiaaInfo},
}; };
use super::SESSION_ID_LENGTH; use super::SESSION_ID_LENGTH;
use crate::{services, utils, Error, Result, Ruma};
/// # `GET /_matrix/client/r0/devices` /// # `GET /_matrix/client/r0/devices`
/// ///
/// Get metadata on all devices of the sender user. /// Get metadata on all devices of the sender user.
pub async fn get_devices_route( pub async fn get_devices_route(body: Ruma<get_devices::v3::Request>) -> Result<get_devices::v3::Response> {
body: Ruma<get_devices::v3::Request>, let sender_user = body.sender_user.as_ref().expect("user is authenticated");
) -> Result<get_devices::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let devices: Vec<device::Device> = services() let devices: Vec<device::Device> = services()
.users .users
.all_devices_metadata(sender_user) .all_devices_metadata(sender_user)
.filter_map(std::result::Result::ok) // Filter out buggy devices .filter_map(std::result::Result::ok) // Filter out buggy devices
.collect(); .collect();
Ok(get_devices::v3::Response { devices }) Ok(get_devices::v3::Response {
devices,
})
} }
/// # `GET /_matrix/client/r0/devices/{deviceId}` /// # `GET /_matrix/client/r0/devices/{deviceId}`
/// ///
/// Get metadata on a single device of the sender user. /// Get metadata on a single device of the sender user.
pub async fn get_device_route( pub async fn get_device_route(body: Ruma<get_device::v3::Request>) -> Result<get_device::v3::Response> {
body: Ruma<get_device::v3::Request>, let sender_user = body.sender_user.as_ref().expect("user is authenticated");
) -> Result<get_device::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let device = services() let device = services()
.users .users
.get_device_metadata(sender_user, &body.body.device_id)? .get_device_metadata(sender_user, &body.body.device_id)?
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Device not found."))?; .ok_or(Error::BadRequest(ErrorKind::NotFound, "Device not found."))?;
Ok(get_device::v3::Response { device }) Ok(get_device::v3::Response {
device,
})
} }
/// # `PUT /_matrix/client/r0/devices/{deviceId}` /// # `PUT /_matrix/client/r0/devices/{deviceId}`
/// ///
/// Updates the metadata on a given device of the sender user. /// Updates the metadata on a given device of the sender user.
pub async fn update_device_route( pub async fn update_device_route(body: Ruma<update_device::v3::Request>) -> Result<update_device::v3::Response> {
body: Ruma<update_device::v3::Request>, let sender_user = body.sender_user.as_ref().expect("user is authenticated");
) -> Result<update_device::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let mut device = services() let mut device = services()
.users .users
.get_device_metadata(sender_user, &body.device_id)? .get_device_metadata(sender_user, &body.device_id)?
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Device not found."))?; .ok_or(Error::BadRequest(ErrorKind::NotFound, "Device not found."))?;
device.display_name = body.display_name.clone(); device.display_name = body.display_name.clone();
services() services().users.update_device_metadata(sender_user, &body.device_id, &device)?;
.users
.update_device_metadata(sender_user, &body.device_id, &device)?;
Ok(update_device::v3::Response {}) Ok(update_device::v3::Response {})
} }
/// # `DELETE /_matrix/client/r0/devices/{deviceId}` /// # `DELETE /_matrix/client/r0/devices/{deviceId}`
@ -68,50 +64,42 @@ pub async fn update_device_route(
/// ///
/// - Requires UIAA to verify user password /// - Requires UIAA to verify user password
/// - Invalidates access token /// - Invalidates access token
/// - Deletes device metadata (device id, device display name, last seen ip, last seen ts) /// - Deletes device metadata (device id, device display name, last seen ip,
/// last seen ts)
/// - Forgets to-device events /// - Forgets to-device events
/// - Triggers device list updates /// - Triggers device list updates
pub async fn delete_device_route( pub async fn delete_device_route(body: Ruma<delete_device::v3::Request>) -> Result<delete_device::v3::Response> {
body: Ruma<delete_device::v3::Request>, let sender_user = body.sender_user.as_ref().expect("user is authenticated");
) -> Result<delete_device::v3::Response> { let sender_device = body.sender_device.as_ref().expect("user is authenticated");
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
// UIAA // UIAA
let mut uiaainfo = UiaaInfo { let mut uiaainfo = UiaaInfo {
flows: vec![AuthFlow { flows: vec![AuthFlow {
stages: vec![AuthType::Password], stages: vec![AuthType::Password],
}], }],
completed: Vec::new(), completed: Vec::new(),
params: Box::default(), params: Box::default(),
session: None, session: None,
auth_error: None, auth_error: None,
}; };
if let Some(auth) = &body.auth { if let Some(auth) = &body.auth {
let (worked, uiaainfo) = let (worked, uiaainfo) = services().uiaa.try_auth(sender_user, sender_device, auth, &uiaainfo)?;
services() if !worked {
.uiaa return Err(Error::Uiaa(uiaainfo));
.try_auth(sender_user, sender_device, auth, &uiaainfo)?; }
if !worked { // Success!
return Err(Error::Uiaa(uiaainfo)); } else if let Some(json) = body.json_body {
} uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
// Success! services().uiaa.create(sender_user, sender_device, &uiaainfo, &json)?;
} else if let Some(json) = body.json_body { return Err(Error::Uiaa(uiaainfo));
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); } else {
services() return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
.uiaa }
.create(sender_user, sender_device, &uiaainfo, &json)?;
return Err(Error::Uiaa(uiaainfo));
} else {
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
}
services() services().users.remove_device(sender_user, &body.device_id)?;
.users
.remove_device(sender_user, &body.device_id)?;
Ok(delete_device::v3::Response {}) Ok(delete_device::v3::Response {})
} }
/// # `PUT /_matrix/client/r0/devices/{deviceId}` /// # `PUT /_matrix/client/r0/devices/{deviceId}`
@ -122,48 +110,42 @@ pub async fn delete_device_route(
/// ///
/// For each device: /// For each device:
/// - Invalidates access token /// - Invalidates access token
/// - Deletes device metadata (device id, device display name, last seen ip, last seen ts) /// - Deletes device metadata (device id, device display name, last seen ip,
/// last seen ts)
/// - Forgets to-device events /// - Forgets to-device events
/// - Triggers device list updates /// - Triggers device list updates
pub async fn delete_devices_route( pub async fn delete_devices_route(body: Ruma<delete_devices::v3::Request>) -> Result<delete_devices::v3::Response> {
body: Ruma<delete_devices::v3::Request>, let sender_user = body.sender_user.as_ref().expect("user is authenticated");
) -> Result<delete_devices::v3::Response> { let sender_device = body.sender_device.as_ref().expect("user is authenticated");
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
// UIAA // UIAA
let mut uiaainfo = UiaaInfo { let mut uiaainfo = UiaaInfo {
flows: vec![AuthFlow { flows: vec![AuthFlow {
stages: vec![AuthType::Password], stages: vec![AuthType::Password],
}], }],
completed: Vec::new(), completed: Vec::new(),
params: Box::default(), params: Box::default(),
session: None, session: None,
auth_error: None, auth_error: None,
}; };
if let Some(auth) = &body.auth { if let Some(auth) = &body.auth {
let (worked, uiaainfo) = let (worked, uiaainfo) = services().uiaa.try_auth(sender_user, sender_device, auth, &uiaainfo)?;
services() if !worked {
.uiaa return Err(Error::Uiaa(uiaainfo));
.try_auth(sender_user, sender_device, auth, &uiaainfo)?; }
if !worked { // Success!
return Err(Error::Uiaa(uiaainfo)); } else if let Some(json) = body.json_body {
} uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
// Success! services().uiaa.create(sender_user, sender_device, &uiaainfo, &json)?;
} else if let Some(json) = body.json_body { return Err(Error::Uiaa(uiaainfo));
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); } else {
services() return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
.uiaa }
.create(sender_user, sender_device, &uiaainfo, &json)?;
return Err(Error::Uiaa(uiaainfo));
} else {
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
}
for device_id in &body.devices { for device_id in &body.devices {
services().users.remove_device(sender_user, device_id)?; services().users.remove_device(sender_user, device_id)?;
} }
Ok(delete_devices::v3::Response {}) Ok(delete_devices::v3::Response {})
} }

View file

@ -1,57 +1,51 @@
use crate::{services, Error, Result, Ruma};
use ruma::{ use ruma::{
api::{ api::{
client::{ client::{
directory::{ directory::{get_public_rooms, get_public_rooms_filtered, get_room_visibility, set_room_visibility},
get_public_rooms, get_public_rooms_filtered, get_room_visibility, error::ErrorKind,
set_room_visibility, room,
}, },
error::ErrorKind, federation,
room, },
}, directory::{Filter, PublicRoomJoinRule, PublicRoomsChunk, RoomNetwork},
federation, events::{
}, room::{
directory::{Filter, PublicRoomJoinRule, PublicRoomsChunk, RoomNetwork}, avatar::RoomAvatarEventContent,
events::{ canonical_alias::RoomCanonicalAliasEventContent,
room::{ create::RoomCreateEventContent,
avatar::RoomAvatarEventContent, guest_access::{GuestAccess, RoomGuestAccessEventContent},
canonical_alias::RoomCanonicalAliasEventContent, history_visibility::{HistoryVisibility, RoomHistoryVisibilityEventContent},
create::RoomCreateEventContent, join_rules::{JoinRule, RoomJoinRulesEventContent},
guest_access::{GuestAccess, RoomGuestAccessEventContent}, topic::RoomTopicEventContent,
history_visibility::{HistoryVisibility, RoomHistoryVisibilityEventContent}, },
join_rules::{JoinRule, RoomJoinRulesEventContent}, StateEventType,
topic::RoomTopicEventContent, },
}, ServerName, UInt,
StateEventType,
},
ServerName, UInt,
}; };
use tracing::{error, info, warn}; use tracing::{error, info, warn};
use crate::{services, Error, Result, Ruma};
/// # `POST /_matrix/client/v3/publicRooms` /// # `POST /_matrix/client/v3/publicRooms`
/// ///
/// Lists the public rooms on this server. /// Lists the public rooms on this server.
/// ///
/// - Rooms are ordered by the number of joined members /// - Rooms are ordered by the number of joined members
pub async fn get_public_rooms_filtered_route( pub async fn get_public_rooms_filtered_route(
body: Ruma<get_public_rooms_filtered::v3::Request>, body: Ruma<get_public_rooms_filtered::v3::Request>,
) -> Result<get_public_rooms_filtered::v3::Response> { ) -> Result<get_public_rooms_filtered::v3::Response> {
if !services() if !services().globals.config.allow_public_room_directory_without_auth {
.globals let _sender_user = body.sender_user.as_ref().expect("user is authenticated");
.config }
.allow_public_room_directory_without_auth
{
let _sender_user = body.sender_user.as_ref().expect("user is authenticated");
}
get_public_rooms_filtered_helper( get_public_rooms_filtered_helper(
body.server.as_deref(), body.server.as_deref(),
body.limit, body.limit,
body.since.as_deref(), body.since.as_deref(),
&body.filter, &body.filter,
&body.room_network, &body.room_network,
) )
.await .await
} }
/// # `GET /_matrix/client/v3/publicRooms` /// # `GET /_matrix/client/v3/publicRooms`
@ -60,31 +54,27 @@ pub async fn get_public_rooms_filtered_route(
/// ///
/// - Rooms are ordered by the number of joined members /// - Rooms are ordered by the number of joined members
pub async fn get_public_rooms_route( pub async fn get_public_rooms_route(
body: Ruma<get_public_rooms::v3::Request>, body: Ruma<get_public_rooms::v3::Request>,
) -> Result<get_public_rooms::v3::Response> { ) -> Result<get_public_rooms::v3::Response> {
if !services() if !services().globals.config.allow_public_room_directory_without_auth {
.globals let _sender_user = body.sender_user.as_ref().expect("user is authenticated");
.config }
.allow_public_room_directory_without_auth
{
let _sender_user = body.sender_user.as_ref().expect("user is authenticated");
}
let response = get_public_rooms_filtered_helper( let response = get_public_rooms_filtered_helper(
body.server.as_deref(), body.server.as_deref(),
body.limit, body.limit,
body.since.as_deref(), body.since.as_deref(),
&Filter::default(), &Filter::default(),
&RoomNetwork::Matrix, &RoomNetwork::Matrix,
) )
.await?; .await?;
Ok(get_public_rooms::v3::Response { Ok(get_public_rooms::v3::Response {
chunk: response.chunk, chunk: response.chunk,
prev_batch: response.prev_batch, prev_batch: response.prev_batch,
next_batch: response.next_batch, next_batch: response.next_batch,
total_room_count_estimate: response.total_room_count_estimate, total_room_count_estimate: response.total_room_count_estimate,
}) })
} }
/// # `PUT /_matrix/client/r0/directory/list/room/{roomId}` /// # `PUT /_matrix/client/r0/directory/list/room/{roomId}`
@ -93,294 +83,261 @@ pub async fn get_public_rooms_route(
/// ///
/// - TODO: Access control checks /// - TODO: Access control checks
pub async fn set_room_visibility_route( pub async fn set_room_visibility_route(
body: Ruma<set_room_visibility::v3::Request>, body: Ruma<set_room_visibility::v3::Request>,
) -> Result<set_room_visibility::v3::Response> { ) -> Result<set_room_visibility::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if !services().rooms.metadata.exists(&body.room_id)? { if !services().rooms.metadata.exists(&body.room_id)? {
// Return 404 if the room doesn't exist // Return 404 if the room doesn't exist
return Err(Error::BadRequest(ErrorKind::NotFound, "Room not found")); return Err(Error::BadRequest(ErrorKind::NotFound, "Room not found"));
} }
match &body.visibility { match &body.visibility {
room::Visibility::Public => { room::Visibility::Public => {
services().rooms.directory.set_public(&body.room_id)?; services().rooms.directory.set_public(&body.room_id)?;
info!("{} made {} public", sender_user, body.room_id); info!("{} made {} public", sender_user, body.room_id);
} },
room::Visibility::Private => services().rooms.directory.set_not_public(&body.room_id)?, room::Visibility::Private => services().rooms.directory.set_not_public(&body.room_id)?,
_ => { _ => {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::InvalidParam, ErrorKind::InvalidParam,
"Room visibility type is not supported.", "Room visibility type is not supported.",
)); ));
} },
} }
Ok(set_room_visibility::v3::Response {}) Ok(set_room_visibility::v3::Response {})
} }
/// # `GET /_matrix/client/r0/directory/list/room/{roomId}` /// # `GET /_matrix/client/r0/directory/list/room/{roomId}`
/// ///
/// Gets the visibility of a given room in the room directory. /// Gets the visibility of a given room in the room directory.
pub async fn get_room_visibility_route( pub async fn get_room_visibility_route(
body: Ruma<get_room_visibility::v3::Request>, body: Ruma<get_room_visibility::v3::Request>,
) -> Result<get_room_visibility::v3::Response> { ) -> Result<get_room_visibility::v3::Response> {
if !services().rooms.metadata.exists(&body.room_id)? { if !services().rooms.metadata.exists(&body.room_id)? {
// Return 404 if the room doesn't exist // Return 404 if the room doesn't exist
return Err(Error::BadRequest(ErrorKind::NotFound, "Room not found")); return Err(Error::BadRequest(ErrorKind::NotFound, "Room not found"));
} }
Ok(get_room_visibility::v3::Response { Ok(get_room_visibility::v3::Response {
visibility: if services().rooms.directory.is_public_room(&body.room_id)? { visibility: if services().rooms.directory.is_public_room(&body.room_id)? {
room::Visibility::Public room::Visibility::Public
} else { } else {
room::Visibility::Private room::Visibility::Private
}, },
}) })
} }
pub(crate) async fn get_public_rooms_filtered_helper( pub(crate) async fn get_public_rooms_filtered_helper(
server: Option<&ServerName>, server: Option<&ServerName>, limit: Option<UInt>, since: Option<&str>, filter: &Filter, _network: &RoomNetwork,
limit: Option<UInt>,
since: Option<&str>,
filter: &Filter,
_network: &RoomNetwork,
) -> Result<get_public_rooms_filtered::v3::Response> { ) -> Result<get_public_rooms_filtered::v3::Response> {
if let Some(other_server) = if let Some(other_server) = server.filter(|server| *server != services().globals.server_name().as_str()) {
server.filter(|server| *server != services().globals.server_name().as_str()) let response = services()
{ .sending
let response = services() .send_federation_request(
.sending other_server,
.send_federation_request( federation::directory::get_public_rooms_filtered::v1::Request {
other_server, limit,
federation::directory::get_public_rooms_filtered::v1::Request { since: since.map(ToOwned::to_owned),
limit, filter: Filter {
since: since.map(ToOwned::to_owned), generic_search_term: filter.generic_search_term.clone(),
filter: Filter { room_types: filter.room_types.clone(),
generic_search_term: filter.generic_search_term.clone(), },
room_types: filter.room_types.clone(), room_network: RoomNetwork::Matrix,
}, },
room_network: RoomNetwork::Matrix, )
}, .await?;
)
.await?;
return Ok(get_public_rooms_filtered::v3::Response { return Ok(get_public_rooms_filtered::v3::Response {
chunk: response.chunk, chunk: response.chunk,
prev_batch: response.prev_batch, prev_batch: response.prev_batch,
next_batch: response.next_batch, next_batch: response.next_batch,
total_room_count_estimate: response.total_room_count_estimate, total_room_count_estimate: response.total_room_count_estimate,
}); });
} }
let limit = limit.map_or(10, u64::from); let limit = limit.map_or(10, u64::from);
let mut num_since = 0_u64; let mut num_since = 0_u64;
if let Some(s) = &since { if let Some(s) = &since {
let mut characters = s.chars(); let mut characters = s.chars();
let backwards = match characters.next() { let backwards = match characters.next() {
Some('n') => false, Some('n') => false,
Some('p') => true, Some('p') => true,
_ => { _ => return Err(Error::BadRequest(ErrorKind::InvalidParam, "Invalid `since` token")),
return Err(Error::BadRequest( };
ErrorKind::InvalidParam,
"Invalid `since` token",
))
}
};
num_since = characters num_since = characters
.collect::<String>() .collect::<String>()
.parse() .parse()
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `since` token."))?; .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `since` token."))?;
if backwards { if backwards {
num_since = num_since.saturating_sub(limit); num_since = num_since.saturating_sub(limit);
} }
} }
let mut all_rooms: Vec<_> = services() let mut all_rooms: Vec<_> = services()
.rooms .rooms
.directory .directory
.public_rooms() .public_rooms()
.map(|room_id| { .map(|room_id| {
let room_id = room_id?; let room_id = room_id?;
let chunk = PublicRoomsChunk { let chunk = PublicRoomsChunk {
canonical_alias: services() canonical_alias: services()
.rooms .rooms
.state_accessor .state_accessor
.room_state_get(&room_id, &StateEventType::RoomCanonicalAlias, "")? .room_state_get(&room_id, &StateEventType::RoomCanonicalAlias, "")?
.map_or(Ok(None), |s| { .map_or(Ok(None), |s| {
serde_json::from_str(s.content.get()) serde_json::from_str(s.content.get())
.map(|c: RoomCanonicalAliasEventContent| c.alias) .map(|c: RoomCanonicalAliasEventContent| c.alias)
.map_err(|_| { .map_err(|_| Error::bad_database("Invalid canonical alias event in database."))
Error::bad_database("Invalid canonical alias event in database.") })?,
}) name: services().rooms.state_accessor.get_name(&room_id)?,
})?, num_joined_members: services()
name: services().rooms.state_accessor.get_name(&room_id)?, .rooms
num_joined_members: services() .state_cache
.rooms .room_joined_count(&room_id)?
.state_cache .unwrap_or_else(|| {
.room_joined_count(&room_id)? warn!("Room {} has no member count", room_id);
.unwrap_or_else(|| { 0
warn!("Room {} has no member count", room_id); })
0 .try_into()
}) .expect("user count should not be that big"),
.try_into() topic: services()
.expect("user count should not be that big"), .rooms
topic: services() .state_accessor
.rooms .room_state_get(&room_id, &StateEventType::RoomTopic, "")?
.state_accessor .map_or(Ok(None), |s| {
.room_state_get(&room_id, &StateEventType::RoomTopic, "")? serde_json::from_str(s.content.get())
.map_or(Ok(None), |s| { .map(|c: RoomTopicEventContent| Some(c.topic))
serde_json::from_str(s.content.get()) .map_err(|_| {
.map(|c: RoomTopicEventContent| Some(c.topic)) error!("Invalid room topic event in database for room {}", room_id);
.map_err(|_| { Error::bad_database("Invalid room topic event in database.")
error!("Invalid room topic event in database for room {}", room_id); })
Error::bad_database("Invalid room topic event in database.") })
}) .unwrap_or(None),
}) world_readable: services()
.unwrap_or(None), .rooms
world_readable: services() .state_accessor
.rooms .room_state_get(&room_id, &StateEventType::RoomHistoryVisibility, "")?
.state_accessor .map_or(Ok(false), |s| {
.room_state_get(&room_id, &StateEventType::RoomHistoryVisibility, "")? serde_json::from_str(s.content.get())
.map_or(Ok(false), |s| { .map(|c: RoomHistoryVisibilityEventContent| {
serde_json::from_str(s.content.get()) c.history_visibility == HistoryVisibility::WorldReadable
.map(|c: RoomHistoryVisibilityEventContent| { })
c.history_visibility == HistoryVisibility::WorldReadable .map_err(|_| Error::bad_database("Invalid room history visibility event in database."))
}) })?,
.map_err(|_| { guest_can_join: services()
Error::bad_database( .rooms
"Invalid room history visibility event in database.", .state_accessor
) .room_state_get(&room_id, &StateEventType::RoomGuestAccess, "")?
}) .map_or(Ok(false), |s| {
})?, serde_json::from_str(s.content.get())
guest_can_join: services() .map(|c: RoomGuestAccessEventContent| c.guest_access == GuestAccess::CanJoin)
.rooms .map_err(|_| Error::bad_database("Invalid room guest access event in database."))
.state_accessor })?,
.room_state_get(&room_id, &StateEventType::RoomGuestAccess, "")? avatar_url: services()
.map_or(Ok(false), |s| { .rooms
serde_json::from_str(s.content.get()) .state_accessor
.map(|c: RoomGuestAccessEventContent| { .room_state_get(&room_id, &StateEventType::RoomAvatar, "")?
c.guest_access == GuestAccess::CanJoin .map(|s| {
}) serde_json::from_str(s.content.get())
.map_err(|_| { .map(|c: RoomAvatarEventContent| c.url)
Error::bad_database("Invalid room guest access event in database.") .map_err(|_| Error::bad_database("Invalid room avatar event in database."))
}) })
})?, .transpose()?
avatar_url: services() // url is now an Option<String> so we must flatten
.rooms .flatten(),
.state_accessor join_rule: services()
.room_state_get(&room_id, &StateEventType::RoomAvatar, "")? .rooms
.map(|s| { .state_accessor
serde_json::from_str(s.content.get()) .room_state_get(&room_id, &StateEventType::RoomJoinRules, "")?
.map(|c: RoomAvatarEventContent| c.url) .map(|s| {
.map_err(|_| { serde_json::from_str(s.content.get())
Error::bad_database("Invalid room avatar event in database.") .map(|c: RoomJoinRulesEventContent| match c.join_rule {
}) JoinRule::Public => Some(PublicRoomJoinRule::Public),
}) JoinRule::Knock => Some(PublicRoomJoinRule::Knock),
.transpose()? _ => None,
// url is now an Option<String> so we must flatten })
.flatten(), .map_err(|e| {
join_rule: services() error!("Invalid room join rule event in database: {}", e);
.rooms Error::BadDatabase("Invalid room join rule event in database.")
.state_accessor })
.room_state_get(&room_id, &StateEventType::RoomJoinRules, "")? })
.map(|s| { .transpose()?
serde_json::from_str(s.content.get()) .flatten()
.map(|c: RoomJoinRulesEventContent| match c.join_rule { .ok_or_else(|| Error::bad_database("Missing room join rule event for room."))?,
JoinRule::Public => Some(PublicRoomJoinRule::Public), room_type: services()
JoinRule::Knock => Some(PublicRoomJoinRule::Knock), .rooms
_ => None, .state_accessor
}) .room_state_get(&room_id, &StateEventType::RoomCreate, "")?
.map_err(|e| { .map(|s| {
error!("Invalid room join rule event in database: {}", e); serde_json::from_str::<RoomCreateEventContent>(s.content.get()).map_err(|e| {
Error::BadDatabase("Invalid room join rule event in database.") error!("Invalid room create event in database: {}", e);
}) Error::BadDatabase("Invalid room create event in database.")
}) })
.transpose()? })
.flatten() .transpose()?
.ok_or_else(|| Error::bad_database("Missing room join rule event for room."))?, .and_then(|e| e.room_type),
room_type: services() room_id,
.rooms };
.state_accessor Ok(chunk)
.room_state_get(&room_id, &StateEventType::RoomCreate, "")? })
.map(|s| { .filter_map(|r: Result<_>| r.ok()) // Filter out buggy rooms
serde_json::from_str::<RoomCreateEventContent>(s.content.get()).map_err( .filter(|chunk| {
|e| { if let Some(query) = filter.generic_search_term.as_ref().map(|q| q.to_lowercase()) {
error!("Invalid room create event in database: {}", e); if let Some(name) = &chunk.name {
Error::BadDatabase("Invalid room create event in database.") if name.as_str().to_lowercase().contains(&query) {
}, return true;
) }
}) }
.transpose()?
.and_then(|e| e.room_type),
room_id,
};
Ok(chunk)
})
.filter_map(|r: Result<_>| r.ok()) // Filter out buggy rooms
.filter(|chunk| {
if let Some(query) = filter
.generic_search_term
.as_ref()
.map(|q| q.to_lowercase())
{
if let Some(name) = &chunk.name {
if name.as_str().to_lowercase().contains(&query) {
return true;
}
}
if let Some(topic) = &chunk.topic { if let Some(topic) = &chunk.topic {
if topic.to_lowercase().contains(&query) { if topic.to_lowercase().contains(&query) {
return true; return true;
} }
} }
if let Some(canonical_alias) = &chunk.canonical_alias { if let Some(canonical_alias) = &chunk.canonical_alias {
if canonical_alias.as_str().to_lowercase().contains(&query) { if canonical_alias.as_str().to_lowercase().contains(&query) {
return true; return true;
} }
} }
false false
} else { } else {
// No search term // No search term
true true
} }
}) })
// We need to collect all, so we can sort by member count // We need to collect all, so we can sort by member count
.collect(); .collect();
all_rooms.sort_by(|l, r| r.num_joined_members.cmp(&l.num_joined_members)); all_rooms.sort_by(|l, r| r.num_joined_members.cmp(&l.num_joined_members));
let total_room_count_estimate = (all_rooms.len() as u32).into(); let total_room_count_estimate = (all_rooms.len() as u32).into();
let chunk: Vec<_> = all_rooms let chunk: Vec<_> = all_rooms.into_iter().skip(num_since as usize).take(limit as usize).collect();
.into_iter()
.skip(num_since as usize)
.take(limit as usize)
.collect();
let prev_batch = if num_since == 0 { let prev_batch = if num_since == 0 {
None None
} else { } else {
Some(format!("p{num_since}")) Some(format!("p{num_since}"))
}; };
let next_batch = if chunk.len() < limit as usize { let next_batch = if chunk.len() < limit as usize {
None None
} else { } else {
Some(format!("n{}", num_since + limit)) Some(format!("n{}", num_since + limit))
}; };
Ok(get_public_rooms_filtered::v3::Response { Ok(get_public_rooms_filtered::v3::Response {
chunk, chunk,
prev_batch, prev_batch,
next_batch, next_batch,
total_room_count_estimate: Some(total_room_count_estimate), total_room_count_estimate: Some(total_room_count_estimate),
}) })
} }

View file

@ -1,34 +1,31 @@
use crate::{services, Error, Result, Ruma};
use ruma::api::client::{ use ruma::api::client::{
error::ErrorKind, error::ErrorKind,
filter::{create_filter, get_filter}, filter::{create_filter, get_filter},
}; };
use crate::{services, Error, Result, Ruma};
/// # `GET /_matrix/client/r0/user/{userId}/filter/{filterId}` /// # `GET /_matrix/client/r0/user/{userId}/filter/{filterId}`
/// ///
/// Loads a filter that was previously created. /// Loads a filter that was previously created.
/// ///
/// - A user can only access their own filters /// - A user can only access their own filters
pub async fn get_filter_route( pub async fn get_filter_route(body: Ruma<get_filter::v3::Request>) -> Result<get_filter::v3::Response> {
body: Ruma<get_filter::v3::Request>, let sender_user = body.sender_user.as_ref().expect("user is authenticated");
) -> Result<get_filter::v3::Response> { let filter = match services().users.get_filter(sender_user, &body.filter_id)? {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); Some(filter) => filter,
let filter = match services().users.get_filter(sender_user, &body.filter_id)? { None => return Err(Error::BadRequest(ErrorKind::NotFound, "Filter not found.")),
Some(filter) => filter, };
None => return Err(Error::BadRequest(ErrorKind::NotFound, "Filter not found.")),
};
Ok(get_filter::v3::Response::new(filter)) Ok(get_filter::v3::Response::new(filter))
} }
/// # `PUT /_matrix/client/r0/user/{userId}/filter` /// # `PUT /_matrix/client/r0/user/{userId}/filter`
/// ///
/// Creates a new filter to be used by other endpoints. /// Creates a new filter to be used by other endpoints.
pub async fn create_filter_route( pub async fn create_filter_route(body: Ruma<create_filter::v3::Request>) -> Result<create_filter::v3::Response> {
body: Ruma<create_filter::v3::Request>, let sender_user = body.sender_user.as_ref().expect("user is authenticated");
) -> Result<create_filter::v3::Response> { Ok(create_filter::v3::Response::new(
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); services().users.create_filter(sender_user, &body.filter)?,
Ok(create_filter::v3::Response::new( ))
services().users.create_filter(sender_user, &body.filter)?,
))
} }

View file

@ -1,65 +1,53 @@
use super::SESSION_ID_LENGTH; use std::{
use crate::{services, utils, Error, Result, Ruma}; collections::{hash_map, BTreeMap, HashMap, HashSet},
time::{Duration, Instant},
};
use futures_util::{stream::FuturesUnordered, StreamExt}; use futures_util::{stream::FuturesUnordered, StreamExt};
use ruma::{ use ruma::{
api::{ api::{
client::{ client::{
error::ErrorKind, error::ErrorKind,
keys::{ keys::{claim_keys, get_key_changes, get_keys, upload_keys, upload_signatures, upload_signing_keys},
claim_keys, get_key_changes, get_keys, upload_keys, upload_signatures, uiaa::{AuthFlow, AuthType, UiaaInfo},
upload_signing_keys, },
}, federation,
uiaa::{AuthFlow, AuthType, UiaaInfo}, },
}, serde::Raw,
federation, DeviceKeyAlgorithm, OwnedDeviceId, OwnedUserId, UserId,
},
serde::Raw,
DeviceKeyAlgorithm, OwnedDeviceId, OwnedUserId, UserId,
}; };
use serde_json::json; use serde_json::json;
use std::{
collections::{hash_map, BTreeMap, HashMap, HashSet},
time::{Duration, Instant},
};
use tracing::{debug, error}; use tracing::{debug, error};
use super::SESSION_ID_LENGTH;
use crate::{services, utils, Error, Result, Ruma};
/// # `POST /_matrix/client/r0/keys/upload` /// # `POST /_matrix/client/r0/keys/upload`
/// ///
/// Publish end-to-end encryption keys for the sender device. /// Publish end-to-end encryption keys for the sender device.
/// ///
/// - Adds one time keys /// - Adds one time keys
/// - If there are no device keys yet: Adds device keys (TODO: merge with existing keys?) /// - If there are no device keys yet: Adds device keys (TODO: merge with
pub async fn upload_keys_route( /// existing keys?)
body: Ruma<upload_keys::v3::Request>, pub async fn upload_keys_route(body: Ruma<upload_keys::v3::Request>) -> Result<upload_keys::v3::Response> {
) -> Result<upload_keys::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_ref().expect("user is authenticated");
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
for (key_key, key_value) in &body.one_time_keys { for (key_key, key_value) in &body.one_time_keys {
services() services().users.add_one_time_key(sender_user, sender_device, key_key, key_value)?;
.users }
.add_one_time_key(sender_user, sender_device, key_key, key_value)?;
}
if let Some(device_keys) = &body.device_keys { if let Some(device_keys) = &body.device_keys {
// TODO: merge this and the existing event? // TODO: merge this and the existing event?
// This check is needed to assure that signatures are kept // This check is needed to assure that signatures are kept
if services() if services().users.get_device_keys(sender_user, sender_device)?.is_none() {
.users services().users.add_device_keys(sender_user, sender_device, device_keys)?;
.get_device_keys(sender_user, sender_device)? }
.is_none() }
{
services()
.users
.add_device_keys(sender_user, sender_device, device_keys)?;
}
}
Ok(upload_keys::v3::Response { Ok(upload_keys::v3::Response {
one_time_key_counts: services() one_time_key_counts: services().users.count_one_time_keys(sender_user, sender_device)?,
.users })
.count_one_time_keys(sender_user, sender_device)?,
})
} }
/// # `POST /_matrix/client/r0/keys/query` /// # `POST /_matrix/client/r0/keys/query`
@ -68,30 +56,29 @@ pub async fn upload_keys_route(
/// ///
/// - Always fetches users from other servers over federation /// - Always fetches users from other servers over federation
/// - Gets master keys, self-signing keys, user signing keys and device keys. /// - Gets master keys, self-signing keys, user signing keys and device keys.
/// - The master and self-signing keys contain signatures that the user is allowed to see /// - The master and self-signing keys contain signatures that the user is
/// allowed to see
pub async fn get_keys_route(body: Ruma<get_keys::v3::Request>) -> Result<get_keys::v3::Response> { pub async fn get_keys_route(body: Ruma<get_keys::v3::Request>) -> Result<get_keys::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let response = get_keys_helper( let response = get_keys_helper(
Some(sender_user), Some(sender_user),
&body.device_keys, &body.device_keys,
|u| u == sender_user, |u| u == sender_user,
true, // Always allow local users to see device names of other local users true, // Always allow local users to see device names of other local users
) )
.await?; .await?;
Ok(response) Ok(response)
} }
/// # `POST /_matrix/client/r0/keys/claim` /// # `POST /_matrix/client/r0/keys/claim`
/// ///
/// Claims one-time keys /// Claims one-time keys
pub async fn claim_keys_route( pub async fn claim_keys_route(body: Ruma<claim_keys::v3::Request>) -> Result<claim_keys::v3::Response> {
body: Ruma<claim_keys::v3::Request>, let response = claim_keys_helper(&body.one_time_keys).await?;
) -> Result<claim_keys::v3::Response> {
let response = claim_keys_helper(&body.one_time_keys).await?;
Ok(response) Ok(response)
} }
/// # `POST /_matrix/client/r0/keys/device_signing/upload` /// # `POST /_matrix/client/r0/keys/device_signing/upload`
@ -100,452 +87,373 @@ pub async fn claim_keys_route(
/// ///
/// - Requires UIAA to verify password /// - Requires UIAA to verify password
pub async fn upload_signing_keys_route( pub async fn upload_signing_keys_route(
body: Ruma<upload_signing_keys::v3::Request>, body: Ruma<upload_signing_keys::v3::Request>,
) -> Result<upload_signing_keys::v3::Response> { ) -> Result<upload_signing_keys::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_device = body.sender_device.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_ref().expect("user is authenticated");
// UIAA // UIAA
let mut uiaainfo = UiaaInfo { let mut uiaainfo = UiaaInfo {
flows: vec![AuthFlow { flows: vec![AuthFlow {
stages: vec![AuthType::Password], stages: vec![AuthType::Password],
}], }],
completed: Vec::new(), completed: Vec::new(),
params: Box::default(), params: Box::default(),
session: None, session: None,
auth_error: None, auth_error: None,
}; };
if let Some(auth) = &body.auth { if let Some(auth) = &body.auth {
let (worked, uiaainfo) = let (worked, uiaainfo) = services().uiaa.try_auth(sender_user, sender_device, auth, &uiaainfo)?;
services() if !worked {
.uiaa return Err(Error::Uiaa(uiaainfo));
.try_auth(sender_user, sender_device, auth, &uiaainfo)?; }
if !worked { // Success!
return Err(Error::Uiaa(uiaainfo)); } else if let Some(json) = body.json_body {
} uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
// Success! services().uiaa.create(sender_user, sender_device, &uiaainfo, &json)?;
} else if let Some(json) = body.json_body { return Err(Error::Uiaa(uiaainfo));
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); } else {
services() return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
.uiaa }
.create(sender_user, sender_device, &uiaainfo, &json)?;
return Err(Error::Uiaa(uiaainfo));
} else {
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
}
if let Some(master_key) = &body.master_key { if let Some(master_key) = &body.master_key {
services().users.add_cross_signing_keys( services().users.add_cross_signing_keys(
sender_user, sender_user,
master_key, master_key,
&body.self_signing_key, &body.self_signing_key,
&body.user_signing_key, &body.user_signing_key,
true, // notify so that other users see the new keys true, // notify so that other users see the new keys
)?; )?;
} }
Ok(upload_signing_keys::v3::Response {}) Ok(upload_signing_keys::v3::Response {})
} }
/// # `POST /_matrix/client/r0/keys/signatures/upload` /// # `POST /_matrix/client/r0/keys/signatures/upload`
/// ///
/// Uploads end-to-end key signatures from the sender user. /// Uploads end-to-end key signatures from the sender user.
pub async fn upload_signatures_route( pub async fn upload_signatures_route(
body: Ruma<upload_signatures::v3::Request>, body: Ruma<upload_signatures::v3::Request>,
) -> Result<upload_signatures::v3::Response> { ) -> Result<upload_signatures::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
for (user_id, keys) in &body.signed_keys { for (user_id, keys) in &body.signed_keys {
for (key_id, key) in keys { for (key_id, key) in keys {
let key = serde_json::to_value(key) let key = serde_json::to_value(key)
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid key JSON"))?; .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid key JSON"))?;
for signature in key for signature in key
.get("signatures") .get("signatures")
.ok_or(Error::BadRequest( .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Missing signatures field."))?
ErrorKind::InvalidParam, .get(sender_user.to_string())
"Missing signatures field.", .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Invalid user in signatures field."))?
))? .as_object()
.get(sender_user.to_string()) .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Invalid signature."))?
.ok_or(Error::BadRequest( .clone()
ErrorKind::InvalidParam, .into_iter()
"Invalid user in signatures field.", {
))? // Signature validation?
.as_object() let signature = (
.ok_or(Error::BadRequest( signature.0,
ErrorKind::InvalidParam, signature
"Invalid signature.", .1
))? .as_str()
.clone() .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Invalid signature value."))?
.into_iter() .to_owned(),
{ );
// Signature validation? services().users.sign_key(user_id, key_id, signature, sender_user)?;
let signature = ( }
signature.0, }
signature }
.1
.as_str()
.ok_or(Error::BadRequest(
ErrorKind::InvalidParam,
"Invalid signature value.",
))?
.to_owned(),
);
services()
.users
.sign_key(user_id, key_id, signature, sender_user)?;
}
}
}
Ok(upload_signatures::v3::Response { Ok(upload_signatures::v3::Response {
failures: BTreeMap::new(), // TODO: integrate failures: BTreeMap::new(), // TODO: integrate
}) })
} }
/// # `POST /_matrix/client/r0/keys/changes` /// # `POST /_matrix/client/r0/keys/changes`
/// ///
/// Gets a list of users who have updated their device identity keys since the previous sync token. /// Gets a list of users who have updated their device identity keys since the
/// previous sync token.
/// ///
/// - TODO: left users /// - TODO: left users
pub async fn get_key_changes_route( pub async fn get_key_changes_route(body: Ruma<get_key_changes::v3::Request>) -> Result<get_key_changes::v3::Response> {
body: Ruma<get_key_changes::v3::Request>, let sender_user = body.sender_user.as_ref().expect("user is authenticated");
) -> Result<get_key_changes::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let mut device_list_updates = HashSet::new(); let mut device_list_updates = HashSet::new();
device_list_updates.extend( device_list_updates.extend(
services() services()
.users .users
.keys_changed( .keys_changed(
sender_user.as_str(), sender_user.as_str(),
body.from body.from.parse().map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `from`."))?,
.parse() Some(body.to.parse().map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `to`."))?),
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `from`."))?, )
Some( .filter_map(std::result::Result::ok),
body.to );
.parse()
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `to`."))?,
),
)
.filter_map(std::result::Result::ok),
);
for room_id in services() for room_id in services().rooms.state_cache.rooms_joined(sender_user).filter_map(std::result::Result::ok) {
.rooms device_list_updates.extend(
.state_cache services()
.rooms_joined(sender_user) .users
.filter_map(std::result::Result::ok) .keys_changed(
{ room_id.as_ref(),
device_list_updates.extend( body.from.parse().map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `from`."))?,
services() Some(body.to.parse().map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `to`."))?),
.users )
.keys_changed( .filter_map(std::result::Result::ok),
room_id.as_ref(), );
body.from.parse().map_err(|_| { }
Error::BadRequest(ErrorKind::InvalidParam, "Invalid `from`.") Ok(get_key_changes::v3::Response {
})?, changed: device_list_updates.into_iter().collect(),
Some(body.to.parse().map_err(|_| { left: Vec::new(), // TODO
Error::BadRequest(ErrorKind::InvalidParam, "Invalid `to`.") })
})?),
)
.filter_map(std::result::Result::ok),
);
}
Ok(get_key_changes::v3::Response {
changed: device_list_updates.into_iter().collect(),
left: Vec::new(), // TODO
})
} }
pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>( pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
sender_user: Option<&UserId>, sender_user: Option<&UserId>, device_keys_input: &BTreeMap<OwnedUserId, Vec<OwnedDeviceId>>, allowed_signatures: F,
device_keys_input: &BTreeMap<OwnedUserId, Vec<OwnedDeviceId>>, include_display_names: bool,
allowed_signatures: F,
include_display_names: bool,
) -> Result<get_keys::v3::Response> { ) -> Result<get_keys::v3::Response> {
let mut master_keys = BTreeMap::new(); let mut master_keys = BTreeMap::new();
let mut self_signing_keys = BTreeMap::new(); let mut self_signing_keys = BTreeMap::new();
let mut user_signing_keys = BTreeMap::new(); let mut user_signing_keys = BTreeMap::new();
let mut device_keys = BTreeMap::new(); let mut device_keys = BTreeMap::new();
let mut get_over_federation = HashMap::new(); let mut get_over_federation = HashMap::new();
for (user_id, device_ids) in device_keys_input { for (user_id, device_ids) in device_keys_input {
let user_id: &UserId = user_id; let user_id: &UserId = user_id;
if user_id.server_name() != services().globals.server_name() { if user_id.server_name() != services().globals.server_name() {
get_over_federation get_over_federation.entry(user_id.server_name()).or_insert_with(Vec::new).push((user_id, device_ids));
.entry(user_id.server_name()) continue;
.or_insert_with(Vec::new) }
.push((user_id, device_ids));
continue;
}
if device_ids.is_empty() { if device_ids.is_empty() {
let mut container = BTreeMap::new(); let mut container = BTreeMap::new();
for device_id in services().users.all_device_ids(user_id) { for device_id in services().users.all_device_ids(user_id) {
let device_id = device_id?; let device_id = device_id?;
if let Some(mut keys) = services().users.get_device_keys(user_id, &device_id)? { if let Some(mut keys) = services().users.get_device_keys(user_id, &device_id)? {
let metadata = services() let metadata = services()
.users .users
.get_device_metadata(user_id, &device_id)? .get_device_metadata(user_id, &device_id)?
.ok_or_else(|| { .ok_or_else(|| Error::bad_database("all_device_keys contained nonexistent device."))?;
Error::bad_database("all_device_keys contained nonexistent device.")
})?;
add_unsigned_device_display_name(&mut keys, metadata, include_display_names) add_unsigned_device_display_name(&mut keys, metadata, include_display_names)
.map_err(|_| Error::bad_database("invalid device keys in database"))?; .map_err(|_| Error::bad_database("invalid device keys in database"))?;
container.insert(device_id, keys); container.insert(device_id, keys);
} }
} }
device_keys.insert(user_id.to_owned(), container); device_keys.insert(user_id.to_owned(), container);
} else { } else {
for device_id in device_ids { for device_id in device_ids {
let mut container = BTreeMap::new(); let mut container = BTreeMap::new();
if let Some(mut keys) = services().users.get_device_keys(user_id, device_id)? { if let Some(mut keys) = services().users.get_device_keys(user_id, device_id)? {
let metadata = services() let metadata = services().users.get_device_metadata(user_id, device_id)?.ok_or(
.users Error::BadRequest(ErrorKind::InvalidParam, "Tried to get keys for nonexistent device."),
.get_device_metadata(user_id, device_id)? )?;
.ok_or(Error::BadRequest(
ErrorKind::InvalidParam,
"Tried to get keys for nonexistent device.",
))?;
add_unsigned_device_display_name(&mut keys, metadata, include_display_names) add_unsigned_device_display_name(&mut keys, metadata, include_display_names)
.map_err(|_| Error::bad_database("invalid device keys in database"))?; .map_err(|_| Error::bad_database("invalid device keys in database"))?;
container.insert(device_id.to_owned(), keys); container.insert(device_id.to_owned(), keys);
} }
device_keys.insert(user_id.to_owned(), container); device_keys.insert(user_id.to_owned(), container);
} }
} }
if let Some(master_key) = if let Some(master_key) = services().users.get_master_key(sender_user, user_id, &allowed_signatures)? {
services() master_keys.insert(user_id.to_owned(), master_key);
.users }
.get_master_key(sender_user, user_id, &allowed_signatures)? if let Some(self_signing_key) =
{ services().users.get_self_signing_key(sender_user, user_id, &allowed_signatures)?
master_keys.insert(user_id.to_owned(), master_key); {
} self_signing_keys.insert(user_id.to_owned(), self_signing_key);
if let Some(self_signing_key) = }
services() if Some(user_id) == sender_user {
.users if let Some(user_signing_key) = services().users.get_user_signing_key(user_id)? {
.get_self_signing_key(sender_user, user_id, &allowed_signatures)? user_signing_keys.insert(user_id.to_owned(), user_signing_key);
{ }
self_signing_keys.insert(user_id.to_owned(), self_signing_key); }
} }
if Some(user_id) == sender_user {
if let Some(user_signing_key) = services().users.get_user_signing_key(user_id)? {
user_signing_keys.insert(user_id.to_owned(), user_signing_key);
}
}
}
let mut failures = BTreeMap::new(); let mut failures = BTreeMap::new();
let back_off = |id| match services() let back_off = |id| match services().globals.bad_query_ratelimiter.write().unwrap().entry(id) {
.globals hash_map::Entry::Vacant(e) => {
.bad_query_ratelimiter e.insert((Instant::now(), 1));
.write() },
.unwrap() hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1),
.entry(id) };
{
hash_map::Entry::Vacant(e) => {
e.insert((Instant::now(), 1));
}
hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1),
};
let mut futures: FuturesUnordered<_> = get_over_federation let mut futures: FuturesUnordered<_> = get_over_federation
.into_iter() .into_iter()
.map(|(server, vec)| async move { .map(|(server, vec)| async move {
if let Some((time, tries)) = services() if let Some((time, tries)) = services().globals.bad_query_ratelimiter.read().unwrap().get(server) {
.globals // Exponential backoff
.bad_query_ratelimiter let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries);
.read() if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
.unwrap() min_elapsed_duration = Duration::from_secs(60 * 60 * 24);
.get(server) }
{
// Exponential backoff
let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries);
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
min_elapsed_duration = Duration::from_secs(60 * 60 * 24);
}
if time.elapsed() < min_elapsed_duration { if time.elapsed() < min_elapsed_duration {
debug!("Backing off query from {:?}", server); debug!("Backing off query from {:?}", server);
return ( return (server, Err(Error::BadServerResponse("bad query, still backing off")));
server, }
Err(Error::BadServerResponse("bad query, still backing off")), }
);
}
}
let mut device_keys_input_fed = BTreeMap::new(); let mut device_keys_input_fed = BTreeMap::new();
for (user_id, keys) in vec { for (user_id, keys) in vec {
device_keys_input_fed.insert(user_id.to_owned(), keys.clone()); device_keys_input_fed.insert(user_id.to_owned(), keys.clone());
} }
( (
server, server,
tokio::time::timeout( tokio::time::timeout(
Duration::from_secs(50), Duration::from_secs(50),
services().sending.send_federation_request( services().sending.send_federation_request(
server, server,
federation::keys::get_keys::v1::Request { federation::keys::get_keys::v1::Request {
device_keys: device_keys_input_fed, device_keys: device_keys_input_fed,
}, },
), ),
) )
.await .await
.map_err(|e| { .map_err(|e| {
error!("get_keys_helper query took too long: {}", e); error!("get_keys_helper query took too long: {}", e);
Error::BadServerResponse("get_keys_helper query took too long") Error::BadServerResponse("get_keys_helper query took too long")
}), }),
) )
}) })
.collect(); .collect();
while let Some((server, response)) = futures.next().await { while let Some((server, response)) = futures.next().await {
match response { match response {
Ok(Ok(response)) => { Ok(Ok(response)) => {
for (user, masterkey) in response.master_keys { for (user, masterkey) in response.master_keys {
let (master_key_id, mut master_key) = let (master_key_id, mut master_key) = services().users.parse_master_key(&user, &masterkey)?;
services().users.parse_master_key(&user, &masterkey)?;
if let Some(our_master_key) = services().users.get_key( if let Some(our_master_key) =
&master_key_id, services().users.get_key(&master_key_id, sender_user, &user, &allowed_signatures)?
sender_user, {
&user, let (_, our_master_key) = services().users.parse_master_key(&user, &our_master_key)?;
&allowed_signatures, master_key.signatures.extend(our_master_key.signatures);
)? { }
let (_, our_master_key) = let json = serde_json::to_value(master_key).expect("to_value always works");
services().users.parse_master_key(&user, &our_master_key)?; let raw = serde_json::from_value(json).expect("Raw::from_value always works");
master_key.signatures.extend(our_master_key.signatures); services().users.add_cross_signing_keys(
} &user, &raw, &None, &None,
let json = serde_json::to_value(master_key).expect("to_value always works"); false, /* Dont notify. A notification would trigger another key request resulting in an
let raw = serde_json::from_value(json).expect("Raw::from_value always works"); * endless loop */
services().users.add_cross_signing_keys( )?;
&user, &raw, &None, &None, master_keys.insert(user, raw);
false, // Dont notify. A notification would trigger another key request resulting in an endless loop }
)?;
master_keys.insert(user, raw);
}
self_signing_keys.extend(response.self_signing_keys); self_signing_keys.extend(response.self_signing_keys);
device_keys.extend(response.device_keys); device_keys.extend(response.device_keys);
} },
_ => { _ => {
back_off(server.to_owned()); back_off(server.to_owned());
failures.insert(server.to_string(), json!({})); failures.insert(server.to_string(), json!({}));
} },
} }
} }
Ok(get_keys::v3::Response { Ok(get_keys::v3::Response {
master_keys, master_keys,
self_signing_keys, self_signing_keys,
user_signing_keys, user_signing_keys,
device_keys, device_keys,
failures, failures,
}) })
} }
fn add_unsigned_device_display_name( fn add_unsigned_device_display_name(
keys: &mut Raw<ruma::encryption::DeviceKeys>, keys: &mut Raw<ruma::encryption::DeviceKeys>, metadata: ruma::api::client::device::Device,
metadata: ruma::api::client::device::Device, include_display_names: bool,
include_display_names: bool,
) -> serde_json::Result<()> { ) -> serde_json::Result<()> {
if let Some(display_name) = metadata.display_name { if let Some(display_name) = metadata.display_name {
let mut object = keys.deserialize_as::<serde_json::Map<String, serde_json::Value>>()?; let mut object = keys.deserialize_as::<serde_json::Map<String, serde_json::Value>>()?;
let unsigned = object.entry("unsigned").or_insert_with(|| json!({})); let unsigned = object.entry("unsigned").or_insert_with(|| json!({}));
if let serde_json::Value::Object(unsigned_object) = unsigned { if let serde_json::Value::Object(unsigned_object) = unsigned {
if include_display_names { if include_display_names {
unsigned_object.insert("device_display_name".to_owned(), display_name.into()); unsigned_object.insert("device_display_name".to_owned(), display_name.into());
} else { } else {
unsigned_object.insert( unsigned_object.insert(
"device_display_name".to_owned(), "device_display_name".to_owned(),
Some(metadata.device_id.as_str().to_owned()).into(), Some(metadata.device_id.as_str().to_owned()).into(),
); );
} }
} }
*keys = Raw::from_json(serde_json::value::to_raw_value(&object)?); *keys = Raw::from_json(serde_json::value::to_raw_value(&object)?);
} }
Ok(()) Ok(())
} }
pub(crate) async fn claim_keys_helper( pub(crate) async fn claim_keys_helper(
one_time_keys_input: &BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, DeviceKeyAlgorithm>>, one_time_keys_input: &BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, DeviceKeyAlgorithm>>,
) -> Result<claim_keys::v3::Response> { ) -> Result<claim_keys::v3::Response> {
let mut one_time_keys = BTreeMap::new(); let mut one_time_keys = BTreeMap::new();
let mut get_over_federation = BTreeMap::new(); let mut get_over_federation = BTreeMap::new();
for (user_id, map) in one_time_keys_input { for (user_id, map) in one_time_keys_input {
if user_id.server_name() != services().globals.server_name() { if user_id.server_name() != services().globals.server_name() {
get_over_federation get_over_federation.entry(user_id.server_name()).or_insert_with(Vec::new).push((user_id, map));
.entry(user_id.server_name()) }
.or_insert_with(Vec::new)
.push((user_id, map));
}
let mut container = BTreeMap::new(); let mut container = BTreeMap::new();
for (device_id, key_algorithm) in map { for (device_id, key_algorithm) in map {
if let Some(one_time_keys) = if let Some(one_time_keys) = services().users.take_one_time_key(user_id, device_id, key_algorithm)? {
services() let mut c = BTreeMap::new();
.users c.insert(one_time_keys.0, one_time_keys.1);
.take_one_time_key(user_id, device_id, key_algorithm)? container.insert(device_id.clone(), c);
{ }
let mut c = BTreeMap::new(); }
c.insert(one_time_keys.0, one_time_keys.1); one_time_keys.insert(user_id.clone(), container);
container.insert(device_id.clone(), c); }
}
}
one_time_keys.insert(user_id.clone(), container);
}
let mut failures = BTreeMap::new(); let mut failures = BTreeMap::new();
let mut futures: FuturesUnordered<_> = get_over_federation let mut futures: FuturesUnordered<_> = get_over_federation
.into_iter() .into_iter()
.map(|(server, vec)| async move { .map(|(server, vec)| async move {
let mut one_time_keys_input_fed = BTreeMap::new(); let mut one_time_keys_input_fed = BTreeMap::new();
for (user_id, keys) in vec { for (user_id, keys) in vec {
one_time_keys_input_fed.insert(user_id.clone(), keys.clone()); one_time_keys_input_fed.insert(user_id.clone(), keys.clone());
} }
( (
server, server,
services() services()
.sending .sending
.send_federation_request( .send_federation_request(
server, server,
federation::keys::claim_keys::v1::Request { federation::keys::claim_keys::v1::Request {
one_time_keys: one_time_keys_input_fed, one_time_keys: one_time_keys_input_fed,
}, },
) )
.await, .await,
) )
}) })
.collect(); .collect();
while let Some((server, response)) = futures.next().await { while let Some((server, response)) = futures.next().await {
match response { match response {
Ok(keys) => { Ok(keys) => {
one_time_keys.extend(keys.one_time_keys); one_time_keys.extend(keys.one_time_keys);
} },
Err(_e) => { Err(_e) => {
failures.insert(server.to_string(), json!({})); failures.insert(server.to_string(), json!({}));
} },
} }
} }
Ok(claim_keys::v3::Response { Ok(claim_keys::v3::Response {
failures, failures,
one_time_keys, one_time_keys,
}) })
} }

View file

@ -1,22 +1,22 @@
use std::{io::Cursor, net::IpAddr, sync::Arc, time::Duration}; use std::{io::Cursor, net::IpAddr, sync::Arc, time::Duration};
use crate::{
service::media::{FileMeta, UrlPreviewData},
services, utils, Error, Result, Ruma,
};
use image::io::Reader as ImgReader; use image::io::Reader as ImgReader;
use reqwest::Url; use reqwest::Url;
use ruma::api::client::{ use ruma::api::client::{
error::ErrorKind, error::ErrorKind,
media::{ media::{
create_content, get_content, get_content_as_filename, get_content_thumbnail, create_content, get_content, get_content_as_filename, get_content_thumbnail, get_media_config,
get_media_config, get_media_preview, get_media_preview,
}, },
}; };
use tracing::{debug, error, info, warn}; use tracing::{debug, error, info, warn};
use webpage::HTML; use webpage::HTML;
use crate::{
service::media::{FileMeta, UrlPreviewData},
services, utils, Error, Result, Ruma,
};
/// generated MXC ID (`media-id`) length /// generated MXC ID (`media-id`) length
const MXC_LENGTH: usize = 32; const MXC_LENGTH: usize = 32;
@ -24,48 +24,39 @@ const MXC_LENGTH: usize = 32;
/// ///
/// Returns max upload size. /// Returns max upload size.
pub async fn get_media_config_route( pub async fn get_media_config_route(
_body: Ruma<get_media_config::v3::Request>, _body: Ruma<get_media_config::v3::Request>,
) -> Result<get_media_config::v3::Response> { ) -> Result<get_media_config::v3::Response> {
Ok(get_media_config::v3::Response { Ok(get_media_config::v3::Response {
upload_size: services().globals.max_request_size().into(), upload_size: services().globals.max_request_size().into(),
}) })
} }
/// # `GET /_matrix/media/v3/preview_url` /// # `GET /_matrix/media/v3/preview_url`
/// ///
/// Returns URL preview. /// Returns URL preview.
pub async fn get_media_preview_route( pub async fn get_media_preview_route(
body: Ruma<get_media_preview::v3::Request>, body: Ruma<get_media_preview::v3::Request>,
) -> Result<get_media_preview::v3::Response> { ) -> Result<get_media_preview::v3::Response> {
let url = &body.url; let url = &body.url;
if !url_preview_allowed(url) { if !url_preview_allowed(url) {
return Err(Error::BadRequest( return Err(Error::BadRequest(ErrorKind::Forbidden, "URL is not allowed to be previewed"));
ErrorKind::Forbidden, }
"URL is not allowed to be previewed",
));
}
if let Ok(preview) = get_url_preview(url).await { if let Ok(preview) = get_url_preview(url).await {
let res = serde_json::value::to_raw_value(&preview).map_err(|e| { let res = serde_json::value::to_raw_value(&preview).map_err(|e| {
error!( error!("Failed to convert UrlPreviewData into a serde json value: {}", e);
"Failed to convert UrlPreviewData into a serde json value: {}", Error::BadRequest(ErrorKind::Unknown, "Unknown error occurred parsing URL preview")
e })?;
);
Error::BadRequest(
ErrorKind::Unknown,
"Unknown error occurred parsing URL preview",
)
})?;
return Ok(get_media_preview::v3::Response::from_raw_value(res)); return Ok(get_media_preview::v3::Response::from_raw_value(res));
} }
Err(Error::BadRequest( Err(Error::BadRequest(
ErrorKind::LimitExceeded { ErrorKind::LimitExceeded {
retry_after_ms: Some(Duration::from_secs(5)), retry_after_ms: Some(Duration::from_secs(5)),
}, },
"Retry later", "Retry later",
)) ))
} }
/// # `POST /_matrix/media/v3/upload` /// # `POST /_matrix/media/v3/upload`
@ -74,80 +65,70 @@ pub async fn get_media_preview_route(
/// ///
/// - Some metadata will be saved in the database /// - Some metadata will be saved in the database
/// - Media will be saved in the media/ directory /// - Media will be saved in the media/ directory
pub async fn create_content_route( pub async fn create_content_route(body: Ruma<create_content::v3::Request>) -> Result<create_content::v3::Response> {
body: Ruma<create_content::v3::Request>, let mxc = format!(
) -> Result<create_content::v3::Response> { "mxc://{}/{}",
let mxc = format!( services().globals.server_name(),
"mxc://{}/{}", utils::random_string(MXC_LENGTH)
services().globals.server_name(), );
utils::random_string(MXC_LENGTH)
);
services() services()
.media .media
.create( .create(
mxc.clone(), mxc.clone(),
body.filename body.filename.as_ref().map(|filename| "inline; filename=".to_owned() + filename).as_deref(),
.as_ref() body.content_type.as_deref(),
.map(|filename| "inline; filename=".to_owned() + filename) &body.file,
.as_deref(), )
body.content_type.as_deref(), .await?;
&body.file,
)
.await?;
let content_uri = mxc.into(); let content_uri = mxc.into();
Ok(create_content::v3::Response { Ok(create_content::v3::Response {
content_uri, content_uri,
blurhash: None, blurhash: None,
}) })
} }
/// helper method to fetch remote media from other servers over federation /// helper method to fetch remote media from other servers over federation
pub async fn get_remote_content( pub async fn get_remote_content(
mxc: &str, mxc: &str, server_name: &ruma::ServerName, media_id: String, allow_redirect: bool, timeout_ms: Duration,
server_name: &ruma::ServerName,
media_id: String,
allow_redirect: bool,
timeout_ms: Duration,
) -> Result<get_content::v3::Response, Error> { ) -> Result<get_content::v3::Response, Error> {
// we'll lie to the client and say the blocked server's media was not found and log. // we'll lie to the client and say the blocked server's media was not found and
// the client has no way of telling anyways so this is a security bonus. // log. the client has no way of telling anyways so this is a security bonus.
if services() if services().globals.prevent_media_downloads_from().contains(&server_name.to_owned()) {
.globals info!(
.prevent_media_downloads_from() "Received request for remote media `{}` but server is in our media server blocklist. Returning 404.",
.contains(&server_name.to_owned()) mxc
{ );
info!("Received request for remote media `{}` but server is in our media server blocklist. Returning 404.", mxc); return Err(Error::BadRequest(ErrorKind::NotFound, "Media not found."));
return Err(Error::BadRequest(ErrorKind::NotFound, "Media not found.")); }
}
let content_response = services() let content_response = services()
.sending .sending
.send_federation_request( .send_federation_request(
server_name, server_name,
get_content::v3::Request { get_content::v3::Request {
allow_remote: true, allow_remote: true,
server_name: server_name.to_owned(), server_name: server_name.to_owned(),
media_id, media_id,
timeout_ms, timeout_ms,
allow_redirect, allow_redirect,
}, },
) )
.await?; .await?;
services() services()
.media .media
.create( .create(
mxc.to_owned(), mxc.to_owned(),
content_response.content_disposition.as_deref(), content_response.content_disposition.as_deref(),
content_response.content_type.as_deref(), content_response.content_type.as_deref(),
&content_response.file, &content_response.file,
) )
.await?; .await?;
Ok(content_response) Ok(content_response)
} }
/// # `GET /_matrix/media/v3/download/{serverName}/{mediaId}` /// # `GET /_matrix/media/v3/download/{serverName}/{mediaId}`
@ -156,37 +137,36 @@ pub async fn get_remote_content(
/// ///
/// - Only allows federation if `allow_remote` is true /// - Only allows federation if `allow_remote` is true
/// - Only redirects if `allow_redirect` is true /// - Only redirects if `allow_redirect` is true
/// - Uses client-provided `timeout_ms` if available, else defaults to 20 seconds /// - Uses client-provided `timeout_ms` if available, else defaults to 20
pub async fn get_content_route( /// seconds
body: Ruma<get_content::v3::Request>, pub async fn get_content_route(body: Ruma<get_content::v3::Request>) -> Result<get_content::v3::Response> {
) -> Result<get_content::v3::Response> { let mxc = format!("mxc://{}/{}", body.server_name, body.media_id);
let mxc = format!("mxc://{}/{}", body.server_name, body.media_id);
if let Some(FileMeta { if let Some(FileMeta {
content_disposition, content_disposition,
content_type, content_type,
file, file,
}) = services().media.get(mxc.clone()).await? }) = services().media.get(mxc.clone()).await?
{ {
Ok(get_content::v3::Response { Ok(get_content::v3::Response {
file, file,
content_type, content_type,
content_disposition, content_disposition,
cross_origin_resource_policy: Some("cross-origin".to_owned()), cross_origin_resource_policy: Some("cross-origin".to_owned()),
}) })
} else if &*body.server_name != services().globals.server_name() && body.allow_remote { } else if &*body.server_name != services().globals.server_name() && body.allow_remote {
let remote_content_response = get_remote_content( let remote_content_response = get_remote_content(
&mxc, &mxc,
&body.server_name, &body.server_name,
body.media_id.clone(), body.media_id.clone(),
body.allow_redirect, body.allow_redirect,
body.timeout_ms, body.timeout_ms,
) )
.await?; .await?;
Ok(remote_content_response) Ok(remote_content_response)
} else { } else {
Err(Error::BadRequest(ErrorKind::NotFound, "Media not found.")) Err(Error::BadRequest(ErrorKind::NotFound, "Media not found."))
} }
} }
/// # `GET /_matrix/media/v3/download/{serverName}/{mediaId}/{fileName}` /// # `GET /_matrix/media/v3/download/{serverName}/{mediaId}/{fileName}`
@ -195,41 +175,44 @@ pub async fn get_content_route(
/// ///
/// - Only allows federation if `allow_remote` is true /// - Only allows federation if `allow_remote` is true
/// - Only redirects if `allow_redirect` is true /// - Only redirects if `allow_redirect` is true
/// - Uses client-provided `timeout_ms` if available, else defaults to 20 seconds /// - Uses client-provided `timeout_ms` if available, else defaults to 20
/// seconds
pub async fn get_content_as_filename_route( pub async fn get_content_as_filename_route(
body: Ruma<get_content_as_filename::v3::Request>, body: Ruma<get_content_as_filename::v3::Request>,
) -> Result<get_content_as_filename::v3::Response> { ) -> Result<get_content_as_filename::v3::Response> {
let mxc = format!("mxc://{}/{}", body.server_name, body.media_id); let mxc = format!("mxc://{}/{}", body.server_name, body.media_id);
if let Some(FileMeta { if let Some(FileMeta {
content_type, file, .. content_type,
}) = services().media.get(mxc.clone()).await? file,
{ ..
Ok(get_content_as_filename::v3::Response { }) = services().media.get(mxc.clone()).await?
file, {
content_type, Ok(get_content_as_filename::v3::Response {
content_disposition: Some(format!("inline; filename={}", body.filename)), file,
cross_origin_resource_policy: Some("cross-origin".to_owned()), content_type,
}) content_disposition: Some(format!("inline; filename={}", body.filename)),
} else if &*body.server_name != services().globals.server_name() && body.allow_remote { cross_origin_resource_policy: Some("cross-origin".to_owned()),
let remote_content_response = get_remote_content( })
&mxc, } else if &*body.server_name != services().globals.server_name() && body.allow_remote {
&body.server_name, let remote_content_response = get_remote_content(
body.media_id.clone(), &mxc,
body.allow_redirect, &body.server_name,
body.timeout_ms, body.media_id.clone(),
) body.allow_redirect,
.await?; body.timeout_ms,
)
.await?;
Ok(get_content_as_filename::v3::Response { Ok(get_content_as_filename::v3::Response {
content_disposition: Some(format!("inline: filename={}", body.filename)), content_disposition: Some(format!("inline: filename={}", body.filename)),
content_type: remote_content_response.content_type, content_type: remote_content_response.content_type,
file: remote_content_response.file, file: remote_content_response.file,
cross_origin_resource_policy: Some("cross-origin".to_owned()), cross_origin_resource_policy: Some("cross-origin".to_owned()),
}) })
} else { } else {
Err(Error::BadRequest(ErrorKind::NotFound, "Media not found.")) Err(Error::BadRequest(ErrorKind::NotFound, "Media not found."))
} }
} }
/// # `GET /_matrix/media/v3/thumbnail/{serverName}/{mediaId}` /// # `GET /_matrix/media/v3/thumbnail/{serverName}/{mediaId}`
@ -238,157 +221,152 @@ pub async fn get_content_as_filename_route(
/// ///
/// - Only allows federation if `allow_remote` is true /// - Only allows federation if `allow_remote` is true
/// - Only redirects if `allow_redirect` is true /// - Only redirects if `allow_redirect` is true
/// - Uses client-provided `timeout_ms` if available, else defaults to 20 seconds /// - Uses client-provided `timeout_ms` if available, else defaults to 20
/// seconds
pub async fn get_content_thumbnail_route( pub async fn get_content_thumbnail_route(
body: Ruma<get_content_thumbnail::v3::Request>, body: Ruma<get_content_thumbnail::v3::Request>,
) -> Result<get_content_thumbnail::v3::Response> { ) -> Result<get_content_thumbnail::v3::Response> {
let mxc = format!("mxc://{}/{}", body.server_name, body.media_id); let mxc = format!("mxc://{}/{}", body.server_name, body.media_id);
if let Some(FileMeta { if let Some(FileMeta {
content_type, file, .. content_type,
}) = services() file,
.media ..
.get_thumbnail( }) = services()
mxc.clone(), .media
body.width .get_thumbnail(
.try_into() mxc.clone(),
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid."))?, body.width.try_into().map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid."))?,
body.height body.height.try_into().map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Height is invalid."))?,
.try_into() )
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Height is invalid."))?, .await?
) {
.await? Ok(get_content_thumbnail::v3::Response {
{ file,
Ok(get_content_thumbnail::v3::Response { content_type,
file, cross_origin_resource_policy: Some("cross-origin".to_owned()),
content_type, })
cross_origin_resource_policy: Some("cross-origin".to_owned()), } else if &*body.server_name != services().globals.server_name() && body.allow_remote {
}) // we'll lie to the client and say the blocked server's media was not found and
} else if &*body.server_name != services().globals.server_name() && body.allow_remote { // log. the client has no way of telling anyways so this is a security bonus.
// we'll lie to the client and say the blocked server's media was not found and log. if services().globals.prevent_media_downloads_from().contains(&body.server_name.clone()) {
// the client has no way of telling anyways so this is a security bonus. info!(
if services() "Received request for remote media `{}` but server is in our media server blocklist. Returning 404.",
.globals mxc
.prevent_media_downloads_from() );
.contains(&body.server_name.clone()) return Err(Error::BadRequest(ErrorKind::NotFound, "Media not found."));
{ }
info!("Received request for remote media `{}` but server is in our media server blocklist. Returning 404.", mxc);
return Err(Error::BadRequest(ErrorKind::NotFound, "Media not found."));
}
let get_thumbnail_response = services() let get_thumbnail_response = services()
.sending .sending
.send_federation_request( .send_federation_request(
&body.server_name, &body.server_name,
get_content_thumbnail::v3::Request { get_content_thumbnail::v3::Request {
allow_remote: body.allow_remote, allow_remote: body.allow_remote,
height: body.height, height: body.height,
width: body.width, width: body.width,
method: body.method.clone(), method: body.method.clone(),
server_name: body.server_name.clone(), server_name: body.server_name.clone(),
media_id: body.media_id.clone(), media_id: body.media_id.clone(),
timeout_ms: body.timeout_ms, timeout_ms: body.timeout_ms,
allow_redirect: body.allow_redirect, allow_redirect: body.allow_redirect,
}, },
) )
.await?; .await?;
services() services()
.media .media
.upload_thumbnail( .upload_thumbnail(
mxc, mxc,
None, None,
get_thumbnail_response.content_type.as_deref(), get_thumbnail_response.content_type.as_deref(),
body.width.try_into().expect("all UInts are valid u32s"), body.width.try_into().expect("all UInts are valid u32s"),
body.height.try_into().expect("all UInts are valid u32s"), body.height.try_into().expect("all UInts are valid u32s"),
&get_thumbnail_response.file, &get_thumbnail_response.file,
) )
.await?; .await?;
Ok(get_thumbnail_response) Ok(get_thumbnail_response)
} else { } else {
Err(Error::BadRequest(ErrorKind::NotFound, "Media not found.")) Err(Error::BadRequest(ErrorKind::NotFound, "Media not found."))
} }
} }
async fn download_image(client: &reqwest::Client, url: &str) -> Result<UrlPreviewData> { async fn download_image(client: &reqwest::Client, url: &str) -> Result<UrlPreviewData> {
let image = client.get(url).send().await?.bytes().await?; let image = client.get(url).send().await?.bytes().await?;
let mxc = format!( let mxc = format!(
"mxc://{}/{}", "mxc://{}/{}",
services().globals.server_name(), services().globals.server_name(),
utils::random_string(MXC_LENGTH) utils::random_string(MXC_LENGTH)
); );
services() services().media.create(mxc.clone(), None, None, &image).await?;
.media
.create(mxc.clone(), None, None, &image)
.await?;
let (width, height) = match ImgReader::new(Cursor::new(&image)).with_guessed_format() { let (width, height) = match ImgReader::new(Cursor::new(&image)).with_guessed_format() {
Err(_) => (None, None), Err(_) => (None, None),
Ok(reader) => match reader.into_dimensions() { Ok(reader) => match reader.into_dimensions() {
Err(_) => (None, None), Err(_) => (None, None),
Ok((width, height)) => (Some(width), Some(height)), Ok((width, height)) => (Some(width), Some(height)),
}, },
}; };
Ok(UrlPreviewData { Ok(UrlPreviewData {
image: Some(mxc), image: Some(mxc),
image_size: Some(image.len()), image_size: Some(image.len()),
image_width: width, image_width: width,
image_height: height, image_height: height,
..Default::default() ..Default::default()
}) })
} }
async fn download_html(client: &reqwest::Client, url: &str) -> Result<UrlPreviewData> { async fn download_html(client: &reqwest::Client, url: &str) -> Result<UrlPreviewData> {
let mut response = client.get(url).send().await?; let mut response = client.get(url).send().await?;
let mut bytes: Vec<u8> = Vec::new(); let mut bytes: Vec<u8> = Vec::new();
while let Some(chunk) = response.chunk().await? { while let Some(chunk) = response.chunk().await? {
bytes.extend_from_slice(&chunk); bytes.extend_from_slice(&chunk);
if bytes.len() > services().globals.url_preview_max_spider_size() { if bytes.len() > services().globals.url_preview_max_spider_size() {
debug!("Response body from URL {} exceeds url_preview_max_spider_size ({}), not processing the rest of the response body and assuming our necessary data is in this range.", url, services().globals.url_preview_max_spider_size()); debug!(
break; "Response body from URL {} exceeds url_preview_max_spider_size ({}), not processing the rest of the \
} response body and assuming our necessary data is in this range.",
} url,
let body = String::from_utf8_lossy(&bytes); services().globals.url_preview_max_spider_size()
let html = match HTML::from_string(body.to_string(), Some(url.to_owned())) { );
Ok(html) => html, break;
Err(_) => { }
return Err(Error::BadRequest( }
ErrorKind::Unknown, let body = String::from_utf8_lossy(&bytes);
"Failed to parse HTML", let html = match HTML::from_string(body.to_string(), Some(url.to_owned())) {
)) Ok(html) => html,
} Err(_) => return Err(Error::BadRequest(ErrorKind::Unknown, "Failed to parse HTML")),
}; };
let mut data = match html.opengraph.images.first() { let mut data = match html.opengraph.images.first() {
None => UrlPreviewData::default(), None => UrlPreviewData::default(),
Some(obj) => download_image(client, &obj.url).await?, Some(obj) => download_image(client, &obj.url).await?,
}; };
let props = html.opengraph.properties; let props = html.opengraph.properties;
/* use OpenGraph title/description, but fall back to HTML if not available */ /* use OpenGraph title/description, but fall back to HTML if not available */
data.title = props.get("title").cloned().or(html.title); data.title = props.get("title").cloned().or(html.title);
data.description = props.get("description").cloned().or(html.description); data.description = props.get("description").cloned().or(html.description);
Ok(data) Ok(data)
} }
fn url_request_allowed(addr: &IpAddr) -> bool { fn url_request_allowed(addr: &IpAddr) -> bool {
// TODO: make this check ip_range_denylist // TODO: make this check ip_range_denylist
// could be implemented with reqwest when it supports IP filtering: // could be implemented with reqwest when it supports IP filtering:
// https://github.com/seanmonstar/reqwest/issues/1515 // https://github.com/seanmonstar/reqwest/issues/1515
// These checks have been taken from the Rust core/net/ipaddr.rs crate, // These checks have been taken from the Rust core/net/ipaddr.rs crate,
// IpAddr::V4.is_global() and IpAddr::V6.is_global(), as .is_global is not // IpAddr::V4.is_global() and IpAddr::V6.is_global(), as .is_global is not
// yet stabilized. TODO: Once this is stable, this match can be simplified. // yet stabilized. TODO: Once this is stable, this match can be simplified.
match addr { match addr {
IpAddr::V4(ip4) => { IpAddr::V4(ip4) => {
!(ip4.octets()[0] == 0 // "This network" !(ip4.octets()[0] == 0 // "This network"
|| ip4.is_private() || ip4.is_private()
|| (ip4.octets()[0] == 100 && (ip4.octets()[1] & 0b1100_0000 == 0b0100_0000)) // is_shared() || (ip4.octets()[0] == 100 && (ip4.octets()[1] & 0b1100_0000 == 0b0100_0000)) // is_shared()
|| ip4.is_loopback() || ip4.is_loopback()
@ -399,9 +377,9 @@ fn url_request_allowed(addr: &IpAddr) -> bool {
|| (ip4.octets()[0] == 198 && (ip4.octets()[1] & 0xfe) == 18) // is_benchmarking() || (ip4.octets()[0] == 198 && (ip4.octets()[1] & 0xfe) == 18) // is_benchmarking()
|| (ip4.octets()[0] & 240 == 240 && !ip4.is_broadcast()) // is_reserved() || (ip4.octets()[0] & 240 == 240 && !ip4.is_broadcast()) // is_reserved()
|| ip4.is_broadcast()) || ip4.is_broadcast())
} },
IpAddr::V6(ip6) => { IpAddr::V6(ip6) => {
!(ip6.is_unspecified() !(ip6.is_unspecified()
|| ip6.is_loopback() || ip6.is_loopback()
// IPv4-mapped Address (`::ffff:0:0/96`) // IPv4-mapped Address (`::ffff:0:0/96`)
|| matches!(ip6.segments(), [0, 0, 0, 0, 0, 0xffff, _, _]) || matches!(ip6.segments(), [0, 0, 0, 0, 0, 0xffff, _, _])
@ -426,178 +404,127 @@ fn url_request_allowed(addr: &IpAddr) -> bool {
|| ((ip6.segments()[0] == 0x2001) && (ip6.segments()[1] == 0xdb8)) // is_documentation() || ((ip6.segments()[0] == 0x2001) && (ip6.segments()[1] == 0xdb8)) // is_documentation()
|| ((ip6.segments()[0] & 0xfe00) == 0xfc00) // is_unique_local() || ((ip6.segments()[0] & 0xfe00) == 0xfc00) // is_unique_local()
|| ((ip6.segments()[0] & 0xffc0) == 0xfe80)) // is_unicast_link_local || ((ip6.segments()[0] & 0xffc0) == 0xfe80)) // is_unicast_link_local
} },
} }
} }
async fn request_url_preview(url: &str) -> Result<UrlPreviewData> { async fn request_url_preview(url: &str) -> Result<UrlPreviewData> {
let client = services().globals.url_preview_client(); let client = services().globals.url_preview_client();
let response = client.head(url).send().await?; let response = client.head(url).send().await?;
if !response if !response.remote_addr().map_or(false, |a| url_request_allowed(&a.ip())) {
.remote_addr() return Err(Error::BadRequest(
.map_or(false, |a| url_request_allowed(&a.ip())) ErrorKind::Forbidden,
{ "Requesting from this address is forbidden",
return Err(Error::BadRequest( ));
ErrorKind::Forbidden, }
"Requesting from this address is forbidden",
));
}
let content_type = match response let content_type = match response.headers().get(reqwest::header::CONTENT_TYPE).and_then(|x| x.to_str().ok()) {
.headers() Some(ct) => ct,
.get(reqwest::header::CONTENT_TYPE) None => return Err(Error::BadRequest(ErrorKind::Unknown, "Unknown Content-Type")),
.and_then(|x| x.to_str().ok()) };
{ let data = match content_type {
Some(ct) => ct, html if html.starts_with("text/html") => download_html(&client, url).await?,
None => { img if img.starts_with("image/") => download_image(&client, url).await?,
return Err(Error::BadRequest( _ => return Err(Error::BadRequest(ErrorKind::Unknown, "Unsupported Content-Type")),
ErrorKind::Unknown, };
"Unknown Content-Type",
))
}
};
let data = match content_type {
html if html.starts_with("text/html") => download_html(&client, url).await?,
img if img.starts_with("image/") => download_image(&client, url).await?,
_ => {
return Err(Error::BadRequest(
ErrorKind::Unknown,
"Unsupported Content-Type",
))
}
};
services().media.set_url_preview(url, &data).await?; services().media.set_url_preview(url, &data).await?;
Ok(data) Ok(data)
} }
async fn get_url_preview(url: &str) -> Result<UrlPreviewData> { async fn get_url_preview(url: &str) -> Result<UrlPreviewData> {
if let Some(preview) = services().media.get_url_preview(url).await { if let Some(preview) = services().media.get_url_preview(url).await {
return Ok(preview); return Ok(preview);
} }
// ensure that only one request is made per URL // ensure that only one request is made per URL
let mutex_request = Arc::clone( let mutex_request =
services() Arc::clone(services().media.url_preview_mutex.write().unwrap().entry(url.to_owned()).or_default());
.media let _request_lock = mutex_request.lock().await;
.url_preview_mutex
.write()
.unwrap()
.entry(url.to_owned())
.or_default(),
);
let _request_lock = mutex_request.lock().await;
match services().media.get_url_preview(url).await { match services().media.get_url_preview(url).await {
Some(preview) => Ok(preview), Some(preview) => Ok(preview),
None => request_url_preview(url).await, None => request_url_preview(url).await,
} }
} }
fn url_preview_allowed(url_str: &str) -> bool { fn url_preview_allowed(url_str: &str) -> bool {
let url: Url = match Url::parse(url_str) { let url: Url = match Url::parse(url_str) {
Ok(u) => u, Ok(u) => u,
Err(e) => { Err(e) => {
warn!("Failed to parse URL from a str: {}", e); warn!("Failed to parse URL from a str: {}", e);
return false; return false;
} },
}; };
if ["http", "https"] if ["http", "https"].iter().all(|&scheme| scheme != url.scheme().to_lowercase()) {
.iter() debug!("Ignoring non-HTTP/HTTPS URL to preview: {}", url);
.all(|&scheme| scheme != url.scheme().to_lowercase()) return false;
{ }
debug!("Ignoring non-HTTP/HTTPS URL to preview: {}", url);
return false;
}
let host = match url.host_str() { let host = match url.host_str() {
None => { None => {
debug!( debug!("Ignoring URL preview for a URL that does not have a host (?): {}", url);
"Ignoring URL preview for a URL that does not have a host (?): {}", return false;
url },
); Some(h) => h.to_owned(),
return false; };
}
Some(h) => h.to_owned(),
};
let allowlist_domain_contains = services().globals.url_preview_domain_contains_allowlist(); let allowlist_domain_contains = services().globals.url_preview_domain_contains_allowlist();
let allowlist_domain_explicit = services().globals.url_preview_domain_explicit_allowlist(); let allowlist_domain_explicit = services().globals.url_preview_domain_explicit_allowlist();
let allowlist_url_contains = services().globals.url_preview_url_contains_allowlist(); let allowlist_url_contains = services().globals.url_preview_url_contains_allowlist();
if allowlist_domain_contains.contains(&"*".to_owned()) if allowlist_domain_contains.contains(&"*".to_owned())
|| allowlist_domain_explicit.contains(&"*".to_owned()) || allowlist_domain_explicit.contains(&"*".to_owned())
|| allowlist_url_contains.contains(&"*".to_owned()) || allowlist_url_contains.contains(&"*".to_owned())
{ {
debug!( debug!("Config key contains * which is allowing all URL previews. Allowing URL {}", url);
"Config key contains * which is allowing all URL previews. Allowing URL {}", return true;
url }
);
return true;
}
if !host.is_empty() { if !host.is_empty() {
if allowlist_domain_explicit.contains(&host) { if allowlist_domain_explicit.contains(&host) {
debug!( debug!("Host {} is allowed by url_preview_domain_explicit_allowlist (check 1/3)", &host);
"Host {} is allowed by url_preview_domain_explicit_allowlist (check 1/3)", return true;
&host }
);
return true;
}
if allowlist_domain_contains if allowlist_domain_contains.iter().any(|domain_s| domain_s.contains(&host.clone())) {
.iter() debug!("Host {} is allowed by url_preview_domain_contains_allowlist (check 2/3)", &host);
.any(|domain_s| domain_s.contains(&host.clone())) return true;
{ }
debug!(
"Host {} is allowed by url_preview_domain_contains_allowlist (check 2/3)",
&host
);
return true;
}
if allowlist_url_contains if allowlist_url_contains.iter().any(|url_s| url.to_string().contains(&url_s.to_string())) {
.iter() debug!("URL {} is allowed by url_preview_url_contains_allowlist (check 3/3)", &host);
.any(|url_s| url.to_string().contains(&url_s.to_string())) return true;
{ }
debug!(
"URL {} is allowed by url_preview_url_contains_allowlist (check 3/3)",
&host
);
return true;
}
// check root domain if available and if user has root domain checks // check root domain if available and if user has root domain checks
if services().globals.url_preview_check_root_domain() { if services().globals.url_preview_check_root_domain() {
debug!("Checking root domain"); debug!("Checking root domain");
match host.split_once('.') { match host.split_once('.') {
None => return false, None => return false,
Some((_, root_domain)) => { Some((_, root_domain)) => {
if allowlist_domain_explicit.contains(&root_domain.to_owned()) { if allowlist_domain_explicit.contains(&root_domain.to_owned()) {
debug!( debug!(
"Root domain {} is allowed by url_preview_domain_explicit_allowlist (check 1/3)", "Root domain {} is allowed by url_preview_domain_explicit_allowlist (check 1/3)",
&root_domain &root_domain
); );
return true; return true;
} }
if allowlist_domain_contains if allowlist_domain_contains.iter().any(|domain_s| domain_s.contains(&root_domain.to_owned())) {
.iter() debug!(
.any(|domain_s| domain_s.contains(&root_domain.to_owned())) "Root domain {} is allowed by url_preview_domain_contains_allowlist (check 2/3)",
{ &root_domain
debug!( );
"Root domain {} is allowed by url_preview_domain_contains_allowlist (check 2/3)", return true;
&root_domain }
); },
return true; }
} }
} }
}
}
}
false false
} }

File diff suppressed because it is too large Load diff

View file

@ -1,316 +1,284 @@
use crate::{ use std::{
service::{pdu::PduBuilder, rooms::timeline::PduCount}, collections::{BTreeMap, HashSet},
services, utils, Error, Result, Ruma, sync::Arc,
}; };
use ruma::{ use ruma::{
api::client::{ api::client::{
error::ErrorKind, error::ErrorKind,
message::{get_message_events, send_message_event}, message::{get_message_events, send_message_event},
}, },
events::{StateEventType, TimelineEventType}, events::{StateEventType, TimelineEventType},
}; };
use serde_json::from_str; use serde_json::from_str;
use std::{
collections::{BTreeMap, HashSet}, use crate::{
sync::Arc, service::{pdu::PduBuilder, rooms::timeline::PduCount},
services, utils, Error, Result, Ruma,
}; };
/// # `PUT /_matrix/client/v3/rooms/{roomId}/send/{eventType}/{txnId}` /// # `PUT /_matrix/client/v3/rooms/{roomId}/send/{eventType}/{txnId}`
/// ///
/// Send a message event into the room. /// Send a message event into the room.
/// ///
/// - Is a NOOP if the txn id was already used before and returns the same event id again /// - Is a NOOP if the txn id was already used before and returns the same event
/// id again
/// - The only requirement for the content is that it has to be valid json /// - The only requirement for the content is that it has to be valid json
/// - Tries to send the event into the room, auth rules will determine if it is allowed /// - Tries to send the event into the room, auth rules will determine if it is
/// allowed
pub async fn send_message_event_route( pub async fn send_message_event_route(
body: Ruma<send_message_event::v3::Request>, body: Ruma<send_message_event::v3::Request>,
) -> Result<send_message_event::v3::Response> { ) -> Result<send_message_event::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_device = body.sender_device.as_deref(); let sender_device = body.sender_device.as_deref();
let mutex_state = Arc::clone( let mutex_state =
services() Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(body.room_id.clone()).or_default());
.globals let state_lock = mutex_state.lock().await;
.roomid_mutex_state
.write()
.unwrap()
.entry(body.room_id.clone())
.or_default(),
);
let state_lock = mutex_state.lock().await;
// Forbid m.room.encrypted if encryption is disabled // Forbid m.room.encrypted if encryption is disabled
if TimelineEventType::RoomEncrypted == body.event_type.to_string().into() if TimelineEventType::RoomEncrypted == body.event_type.to_string().into() && !services().globals.allow_encryption()
&& !services().globals.allow_encryption() {
{ return Err(Error::BadRequest(ErrorKind::Forbidden, "Encryption has been disabled"));
return Err(Error::BadRequest( }
ErrorKind::Forbidden,
"Encryption has been disabled",
));
}
// certain event types require certain fields to be valid in request bodies. // certain event types require certain fields to be valid in request bodies.
// this helps prevent attempting to handle events that we can't deserialise later so don't waste resources on it. // this helps prevent attempting to handle events that we can't deserialise
// // later so don't waste resources on it.
// see https://spec.matrix.org/v1.9/client-server-api/#events-2 for what's required per event type. //
match body.event_type.to_string().into() { // see https://spec.matrix.org/v1.9/client-server-api/#events-2 for what's required per event type.
TimelineEventType::RoomMessage => { match body.event_type.to_string().into() {
let body_field = body.body.body.get_field::<String>("body"); TimelineEventType::RoomMessage => {
let msgtype_field = body.body.body.get_field::<String>("msgtype"); let body_field = body.body.body.get_field::<String>("body");
let msgtype_field = body.body.body.get_field::<String>("msgtype");
if body_field.is_err() { if body_field.is_err() {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::InvalidParam, ErrorKind::InvalidParam,
"'body' field in JSON request is invalid", "'body' field in JSON request is invalid",
)); ));
} }
if msgtype_field.is_err() { if msgtype_field.is_err() {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::InvalidParam, ErrorKind::InvalidParam,
"'msgtype' field in JSON request is invalid", "'msgtype' field in JSON request is invalid",
)); ));
} }
} },
TimelineEventType::RoomName => { TimelineEventType::RoomName => {
let name_field = body.body.body.get_field::<String>("name"); let name_field = body.body.body.get_field::<String>("name");
if name_field.is_err() { if name_field.is_err() {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::InvalidParam, ErrorKind::InvalidParam,
"'name' field in JSON request is invalid", "'name' field in JSON request is invalid",
)); ));
} }
} },
TimelineEventType::RoomTopic => { TimelineEventType::RoomTopic => {
let topic_field = body.body.body.get_field::<String>("topic"); let topic_field = body.body.body.get_field::<String>("topic");
if topic_field.is_err() { if topic_field.is_err() {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::InvalidParam, ErrorKind::InvalidParam,
"'topic' field in JSON request is invalid", "'topic' field in JSON request is invalid",
)); ));
} }
} },
_ => {} // event may be custom/experimental or can be empty don't do anything with it _ => {}, // event may be custom/experimental or can be empty don't do anything with it
}; };
// Check if this is a new transaction id // Check if this is a new transaction id
if let Some(response) = if let Some(response) = services().transaction_ids.existing_txnid(sender_user, sender_device, &body.txn_id)? {
services() // The client might have sent a txnid of the /sendToDevice endpoint
.transaction_ids // This txnid has no response associated with it
.existing_txnid(sender_user, sender_device, &body.txn_id)? if response.is_empty() {
{ return Err(Error::BadRequest(
// The client might have sent a txnid of the /sendToDevice endpoint ErrorKind::InvalidParam,
// This txnid has no response associated with it "Tried to use txn id already used for an incompatible endpoint.",
if response.is_empty() { ));
return Err(Error::BadRequest( }
ErrorKind::InvalidParam,
"Tried to use txn id already used for an incompatible endpoint.",
));
}
let event_id = utils::string_from_bytes(&response) let event_id = utils::string_from_bytes(&response)
.map_err(|_| Error::bad_database("Invalid txnid bytes in database."))? .map_err(|_| Error::bad_database("Invalid txnid bytes in database."))?
.try_into() .try_into()
.map_err(|_| Error::bad_database("Invalid event id in txnid data."))?; .map_err(|_| Error::bad_database("Invalid event id in txnid data."))?;
return Ok(send_message_event::v3::Response { event_id }); return Ok(send_message_event::v3::Response {
} event_id,
});
}
let mut unsigned = BTreeMap::new(); let mut unsigned = BTreeMap::new();
unsigned.insert("transaction_id".to_owned(), body.txn_id.to_string().into()); unsigned.insert("transaction_id".to_owned(), body.txn_id.to_string().into());
let event_id = services() let event_id = services()
.rooms .rooms
.timeline .timeline
.build_and_append_pdu( .build_and_append_pdu(
PduBuilder { PduBuilder {
event_type: body.event_type.to_string().into(), event_type: body.event_type.to_string().into(),
content: from_str(body.body.body.json().get()) content: from_str(body.body.body.json().get())
.map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Invalid JSON body."))?, .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Invalid JSON body."))?,
unsigned: Some(unsigned), unsigned: Some(unsigned),
state_key: None, state_key: None,
redacts: None, redacts: None,
}, },
sender_user, sender_user,
&body.room_id, &body.room_id,
&state_lock, &state_lock,
) )
.await?; .await?;
services().transaction_ids.add_txnid( services().transaction_ids.add_txnid(sender_user, sender_device, &body.txn_id, event_id.as_bytes())?;
sender_user,
sender_device,
&body.txn_id,
event_id.as_bytes(),
)?;
drop(state_lock); drop(state_lock);
Ok(send_message_event::v3::Response::new( Ok(send_message_event::v3::Response::new((*event_id).to_owned()))
(*event_id).to_owned(),
))
} }
/// # `GET /_matrix/client/r0/rooms/{roomId}/messages` /// # `GET /_matrix/client/r0/rooms/{roomId}/messages`
/// ///
/// Allows paginating through room history. /// Allows paginating through room history.
/// ///
/// - Only works if the user is joined (TODO: always allow, but only show events where the user was /// - Only works if the user is joined (TODO: always allow, but only show events
/// where the user was
/// joined, depending on history_visibility) /// joined, depending on history_visibility)
pub async fn get_message_events_route( pub async fn get_message_events_route(
body: Ruma<get_message_events::v3::Request>, body: Ruma<get_message_events::v3::Request>,
) -> Result<get_message_events::v3::Response> { ) -> Result<get_message_events::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_device = body.sender_device.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_ref().expect("user is authenticated");
let from = match body.from.clone() { let from = match body.from.clone() {
Some(from) => PduCount::try_from_string(&from)?, Some(from) => PduCount::try_from_string(&from)?,
None => match body.dir { None => match body.dir {
ruma::api::Direction::Forward => PduCount::min(), ruma::api::Direction::Forward => PduCount::min(),
ruma::api::Direction::Backward => PduCount::max(), ruma::api::Direction::Backward => PduCount::max(),
}, },
}; };
let to = body let to = body.to.as_ref().and_then(|t| PduCount::try_from_string(t).ok());
.to
.as_ref()
.and_then(|t| PduCount::try_from_string(t).ok());
services().rooms.lazy_loading.lazy_load_confirm_delivery( services().rooms.lazy_loading.lazy_load_confirm_delivery(sender_user, sender_device, &body.room_id, from)?;
sender_user,
sender_device,
&body.room_id,
from,
)?;
let limit = u64::from(body.limit).min(100) as usize; let limit = u64::from(body.limit).min(100) as usize;
let next_token; let next_token;
let mut resp = get_message_events::v3::Response::new(); let mut resp = get_message_events::v3::Response::new();
let mut lazy_loaded = HashSet::new(); let mut lazy_loaded = HashSet::new();
match body.dir { match body.dir {
ruma::api::Direction::Forward => { ruma::api::Direction::Forward => {
let events_after: Vec<_> = services() let events_after: Vec<_> = services()
.rooms .rooms
.timeline .timeline
.pdus_after(sender_user, &body.room_id, from)? .pdus_after(sender_user, &body.room_id, from)?
.take(limit) .take(limit)
.filter_map(std::result::Result::ok) // Filter out buggy events .filter_map(std::result::Result::ok) // Filter out buggy events
.filter(|(_, pdu)| { .filter(|(_, pdu)| {
services() services()
.rooms .rooms
.state_accessor .state_accessor
.user_can_see_event(sender_user, &body.room_id, &pdu.event_id) .user_can_see_event(sender_user, &body.room_id, &pdu.event_id)
.unwrap_or(false) .unwrap_or(false)
}) })
.take_while(|&(k, _)| Some(k) != to) // Stop at `to` .take_while(|&(k, _)| Some(k) != to) // Stop at `to`
.collect(); .collect();
for (_, event) in &events_after { for (_, event) in &events_after {
/* TODO: Remove this when these are resolved: /* TODO: Remove this when these are resolved:
* https://github.com/vector-im/element-android/issues/3417 * https://github.com/vector-im/element-android/issues/3417
* https://github.com/vector-im/element-web/issues/21034 * https://github.com/vector-im/element-web/issues/21034
if !services().rooms.lazy_loading.lazy_load_was_sent_before( if !services().rooms.lazy_loading.lazy_load_was_sent_before(
sender_user, sender_user,
sender_device, sender_device,
&body.room_id, &body.room_id,
&event.sender, &event.sender,
)? { )? {
lazy_loaded.insert(event.sender.clone()); lazy_loaded.insert(event.sender.clone());
} }
*/ */
lazy_loaded.insert(event.sender.clone()); lazy_loaded.insert(event.sender.clone());
} }
next_token = events_after.last().map(|(count, _)| count).copied(); next_token = events_after.last().map(|(count, _)| count).copied();
let events_after: Vec<_> = events_after let events_after: Vec<_> = events_after.into_iter().map(|(_, pdu)| pdu.to_room_event()).collect();
.into_iter()
.map(|(_, pdu)| pdu.to_room_event())
.collect();
resp.start = from.stringify(); resp.start = from.stringify();
resp.end = next_token.map(|count| count.stringify()); resp.end = next_token.map(|count| count.stringify());
resp.chunk = events_after; resp.chunk = events_after;
} },
ruma::api::Direction::Backward => { ruma::api::Direction::Backward => {
services() services().rooms.timeline.backfill_if_required(&body.room_id, from).await?;
.rooms let events_before: Vec<_> = services()
.timeline .rooms
.backfill_if_required(&body.room_id, from) .timeline
.await?; .pdus_until(sender_user, &body.room_id, from)?
let events_before: Vec<_> = services() .take(limit)
.rooms .filter_map(std::result::Result::ok) // Filter out buggy events
.timeline .filter(|(_, pdu)| {
.pdus_until(sender_user, &body.room_id, from)? services()
.take(limit) .rooms
.filter_map(std::result::Result::ok) // Filter out buggy events .state_accessor
.filter(|(_, pdu)| { .user_can_see_event(sender_user, &body.room_id, &pdu.event_id)
services() .unwrap_or(false)
.rooms })
.state_accessor .take_while(|&(k, _)| Some(k) != to) // Stop at `to`
.user_can_see_event(sender_user, &body.room_id, &pdu.event_id) .collect();
.unwrap_or(false)
})
.take_while(|&(k, _)| Some(k) != to) // Stop at `to`
.collect();
for (_, event) in &events_before { for (_, event) in &events_before {
/* TODO: Remove this when these are resolved: /* TODO: Remove this when these are resolved:
* https://github.com/vector-im/element-android/issues/3417 * https://github.com/vector-im/element-android/issues/3417
* https://github.com/vector-im/element-web/issues/21034 * https://github.com/vector-im/element-web/issues/21034
if !services().rooms.lazy_loading.lazy_load_was_sent_before( if !services().rooms.lazy_loading.lazy_load_was_sent_before(
sender_user, sender_user,
sender_device, sender_device,
&body.room_id, &body.room_id,
&event.sender, &event.sender,
)? { )? {
lazy_loaded.insert(event.sender.clone()); lazy_loaded.insert(event.sender.clone());
} }
*/ */
lazy_loaded.insert(event.sender.clone()); lazy_loaded.insert(event.sender.clone());
} }
next_token = events_before.last().map(|(count, _)| count).copied(); next_token = events_before.last().map(|(count, _)| count).copied();
let events_before: Vec<_> = events_before let events_before: Vec<_> = events_before.into_iter().map(|(_, pdu)| pdu.to_room_event()).collect();
.into_iter()
.map(|(_, pdu)| pdu.to_room_event())
.collect();
resp.start = from.stringify(); resp.start = from.stringify();
resp.end = next_token.map(|count| count.stringify()); resp.end = next_token.map(|count| count.stringify());
resp.chunk = events_before; resp.chunk = events_before;
} },
} }
resp.state = Vec::new(); resp.state = Vec::new();
for ll_id in &lazy_loaded { for ll_id in &lazy_loaded {
if let Some(member_event) = services().rooms.state_accessor.room_state_get( if let Some(member_event) = services().rooms.state_accessor.room_state_get(
&body.room_id, &body.room_id,
&StateEventType::RoomMember, &StateEventType::RoomMember,
ll_id.as_str(), ll_id.as_str(),
)? { )? {
resp.state.push(member_event.to_state_event()); resp.state.push(member_event.to_state_event());
} }
} }
// TODO: enable again when we are sure clients can handle it // TODO: enable again when we are sure clients can handle it
/* /*
if let Some(next_token) = next_token { if let Some(next_token) = next_token {
services().rooms.lazy_loading.lazy_load_mark_sent( services().rooms.lazy_loading.lazy_load_mark_sent(
sender_user, sender_user,
sender_device, sender_device,
&body.room_id, &body.room_id,
lazy_loaded, lazy_loaded,
next_token, next_token,
); );
} }
*/ */
Ok(resp) Ok(resp)
} }

View file

@ -1,38 +1,35 @@
use crate::{services, Error, Result, Ruma};
use ruma::api::client::{
error::ErrorKind,
presence::{get_presence, set_presence},
};
use std::time::Duration; use std::time::Duration;
use ruma::api::client::{
error::ErrorKind,
presence::{get_presence, set_presence},
};
use crate::{services, Error, Result, Ruma};
/// # `PUT /_matrix/client/r0/presence/{userId}/status` /// # `PUT /_matrix/client/r0/presence/{userId}/status`
/// ///
/// Sets the presence state of the sender user. /// Sets the presence state of the sender user.
pub async fn set_presence_route( pub async fn set_presence_route(body: Ruma<set_presence::v3::Request>) -> Result<set_presence::v3::Response> {
body: Ruma<set_presence::v3::Request>, if !services().globals.allow_local_presence() {
) -> Result<set_presence::v3::Response> { return Err(Error::BadRequest(ErrorKind::Forbidden, "Presence is disabled on this server"));
if !services().globals.allow_local_presence() { }
return Err(Error::BadRequest(
ErrorKind::Forbidden,
"Presence is disabled on this server",
));
}
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
for room_id in services().rooms.state_cache.rooms_joined(sender_user) { for room_id in services().rooms.state_cache.rooms_joined(sender_user) {
let room_id = room_id?; let room_id = room_id?;
services().rooms.edus.presence.set_presence( services().rooms.edus.presence.set_presence(
&room_id, &room_id,
sender_user, sender_user,
body.presence.clone(), body.presence.clone(),
None, None,
None, None,
body.status_msg.clone(), body.status_msg.clone(),
)?; )?;
} }
Ok(set_presence::v3::Response {}) Ok(set_presence::v3::Response {})
} }
/// # `GET /_matrix/client/r0/presence/{userId}/status` /// # `GET /_matrix/client/r0/presence/{userId}/status`
@ -40,53 +37,36 @@ pub async fn set_presence_route(
/// Gets the presence state of the given user. /// Gets the presence state of the given user.
/// ///
/// - Only works if you share a room with the user /// - Only works if you share a room with the user
pub async fn get_presence_route( pub async fn get_presence_route(body: Ruma<get_presence::v3::Request>) -> Result<get_presence::v3::Response> {
body: Ruma<get_presence::v3::Request>, if !services().globals.allow_local_presence() {
) -> Result<get_presence::v3::Response> { return Err(Error::BadRequest(ErrorKind::Forbidden, "Presence is disabled on this server"));
if !services().globals.allow_local_presence() { }
return Err(Error::BadRequest(
ErrorKind::Forbidden,
"Presence is disabled on this server",
));
}
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let mut presence_event = None; let mut presence_event = None;
for room_id in services() for room_id in services().rooms.user.get_shared_rooms(vec![sender_user.clone(), body.user_id.clone()])? {
.rooms let room_id = room_id?;
.user
.get_shared_rooms(vec![sender_user.clone(), body.user_id.clone()])?
{
let room_id = room_id?;
if let Some(presence) = services() if let Some(presence) = services().rooms.edus.presence.get_presence(&room_id, sender_user)? {
.rooms presence_event = Some(presence);
.edus break;
.presence }
.get_presence(&room_id, sender_user)? }
{
presence_event = Some(presence);
break;
}
}
if let Some(presence) = presence_event { if let Some(presence) = presence_event {
Ok(get_presence::v3::Response { Ok(get_presence::v3::Response {
// TODO: Should ruma just use the presenceeventcontent type here? // TODO: Should ruma just use the presenceeventcontent type here?
status_msg: presence.content.status_msg, status_msg: presence.content.status_msg,
currently_active: presence.content.currently_active, currently_active: presence.content.currently_active,
last_active_ago: presence last_active_ago: presence.content.last_active_ago.map(|millis| Duration::from_millis(millis.into())),
.content presence: presence.content.presence,
.last_active_ago })
.map(|millis| Duration::from_millis(millis.into())), } else {
presence: presence.content.presence, Err(Error::BadRequest(
}) ErrorKind::NotFound,
} else { "Presence state for this user was not found",
Err(Error::BadRequest( ))
ErrorKind::NotFound, }
"Presence state for this user was not found",
))
}
} }

View file

@ -1,17 +1,15 @@
use std::sync::Arc; use std::sync::Arc;
use ruma::{ use ruma::{
api::{ api::{
client::{ client::{
error::ErrorKind, error::ErrorKind,
profile::{ profile::{get_avatar_url, get_display_name, get_profile, set_avatar_url, set_display_name},
get_avatar_url, get_display_name, get_profile, set_avatar_url, set_display_name, },
}, federation,
}, },
federation, events::{room::member::RoomMemberEventContent, StateEventType, TimelineEventType},
}, presence::PresenceState,
events::{room::member::RoomMemberEventContent, StateEventType, TimelineEventType},
presence::PresenceState,
}; };
use serde_json::value::to_raw_value; use serde_json::value::to_raw_value;
@ -23,87 +21,62 @@ use crate::{service::pdu::PduBuilder, services, Error, Result, Ruma};
/// ///
/// - Also makes sure other users receive the update using presence EDUs /// - Also makes sure other users receive the update using presence EDUs
pub async fn set_displayname_route( pub async fn set_displayname_route(
body: Ruma<set_display_name::v3::Request>, body: Ruma<set_display_name::v3::Request>,
) -> Result<set_display_name::v3::Response> { ) -> Result<set_display_name::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
services() services().users.set_displayname(sender_user, body.displayname.clone()).await?;
.users
.set_displayname(sender_user, body.displayname.clone())
.await?;
// Send a new membership event and presence update into all joined rooms // Send a new membership event and presence update into all joined rooms
let all_rooms_joined: Vec<_> = services() let all_rooms_joined: Vec<_> = services()
.rooms .rooms
.state_cache .state_cache
.rooms_joined(sender_user) .rooms_joined(sender_user)
.filter_map(std::result::Result::ok) .filter_map(std::result::Result::ok)
.map(|room_id| { .map(|room_id| {
Ok::<_, Error>(( Ok::<_, Error>((
PduBuilder { PduBuilder {
event_type: TimelineEventType::RoomMember, event_type: TimelineEventType::RoomMember,
content: to_raw_value(&RoomMemberEventContent { content: to_raw_value(&RoomMemberEventContent {
displayname: body.displayname.clone(), displayname: body.displayname.clone(),
..serde_json::from_str( ..serde_json::from_str(
services() services()
.rooms .rooms
.state_accessor .state_accessor
.room_state_get( .room_state_get(&room_id, &StateEventType::RoomMember, sender_user.as_str())?
&room_id, .ok_or_else(|| {
&StateEventType::RoomMember, Error::bad_database("Tried to send displayname update for user not in the room.")
sender_user.as_str(), })?
)? .content
.ok_or_else(|| { .get(),
Error::bad_database( )
"Tried to send displayname update for user not in the \ .map_err(|_| Error::bad_database("Database contains invalid PDU."))?
room.", })
) .expect("event is valid, we just created it"),
})? unsigned: None,
.content state_key: Some(sender_user.to_string()),
.get(), redacts: None,
) },
.map_err(|_| Error::bad_database("Database contains invalid PDU."))? room_id,
}) ))
.expect("event is valid, we just created it"), })
unsigned: None, .filter_map(std::result::Result::ok)
state_key: Some(sender_user.to_string()), .collect();
redacts: None,
},
room_id,
))
})
.filter_map(std::result::Result::ok)
.collect();
for (pdu_builder, room_id) in all_rooms_joined { for (pdu_builder, room_id) in all_rooms_joined {
let mutex_state = Arc::clone( let mutex_state =
services() Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(room_id.clone()).or_default());
.globals let state_lock = mutex_state.lock().await;
.roomid_mutex_state
.write()
.unwrap()
.entry(room_id.clone())
.or_default(),
);
let state_lock = mutex_state.lock().await;
let _ = services() let _ = services().rooms.timeline.build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock).await;
.rooms }
.timeline
.build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock)
.await;
}
if services().globals.allow_local_presence() { if services().globals.allow_local_presence() {
// Presence update // Presence update
services() services().rooms.edus.presence.ping_presence(sender_user, PresenceState::Online)?;
.rooms }
.edus
.presence
.ping_presence(sender_user, PresenceState::Online)?;
}
Ok(set_display_name::v3::Response {}) Ok(set_display_name::v3::Response {})
} }
/// # `GET /_matrix/client/v3/profile/{userId}/displayname` /// # `GET /_matrix/client/v3/profile/{userId}/displayname`
@ -113,55 +86,44 @@ pub async fn set_displayname_route(
/// - If user is on another server and we do not have a local copy already /// - If user is on another server and we do not have a local copy already
/// fetch displayname over federation /// fetch displayname over federation
pub async fn get_displayname_route( pub async fn get_displayname_route(
body: Ruma<get_display_name::v3::Request>, body: Ruma<get_display_name::v3::Request>,
) -> Result<get_display_name::v3::Response> { ) -> Result<get_display_name::v3::Response> {
if body.user_id.server_name() != services().globals.server_name() { if body.user_id.server_name() != services().globals.server_name() {
// Create and update our local copy of the user // Create and update our local copy of the user
if let Ok(response) = services() if let Ok(response) = services()
.sending .sending
.send_federation_request( .send_federation_request(
body.user_id.server_name(), body.user_id.server_name(),
federation::query::get_profile_information::v1::Request { federation::query::get_profile_information::v1::Request {
user_id: body.user_id.clone(), user_id: body.user_id.clone(),
field: None, // we want the full user's profile to update locally too field: None, // we want the full user's profile to update locally too
}, },
) )
.await .await
{ {
if !services().users.exists(&body.user_id)? { if !services().users.exists(&body.user_id)? {
services().users.create(&body.user_id, None)?; services().users.create(&body.user_id, None)?;
} }
services() services().users.set_displayname(&body.user_id, response.displayname.clone()).await?;
.users services().users.set_avatar_url(&body.user_id, response.avatar_url.clone()).await?;
.set_displayname(&body.user_id, response.displayname.clone()) services().users.set_blurhash(&body.user_id, response.blurhash.clone()).await?;
.await?;
services()
.users
.set_avatar_url(&body.user_id, response.avatar_url.clone())
.await?;
services()
.users
.set_blurhash(&body.user_id, response.blurhash.clone())
.await?;
return Ok(get_display_name::v3::Response { return Ok(get_display_name::v3::Response {
displayname: response.displayname, displayname: response.displayname,
}); });
} }
} }
if !services().users.exists(&body.user_id)? { if !services().users.exists(&body.user_id)? {
// Return 404 if this user doesn't exist and we couldn't fetch it over federation // Return 404 if this user doesn't exist and we couldn't fetch it over
return Err(Error::BadRequest( // federation
ErrorKind::NotFound, return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found."));
"Profile was not found.", }
));
}
Ok(get_display_name::v3::Response { Ok(get_display_name::v3::Response {
displayname: services().users.displayname(&body.user_id)?, displayname: services().users.displayname(&body.user_id)?,
}) })
} }
/// # `PUT /_matrix/client/r0/profile/{userId}/avatar_url` /// # `PUT /_matrix/client/r0/profile/{userId}/avatar_url`
@ -169,93 +131,63 @@ pub async fn get_displayname_route(
/// Updates the avatar_url and blurhash. /// Updates the avatar_url and blurhash.
/// ///
/// - Also makes sure other users receive the update using presence EDUs /// - Also makes sure other users receive the update using presence EDUs
pub async fn set_avatar_url_route( pub async fn set_avatar_url_route(body: Ruma<set_avatar_url::v3::Request>) -> Result<set_avatar_url::v3::Response> {
body: Ruma<set_avatar_url::v3::Request>, let sender_user = body.sender_user.as_ref().expect("user is authenticated");
) -> Result<set_avatar_url::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
services() services().users.set_avatar_url(sender_user, body.avatar_url.clone()).await?;
.users
.set_avatar_url(sender_user, body.avatar_url.clone())
.await?;
services() services().users.set_blurhash(sender_user, body.blurhash.clone()).await?;
.users
.set_blurhash(sender_user, body.blurhash.clone())
.await?;
// Send a new membership event and presence update into all joined rooms // Send a new membership event and presence update into all joined rooms
let all_joined_rooms: Vec<_> = services() let all_joined_rooms: Vec<_> = services()
.rooms .rooms
.state_cache .state_cache
.rooms_joined(sender_user) .rooms_joined(sender_user)
.filter_map(std::result::Result::ok) .filter_map(std::result::Result::ok)
.map(|room_id| { .map(|room_id| {
Ok::<_, Error>(( Ok::<_, Error>((
PduBuilder { PduBuilder {
event_type: TimelineEventType::RoomMember, event_type: TimelineEventType::RoomMember,
content: to_raw_value(&RoomMemberEventContent { content: to_raw_value(&RoomMemberEventContent {
avatar_url: body.avatar_url.clone(), avatar_url: body.avatar_url.clone(),
..serde_json::from_str( ..serde_json::from_str(
services() services()
.rooms .rooms
.state_accessor .state_accessor
.room_state_get( .room_state_get(&room_id, &StateEventType::RoomMember, sender_user.as_str())?
&room_id, .ok_or_else(|| {
&StateEventType::RoomMember, Error::bad_database("Tried to send displayname update for user not in the room.")
sender_user.as_str(), })?
)? .content
.ok_or_else(|| { .get(),
Error::bad_database( )
"Tried to send displayname update for user not in the \ .map_err(|_| Error::bad_database("Database contains invalid PDU."))?
room.", })
) .expect("event is valid, we just created it"),
})? unsigned: None,
.content state_key: Some(sender_user.to_string()),
.get(), redacts: None,
) },
.map_err(|_| Error::bad_database("Database contains invalid PDU."))? room_id,
}) ))
.expect("event is valid, we just created it"), })
unsigned: None, .filter_map(std::result::Result::ok)
state_key: Some(sender_user.to_string()), .collect();
redacts: None,
},
room_id,
))
})
.filter_map(std::result::Result::ok)
.collect();
for (pdu_builder, room_id) in all_joined_rooms { for (pdu_builder, room_id) in all_joined_rooms {
let mutex_state = Arc::clone( let mutex_state =
services() Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(room_id.clone()).or_default());
.globals let state_lock = mutex_state.lock().await;
.roomid_mutex_state
.write()
.unwrap()
.entry(room_id.clone())
.or_default(),
);
let state_lock = mutex_state.lock().await;
let _ = services() let _ = services().rooms.timeline.build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock).await;
.rooms }
.timeline
.build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock)
.await;
}
if services().globals.allow_local_presence() { if services().globals.allow_local_presence() {
// Presence update // Presence update
services() services().rooms.edus.presence.ping_presence(sender_user, PresenceState::Online)?;
.rooms }
.edus
.presence
.ping_presence(sender_user, PresenceState::Online)?;
}
Ok(set_avatar_url::v3::Response {}) Ok(set_avatar_url::v3::Response {})
} }
/// # `GET /_matrix/client/v3/profile/{userId}/avatar_url` /// # `GET /_matrix/client/v3/profile/{userId}/avatar_url`
@ -264,58 +196,45 @@ pub async fn set_avatar_url_route(
/// ///
/// - If user is on another server and we do not have a local copy already /// - If user is on another server and we do not have a local copy already
/// fetch avatar_url and blurhash over federation /// fetch avatar_url and blurhash over federation
pub async fn get_avatar_url_route( pub async fn get_avatar_url_route(body: Ruma<get_avatar_url::v3::Request>) -> Result<get_avatar_url::v3::Response> {
body: Ruma<get_avatar_url::v3::Request>, if body.user_id.server_name() != services().globals.server_name() {
) -> Result<get_avatar_url::v3::Response> { // Create and update our local copy of the user
if body.user_id.server_name() != services().globals.server_name() { if let Ok(response) = services()
// Create and update our local copy of the user .sending
if let Ok(response) = services() .send_federation_request(
.sending body.user_id.server_name(),
.send_federation_request( federation::query::get_profile_information::v1::Request {
body.user_id.server_name(), user_id: body.user_id.clone(),
federation::query::get_profile_information::v1::Request { field: None, // we want the full user's profile to update locally as well
user_id: body.user_id.clone(), },
field: None, // we want the full user's profile to update locally as well )
}, .await
) {
.await if !services().users.exists(&body.user_id)? {
{ services().users.create(&body.user_id, None)?;
if !services().users.exists(&body.user_id)? { }
services().users.create(&body.user_id, None)?;
}
services() services().users.set_displayname(&body.user_id, response.displayname.clone()).await?;
.users services().users.set_avatar_url(&body.user_id, response.avatar_url.clone()).await?;
.set_displayname(&body.user_id, response.displayname.clone()) services().users.set_blurhash(&body.user_id, response.blurhash.clone()).await?;
.await?;
services()
.users
.set_avatar_url(&body.user_id, response.avatar_url.clone())
.await?;
services()
.users
.set_blurhash(&body.user_id, response.blurhash.clone())
.await?;
return Ok(get_avatar_url::v3::Response { return Ok(get_avatar_url::v3::Response {
avatar_url: response.avatar_url, avatar_url: response.avatar_url,
blurhash: response.blurhash, blurhash: response.blurhash,
}); });
} }
} }
if !services().users.exists(&body.user_id)? { if !services().users.exists(&body.user_id)? {
// Return 404 if this user doesn't exist and we couldn't fetch it over federation // Return 404 if this user doesn't exist and we couldn't fetch it over
return Err(Error::BadRequest( // federation
ErrorKind::NotFound, return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found."));
"Profile was not found.", }
));
}
Ok(get_avatar_url::v3::Response { Ok(get_avatar_url::v3::Response {
avatar_url: services().users.avatar_url(&body.user_id)?, avatar_url: services().users.avatar_url(&body.user_id)?,
blurhash: services().users.blurhash(&body.user_id)?, blurhash: services().users.blurhash(&body.user_id)?,
}) })
} }
/// # `GET /_matrix/client/v3/profile/{userId}` /// # `GET /_matrix/client/v3/profile/{userId}`
@ -324,58 +243,45 @@ pub async fn get_avatar_url_route(
/// ///
/// - If user is on another server and we do not have a local copy already, /// - If user is on another server and we do not have a local copy already,
/// fetch profile over federation. /// fetch profile over federation.
pub async fn get_profile_route( pub async fn get_profile_route(body: Ruma<get_profile::v3::Request>) -> Result<get_profile::v3::Response> {
body: Ruma<get_profile::v3::Request>, if body.user_id.server_name() != services().globals.server_name() {
) -> Result<get_profile::v3::Response> { // Create and update our local copy of the user
if body.user_id.server_name() != services().globals.server_name() { if let Ok(response) = services()
// Create and update our local copy of the user .sending
if let Ok(response) = services() .send_federation_request(
.sending body.user_id.server_name(),
.send_federation_request( federation::query::get_profile_information::v1::Request {
body.user_id.server_name(), user_id: body.user_id.clone(),
federation::query::get_profile_information::v1::Request { field: None,
user_id: body.user_id.clone(), },
field: None, )
}, .await
) {
.await if !services().users.exists(&body.user_id)? {
{ services().users.create(&body.user_id, None)?;
if !services().users.exists(&body.user_id)? { }
services().users.create(&body.user_id, None)?;
}
services() services().users.set_displayname(&body.user_id, response.displayname.clone()).await?;
.users services().users.set_avatar_url(&body.user_id, response.avatar_url.clone()).await?;
.set_displayname(&body.user_id, response.displayname.clone()) services().users.set_blurhash(&body.user_id, response.blurhash.clone()).await?;
.await?;
services()
.users
.set_avatar_url(&body.user_id, response.avatar_url.clone())
.await?;
services()
.users
.set_blurhash(&body.user_id, response.blurhash.clone())
.await?;
return Ok(get_profile::v3::Response { return Ok(get_profile::v3::Response {
displayname: response.displayname, displayname: response.displayname,
avatar_url: response.avatar_url, avatar_url: response.avatar_url,
blurhash: response.blurhash, blurhash: response.blurhash,
}); });
} }
} }
if !services().users.exists(&body.user_id)? { if !services().users.exists(&body.user_id)? {
// Return 404 if this user doesn't exist and we couldn't fetch it over federation // Return 404 if this user doesn't exist and we couldn't fetch it over
return Err(Error::BadRequest( // federation
ErrorKind::NotFound, return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found."));
"Profile was not found.", }
));
}
Ok(get_profile::v3::Response { Ok(get_profile::v3::Response {
avatar_url: services().users.avatar_url(&body.user_id)?, avatar_url: services().users.avatar_url(&body.user_id)?,
blurhash: services().users.blurhash(&body.user_id)?, blurhash: services().users.blurhash(&body.user_id)?,
displayname: services().users.displayname(&body.user_id)?, displayname: services().users.displayname(&body.user_id)?,
}) })
} }

View file

@ -1,417 +1,320 @@
use crate::{services, Error, Result, Ruma};
use ruma::{ use ruma::{
api::client::{ api::client::{
error::ErrorKind, error::ErrorKind,
push::{ push::{
delete_pushrule, get_pushers, get_pushrule, get_pushrule_actions, get_pushrule_enabled, delete_pushrule, get_pushers, get_pushrule, get_pushrule_actions, get_pushrule_enabled, get_pushrules_all,
get_pushrules_all, set_pusher, set_pushrule, set_pushrule_actions, set_pusher, set_pushrule, set_pushrule_actions, set_pushrule_enabled, RuleScope,
set_pushrule_enabled, RuleScope, },
}, },
}, events::{push_rules::PushRulesEvent, GlobalAccountDataEventType},
events::{push_rules::PushRulesEvent, GlobalAccountDataEventType}, push::{InsertPushRuleError, RemovePushRuleError},
push::{InsertPushRuleError, RemovePushRuleError},
}; };
use crate::{services, Error, Result, Ruma};
/// # `GET /_matrix/client/r0/pushrules` /// # `GET /_matrix/client/r0/pushrules`
/// ///
/// Retrieves the push rules event for this user. /// Retrieves the push rules event for this user.
pub async fn get_pushrules_all_route( pub async fn get_pushrules_all_route(
body: Ruma<get_pushrules_all::v3::Request>, body: Ruma<get_pushrules_all::v3::Request>,
) -> Result<get_pushrules_all::v3::Response> { ) -> Result<get_pushrules_all::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let event = services() let event = services()
.account_data .account_data
.get( .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())?
None, .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
sender_user,
GlobalAccountDataEventType::PushRules.to_string().into(),
)?
.ok_or(Error::BadRequest(
ErrorKind::NotFound,
"PushRules event not found.",
))?;
let account_data = serde_json::from_str::<PushRulesEvent>(event.get()) let account_data = serde_json::from_str::<PushRulesEvent>(event.get())
.map_err(|_| Error::bad_database("Invalid account data event in db."))? .map_err(|_| Error::bad_database("Invalid account data event in db."))?
.content; .content;
Ok(get_pushrules_all::v3::Response { Ok(get_pushrules_all::v3::Response {
global: account_data.global, global: account_data.global,
}) })
} }
/// # `GET /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}` /// # `GET /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}`
/// ///
/// Retrieves a single specified push rule for this user. /// Retrieves a single specified push rule for this user.
pub async fn get_pushrule_route( pub async fn get_pushrule_route(body: Ruma<get_pushrule::v3::Request>) -> Result<get_pushrule::v3::Response> {
body: Ruma<get_pushrule::v3::Request>, let sender_user = body.sender_user.as_ref().expect("user is authenticated");
) -> Result<get_pushrule::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let event = services() let event = services()
.account_data .account_data
.get( .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())?
None, .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
sender_user,
GlobalAccountDataEventType::PushRules.to_string().into(),
)?
.ok_or(Error::BadRequest(
ErrorKind::NotFound,
"PushRules event not found.",
))?;
let account_data = serde_json::from_str::<PushRulesEvent>(event.get()) let account_data = serde_json::from_str::<PushRulesEvent>(event.get())
.map_err(|_| Error::bad_database("Invalid account data event in db."))? .map_err(|_| Error::bad_database("Invalid account data event in db."))?
.content; .content;
let rule = account_data let rule = account_data.global.get(body.kind.clone(), &body.rule_id).map(Into::into);
.global
.get(body.kind.clone(), &body.rule_id)
.map(Into::into);
if let Some(rule) = rule { if let Some(rule) = rule {
Ok(get_pushrule::v3::Response { rule }) Ok(get_pushrule::v3::Response {
} else { rule,
Err(Error::BadRequest( })
ErrorKind::NotFound, } else {
"Push rule not found.", Err(Error::BadRequest(ErrorKind::NotFound, "Push rule not found."))
)) }
}
} }
/// # `PUT /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}` /// # `PUT /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}`
/// ///
/// Creates a single specified push rule for this user. /// Creates a single specified push rule for this user.
pub async fn set_pushrule_route( pub async fn set_pushrule_route(body: Ruma<set_pushrule::v3::Request>) -> Result<set_pushrule::v3::Response> {
body: Ruma<set_pushrule::v3::Request>, let sender_user = body.sender_user.as_ref().expect("user is authenticated");
) -> Result<set_pushrule::v3::Response> { let body = body.body;
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let body = body.body;
if body.scope != RuleScope::Global { if body.scope != RuleScope::Global {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::InvalidParam, ErrorKind::InvalidParam,
"Scopes other than 'global' are not supported.", "Scopes other than 'global' are not supported.",
)); ));
} }
let event = services() let event = services()
.account_data .account_data
.get( .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())?
None, .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
sender_user,
GlobalAccountDataEventType::PushRules.to_string().into(),
)?
.ok_or(Error::BadRequest(
ErrorKind::NotFound,
"PushRules event not found.",
))?;
let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get()) let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get())
.map_err(|_| Error::bad_database("Invalid account data event in db."))?; .map_err(|_| Error::bad_database("Invalid account data event in db."))?;
if let Err(error) = account_data.content.global.insert( if let Err(error) =
body.rule.clone(), account_data.content.global.insert(body.rule.clone(), body.after.as_deref(), body.before.as_deref())
body.after.as_deref(), {
body.before.as_deref(), let err = match error {
) { InsertPushRuleError::ServerDefaultRuleId => Error::BadRequest(
let err = match error { ErrorKind::InvalidParam,
InsertPushRuleError::ServerDefaultRuleId => Error::BadRequest( "Rule IDs starting with a dot are reserved for server-default rules.",
ErrorKind::InvalidParam, ),
"Rule IDs starting with a dot are reserved for server-default rules.", InsertPushRuleError::InvalidRuleId => {
), Error::BadRequest(ErrorKind::InvalidParam, "Rule ID containing invalid characters.")
InsertPushRuleError::InvalidRuleId => Error::BadRequest( },
ErrorKind::InvalidParam, InsertPushRuleError::RelativeToServerDefaultRule => Error::BadRequest(
"Rule ID containing invalid characters.", ErrorKind::InvalidParam,
), "Can't place a push rule relatively to a server-default rule.",
InsertPushRuleError::RelativeToServerDefaultRule => Error::BadRequest( ),
ErrorKind::InvalidParam, InsertPushRuleError::UnknownRuleId => {
"Can't place a push rule relatively to a server-default rule.", Error::BadRequest(ErrorKind::NotFound, "The before or after rule could not be found.")
), },
InsertPushRuleError::UnknownRuleId => Error::BadRequest( InsertPushRuleError::BeforeHigherThanAfter => Error::BadRequest(
ErrorKind::NotFound, ErrorKind::InvalidParam,
"The before or after rule could not be found.", "The before rule has a higher priority than the after rule.",
), ),
InsertPushRuleError::BeforeHigherThanAfter => Error::BadRequest( _ => Error::BadRequest(ErrorKind::InvalidParam, "Invalid data."),
ErrorKind::InvalidParam, };
"The before rule has a higher priority than the after rule.",
),
_ => Error::BadRequest(ErrorKind::InvalidParam, "Invalid data."),
};
return Err(err); return Err(err);
} }
services().account_data.update( services().account_data.update(
None, None,
sender_user, sender_user,
GlobalAccountDataEventType::PushRules.to_string().into(), GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(account_data).expect("to json value always works"), &serde_json::to_value(account_data).expect("to json value always works"),
)?; )?;
Ok(set_pushrule::v3::Response {}) Ok(set_pushrule::v3::Response {})
} }
/// # `GET /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}/actions` /// # `GET /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}/actions`
/// ///
/// Gets the actions of a single specified push rule for this user. /// Gets the actions of a single specified push rule for this user.
pub async fn get_pushrule_actions_route( pub async fn get_pushrule_actions_route(
body: Ruma<get_pushrule_actions::v3::Request>, body: Ruma<get_pushrule_actions::v3::Request>,
) -> Result<get_pushrule_actions::v3::Response> { ) -> Result<get_pushrule_actions::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if body.scope != RuleScope::Global { if body.scope != RuleScope::Global {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::InvalidParam, ErrorKind::InvalidParam,
"Scopes other than 'global' are not supported.", "Scopes other than 'global' are not supported.",
)); ));
} }
let event = services() let event = services()
.account_data .account_data
.get( .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())?
None, .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
sender_user,
GlobalAccountDataEventType::PushRules.to_string().into(),
)?
.ok_or(Error::BadRequest(
ErrorKind::NotFound,
"PushRules event not found.",
))?;
let account_data = serde_json::from_str::<PushRulesEvent>(event.get()) let account_data = serde_json::from_str::<PushRulesEvent>(event.get())
.map_err(|_| Error::bad_database("Invalid account data event in db."))? .map_err(|_| Error::bad_database("Invalid account data event in db."))?
.content; .content;
let global = account_data.global; let global = account_data.global;
let actions = global let actions = global
.get(body.kind.clone(), &body.rule_id) .get(body.kind.clone(), &body.rule_id)
.map(|rule| rule.actions().to_owned()) .map(|rule| rule.actions().to_owned())
.ok_or(Error::BadRequest( .ok_or(Error::BadRequest(ErrorKind::NotFound, "Push rule not found."))?;
ErrorKind::NotFound,
"Push rule not found.",
))?;
Ok(get_pushrule_actions::v3::Response { actions }) Ok(get_pushrule_actions::v3::Response {
actions,
})
} }
/// # `PUT /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}/actions` /// # `PUT /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}/actions`
/// ///
/// Sets the actions of a single specified push rule for this user. /// Sets the actions of a single specified push rule for this user.
pub async fn set_pushrule_actions_route( pub async fn set_pushrule_actions_route(
body: Ruma<set_pushrule_actions::v3::Request>, body: Ruma<set_pushrule_actions::v3::Request>,
) -> Result<set_pushrule_actions::v3::Response> { ) -> Result<set_pushrule_actions::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if body.scope != RuleScope::Global { if body.scope != RuleScope::Global {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::InvalidParam, ErrorKind::InvalidParam,
"Scopes other than 'global' are not supported.", "Scopes other than 'global' are not supported.",
)); ));
} }
let event = services() let event = services()
.account_data .account_data
.get( .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())?
None, .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
sender_user,
GlobalAccountDataEventType::PushRules.to_string().into(),
)?
.ok_or(Error::BadRequest(
ErrorKind::NotFound,
"PushRules event not found.",
))?;
let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get()) let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get())
.map_err(|_| Error::bad_database("Invalid account data event in db."))?; .map_err(|_| Error::bad_database("Invalid account data event in db."))?;
if account_data if account_data.content.global.set_actions(body.kind.clone(), &body.rule_id, body.actions.clone()).is_err() {
.content return Err(Error::BadRequest(ErrorKind::NotFound, "Push rule not found."));
.global }
.set_actions(body.kind.clone(), &body.rule_id, body.actions.clone())
.is_err()
{
return Err(Error::BadRequest(
ErrorKind::NotFound,
"Push rule not found.",
));
}
services().account_data.update( services().account_data.update(
None, None,
sender_user, sender_user,
GlobalAccountDataEventType::PushRules.to_string().into(), GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(account_data).expect("to json value always works"), &serde_json::to_value(account_data).expect("to json value always works"),
)?; )?;
Ok(set_pushrule_actions::v3::Response {}) Ok(set_pushrule_actions::v3::Response {})
} }
/// # `GET /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}/enabled` /// # `GET /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}/enabled`
/// ///
/// Gets the enabled status of a single specified push rule for this user. /// Gets the enabled status of a single specified push rule for this user.
pub async fn get_pushrule_enabled_route( pub async fn get_pushrule_enabled_route(
body: Ruma<get_pushrule_enabled::v3::Request>, body: Ruma<get_pushrule_enabled::v3::Request>,
) -> Result<get_pushrule_enabled::v3::Response> { ) -> Result<get_pushrule_enabled::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if body.scope != RuleScope::Global { if body.scope != RuleScope::Global {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::InvalidParam, ErrorKind::InvalidParam,
"Scopes other than 'global' are not supported.", "Scopes other than 'global' are not supported.",
)); ));
} }
let event = services() let event = services()
.account_data .account_data
.get( .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())?
None, .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
sender_user,
GlobalAccountDataEventType::PushRules.to_string().into(),
)?
.ok_or(Error::BadRequest(
ErrorKind::NotFound,
"PushRules event not found.",
))?;
let account_data = serde_json::from_str::<PushRulesEvent>(event.get()) let account_data = serde_json::from_str::<PushRulesEvent>(event.get())
.map_err(|_| Error::bad_database("Invalid account data event in db."))?; .map_err(|_| Error::bad_database("Invalid account data event in db."))?;
let global = account_data.content.global; let global = account_data.content.global;
let enabled = global let enabled = global
.get(body.kind.clone(), &body.rule_id) .get(body.kind.clone(), &body.rule_id)
.map(ruma::push::AnyPushRuleRef::enabled) .map(ruma::push::AnyPushRuleRef::enabled)
.ok_or(Error::BadRequest( .ok_or(Error::BadRequest(ErrorKind::NotFound, "Push rule not found."))?;
ErrorKind::NotFound,
"Push rule not found.",
))?;
Ok(get_pushrule_enabled::v3::Response { enabled }) Ok(get_pushrule_enabled::v3::Response {
enabled,
})
} }
/// # `PUT /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}/enabled` /// # `PUT /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}/enabled`
/// ///
/// Sets the enabled status of a single specified push rule for this user. /// Sets the enabled status of a single specified push rule for this user.
pub async fn set_pushrule_enabled_route( pub async fn set_pushrule_enabled_route(
body: Ruma<set_pushrule_enabled::v3::Request>, body: Ruma<set_pushrule_enabled::v3::Request>,
) -> Result<set_pushrule_enabled::v3::Response> { ) -> Result<set_pushrule_enabled::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if body.scope != RuleScope::Global { if body.scope != RuleScope::Global {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::InvalidParam, ErrorKind::InvalidParam,
"Scopes other than 'global' are not supported.", "Scopes other than 'global' are not supported.",
)); ));
} }
let event = services() let event = services()
.account_data .account_data
.get( .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())?
None, .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
sender_user,
GlobalAccountDataEventType::PushRules.to_string().into(),
)?
.ok_or(Error::BadRequest(
ErrorKind::NotFound,
"PushRules event not found.",
))?;
let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get()) let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get())
.map_err(|_| Error::bad_database("Invalid account data event in db."))?; .map_err(|_| Error::bad_database("Invalid account data event in db."))?;
if account_data if account_data.content.global.set_enabled(body.kind.clone(), &body.rule_id, body.enabled).is_err() {
.content return Err(Error::BadRequest(ErrorKind::NotFound, "Push rule not found."));
.global }
.set_enabled(body.kind.clone(), &body.rule_id, body.enabled)
.is_err()
{
return Err(Error::BadRequest(
ErrorKind::NotFound,
"Push rule not found.",
));
}
services().account_data.update( services().account_data.update(
None, None,
sender_user, sender_user,
GlobalAccountDataEventType::PushRules.to_string().into(), GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(account_data).expect("to json value always works"), &serde_json::to_value(account_data).expect("to json value always works"),
)?; )?;
Ok(set_pushrule_enabled::v3::Response {}) Ok(set_pushrule_enabled::v3::Response {})
} }
/// # `DELETE /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}` /// # `DELETE /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}`
/// ///
/// Deletes a single specified push rule for this user. /// Deletes a single specified push rule for this user.
pub async fn delete_pushrule_route( pub async fn delete_pushrule_route(body: Ruma<delete_pushrule::v3::Request>) -> Result<delete_pushrule::v3::Response> {
body: Ruma<delete_pushrule::v3::Request>, let sender_user = body.sender_user.as_ref().expect("user is authenticated");
) -> Result<delete_pushrule::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if body.scope != RuleScope::Global { if body.scope != RuleScope::Global {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::InvalidParam, ErrorKind::InvalidParam,
"Scopes other than 'global' are not supported.", "Scopes other than 'global' are not supported.",
)); ));
} }
let event = services() let event = services()
.account_data .account_data
.get( .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())?
None, .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
sender_user,
GlobalAccountDataEventType::PushRules.to_string().into(),
)?
.ok_or(Error::BadRequest(
ErrorKind::NotFound,
"PushRules event not found.",
))?;
let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get()) let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get())
.map_err(|_| Error::bad_database("Invalid account data event in db."))?; .map_err(|_| Error::bad_database("Invalid account data event in db."))?;
if let Err(error) = account_data if let Err(error) = account_data.content.global.remove(body.kind.clone(), &body.rule_id) {
.content let err = match error {
.global RemovePushRuleError::ServerDefault => {
.remove(body.kind.clone(), &body.rule_id) Error::BadRequest(ErrorKind::InvalidParam, "Cannot delete a server-default pushrule.")
{ },
let err = match error { RemovePushRuleError::NotFound => Error::BadRequest(ErrorKind::NotFound, "Push rule not found."),
RemovePushRuleError::ServerDefault => Error::BadRequest( _ => Error::BadRequest(ErrorKind::InvalidParam, "Invalid data."),
ErrorKind::InvalidParam, };
"Cannot delete a server-default pushrule.",
),
RemovePushRuleError::NotFound => {
Error::BadRequest(ErrorKind::NotFound, "Push rule not found.")
}
_ => Error::BadRequest(ErrorKind::InvalidParam, "Invalid data."),
};
return Err(err); return Err(err);
} }
services().account_data.update( services().account_data.update(
None, None,
sender_user, sender_user,
GlobalAccountDataEventType::PushRules.to_string().into(), GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(account_data).expect("to json value always works"), &serde_json::to_value(account_data).expect("to json value always works"),
)?; )?;
Ok(delete_pushrule::v3::Response {}) Ok(delete_pushrule::v3::Response {})
} }
/// # `GET /_matrix/client/r0/pushers` /// # `GET /_matrix/client/r0/pushers`
/// ///
/// Gets all currently active pushers for the sender user. /// Gets all currently active pushers for the sender user.
pub async fn get_pushers_route( pub async fn get_pushers_route(body: Ruma<get_pushers::v3::Request>) -> Result<get_pushers::v3::Response> {
body: Ruma<get_pushers::v3::Request>, let sender_user = body.sender_user.as_ref().expect("user is authenticated");
) -> Result<get_pushers::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
Ok(get_pushers::v3::Response { Ok(get_pushers::v3::Response {
pushers: services().pusher.get_pushers(sender_user)?, pushers: services().pusher.get_pushers(sender_user)?,
}) })
} }
/// # `POST /_matrix/client/r0/pushers/set` /// # `POST /_matrix/client/r0/pushers/set`
@ -419,14 +322,10 @@ pub async fn get_pushers_route(
/// Adds a pusher for the sender user. /// Adds a pusher for the sender user.
/// ///
/// - TODO: Handle `append` /// - TODO: Handle `append`
pub async fn set_pushers_route( pub async fn set_pushers_route(body: Ruma<set_pusher::v3::Request>) -> Result<set_pusher::v3::Response> {
body: Ruma<set_pusher::v3::Request>, let sender_user = body.sender_user.as_ref().expect("user is authenticated");
) -> Result<set_pusher::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
services() services().pusher.set_pusher(sender_user, body.action.clone())?;
.pusher
.set_pusher(sender_user, body.action.clone())?;
Ok(set_pusher::v3::Response::default()) Ok(set_pusher::v3::Response::default())
} }

View file

@ -1,182 +1,161 @@
use crate::{service::rooms::timeline::PduCount, services, Error, Result, Ruma};
use ruma::{
api::client::{error::ErrorKind, read_marker::set_read_marker, receipt::create_receipt},
events::{
receipt::{ReceiptThread, ReceiptType},
RoomAccountDataEventType,
},
MilliSecondsSinceUnixEpoch,
};
use std::collections::BTreeMap; use std::collections::BTreeMap;
use ruma::{
api::client::{error::ErrorKind, read_marker::set_read_marker, receipt::create_receipt},
events::{
receipt::{ReceiptThread, ReceiptType},
RoomAccountDataEventType,
},
MilliSecondsSinceUnixEpoch,
};
use crate::{service::rooms::timeline::PduCount, services, Error, Result, Ruma};
/// # `POST /_matrix/client/r0/rooms/{roomId}/read_markers` /// # `POST /_matrix/client/r0/rooms/{roomId}/read_markers`
/// ///
/// Sets different types of read markers. /// Sets different types of read markers.
/// ///
/// - Updates fully-read account data event to `fully_read` /// - Updates fully-read account data event to `fully_read`
/// - If `read_receipt` is set: Update private marker and public read receipt EDU /// - If `read_receipt` is set: Update private marker and public read receipt
pub async fn set_read_marker_route( /// EDU
body: Ruma<set_read_marker::v3::Request>, pub async fn set_read_marker_route(body: Ruma<set_read_marker::v3::Request>) -> Result<set_read_marker::v3::Response> {
) -> Result<set_read_marker::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if let Some(fully_read) = &body.fully_read { if let Some(fully_read) = &body.fully_read {
let fully_read_event = ruma::events::fully_read::FullyReadEvent { let fully_read_event = ruma::events::fully_read::FullyReadEvent {
content: ruma::events::fully_read::FullyReadEventContent { content: ruma::events::fully_read::FullyReadEventContent {
event_id: fully_read.clone(), event_id: fully_read.clone(),
}, },
}; };
services().account_data.update( services().account_data.update(
Some(&body.room_id), Some(&body.room_id),
sender_user, sender_user,
RoomAccountDataEventType::FullyRead, RoomAccountDataEventType::FullyRead,
&serde_json::to_value(fully_read_event).expect("to json value always works"), &serde_json::to_value(fully_read_event).expect("to json value always works"),
)?; )?;
} }
if body.private_read_receipt.is_some() || body.read_receipt.is_some() { if body.private_read_receipt.is_some() || body.read_receipt.is_some() {
services() services().rooms.user.reset_notification_counts(sender_user, &body.room_id)?;
.rooms }
.user
.reset_notification_counts(sender_user, &body.room_id)?;
}
if let Some(event) = &body.private_read_receipt { if let Some(event) = &body.private_read_receipt {
let count = services() let count = services()
.rooms .rooms
.timeline .timeline
.get_pdu_count(event)? .get_pdu_count(event)?
.ok_or(Error::BadRequest( .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Event does not exist."))?;
ErrorKind::InvalidParam, let count = match count {
"Event does not exist.", PduCount::Backfilled(_) => {
))?; return Err(Error::BadRequest(
let count = match count { ErrorKind::InvalidParam,
PduCount::Backfilled(_) => { "Read receipt is in backfilled timeline",
return Err(Error::BadRequest( ))
ErrorKind::InvalidParam, },
"Read receipt is in backfilled timeline", PduCount::Normal(c) => c,
)) };
} services().rooms.edus.read_receipt.private_read_set(&body.room_id, sender_user, count)?;
PduCount::Normal(c) => c, }
};
services()
.rooms
.edus
.read_receipt
.private_read_set(&body.room_id, sender_user, count)?;
}
if let Some(event) = &body.read_receipt { if let Some(event) = &body.read_receipt {
let mut user_receipts = BTreeMap::new(); let mut user_receipts = BTreeMap::new();
user_receipts.insert( user_receipts.insert(
sender_user.clone(), sender_user.clone(),
ruma::events::receipt::Receipt { ruma::events::receipt::Receipt {
ts: Some(MilliSecondsSinceUnixEpoch::now()), ts: Some(MilliSecondsSinceUnixEpoch::now()),
thread: ReceiptThread::Unthreaded, thread: ReceiptThread::Unthreaded,
}, },
); );
let mut receipts = BTreeMap::new(); let mut receipts = BTreeMap::new();
receipts.insert(ReceiptType::Read, user_receipts); receipts.insert(ReceiptType::Read, user_receipts);
let mut receipt_content = BTreeMap::new(); let mut receipt_content = BTreeMap::new();
receipt_content.insert(event.to_owned(), receipts); receipt_content.insert(event.to_owned(), receipts);
services().rooms.edus.read_receipt.readreceipt_update( services().rooms.edus.read_receipt.readreceipt_update(
sender_user, sender_user,
&body.room_id, &body.room_id,
ruma::events::receipt::ReceiptEvent { ruma::events::receipt::ReceiptEvent {
content: ruma::events::receipt::ReceiptEventContent(receipt_content), content: ruma::events::receipt::ReceiptEventContent(receipt_content),
room_id: body.room_id.clone(), room_id: body.room_id.clone(),
}, },
)?; )?;
} }
Ok(set_read_marker::v3::Response {}) Ok(set_read_marker::v3::Response {})
} }
/// # `POST /_matrix/client/r0/rooms/{roomId}/receipt/{receiptType}/{eventId}` /// # `POST /_matrix/client/r0/rooms/{roomId}/receipt/{receiptType}/{eventId}`
/// ///
/// Sets private read marker and public read receipt EDU. /// Sets private read marker and public read receipt EDU.
pub async fn create_receipt_route( pub async fn create_receipt_route(body: Ruma<create_receipt::v3::Request>) -> Result<create_receipt::v3::Response> {
body: Ruma<create_receipt::v3::Request>, let sender_user = body.sender_user.as_ref().expect("user is authenticated");
) -> Result<create_receipt::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if matches!( if matches!(
&body.receipt_type, &body.receipt_type,
create_receipt::v3::ReceiptType::Read | create_receipt::v3::ReceiptType::ReadPrivate create_receipt::v3::ReceiptType::Read | create_receipt::v3::ReceiptType::ReadPrivate
) { ) {
services() services().rooms.user.reset_notification_counts(sender_user, &body.room_id)?;
.rooms }
.user
.reset_notification_counts(sender_user, &body.room_id)?;
}
match body.receipt_type { match body.receipt_type {
create_receipt::v3::ReceiptType::FullyRead => { create_receipt::v3::ReceiptType::FullyRead => {
let fully_read_event = ruma::events::fully_read::FullyReadEvent { let fully_read_event = ruma::events::fully_read::FullyReadEvent {
content: ruma::events::fully_read::FullyReadEventContent { content: ruma::events::fully_read::FullyReadEventContent {
event_id: body.event_id.clone(), event_id: body.event_id.clone(),
}, },
}; };
services().account_data.update( services().account_data.update(
Some(&body.room_id), Some(&body.room_id),
sender_user, sender_user,
RoomAccountDataEventType::FullyRead, RoomAccountDataEventType::FullyRead,
&serde_json::to_value(fully_read_event).expect("to json value always works"), &serde_json::to_value(fully_read_event).expect("to json value always works"),
)?; )?;
} },
create_receipt::v3::ReceiptType::Read => { create_receipt::v3::ReceiptType::Read => {
let mut user_receipts = BTreeMap::new(); let mut user_receipts = BTreeMap::new();
user_receipts.insert( user_receipts.insert(
sender_user.clone(), sender_user.clone(),
ruma::events::receipt::Receipt { ruma::events::receipt::Receipt {
ts: Some(MilliSecondsSinceUnixEpoch::now()), ts: Some(MilliSecondsSinceUnixEpoch::now()),
thread: ReceiptThread::Unthreaded, thread: ReceiptThread::Unthreaded,
}, },
); );
let mut receipts = BTreeMap::new(); let mut receipts = BTreeMap::new();
receipts.insert(ReceiptType::Read, user_receipts); receipts.insert(ReceiptType::Read, user_receipts);
let mut receipt_content = BTreeMap::new(); let mut receipt_content = BTreeMap::new();
receipt_content.insert(body.event_id.clone(), receipts); receipt_content.insert(body.event_id.clone(), receipts);
services().rooms.edus.read_receipt.readreceipt_update( services().rooms.edus.read_receipt.readreceipt_update(
sender_user, sender_user,
&body.room_id, &body.room_id,
ruma::events::receipt::ReceiptEvent { ruma::events::receipt::ReceiptEvent {
content: ruma::events::receipt::ReceiptEventContent(receipt_content), content: ruma::events::receipt::ReceiptEventContent(receipt_content),
room_id: body.room_id.clone(), room_id: body.room_id.clone(),
}, },
)?; )?;
} },
create_receipt::v3::ReceiptType::ReadPrivate => { create_receipt::v3::ReceiptType::ReadPrivate => {
let count = services() let count = services()
.rooms .rooms
.timeline .timeline
.get_pdu_count(&body.event_id)? .get_pdu_count(&body.event_id)?
.ok_or(Error::BadRequest( .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Event does not exist."))?;
ErrorKind::InvalidParam, let count = match count {
"Event does not exist.", PduCount::Backfilled(_) => {
))?; return Err(Error::BadRequest(
let count = match count { ErrorKind::InvalidParam,
PduCount::Backfilled(_) => { "Read receipt is in backfilled timeline",
return Err(Error::BadRequest( ))
ErrorKind::InvalidParam, },
"Read receipt is in backfilled timeline", PduCount::Normal(c) => c,
)) };
} services().rooms.edus.read_receipt.private_read_set(&body.room_id, sender_user, count)?;
PduCount::Normal(c) => c, },
}; _ => return Err(Error::bad_database("Unsupported receipt type")),
services().rooms.edus.read_receipt.private_read_set( }
&body.room_id,
sender_user,
count,
)?;
}
_ => return Err(Error::bad_database("Unsupported receipt type")),
}
Ok(create_receipt::v3::Response {}) Ok(create_receipt::v3::Response {})
} }

View file

@ -1,58 +1,51 @@
use std::sync::Arc; use std::sync::Arc;
use crate::{service::pdu::PduBuilder, services, Result, Ruma};
use ruma::{ use ruma::{
api::client::redact::redact_event, api::client::redact::redact_event,
events::{room::redaction::RoomRedactionEventContent, TimelineEventType}, events::{room::redaction::RoomRedactionEventContent, TimelineEventType},
}; };
use serde_json::value::to_raw_value; use serde_json::value::to_raw_value;
use crate::{service::pdu::PduBuilder, services, Result, Ruma};
/// # `PUT /_matrix/client/r0/rooms/{roomId}/redact/{eventId}/{txnId}` /// # `PUT /_matrix/client/r0/rooms/{roomId}/redact/{eventId}/{txnId}`
/// ///
/// Tries to send a redaction event into the room. /// Tries to send a redaction event into the room.
/// ///
/// - TODO: Handle txn id /// - TODO: Handle txn id
pub async fn redact_event_route( pub async fn redact_event_route(body: Ruma<redact_event::v3::Request>) -> Result<redact_event::v3::Response> {
body: Ruma<redact_event::v3::Request>, let sender_user = body.sender_user.as_ref().expect("user is authenticated");
) -> Result<redact_event::v3::Response> { let body = body.body;
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let body = body.body;
let mutex_state = Arc::clone( let mutex_state =
services() Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(body.room_id.clone()).or_default());
.globals let state_lock = mutex_state.lock().await;
.roomid_mutex_state
.write()
.unwrap()
.entry(body.room_id.clone())
.or_default(),
);
let state_lock = mutex_state.lock().await;
let event_id = services() let event_id = services()
.rooms .rooms
.timeline .timeline
.build_and_append_pdu( .build_and_append_pdu(
PduBuilder { PduBuilder {
event_type: TimelineEventType::RoomRedaction, event_type: TimelineEventType::RoomRedaction,
content: to_raw_value(&RoomRedactionEventContent { content: to_raw_value(&RoomRedactionEventContent {
redacts: Some(body.event_id.clone()), redacts: Some(body.event_id.clone()),
reason: body.reason.clone(), reason: body.reason.clone(),
}) })
.expect("event is valid, we just created it"), .expect("event is valid, we just created it"),
unsigned: None, unsigned: None,
state_key: None, state_key: None,
redacts: Some(body.event_id.into()), redacts: Some(body.event_id.into()),
}, },
sender_user, sender_user,
&body.room_id, &body.room_id,
&state_lock, &state_lock,
) )
.await?; .await?;
drop(state_lock); drop(state_lock);
let event_id = (*event_id).to_owned(); let event_id = (*event_id).to_owned();
Ok(redact_event::v3::Response { event_id }) Ok(redact_event::v3::Response {
event_id,
})
} }

View file

@ -1,146 +1,113 @@
use ruma::api::client::relations::{ use ruma::api::client::relations::{
get_relating_events, get_relating_events_with_rel_type, get_relating_events, get_relating_events_with_rel_type, get_relating_events_with_rel_type_and_event_type,
get_relating_events_with_rel_type_and_event_type,
}; };
use crate::{service::rooms::timeline::PduCount, services, Result, Ruma}; use crate::{service::rooms::timeline::PduCount, services, Result, Ruma};
/// # `GET /_matrix/client/r0/rooms/{roomId}/relations/{eventId}/{relType}/{eventType}` /// # `GET /_matrix/client/r0/rooms/{roomId}/relations/{eventId}/{relType}/{eventType}`
pub async fn get_relating_events_with_rel_type_and_event_type_route( pub async fn get_relating_events_with_rel_type_and_event_type_route(
body: Ruma<get_relating_events_with_rel_type_and_event_type::v1::Request>, body: Ruma<get_relating_events_with_rel_type_and_event_type::v1::Request>,
) -> Result<get_relating_events_with_rel_type_and_event_type::v1::Response> { ) -> Result<get_relating_events_with_rel_type_and_event_type::v1::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let from = match body.from.clone() { let from = match body.from.clone() {
Some(from) => PduCount::try_from_string(&from)?, Some(from) => PduCount::try_from_string(&from)?,
None => match ruma::api::Direction::Backward { None => match ruma::api::Direction::Backward {
// TODO: fix ruma so `body.dir` exists // TODO: fix ruma so `body.dir` exists
ruma::api::Direction::Forward => PduCount::min(), ruma::api::Direction::Forward => PduCount::min(),
ruma::api::Direction::Backward => PduCount::max(), ruma::api::Direction::Backward => PduCount::max(),
}, },
}; };
let to = body let to = body.to.as_ref().and_then(|t| PduCount::try_from_string(t).ok());
.to
.as_ref()
.and_then(|t| PduCount::try_from_string(t).ok());
// Use limit or else 10, with maximum 100 // Use limit or else 10, with maximum 100
let limit = body let limit = body.limit.and_then(|u| u32::try_from(u).ok()).map_or(10_usize, |u| u as usize).min(100);
.limit
.and_then(|u| u32::try_from(u).ok())
.map_or(10_usize, |u| u as usize)
.min(100);
let res = services() let res = services().rooms.pdu_metadata.paginate_relations_with_filter(
.rooms sender_user,
.pdu_metadata &body.room_id,
.paginate_relations_with_filter( &body.event_id,
sender_user, Some(body.event_type.clone()),
&body.room_id, Some(body.rel_type.clone()),
&body.event_id, from,
Some(body.event_type.clone()), to,
Some(body.rel_type.clone()), limit,
from, )?;
to,
limit,
)?;
Ok( Ok(get_relating_events_with_rel_type_and_event_type::v1::Response {
get_relating_events_with_rel_type_and_event_type::v1::Response { chunk: res.chunk,
chunk: res.chunk, next_batch: res.next_batch,
next_batch: res.next_batch, prev_batch: res.prev_batch,
prev_batch: res.prev_batch, })
},
)
} }
/// # `GET /_matrix/client/r0/rooms/{roomId}/relations/{eventId}/{relType}` /// # `GET /_matrix/client/r0/rooms/{roomId}/relations/{eventId}/{relType}`
pub async fn get_relating_events_with_rel_type_route( pub async fn get_relating_events_with_rel_type_route(
body: Ruma<get_relating_events_with_rel_type::v1::Request>, body: Ruma<get_relating_events_with_rel_type::v1::Request>,
) -> Result<get_relating_events_with_rel_type::v1::Response> { ) -> Result<get_relating_events_with_rel_type::v1::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let from = match body.from.clone() { let from = match body.from.clone() {
Some(from) => PduCount::try_from_string(&from)?, Some(from) => PduCount::try_from_string(&from)?,
None => match ruma::api::Direction::Backward { None => match ruma::api::Direction::Backward {
// TODO: fix ruma so `body.dir` exists // TODO: fix ruma so `body.dir` exists
ruma::api::Direction::Forward => PduCount::min(), ruma::api::Direction::Forward => PduCount::min(),
ruma::api::Direction::Backward => PduCount::max(), ruma::api::Direction::Backward => PduCount::max(),
}, },
}; };
let to = body let to = body.to.as_ref().and_then(|t| PduCount::try_from_string(t).ok());
.to
.as_ref()
.and_then(|t| PduCount::try_from_string(t).ok());
// Use limit or else 10, with maximum 100 // Use limit or else 10, with maximum 100
let limit = body let limit = body.limit.and_then(|u| u32::try_from(u).ok()).map_or(10_usize, |u| u as usize).min(100);
.limit
.and_then(|u| u32::try_from(u).ok())
.map_or(10_usize, |u| u as usize)
.min(100);
let res = services() let res = services().rooms.pdu_metadata.paginate_relations_with_filter(
.rooms sender_user,
.pdu_metadata &body.room_id,
.paginate_relations_with_filter( &body.event_id,
sender_user, None,
&body.room_id, Some(body.rel_type.clone()),
&body.event_id, from,
None, to,
Some(body.rel_type.clone()), limit,
from, )?;
to,
limit,
)?;
Ok(get_relating_events_with_rel_type::v1::Response { Ok(get_relating_events_with_rel_type::v1::Response {
chunk: res.chunk, chunk: res.chunk,
next_batch: res.next_batch, next_batch: res.next_batch,
prev_batch: res.prev_batch, prev_batch: res.prev_batch,
}) })
} }
/// # `GET /_matrix/client/r0/rooms/{roomId}/relations/{eventId}` /// # `GET /_matrix/client/r0/rooms/{roomId}/relations/{eventId}`
pub async fn get_relating_events_route( pub async fn get_relating_events_route(
body: Ruma<get_relating_events::v1::Request>, body: Ruma<get_relating_events::v1::Request>,
) -> Result<get_relating_events::v1::Response> { ) -> Result<get_relating_events::v1::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let from = match body.from.clone() { let from = match body.from.clone() {
Some(from) => PduCount::try_from_string(&from)?, Some(from) => PduCount::try_from_string(&from)?,
None => match ruma::api::Direction::Backward { None => match ruma::api::Direction::Backward {
// TODO: fix ruma so `body.dir` exists // TODO: fix ruma so `body.dir` exists
ruma::api::Direction::Forward => PduCount::min(), ruma::api::Direction::Forward => PduCount::min(),
ruma::api::Direction::Backward => PduCount::max(), ruma::api::Direction::Backward => PduCount::max(),
}, },
}; };
let to = body let to = body.to.as_ref().and_then(|t| PduCount::try_from_string(t).ok());
.to
.as_ref()
.and_then(|t| PduCount::try_from_string(t).ok());
// Use limit or else 10, with maximum 100 // Use limit or else 10, with maximum 100
let limit = body let limit = body.limit.and_then(|u| u32::try_from(u).ok()).map_or(10_usize, |u| u as usize).min(100);
.limit
.and_then(|u| u32::try_from(u).ok())
.map_or(10_usize, |u| u as usize)
.min(100);
services() services().rooms.pdu_metadata.paginate_relations_with_filter(
.rooms sender_user,
.pdu_metadata &body.room_id,
.paginate_relations_with_filter( &body.event_id,
sender_user, None,
&body.room_id, None,
&body.event_id, from,
None, to,
None, limit,
from, )
to,
limit,
)
} }

View file

@ -1,118 +1,112 @@
use std::time::Duration; use std::time::Duration;
use crate::{services, utils::HtmlEscape, Error, Result, Ruma};
use rand::Rng; use rand::Rng;
use ruma::{ use ruma::{
api::client::{error::ErrorKind, room::report_content}, api::client::{error::ErrorKind, room::report_content},
events::room::message, events::room::message,
int, int,
}; };
use tokio::time::sleep; use tokio::time::sleep;
use tracing::{debug, info}; use tracing::{debug, info};
use crate::{services, utils::HtmlEscape, Error, Result, Ruma};
/// # `POST /_matrix/client/v3/rooms/{roomId}/report/{eventId}` /// # `POST /_matrix/client/v3/rooms/{roomId}/report/{eventId}`
/// ///
/// Reports an inappropriate event to homeserver admins /// Reports an inappropriate event to homeserver admins
/// pub async fn report_event_route(body: Ruma<report_content::v3::Request>) -> Result<report_content::v3::Response> {
pub async fn report_event_route( // user authentication
body: Ruma<report_content::v3::Request>, let sender_user = body.sender_user.as_ref().expect("user is authenticated");
) -> Result<report_content::v3::Response> {
// user authentication
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
info!("Received /report request by user {}", sender_user); info!("Received /report request by user {}", sender_user);
// check if we know about the reported event ID or if it's invalid // check if we know about the reported event ID or if it's invalid
let pdu = match services().rooms.timeline.get_pdu(&body.event_id)? { let pdu = match services().rooms.timeline.get_pdu(&body.event_id)? {
Some(pdu) => pdu, Some(pdu) => pdu,
_ => { _ => {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::NotFound, ErrorKind::NotFound,
"Event ID is not known to us or Event ID is invalid", "Event ID is not known to us or Event ID is invalid",
)) ))
} },
}; };
// check if the room ID from the URI matches the PDU's room ID // check if the room ID from the URI matches the PDU's room ID
if body.room_id != pdu.room_id { if body.room_id != pdu.room_id {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::NotFound, ErrorKind::NotFound,
"Event ID does not belong to the reported room", "Event ID does not belong to the reported room",
)); ));
} }
// check if reporting user is in the reporting room // check if reporting user is in the reporting room
if !services() if !services()
.rooms .rooms
.state_cache .state_cache
.room_members(&pdu.room_id) .room_members(&pdu.room_id)
.filter_map(std::result::Result::ok) .filter_map(std::result::Result::ok)
.any(|user_id| user_id == *sender_user) .any(|user_id| user_id == *sender_user)
{ {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::NotFound, ErrorKind::NotFound,
"You are not in the room you are reporting.", "You are not in the room you are reporting.",
)); ));
} }
// check if score is in valid range // check if score is in valid range
if let Some(true) = body.score.map(|s| s > int!(0) || s < int!(-100)) { if let Some(true) = body.score.map(|s| s > int!(0) || s < int!(-100)) {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::InvalidParam, ErrorKind::InvalidParam,
"Invalid score, must be within 0 to -100", "Invalid score, must be within 0 to -100",
)); ));
}; };
// check if report reasoning is less than or equal to 750 characters // check if report reasoning is less than or equal to 750 characters
if let Some(true) = body.reason.clone().map(|s| s.chars().count() >= 750) { if let Some(true) = body.reason.clone().map(|s| s.chars().count() >= 750) {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::InvalidParam, ErrorKind::InvalidParam,
"Reason too long, should be 750 characters or fewer", "Reason too long, should be 750 characters or fewer",
)); ));
}; };
// send admin room message that we received the report with an @room ping for urgency // send admin room message that we received the report with an @room ping for
services() // urgency
.admin services().admin.send_message(message::RoomMessageEventContent::text_html(
.send_message(message::RoomMessageEventContent::text_html( format!(
format!( "@room Report received from: {}\n\nEvent ID: {}\nRoom ID: {}\nSent By: {}\n\nReport Score: {}\nReport \
"@room Report received from: {}\n\n\ Reason: {}",
Event ID: {}\n\ sender_user.to_owned(),
Room ID: {}\n\ pdu.event_id,
Sent By: {}\n\n\ pdu.room_id,
Report Score: {}\n\ pdu.sender.clone(),
Report Reason: {}", body.score.unwrap_or_else(|| ruma::Int::from(0)),
sender_user.to_owned(), body.reason.as_deref().unwrap_or("")
pdu.event_id, ),
pdu.room_id, format!(
pdu.sender.clone(), "<details><summary>@room Report received from: <a href=\"https://matrix.to/#/{0}\">{0}\
body.score.unwrap_or_else(|| ruma::Int::from(0)),
body.reason.as_deref().unwrap_or("")
),
format!(
"<details><summary>@room Report received from: <a href=\"https://matrix.to/#/{0}\">{0}\
</a></summary><ul><li>Event Info<ul><li>Event ID: <code>{1}</code>\ </a></summary><ul><li>Event Info<ul><li>Event ID: <code>{1}</code>\
<a href=\"https://matrix.to/#/{2}/{1}\">🔗</a></li><li>Room ID: <code>{2}</code>\ <a href=\"https://matrix.to/#/{2}/{1}\">🔗</a></li><li>Room ID: <code>{2}</code>\
</li><li>Sent By: <a href=\"https://matrix.to/#/{3}\">{3}</a></li></ul></li><li>\ </li><li>Sent By: <a href=\"https://matrix.to/#/{3}\">{3}</a></li></ul></li><li>\
Report Info<ul><li>Report Score: {4}</li><li>Report Reason: {5}</li></ul></li>\ Report Info<ul><li>Report Score: {4}</li><li>Report Reason: {5}</li></ul></li>\
</ul></details>", </ul></details>",
sender_user.to_owned(), sender_user.to_owned(),
pdu.event_id.clone(), pdu.event_id.clone(),
pdu.room_id.clone(), pdu.room_id.clone(),
pdu.sender.clone(), pdu.sender.clone(),
body.score.unwrap_or_else(|| ruma::Int::from(0)), body.score.unwrap_or_else(|| ruma::Int::from(0)),
HtmlEscape(body.reason.as_deref().unwrap_or("")) HtmlEscape(body.reason.as_deref().unwrap_or(""))
), ),
)); ));
// even though this is kinda security by obscurity, let's still make a small random delay sending a successful response // even though this is kinda security by obscurity, let's still make a small
// per spec suggestion regarding enumerating for potential events existing in our server. // random delay sending a successful response per spec suggestion regarding
let time_to_wait = rand::thread_rng().gen_range(8..21); // enumerating for potential events existing in our server.
debug!( let time_to_wait = rand::thread_rng().gen_range(8..21);
"Got successful /report request, waiting {} seconds before sending successful response.", debug!(
time_to_wait "Got successful /report request, waiting {} seconds before sending successful response.",
); time_to_wait
sleep(Duration::from_secs(time_to_wait)).await; );
sleep(Duration::from_secs(time_to_wait)).await;
Ok(report_content::v3::Response {}) Ok(report_content::v3::Response {})
} }

File diff suppressed because it is too large Load diff

View file

@ -1,138 +1,120 @@
use crate::{services, Error, Result, Ruma}; use std::collections::BTreeMap;
use ruma::api::client::{ use ruma::api::client::{
error::ErrorKind, error::ErrorKind,
search::search_events::{ search::search_events::{
self, self,
v3::{EventContextResult, ResultCategories, ResultRoomEvents, SearchResult}, v3::{EventContextResult, ResultCategories, ResultRoomEvents, SearchResult},
}, },
}; };
use std::collections::BTreeMap; use crate::{services, Error, Result, Ruma};
/// # `POST /_matrix/client/r0/search` /// # `POST /_matrix/client/r0/search`
/// ///
/// Searches rooms for messages. /// Searches rooms for messages.
/// ///
/// - Only works if the user is currently joined to the room (TODO: Respect history visibility) /// - Only works if the user is currently joined to the room (TODO: Respect
pub async fn search_events_route( /// history visibility)
body: Ruma<search_events::v3::Request>, pub async fn search_events_route(body: Ruma<search_events::v3::Request>) -> Result<search_events::v3::Response> {
) -> Result<search_events::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let search_criteria = body.search_categories.room_events.as_ref().unwrap(); let search_criteria = body.search_categories.room_events.as_ref().unwrap();
let filter = &search_criteria.filter; let filter = &search_criteria.filter;
let room_ids = filter.rooms.clone().unwrap_or_else(|| { let room_ids = filter.rooms.clone().unwrap_or_else(|| {
services() services().rooms.state_cache.rooms_joined(sender_user).filter_map(std::result::Result::ok).collect()
.rooms });
.state_cache
.rooms_joined(sender_user)
.filter_map(std::result::Result::ok)
.collect()
});
// Use limit or else 10, with maximum 100 // Use limit or else 10, with maximum 100
let limit = filter.limit.map_or(10, u64::from).min(100) as usize; let limit = filter.limit.map_or(10, u64::from).min(100) as usize;
let mut searches = Vec::new(); let mut searches = Vec::new();
for room_id in room_ids { for room_id in room_ids {
if !services() if !services().rooms.state_cache.is_joined(sender_user, &room_id)? {
.rooms return Err(Error::BadRequest(
.state_cache ErrorKind::Forbidden,
.is_joined(sender_user, &room_id)? "You don't have permission to view this room.",
{ ));
return Err(Error::BadRequest( }
ErrorKind::Forbidden,
"You don't have permission to view this room.",
));
}
if let Some(search) = services() if let Some(search) = services().rooms.search.search_pdus(&room_id, &search_criteria.search_term)? {
.rooms searches.push(search.0.peekable());
.search }
.search_pdus(&room_id, &search_criteria.search_term)? }
{
searches.push(search.0.peekable());
}
}
let skip = match body.next_batch.as_ref().map(|s| s.parse()) { let skip = match body.next_batch.as_ref().map(|s| s.parse()) {
Some(Ok(s)) => s, Some(Ok(s)) => s,
Some(Err(_)) => { Some(Err(_)) => return Err(Error::BadRequest(ErrorKind::InvalidParam, "Invalid next_batch token.")),
return Err(Error::BadRequest( None => 0, // Default to the start
ErrorKind::InvalidParam, };
"Invalid next_batch token.",
))
}
None => 0, // Default to the start
};
let mut results = Vec::new(); let mut results = Vec::new();
for _ in 0..skip + limit { for _ in 0..skip + limit {
if let Some(s) = searches if let Some(s) = searches
.iter_mut() .iter_mut()
.map(|s| (s.peek().cloned(), s)) .map(|s| (s.peek().cloned(), s))
.max_by_key(|(peek, _)| peek.clone()) .max_by_key(|(peek, _)| peek.clone())
.and_then(|(_, i)| i.next()) .and_then(|(_, i)| i.next())
{ {
results.push(s); results.push(s);
} }
} }
let results: Vec<_> = results let results: Vec<_> = results
.iter() .iter()
.filter_map(|result| { .filter_map(|result| {
services() services()
.rooms .rooms
.timeline .timeline
.get_pdu_from_id(result) .get_pdu_from_id(result)
.ok()? .ok()?
.filter(|pdu| { .filter(|pdu| {
services() services()
.rooms .rooms
.state_accessor .state_accessor
.user_can_see_event(sender_user, &pdu.room_id, &pdu.event_id) .user_can_see_event(sender_user, &pdu.room_id, &pdu.event_id)
.unwrap_or(false) .unwrap_or(false)
}) })
.map(|pdu| pdu.to_room_event()) .map(|pdu| pdu.to_room_event())
}) })
.map(|result| { .map(|result| {
Ok::<_, Error>(SearchResult { Ok::<_, Error>(SearchResult {
context: EventContextResult { context: EventContextResult {
end: None, end: None,
events_after: Vec::new(), events_after: Vec::new(),
events_before: Vec::new(), events_before: Vec::new(),
profile_info: BTreeMap::new(), profile_info: BTreeMap::new(),
start: None, start: None,
}, },
rank: None, rank: None,
result: Some(result), result: Some(result),
}) })
}) })
.filter_map(std::result::Result::ok) .filter_map(std::result::Result::ok)
.skip(skip) .skip(skip)
.take(limit) .take(limit)
.collect(); .collect();
let next_batch = if results.len() < limit { let next_batch = if results.len() < limit {
None None
} else { } else {
Some((skip + limit).to_string()) Some((skip + limit).to_string())
}; };
Ok(search_events::v3::Response::new(ResultCategories { Ok(search_events::v3::Response::new(ResultCategories {
room_events: ResultRoomEvents { room_events: ResultRoomEvents {
count: Some((results.len() as u32).into()), // TODO: set this to none. Element shouldn't depend on it count: Some((results.len() as u32).into()), // TODO: set this to none. Element shouldn't depend on it
groups: BTreeMap::new(), // TODO groups: BTreeMap::new(), // TODO
next_batch, next_batch,
results, results,
state: BTreeMap::new(), // TODO state: BTreeMap::new(), // TODO
highlights: search_criteria highlights: search_criteria
.search_term .search_term
.split_terminator(|c: char| !c.is_alphanumeric()) .split_terminator(|c: char| !c.is_alphanumeric())
.map(str::to_lowercase) .map(str::to_lowercase)
.collect(), .collect(),
}, },
})) }))
} }

View file

@ -1,246 +1,221 @@
use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH};
use crate::{services, utils, Error, Result, Ruma};
use argon2::{PasswordHash, PasswordVerifier}; use argon2::{PasswordHash, PasswordVerifier};
use ruma::{ use ruma::{
api::client::{ api::client::{
error::ErrorKind, error::ErrorKind,
session::{ session::{
get_login_types::{ get_login_types::{
self, self,
v3::{ApplicationServiceLoginType, PasswordLoginType}, v3::{ApplicationServiceLoginType, PasswordLoginType},
}, },
login::{ login::{
self, self,
v3::{DiscoveryInfo, HomeserverInfo}, v3::{DiscoveryInfo, HomeserverInfo},
}, },
logout, logout_all, logout, logout_all,
}, },
uiaa::UserIdentifier, uiaa::UserIdentifier,
}, },
UserId, UserId,
}; };
use serde::Deserialize; use serde::Deserialize;
use tracing::{debug, error, info, warn}; use tracing::{debug, error, info, warn};
use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH};
use crate::{services, utils, Error, Result, Ruma};
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
struct Claims { struct Claims {
sub: String, sub: String,
//exp: usize, //exp: usize,
} }
/// # `GET /_matrix/client/v3/login` /// # `GET /_matrix/client/v3/login`
/// ///
/// Get the supported login types of this server. One of these should be used as the `type` field /// Get the supported login types of this server. One of these should be used as
/// when logging in. /// the `type` field when logging in.
pub async fn get_login_types_route( pub async fn get_login_types_route(_body: Ruma<get_login_types::v3::Request>) -> Result<get_login_types::v3::Response> {
_body: Ruma<get_login_types::v3::Request>, Ok(get_login_types::v3::Response::new(vec![
) -> Result<get_login_types::v3::Response> { get_login_types::v3::LoginType::Password(PasswordLoginType::default()),
Ok(get_login_types::v3::Response::new(vec![ get_login_types::v3::LoginType::ApplicationService(ApplicationServiceLoginType::default()),
get_login_types::v3::LoginType::Password(PasswordLoginType::default()), ]))
get_login_types::v3::LoginType::ApplicationService(ApplicationServiceLoginType::default()),
]))
} }
/// # `POST /_matrix/client/v3/login` /// # `POST /_matrix/client/v3/login`
/// ///
/// Authenticates the user and returns an access token it can use in subsequent requests. /// Authenticates the user and returns an access token it can use in subsequent
/// requests.
/// ///
/// - The user needs to authenticate using their password (or if enabled using a json web token) /// - The user needs to authenticate using their password (or if enabled using a
/// json web token)
/// - If `device_id` is known: invalidates old access token of that device /// - If `device_id` is known: invalidates old access token of that device
/// - If `device_id` is unknown: creates a new device /// - If `device_id` is unknown: creates a new device
/// - Returns access token that is associated with the user and device /// - Returns access token that is associated with the user and device
/// ///
/// Note: You can use [`GET /_matrix/client/r0/login`](fn.get_supported_versions_route.html) to see /// Note: You can use [`GET
/// /_matrix/client/r0/login`](fn.get_supported_versions_route.html) to see
/// supported login types. /// supported login types.
pub async fn login_route(body: Ruma<login::v3::Request>) -> Result<login::v3::Response> { pub async fn login_route(body: Ruma<login::v3::Request>) -> Result<login::v3::Response> {
// Validate login method // Validate login method
// TODO: Other login methods // TODO: Other login methods
let user_id = match &body.login_info { let user_id = match &body.login_info {
#[allow(deprecated)] #[allow(deprecated)]
login::v3::LoginInfo::Password(login::v3::Password { login::v3::LoginInfo::Password(login::v3::Password {
identifier, identifier,
password, password,
user, user,
.. ..
}) => { }) => {
debug!("Got password login type"); debug!("Got password login type");
let username = if let Some(UserIdentifier::UserIdOrLocalpart(user_id)) = identifier { let username = if let Some(UserIdentifier::UserIdOrLocalpart(user_id)) = identifier {
debug!("Using username from identifier field"); debug!("Using username from identifier field");
user_id.to_lowercase() user_id.to_lowercase()
} else if let Some(user_id) = user { } else if let Some(user_id) = user {
warn!("User \"{}\" is attempting to login with the deprecated \"user\" field at \"/_matrix/client/v3/login\". conduwuit implements this deprecated behaviour, but this is destined to be removed in a future Matrix release.", user_id); warn!(
user_id.to_lowercase() "User \"{}\" is attempting to login with the deprecated \"user\" field at \
} else { \"/_matrix/client/v3/login\". conduwuit implements this deprecated behaviour, but this is \
warn!("Bad login type: {:?}", &body.login_info); destined to be removed in a future Matrix release.",
return Err(Error::BadRequest(ErrorKind::Forbidden, "Bad login type.")); user_id
}; );
user_id.to_lowercase()
} else {
warn!("Bad login type: {:?}", &body.login_info);
return Err(Error::BadRequest(ErrorKind::Forbidden, "Bad login type."));
};
let user_id = let user_id = UserId::parse_with_server_name(username, services().globals.server_name()).map_err(|e| {
UserId::parse_with_server_name(username, services().globals.server_name()) warn!("Failed to parse username from user logging in: {}", e);
.map_err(|e| { Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.")
warn!("Failed to parse username from user logging in: {}", e); })?;
Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.")
})?;
let hash = services() let hash = services()
.users .users
.password_hash(&user_id)? .password_hash(&user_id)?
.ok_or(Error::BadRequest( .ok_or(Error::BadRequest(ErrorKind::Forbidden, "Wrong username or password."))?;
ErrorKind::Forbidden,
"Wrong username or password.",
))?;
if hash.is_empty() { if hash.is_empty() {
return Err(Error::BadRequest( return Err(Error::BadRequest(ErrorKind::UserDeactivated, "The user has been deactivated"));
ErrorKind::UserDeactivated, }
"The user has been deactivated",
));
}
let Ok(parsed_hash) = PasswordHash::new(&hash) else { let Ok(parsed_hash) = PasswordHash::new(&hash) else {
error!("error while hashing user {}", user_id); error!("error while hashing user {}", user_id);
return Err(Error::BadServerResponse("could not hash")); return Err(Error::BadServerResponse("could not hash"));
}; };
let hash_matches = services() let hash_matches = services().globals.argon.verify_password(password.as_bytes(), &parsed_hash).is_ok();
.globals
.argon
.verify_password(password.as_bytes(), &parsed_hash)
.is_ok();
if !hash_matches { if !hash_matches {
return Err(Error::BadRequest( return Err(Error::BadRequest(ErrorKind::Forbidden, "Wrong username or password."));
ErrorKind::Forbidden, }
"Wrong username or password.",
));
}
user_id user_id
} },
login::v3::LoginInfo::Token(login::v3::Token { token }) => { login::v3::LoginInfo::Token(login::v3::Token {
debug!("Got token login type"); token,
if let Some(jwt_decoding_key) = services().globals.jwt_decoding_key() { }) => {
let token = jsonwebtoken::decode::<Claims>( debug!("Got token login type");
token, if let Some(jwt_decoding_key) = services().globals.jwt_decoding_key() {
jwt_decoding_key, let token =
&jsonwebtoken::Validation::default(), jsonwebtoken::decode::<Claims>(token, jwt_decoding_key, &jsonwebtoken::Validation::default())
) .map_err(|e| {
.map_err(|e| { warn!("Failed to parse JWT token from user logging in: {}", e);
warn!("Failed to parse JWT token from user logging in: {}", e); Error::BadRequest(ErrorKind::InvalidUsername, "Token is invalid.")
Error::BadRequest(ErrorKind::InvalidUsername, "Token is invalid.") })?;
})?;
let username = token.claims.sub.to_lowercase(); let username = token.claims.sub.to_lowercase();
UserId::parse_with_server_name(username, services().globals.server_name()).map_err( UserId::parse_with_server_name(username, services().globals.server_name()).map_err(|e| {
|e| { warn!("Failed to parse username from user logging in: {}", e);
warn!("Failed to parse username from user logging in: {}", e); Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.")
Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.") })?
}, } else {
)? return Err(Error::BadRequest(
} else { ErrorKind::Unknown,
return Err(Error::BadRequest( "Token login is not supported (server has no jwt decoding key).",
ErrorKind::Unknown, ));
"Token login is not supported (server has no jwt decoding key).", }
)); },
} #[allow(deprecated)]
} login::v3::LoginInfo::ApplicationService(login::v3::ApplicationService {
#[allow(deprecated)] identifier,
login::v3::LoginInfo::ApplicationService(login::v3::ApplicationService { user,
identifier, }) => {
user, debug!("Got appservice login type");
}) => { if !body.from_appservice {
debug!("Got appservice login type"); info!(
if !body.from_appservice { "User tried logging in as an appservice, but request body is not from a known/registered \
info!("User tried logging in as an appservice, but request body is not from a known/registered appservice"); appservice"
return Err(Error::BadRequest( );
ErrorKind::Forbidden, return Err(Error::BadRequest(ErrorKind::Forbidden, "Forbidden login type."));
"Forbidden login type.", };
)); let username = if let Some(UserIdentifier::UserIdOrLocalpart(user_id)) = identifier {
}; user_id.to_lowercase()
let username = if let Some(UserIdentifier::UserIdOrLocalpart(user_id)) = identifier { } else if let Some(user_id) = user {
user_id.to_lowercase() warn!(
} else if let Some(user_id) = user { "Appservice \"{}\" is attempting to login with the deprecated \"user\" field at \
warn!("Appservice \"{}\" is attempting to login with the deprecated \"user\" field at \"/_matrix/client/v3/login\". conduwuit implements this deprecated behaviour, but this is destined to be removed in a future Matrix release.", user_id); \"/_matrix/client/v3/login\". conduwuit implements this deprecated behaviour, but this is \
user_id.to_lowercase() destined to be removed in a future Matrix release.",
} else { user_id
return Err(Error::BadRequest(ErrorKind::Forbidden, "Bad login type.")); );
}; user_id.to_lowercase()
} else {
return Err(Error::BadRequest(ErrorKind::Forbidden, "Bad login type."));
};
UserId::parse_with_server_name(username, services().globals.server_name()).map_err( UserId::parse_with_server_name(username, services().globals.server_name()).map_err(|e| {
|e| { warn!("Failed to parse username from appservice logging in: {}", e);
warn!("Failed to parse username from appservice logging in: {}", e); Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.")
Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.") })?
}, },
)? _ => {
} warn!("Unsupported or unknown login type: {:?}", &body.login_info);
_ => { debug!("JSON body: {:?}", &body.json_body);
warn!("Unsupported or unknown login type: {:?}", &body.login_info); return Err(Error::BadRequest(ErrorKind::Unknown, "Unsupported or unknown login type."));
debug!("JSON body: {:?}", &body.json_body); },
return Err(Error::BadRequest( };
ErrorKind::Unknown,
"Unsupported or unknown login type.",
));
}
};
// Generate new device id if the user didn't specify one // Generate new device id if the user didn't specify one
let device_id = body let device_id = body.device_id.clone().unwrap_or_else(|| utils::random_string(DEVICE_ID_LENGTH).into());
.device_id
.clone()
.unwrap_or_else(|| utils::random_string(DEVICE_ID_LENGTH).into());
// Generate a new token for the device // Generate a new token for the device
let token = utils::random_string(TOKEN_LENGTH); let token = utils::random_string(TOKEN_LENGTH);
// Determine if device_id was provided and exists in the db for this user // Determine if device_id was provided and exists in the db for this user
let device_exists = body.device_id.as_ref().map_or(false, |device_id| { let device_exists = body.device_id.as_ref().map_or(false, |device_id| {
services() services().users.all_device_ids(&user_id).any(|x| x.as_ref().map_or(false, |v| v == device_id))
.users });
.all_device_ids(&user_id)
.any(|x| x.as_ref().map_or(false, |v| v == device_id))
});
if device_exists { if device_exists {
services().users.set_token(&user_id, &device_id, &token)?; services().users.set_token(&user_id, &device_id, &token)?;
} else { } else {
services().users.create_device( services().users.create_device(&user_id, &device_id, &token, body.initial_device_display_name.clone())?;
&user_id, }
&device_id,
&token,
body.initial_device_display_name.clone(),
)?;
}
// send client well-known if specified so the client knows to reconfigure itself // send client well-known if specified so the client knows to reconfigure itself
let client_discovery_info = DiscoveryInfo::new(HomeserverInfo::new( let client_discovery_info = DiscoveryInfo::new(HomeserverInfo::new(
services() services().globals.well_known_client().to_owned().unwrap_or_else(|| "".to_owned()),
.globals ));
.well_known_client()
.to_owned()
.unwrap_or_else(|| "".to_owned()),
));
info!("{} logged in", user_id); info!("{} logged in", user_id);
// home_server is deprecated but apparently must still be sent despite it being deprecated over 6 years ago. // home_server is deprecated but apparently must still be sent despite it being
// initially i thought this macro was unnecessary, but ruma uses this same macro for the same reason so... // deprecated over 6 years ago. initially i thought this macro was unnecessary,
#[allow(deprecated)] // but ruma uses this same macro for the same reason so...
Ok(login::v3::Response { #[allow(deprecated)]
user_id, Ok(login::v3::Response {
access_token: token, user_id,
device_id, access_token: token,
well_known: { device_id,
if client_discovery_info.homeserver.base_url.as_str() == "" { well_known: {
None if client_discovery_info.homeserver.base_url.as_str() == "" {
} else { None
Some(client_discovery_info) } else {
} Some(client_discovery_info)
}, }
expires_in: None, },
home_server: Some(services().globals.server_name().to_owned()), expires_in: None,
refresh_token: None, home_server: Some(services().globals.server_name().to_owned()),
}) refresh_token: None,
})
} }
/// # `POST /_matrix/client/v3/logout` /// # `POST /_matrix/client/v3/logout`
@ -248,19 +223,20 @@ pub async fn login_route(body: Ruma<login::v3::Request>) -> Result<login::v3::Re
/// Log out the current device. /// Log out the current device.
/// ///
/// - Invalidates access token /// - Invalidates access token
/// - Deletes device metadata (device id, device display name, last seen ip, last seen ts) /// - Deletes device metadata (device id, device display name, last seen ip,
/// last seen ts)
/// - Forgets to-device events /// - Forgets to-device events
/// - Triggers device list updates /// - Triggers device list updates
pub async fn logout_route(body: Ruma<logout::v3::Request>) -> Result<logout::v3::Response> { pub async fn logout_route(body: Ruma<logout::v3::Request>) -> Result<logout::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_device = body.sender_device.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_ref().expect("user is authenticated");
services().users.remove_device(sender_user, sender_device)?; services().users.remove_device(sender_user, sender_device)?;
// send device list update for user after logout // send device list update for user after logout
services().users.mark_device_key_update(sender_user)?; services().users.mark_device_key_update(sender_user)?;
Ok(logout::v3::Response::new()) Ok(logout::v3::Response::new())
} }
/// # `POST /_matrix/client/r0/logout/all` /// # `POST /_matrix/client/r0/logout/all`
@ -268,23 +244,23 @@ pub async fn logout_route(body: Ruma<logout::v3::Request>) -> Result<logout::v3:
/// Log out all devices of this user. /// Log out all devices of this user.
/// ///
/// - Invalidates all access tokens /// - Invalidates all access tokens
/// - Deletes all device metadata (device id, device display name, last seen ip, last seen ts) /// - Deletes all device metadata (device id, device display name, last seen ip,
/// last seen ts)
/// - Forgets all to-device events /// - Forgets all to-device events
/// - Triggers device list updates /// - Triggers device list updates
/// ///
/// Note: This is equivalent to calling [`GET /_matrix/client/r0/logout`](fn.logout_route.html) /// Note: This is equivalent to calling [`GET
/// from each device of this user. /// /_matrix/client/r0/logout`](fn.logout_route.html) from each device of this
pub async fn logout_all_route( /// user.
body: Ruma<logout_all::v3::Request>, pub async fn logout_all_route(body: Ruma<logout_all::v3::Request>) -> Result<logout_all::v3::Response> {
) -> Result<logout_all::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
for device_id in services().users.all_device_ids(sender_user).flatten() { for device_id in services().users.all_device_ids(sender_user).flatten() {
services().users.remove_device(sender_user, &device_id)?; services().users.remove_device(sender_user, &device_id)?;
} }
// send device list update for user after logout // send device list update for user after logout
services().users.mark_device_key_update(sender_user)?; services().users.mark_device_key_update(sender_user)?;
Ok(logout_all::v3::Response::new()) Ok(logout_all::v3::Response::new())
} }

View file

@ -1,34 +1,19 @@
use crate::{services, Result, Ruma};
use ruma::api::client::space::get_hierarchy; use ruma::api::client::space::get_hierarchy;
use crate::{services, Result, Ruma};
/// # `GET /_matrix/client/v1/rooms/{room_id}/hierarchy`` /// # `GET /_matrix/client/v1/rooms/{room_id}/hierarchy``
/// ///
/// Paginates over the space tree in a depth-first manner to locate child rooms of a given space. /// Paginates over the space tree in a depth-first manner to locate child rooms
pub async fn get_hierarchy_route( /// of a given space.
body: Ruma<get_hierarchy::v1::Request>, pub async fn get_hierarchy_route(body: Ruma<get_hierarchy::v1::Request>) -> Result<get_hierarchy::v1::Response> {
) -> Result<get_hierarchy::v1::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let skip = body let skip = body.from.as_ref().and_then(|s| s.parse::<usize>().ok()).unwrap_or(0);
.from
.as_ref()
.and_then(|s| s.parse::<usize>().ok())
.unwrap_or(0);
let limit = body.limit.map_or(10, u64::from).min(100) as usize; let limit = body.limit.map_or(10, u64::from).min(100) as usize;
let max_depth = body.max_depth.map_or(3, u64::from).min(10) as usize + 1; // +1 to skip the space room itself let max_depth = body.max_depth.map_or(3, u64::from).min(10) as usize + 1; // +1 to skip the space room itself
services() services().rooms.spaces.get_hierarchy(sender_user, &body.room_id, limit, skip, max_depth, body.suggested_only).await
.rooms
.spaces
.get_hierarchy(
sender_user,
&body.room_id,
limit,
skip,
max_depth,
body.suggested_only,
)
.await
} }

View file

@ -1,42 +1,44 @@
use std::sync::Arc; use std::sync::Arc;
use crate::{service::pdu::PduBuilder, services, Error, Result, Ruma, RumaResponse};
use ruma::{ use ruma::{
api::client::{ api::client::{
error::ErrorKind, error::ErrorKind,
state::{get_state_events, get_state_events_for_key, send_state_event}, state::{get_state_events, get_state_events_for_key, send_state_event},
}, },
events::{ events::{room::canonical_alias::RoomCanonicalAliasEventContent, AnyStateEventContent, StateEventType},
room::canonical_alias::RoomCanonicalAliasEventContent, AnyStateEventContent, StateEventType, serde::Raw,
}, EventId, RoomId, UserId,
serde::Raw,
EventId, RoomId, UserId,
}; };
use tracing::{error, log::warn}; use tracing::{error, log::warn};
use crate::{service::pdu::PduBuilder, services, Error, Result, Ruma, RumaResponse};
/// # `PUT /_matrix/client/r0/rooms/{roomId}/state/{eventType}/{stateKey}` /// # `PUT /_matrix/client/r0/rooms/{roomId}/state/{eventType}/{stateKey}`
/// ///
/// Sends a state event into the room. /// Sends a state event into the room.
/// ///
/// - The only requirement for the content is that it has to be valid json /// - The only requirement for the content is that it has to be valid json
/// - Tries to send the event into the room, auth rules will determine if it is allowed /// - Tries to send the event into the room, auth rules will determine if it is
/// allowed
/// - If event is new canonical_alias: Rejects if alias is incorrect /// - If event is new canonical_alias: Rejects if alias is incorrect
pub async fn send_state_event_for_key_route( pub async fn send_state_event_for_key_route(
body: Ruma<send_state_event::v3::Request>, body: Ruma<send_state_event::v3::Request>,
) -> Result<send_state_event::v3::Response> { ) -> Result<send_state_event::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let event_id = send_state_event_for_key_helper( let event_id = send_state_event_for_key_helper(
sender_user, sender_user,
&body.room_id, &body.room_id,
&body.event_type, &body.event_type,
&body.body.body, // Yes, I hate it too &body.body.body, // Yes, I hate it too
body.state_key.clone(), body.state_key.clone(),
) )
.await?; .await?;
let event_id = (*event_id).to_owned(); let event_id = (*event_id).to_owned();
Ok(send_state_event::v3::Response { event_id }) Ok(send_state_event::v3::Response {
event_id,
})
} }
/// # `PUT /_matrix/client/r0/rooms/{roomId}/state/{eventType}` /// # `PUT /_matrix/client/r0/rooms/{roomId}/state/{eventType}`
@ -44,249 +46,209 @@ pub async fn send_state_event_for_key_route(
/// Sends a state event into the room. /// Sends a state event into the room.
/// ///
/// - The only requirement for the content is that it has to be valid json /// - The only requirement for the content is that it has to be valid json
/// - Tries to send the event into the room, auth rules will determine if it is allowed /// - Tries to send the event into the room, auth rules will determine if it is
/// allowed
/// - If event is new canonical_alias: Rejects if alias is incorrect /// - If event is new canonical_alias: Rejects if alias is incorrect
pub async fn send_state_event_for_empty_key_route( pub async fn send_state_event_for_empty_key_route(
body: Ruma<send_state_event::v3::Request>, body: Ruma<send_state_event::v3::Request>,
) -> Result<RumaResponse<send_state_event::v3::Response>> { ) -> Result<RumaResponse<send_state_event::v3::Response>> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
// Forbid m.room.encryption if encryption is disabled // Forbid m.room.encryption if encryption is disabled
if body.event_type == StateEventType::RoomEncryption && !services().globals.allow_encryption() { if body.event_type == StateEventType::RoomEncryption && !services().globals.allow_encryption() {
return Err(Error::BadRequest( return Err(Error::BadRequest(ErrorKind::Forbidden, "Encryption has been disabled"));
ErrorKind::Forbidden, }
"Encryption has been disabled",
));
}
let event_id = send_state_event_for_key_helper( let event_id = send_state_event_for_key_helper(
sender_user, sender_user,
&body.room_id, &body.room_id,
&body.event_type.to_string().into(), &body.event_type.to_string().into(),
&body.body.body, &body.body.body,
body.state_key.clone(), body.state_key.clone(),
) )
.await?; .await?;
let event_id = (*event_id).to_owned(); let event_id = (*event_id).to_owned();
Ok(send_state_event::v3::Response { event_id }.into()) Ok(send_state_event::v3::Response {
event_id,
}
.into())
} }
/// # `GET /_matrix/client/r0/rooms/{roomid}/state` /// # `GET /_matrix/client/r0/rooms/{roomid}/state`
/// ///
/// Get all state events for a room. /// Get all state events for a room.
/// ///
/// - If not joined: Only works if current room history visibility is world readable /// - If not joined: Only works if current room history visibility is world
/// readable
pub async fn get_state_events_route( pub async fn get_state_events_route(
body: Ruma<get_state_events::v3::Request>, body: Ruma<get_state_events::v3::Request>,
) -> Result<get_state_events::v3::Response> { ) -> Result<get_state_events::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if !services() if !services().rooms.state_accessor.user_can_see_state_events(sender_user, &body.room_id)? {
.rooms return Err(Error::BadRequest(
.state_accessor ErrorKind::Forbidden,
.user_can_see_state_events(sender_user, &body.room_id)? "You don't have permission to view the room state.",
{ ));
return Err(Error::BadRequest( }
ErrorKind::Forbidden,
"You don't have permission to view the room state.",
));
}
Ok(get_state_events::v3::Response { Ok(get_state_events::v3::Response {
room_state: services() room_state: services()
.rooms .rooms
.state_accessor .state_accessor
.room_state_full(&body.room_id) .room_state_full(&body.room_id)
.await? .await?
.values() .values()
.map(|pdu| pdu.to_state_event()) .map(|pdu| pdu.to_state_event())
.collect(), .collect(),
}) })
} }
/// # `GET /_matrix/client/v3/rooms/{roomid}/state/{eventType}/{stateKey}` /// # `GET /_matrix/client/v3/rooms/{roomid}/state/{eventType}/{stateKey}`
/// ///
/// Get single state event of a room with the specified state key. /// Get single state event of a room with the specified state key.
/// The optional query parameter `?format=event|content` allows returning the full room state event /// The optional query parameter `?format=event|content` allows returning the
/// or just the state event's content (default behaviour) /// full room state event or just the state event's content (default behaviour)
/// ///
/// - If not joined: Only works if current room history visibility is world readable /// - If not joined: Only works if current room history visibility is world
/// readable
pub async fn get_state_events_for_key_route( pub async fn get_state_events_for_key_route(
body: Ruma<get_state_events_for_key::v3::Request>, body: Ruma<get_state_events_for_key::v3::Request>,
) -> Result<get_state_events_for_key::v3::Response> { ) -> Result<get_state_events_for_key::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if !services() if !services().rooms.state_accessor.user_can_see_state_events(sender_user, &body.room_id)? {
.rooms return Err(Error::BadRequest(
.state_accessor ErrorKind::Forbidden,
.user_can_see_state_events(sender_user, &body.room_id)? "You don't have permission to view the room state.",
{ ));
return Err(Error::BadRequest( }
ErrorKind::Forbidden,
"You don't have permission to view the room state.",
));
}
let event = services() let event =
.rooms services().rooms.state_accessor.room_state_get(&body.room_id, &body.event_type, &body.state_key)?.ok_or_else(
.state_accessor || {
.room_state_get(&body.room_id, &body.event_type, &body.state_key)? warn!("State event {:?} not found in room {:?}", &body.event_type, &body.room_id);
.ok_or_else(|| { Error::BadRequest(ErrorKind::NotFound, "State event not found.")
warn!( },
"State event {:?} not found in room {:?}", )?;
&body.event_type, &body.room_id if body.format.as_ref().is_some_and(|f| f.to_lowercase().eq("event")) {
); Ok(get_state_events_for_key::v3::Response {
Error::BadRequest(ErrorKind::NotFound, "State event not found.") content: None,
})?; event: serde_json::from_str(event.to_state_event().json().get()).map_err(|e| {
if body error!("Invalid room state event in database: {}", e);
.format Error::bad_database("Invalid room state event in database")
.as_ref() })?,
.is_some_and(|f| f.to_lowercase().eq("event")) })
{ } else {
Ok(get_state_events_for_key::v3::Response { Ok(get_state_events_for_key::v3::Response {
content: None, content: Some(serde_json::from_str(event.content.get()).map_err(|e| {
event: serde_json::from_str(event.to_state_event().json().get()).map_err(|e| { error!("Invalid room state event content in database: {}", e);
error!("Invalid room state event in database: {}", e); Error::bad_database("Invalid room state event content in database")
Error::bad_database("Invalid room state event in database") })?),
})?, event: None,
}) })
} else { }
Ok(get_state_events_for_key::v3::Response {
content: Some(serde_json::from_str(event.content.get()).map_err(|e| {
error!("Invalid room state event content in database: {}", e);
Error::bad_database("Invalid room state event content in database")
})?),
event: None,
})
}
} }
/// # `GET /_matrix/client/v3/rooms/{roomid}/state/{eventType}` /// # `GET /_matrix/client/v3/rooms/{roomid}/state/{eventType}`
/// ///
/// Get single state event of a room. /// Get single state event of a room.
/// The optional query parameter `?format=event|content` allows returning the full room state event /// The optional query parameter `?format=event|content` allows returning the
/// or just the state event's content (default behaviour) /// full room state event or just the state event's content (default behaviour)
/// ///
/// - If not joined: Only works if current room history visibility is world readable /// - If not joined: Only works if current room history visibility is world
/// readable
pub async fn get_state_events_for_empty_key_route( pub async fn get_state_events_for_empty_key_route(
body: Ruma<get_state_events_for_key::v3::Request>, body: Ruma<get_state_events_for_key::v3::Request>,
) -> Result<RumaResponse<get_state_events_for_key::v3::Response>> { ) -> Result<RumaResponse<get_state_events_for_key::v3::Response>> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if !services() if !services().rooms.state_accessor.user_can_see_state_events(sender_user, &body.room_id)? {
.rooms return Err(Error::BadRequest(
.state_accessor ErrorKind::Forbidden,
.user_can_see_state_events(sender_user, &body.room_id)? "You don't have permission to view the room state.",
{ ));
return Err(Error::BadRequest( }
ErrorKind::Forbidden,
"You don't have permission to view the room state.",
));
}
let event = services() let event =
.rooms services().rooms.state_accessor.room_state_get(&body.room_id, &body.event_type, "")?.ok_or_else(|| {
.state_accessor warn!("State event {:?} not found in room {:?}", &body.event_type, &body.room_id);
.room_state_get(&body.room_id, &body.event_type, "")? Error::BadRequest(ErrorKind::NotFound, "State event not found.")
.ok_or_else(|| { })?;
warn!(
"State event {:?} not found in room {:?}",
&body.event_type, &body.room_id
);
Error::BadRequest(ErrorKind::NotFound, "State event not found.")
})?;
if body if body.format.as_ref().is_some_and(|f| f.to_lowercase().eq("event")) {
.format Ok(get_state_events_for_key::v3::Response {
.as_ref() content: None,
.is_some_and(|f| f.to_lowercase().eq("event")) event: serde_json::from_str(event.to_state_event().json().get()).map_err(|e| {
{ error!("Invalid room state event in database: {}", e);
Ok(get_state_events_for_key::v3::Response { Error::bad_database("Invalid room state event in database")
content: None, })?,
event: serde_json::from_str(event.to_state_event().json().get()).map_err(|e| { }
error!("Invalid room state event in database: {}", e); .into())
Error::bad_database("Invalid room state event in database") } else {
})?, Ok(get_state_events_for_key::v3::Response {
} content: Some(serde_json::from_str(event.content.get()).map_err(|e| {
.into()) error!("Invalid room state event content in database: {}", e);
} else { Error::bad_database("Invalid room state event content in database")
Ok(get_state_events_for_key::v3::Response { })?),
content: Some(serde_json::from_str(event.content.get()).map_err(|e| { event: None,
error!("Invalid room state event content in database: {}", e); }
Error::bad_database("Invalid room state event content in database") .into())
})?), }
event: None,
}
.into())
}
} }
async fn send_state_event_for_key_helper( async fn send_state_event_for_key_helper(
sender: &UserId, sender: &UserId, room_id: &RoomId, event_type: &StateEventType, json: &Raw<AnyStateEventContent>, state_key: String,
room_id: &RoomId,
event_type: &StateEventType,
json: &Raw<AnyStateEventContent>,
state_key: String,
) -> Result<Arc<EventId>> { ) -> Result<Arc<EventId>> {
let sender_user = sender; let sender_user = sender;
// TODO: Review this check, error if event is unparsable, use event type, allow alias if it // TODO: Review this check, error if event is unparsable, use event type, allow
// previously existed // alias if it previously existed
if let Ok(canonical_alias) = if let Ok(canonical_alias) = serde_json::from_str::<RoomCanonicalAliasEventContent>(json.json().get()) {
serde_json::from_str::<RoomCanonicalAliasEventContent>(json.json().get()) let mut aliases = canonical_alias.alt_aliases.clone();
{
let mut aliases = canonical_alias.alt_aliases.clone();
if let Some(alias) = canonical_alias.alias { if let Some(alias) = canonical_alias.alias {
aliases.push(alias); aliases.push(alias);
} }
for alias in aliases { for alias in aliases {
if alias.server_name() != services().globals.server_name() if alias.server_name() != services().globals.server_name()
|| services() || services()
.rooms .rooms
.alias .alias
.resolve_local_alias(&alias)? .resolve_local_alias(&alias)?
.filter(|room| room == room_id) // Make sure it's the right room .filter(|room| room == room_id) // Make sure it's the right room
.is_none() .is_none()
{ {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::Forbidden, ErrorKind::Forbidden,
"You are only allowed to send canonical_alias \ "You are only allowed to send canonical_alias events when it's aliases already exists",
events when it's aliases already exists", ));
)); }
} }
} }
}
let mutex_state = Arc::clone( let mutex_state =
services() Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(room_id.to_owned()).or_default());
.globals let state_lock = mutex_state.lock().await;
.roomid_mutex_state
.write()
.unwrap()
.entry(room_id.to_owned())
.or_default(),
);
let state_lock = mutex_state.lock().await;
let event_id = services() let event_id = services()
.rooms .rooms
.timeline .timeline
.build_and_append_pdu( .build_and_append_pdu(
PduBuilder { PduBuilder {
event_type: event_type.to_string().into(), event_type: event_type.to_string().into(),
content: serde_json::from_str(json.json().get()).expect("content is valid json"), content: serde_json::from_str(json.json().get()).expect("content is valid json"),
unsigned: None, unsigned: None,
state_key: Some(state_key), state_key: Some(state_key),
redacts: None, redacts: None,
}, },
sender_user, sender_user,
room_id, room_id,
&state_lock, &state_lock,
) )
.await?; .await?;
Ok(event_id) Ok(event_id)
} }

File diff suppressed because it is too large Load diff

View file

@ -1,55 +1,45 @@
use crate::{services, Error, Result, Ruma};
use ruma::{
api::client::tag::{create_tag, delete_tag, get_tags},
events::{
tag::{TagEvent, TagEventContent},
RoomAccountDataEventType,
},
};
use std::collections::BTreeMap; use std::collections::BTreeMap;
use ruma::{
api::client::tag::{create_tag, delete_tag, get_tags},
events::{
tag::{TagEvent, TagEventContent},
RoomAccountDataEventType,
},
};
use crate::{services, Error, Result, Ruma};
/// # `PUT /_matrix/client/r0/user/{userId}/rooms/{roomId}/tags/{tag}` /// # `PUT /_matrix/client/r0/user/{userId}/rooms/{roomId}/tags/{tag}`
/// ///
/// Adds a tag to the room. /// Adds a tag to the room.
/// ///
/// - Inserts the tag into the tag event of the room account data. /// - Inserts the tag into the tag event of the room account data.
pub async fn update_tag_route( pub async fn update_tag_route(body: Ruma<create_tag::v3::Request>) -> Result<create_tag::v3::Response> {
body: Ruma<create_tag::v3::Request>, let sender_user = body.sender_user.as_ref().expect("user is authenticated");
) -> Result<create_tag::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let event = services().account_data.get( let event = services().account_data.get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)?;
Some(&body.room_id),
sender_user,
RoomAccountDataEventType::Tag,
)?;
let mut tags_event = event let mut tags_event = event
.map(|e| { .map(|e| serde_json::from_str(e.get()).map_err(|_| Error::bad_database("Invalid account data event in db.")))
serde_json::from_str(e.get()) .unwrap_or_else(|| {
.map_err(|_| Error::bad_database("Invalid account data event in db.")) Ok(TagEvent {
}) content: TagEventContent {
.unwrap_or_else(|| { tags: BTreeMap::new(),
Ok(TagEvent { },
content: TagEventContent { })
tags: BTreeMap::new(), })?;
},
})
})?;
tags_event tags_event.content.tags.insert(body.tag.clone().into(), body.tag_info.clone());
.content
.tags
.insert(body.tag.clone().into(), body.tag_info.clone());
services().account_data.update( services().account_data.update(
Some(&body.room_id), Some(&body.room_id),
sender_user, sender_user,
RoomAccountDataEventType::Tag, RoomAccountDataEventType::Tag,
&serde_json::to_value(tags_event).expect("to json value always works"), &serde_json::to_value(tags_event).expect("to json value always works"),
)?; )?;
Ok(create_tag::v3::Response {}) Ok(create_tag::v3::Response {})
} }
/// # `DELETE /_matrix/client/r0/user/{userId}/rooms/{roomId}/tags/{tag}` /// # `DELETE /_matrix/client/r0/user/{userId}/rooms/{roomId}/tags/{tag}`
@ -57,40 +47,31 @@ pub async fn update_tag_route(
/// Deletes a tag from the room. /// Deletes a tag from the room.
/// ///
/// - Removes the tag from the tag event of the room account data. /// - Removes the tag from the tag event of the room account data.
pub async fn delete_tag_route( pub async fn delete_tag_route(body: Ruma<delete_tag::v3::Request>) -> Result<delete_tag::v3::Response> {
body: Ruma<delete_tag::v3::Request>, let sender_user = body.sender_user.as_ref().expect("user is authenticated");
) -> Result<delete_tag::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let event = services().account_data.get( let event = services().account_data.get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)?;
Some(&body.room_id),
sender_user,
RoomAccountDataEventType::Tag,
)?;
let mut tags_event = event let mut tags_event = event
.map(|e| { .map(|e| serde_json::from_str(e.get()).map_err(|_| Error::bad_database("Invalid account data event in db.")))
serde_json::from_str(e.get()) .unwrap_or_else(|| {
.map_err(|_| Error::bad_database("Invalid account data event in db.")) Ok(TagEvent {
}) content: TagEventContent {
.unwrap_or_else(|| { tags: BTreeMap::new(),
Ok(TagEvent { },
content: TagEventContent { })
tags: BTreeMap::new(), })?;
},
})
})?;
tags_event.content.tags.remove(&body.tag.clone().into()); tags_event.content.tags.remove(&body.tag.clone().into());
services().account_data.update( services().account_data.update(
Some(&body.room_id), Some(&body.room_id),
sender_user, sender_user,
RoomAccountDataEventType::Tag, RoomAccountDataEventType::Tag,
&serde_json::to_value(tags_event).expect("to json value always works"), &serde_json::to_value(tags_event).expect("to json value always works"),
)?; )?;
Ok(delete_tag::v3::Response {}) Ok(delete_tag::v3::Response {})
} }
/// # `GET /_matrix/client/r0/user/{userId}/rooms/{roomId}/tags` /// # `GET /_matrix/client/r0/user/{userId}/rooms/{roomId}/tags`
@ -99,28 +80,21 @@ pub async fn delete_tag_route(
/// ///
/// - Gets the tag event of the room account data. /// - Gets the tag event of the room account data.
pub async fn get_tags_route(body: Ruma<get_tags::v3::Request>) -> Result<get_tags::v3::Response> { pub async fn get_tags_route(body: Ruma<get_tags::v3::Request>) -> Result<get_tags::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let event = services().account_data.get( let event = services().account_data.get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)?;
Some(&body.room_id),
sender_user,
RoomAccountDataEventType::Tag,
)?;
let tags_event = event let tags_event = event
.map(|e| { .map(|e| serde_json::from_str(e.get()).map_err(|_| Error::bad_database("Invalid account data event in db.")))
serde_json::from_str(e.get()) .unwrap_or_else(|| {
.map_err(|_| Error::bad_database("Invalid account data event in db.")) Ok(TagEvent {
}) content: TagEventContent {
.unwrap_or_else(|| { tags: BTreeMap::new(),
Ok(TagEvent { },
content: TagEventContent { })
tags: BTreeMap::new(), })?;
},
})
})?;
Ok(get_tags::v3::Response { Ok(get_tags::v3::Response {
tags: tags_event.content.tags, tags: tags_event.content.tags,
}) })
} }

View file

@ -1,16 +1,15 @@
use crate::{Result, Ruma}; use std::collections::BTreeMap;
use ruma::api::client::thirdparty::get_protocols; use ruma::api::client::thirdparty::get_protocols;
use std::collections::BTreeMap; use crate::{Result, Ruma};
/// # `GET /_matrix/client/r0/thirdparty/protocols` /// # `GET /_matrix/client/r0/thirdparty/protocols`
/// ///
/// TODO: Fetches all metadata about protocols supported by the homeserver. /// TODO: Fetches all metadata about protocols supported by the homeserver.
pub async fn get_protocols_route( pub async fn get_protocols_route(_body: Ruma<get_protocols::v3::Request>) -> Result<get_protocols::v3::Response> {
_body: Ruma<get_protocols::v3::Request>, // TODO
) -> Result<get_protocols::v3::Response> { Ok(get_protocols::v3::Response {
// TODO protocols: BTreeMap::new(),
Ok(get_protocols::v3::Response { })
protocols: BTreeMap::new(),
})
} }

View file

@ -3,47 +3,37 @@ use ruma::api::client::{error::ErrorKind, threads::get_threads};
use crate::{services, Error, Result, Ruma}; use crate::{services, Error, Result, Ruma};
/// # `GET /_matrix/client/r0/rooms/{roomId}/threads` /// # `GET /_matrix/client/r0/rooms/{roomId}/threads`
pub async fn get_threads_route( pub async fn get_threads_route(body: Ruma<get_threads::v1::Request>) -> Result<get_threads::v1::Response> {
body: Ruma<get_threads::v1::Request>, let sender_user = body.sender_user.as_ref().expect("user is authenticated");
) -> Result<get_threads::v1::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
// Use limit or else 10, with maximum 100 // Use limit or else 10, with maximum 100
let limit = body let limit = body.limit.and_then(|l| l.try_into().ok()).unwrap_or(10).min(100);
.limit
.and_then(|l| l.try_into().ok())
.unwrap_or(10)
.min(100);
let from = if let Some(from) = &body.from { let from = if let Some(from) = &body.from {
from.parse() from.parse().map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, ""))?
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, ""))? } else {
} else { u64::MAX
u64::MAX };
};
let threads = services() let threads = services()
.rooms .rooms
.threads .threads
.threads_until(sender_user, &body.room_id, from, &body.include)? .threads_until(sender_user, &body.room_id, from, &body.include)?
.take(limit) .take(limit)
.filter_map(std::result::Result::ok) .filter_map(std::result::Result::ok)
.filter(|(_, pdu)| { .filter(|(_, pdu)| {
services() services()
.rooms .rooms
.state_accessor .state_accessor
.user_can_see_event(sender_user, &body.room_id, &pdu.event_id) .user_can_see_event(sender_user, &body.room_id, &pdu.event_id)
.unwrap_or(false) .unwrap_or(false)
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let next_batch = threads.last().map(|(count, _)| count.to_string()); let next_batch = threads.last().map(|(count, _)| count.to_string());
Ok(get_threads::v1::Response { Ok(get_threads::v1::Response {
chunk: threads chunk: threads.into_iter().map(|(_, pdu)| pdu.to_room_event()).collect(),
.into_iter() next_batch,
.map(|(_, pdu)| pdu.to_room_event()) })
.collect(),
next_batch,
})
} }

View file

@ -1,92 +1,85 @@
use std::collections::BTreeMap; use std::collections::BTreeMap;
use crate::{services, Error, Result, Ruma};
use ruma::{ use ruma::{
api::{ api::{
client::{error::ErrorKind, to_device::send_event_to_device}, client::{error::ErrorKind, to_device::send_event_to_device},
federation::{self, transactions::edu::DirectDeviceContent}, federation::{self, transactions::edu::DirectDeviceContent},
}, },
to_device::DeviceIdOrAllDevices, to_device::DeviceIdOrAllDevices,
}; };
use crate::{services, Error, Result, Ruma};
/// # `PUT /_matrix/client/r0/sendToDevice/{eventType}/{txnId}` /// # `PUT /_matrix/client/r0/sendToDevice/{eventType}/{txnId}`
/// ///
/// Send a to-device event to a set of client devices. /// Send a to-device event to a set of client devices.
pub async fn send_event_to_device_route( pub async fn send_event_to_device_route(
body: Ruma<send_event_to_device::v3::Request>, body: Ruma<send_event_to_device::v3::Request>,
) -> Result<send_event_to_device::v3::Response> { ) -> Result<send_event_to_device::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_device = body.sender_device.as_deref(); let sender_device = body.sender_device.as_deref();
// Check if this is a new transaction id // Check if this is a new transaction id
if services() if services().transaction_ids.existing_txnid(sender_user, sender_device, &body.txn_id)?.is_some() {
.transaction_ids return Ok(send_event_to_device::v3::Response {});
.existing_txnid(sender_user, sender_device, &body.txn_id)? }
.is_some()
{
return Ok(send_event_to_device::v3::Response {});
}
for (target_user_id, map) in &body.messages { for (target_user_id, map) in &body.messages {
for (target_device_id_maybe, event) in map { for (target_device_id_maybe, event) in map {
if target_user_id.server_name() != services().globals.server_name() { if target_user_id.server_name() != services().globals.server_name() {
let mut map = BTreeMap::new(); let mut map = BTreeMap::new();
map.insert(target_device_id_maybe.clone(), event.clone()); map.insert(target_device_id_maybe.clone(), event.clone());
let mut messages = BTreeMap::new(); let mut messages = BTreeMap::new();
messages.insert(target_user_id.clone(), map); messages.insert(target_user_id.clone(), map);
let count = services().globals.next_count()?; let count = services().globals.next_count()?;
services().sending.send_reliable_edu( services().sending.send_reliable_edu(
target_user_id.server_name(), target_user_id.server_name(),
serde_json::to_vec(&federation::transactions::edu::Edu::DirectToDevice( serde_json::to_vec(&federation::transactions::edu::Edu::DirectToDevice(DirectDeviceContent {
DirectDeviceContent { sender: sender_user.clone(),
sender: sender_user.clone(), ev_type: body.event_type.clone(),
ev_type: body.event_type.clone(), message_id: count.to_string().into(),
message_id: count.to_string().into(), messages,
messages, }))
}, .expect("DirectToDevice EDU can be serialized"),
)) count,
.expect("DirectToDevice EDU can be serialized"), )?;
count,
)?;
continue; continue;
} }
match target_device_id_maybe { match target_device_id_maybe {
DeviceIdOrAllDevices::DeviceId(target_device_id) => { DeviceIdOrAllDevices::DeviceId(target_device_id) => {
services().users.add_to_device_event( services().users.add_to_device_event(
sender_user, sender_user,
target_user_id, target_user_id,
target_device_id, target_device_id,
&body.event_type.to_string(), &body.event_type.to_string(),
event.deserialize_as().map_err(|_| { event
Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid") .deserialize_as()
})?, .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid"))?,
)?; )?;
} },
DeviceIdOrAllDevices::AllDevices => { DeviceIdOrAllDevices::AllDevices => {
for target_device_id in services().users.all_device_ids(target_user_id) { for target_device_id in services().users.all_device_ids(target_user_id) {
services().users.add_to_device_event( services().users.add_to_device_event(
sender_user, sender_user,
target_user_id, target_user_id,
&target_device_id?, &target_device_id?,
&body.event_type.to_string(), &body.event_type.to_string(),
event.deserialize_as().map_err(|_| { event
Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid") .deserialize_as()
})?, .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid"))?,
)?; )?;
} }
} },
} }
} }
} }
// Save transaction id with empty data // Save transaction id with empty data
services() services().transaction_ids.add_txnid(sender_user, sender_device, &body.txn_id, &[])?;
.transaction_ids
.add_txnid(sender_user, sender_device, &body.txn_id, &[])?;
Ok(send_event_to_device::v3::Response {}) Ok(send_event_to_device::v3::Response {})
} }

View file

@ -1,40 +1,30 @@
use crate::{services, utils, Error, Result, Ruma};
use ruma::api::client::{error::ErrorKind, typing::create_typing_event}; use ruma::api::client::{error::ErrorKind, typing::create_typing_event};
use crate::{services, utils, Error, Result, Ruma};
/// # `PUT /_matrix/client/r0/rooms/{roomId}/typing/{userId}` /// # `PUT /_matrix/client/r0/rooms/{roomId}/typing/{userId}`
/// ///
/// Sets the typing state of the sender user. /// Sets the typing state of the sender user.
pub async fn create_typing_event_route( pub async fn create_typing_event_route(
body: Ruma<create_typing_event::v3::Request>, body: Ruma<create_typing_event::v3::Request>,
) -> Result<create_typing_event::v3::Response> { ) -> Result<create_typing_event::v3::Response> {
use create_typing_event::v3::Typing; use create_typing_event::v3::Typing;
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if !services() if !services().rooms.state_cache.is_joined(sender_user, &body.room_id)? {
.rooms return Err(Error::BadRequest(ErrorKind::Forbidden, "You are not in this room."));
.state_cache }
.is_joined(sender_user, &body.room_id)?
{
return Err(Error::BadRequest(
ErrorKind::Forbidden,
"You are not in this room.",
));
}
if let Typing::Yes(duration) = body.state { if let Typing::Yes(duration) = body.state {
services().rooms.edus.typing.typing_add( services().rooms.edus.typing.typing_add(
sender_user, sender_user,
&body.room_id, &body.room_id,
duration.as_millis() as u64 + utils::millis_since_unix_epoch(), duration.as_millis() as u64 + utils::millis_since_unix_epoch(),
)?; )?;
} else { } else {
services() services().rooms.edus.typing.typing_remove(sender_user, &body.room_id)?;
.rooms }
.edus
.typing
.typing_remove(sender_user, &body.room_id)?;
}
Ok(create_typing_event::v3::Response {}) Ok(create_typing_event::v3::Response {})
} }

View file

@ -7,72 +7,74 @@ use crate::{services, Error, Result, Ruma};
/// # `GET /_matrix/client/versions` /// # `GET /_matrix/client/versions`
/// ///
/// Get the versions of the specification and unstable features supported by this server. /// Get the versions of the specification and unstable features supported by
/// this server.
/// ///
/// - Versions take the form MAJOR.MINOR.PATCH /// - Versions take the form MAJOR.MINOR.PATCH
/// - Only the latest PATCH release will be reported for each MAJOR.MINOR value /// - Only the latest PATCH release will be reported for each MAJOR.MINOR value
/// - Unstable features are namespaced and may include version information in their name /// - Unstable features are namespaced and may include version information in
/// their name
/// ///
/// Note: Unstable features are used while developing new features. Clients should avoid using /// Note: Unstable features are used while developing new features. Clients
/// unstable features in their stable releases /// should avoid using unstable features in their stable releases
pub async fn get_supported_versions_route( pub async fn get_supported_versions_route(
_body: Ruma<get_supported_versions::Request>, _body: Ruma<get_supported_versions::Request>,
) -> Result<get_supported_versions::Response> { ) -> Result<get_supported_versions::Response> {
let resp = get_supported_versions::Response { let resp = get_supported_versions::Response {
versions: vec![ versions: vec![
"r0.0.1".to_owned(), "r0.0.1".to_owned(),
"r0.1.0".to_owned(), "r0.1.0".to_owned(),
"r0.2.0".to_owned(), "r0.2.0".to_owned(),
"r0.3.0".to_owned(), "r0.3.0".to_owned(),
"r0.4.0".to_owned(), "r0.4.0".to_owned(),
"r0.5.0".to_owned(), "r0.5.0".to_owned(),
"r0.6.0".to_owned(), "r0.6.0".to_owned(),
"r0.6.1".to_owned(), "r0.6.1".to_owned(),
"v1.1".to_owned(), "v1.1".to_owned(),
"v1.2".to_owned(), "v1.2".to_owned(),
"v1.3".to_owned(), "v1.3".to_owned(),
"v1.4".to_owned(), "v1.4".to_owned(),
"v1.5".to_owned(), "v1.5".to_owned(),
], ],
unstable_features: BTreeMap::from_iter([ unstable_features: BTreeMap::from_iter([
("org.matrix.e2e_cross_signing".to_owned(), true), ("org.matrix.e2e_cross_signing".to_owned(), true),
("org.matrix.msc2836".to_owned(), true), ("org.matrix.msc2836".to_owned(), true),
("org.matrix.msc3827".to_owned(), true), ("org.matrix.msc3827".to_owned(), true),
("org.matrix.msc2946".to_owned(), true), ("org.matrix.msc2946".to_owned(), true),
]), ]),
}; };
Ok(resp) Ok(resp)
} }
/// # `GET /.well-known/matrix/client` /// # `GET /.well-known/matrix/client`
pub async fn well_known_client_route() -> Result<impl IntoResponse> { pub async fn well_known_client_route() -> Result<impl IntoResponse> {
let client_url = match services().globals.well_known_client() { let client_url = match services().globals.well_known_client() {
Some(url) => url.clone(), Some(url) => url.clone(),
None => return Err(Error::BadRequest(ErrorKind::NotFound, "Not found.")), None => return Err(Error::BadRequest(ErrorKind::NotFound, "Not found.")),
}; };
Ok(Json(serde_json::json!({ Ok(Json(serde_json::json!({
"m.homeserver": {"base_url": client_url}, "m.homeserver": {"base_url": client_url},
"org.matrix.msc3575.proxy": {"url": client_url} "org.matrix.msc3575.proxy": {"url": client_url}
}))) })))
} }
/// # `GET /client/server.json` /// # `GET /client/server.json`
/// ///
/// Endpoint provided by sliding sync proxy used by some clients such as Element Web /// Endpoint provided by sliding sync proxy used by some clients such as Element
/// as a non-standard health check. /// Web as a non-standard health check.
pub async fn syncv3_client_server_json() -> Result<impl IntoResponse> { pub async fn syncv3_client_server_json() -> Result<impl IntoResponse> {
let server_url = match services().globals.well_known_client() { let server_url = match services().globals.well_known_client() {
Some(url) => url.clone(), Some(url) => url.clone(),
None => match services().globals.well_known_server() { None => match services().globals.well_known_server() {
Some(url) => url.clone(), Some(url) => url.clone(),
None => return Err(Error::BadRequest(ErrorKind::NotFound, "Not found.")), None => return Err(Error::BadRequest(ErrorKind::NotFound, "Not found.")),
}, },
}; };
Ok(Json(serde_json::json!({ Ok(Json(serde_json::json!({
"server": server_url, "server": server_url,
"version": format!("{} {}", env!("CARGO_PKG_NAME"), env!("CARGO_PKG_VERSION")) "version": format!("{} {}", env!("CARGO_PKG_NAME"), env!("CARGO_PKG_VERSION"))
}))) })))
} }

View file

@ -1,94 +1,78 @@
use crate::{services, Result, Ruma};
use ruma::{ use ruma::{
api::client::user_directory::search_users, api::client::user_directory::search_users,
events::{ events::{
room::join_rules::{JoinRule, RoomJoinRulesEventContent}, room::join_rules::{JoinRule, RoomJoinRulesEventContent},
StateEventType, StateEventType,
}, },
}; };
use crate::{services, Result, Ruma};
/// # `POST /_matrix/client/r0/user_directory/search` /// # `POST /_matrix/client/r0/user_directory/search`
/// ///
/// Searches all known users for a match. /// Searches all known users for a match.
/// ///
/// - Hides any local users that aren't in any public rooms (i.e. those that have the join rule set to public) /// - Hides any local users that aren't in any public rooms (i.e. those that
/// have the join rule set to public)
/// and don't share a room with the sender /// and don't share a room with the sender
pub async fn search_users_route( pub async fn search_users_route(body: Ruma<search_users::v3::Request>) -> Result<search_users::v3::Response> {
body: Ruma<search_users::v3::Request>, let sender_user = body.sender_user.as_ref().expect("user is authenticated");
) -> Result<search_users::v3::Response> { let limit = u64::from(body.limit) as usize;
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let limit = u64::from(body.limit) as usize;
let mut users = services().users.iter().filter_map(|user_id| { let mut users = services().users.iter().filter_map(|user_id| {
// Filter out buggy users (they should not exist, but you never know...) // Filter out buggy users (they should not exist, but you never know...)
let user_id = user_id.ok()?; let user_id = user_id.ok()?;
let user = search_users::v3::User { let user = search_users::v3::User {
user_id: user_id.clone(), user_id: user_id.clone(),
display_name: services().users.displayname(&user_id).ok()?, display_name: services().users.displayname(&user_id).ok()?,
avatar_url: services().users.avatar_url(&user_id).ok()?, avatar_url: services().users.avatar_url(&user_id).ok()?,
}; };
let user_id_matches = user let user_id_matches = user.user_id.to_string().to_lowercase().contains(&body.search_term.to_lowercase());
.user_id
.to_string()
.to_lowercase()
.contains(&body.search_term.to_lowercase());
let user_displayname_matches = user let user_displayname_matches = user
.display_name .display_name
.as_ref() .as_ref()
.filter(|name| { .filter(|name| name.to_lowercase().contains(&body.search_term.to_lowercase()))
name.to_lowercase() .is_some();
.contains(&body.search_term.to_lowercase())
})
.is_some();
if !user_id_matches && !user_displayname_matches { if !user_id_matches && !user_displayname_matches {
return None; return None;
} }
let user_is_in_public_rooms = services() let user_is_in_public_rooms =
.rooms services().rooms.state_cache.rooms_joined(&user_id).filter_map(std::result::Result::ok).any(|room| {
.state_cache services().rooms.state_accessor.room_state_get(&room, &StateEventType::RoomJoinRules, "").map_or(
.rooms_joined(&user_id) false,
.filter_map(std::result::Result::ok) |event| {
.any(|room| { event.map_or(false, |event| {
services() serde_json::from_str(event.content.get())
.rooms .map_or(false, |r: RoomJoinRulesEventContent| r.join_rule == JoinRule::Public)
.state_accessor })
.room_state_get(&room, &StateEventType::RoomJoinRules, "") },
.map_or(false, |event| { )
event.map_or(false, |event| { });
serde_json::from_str(event.content.get())
.map_or(false, |r: RoomJoinRulesEventContent| {
r.join_rule == JoinRule::Public
})
})
})
});
if user_is_in_public_rooms { if user_is_in_public_rooms {
return Some(user); return Some(user);
} }
let user_is_in_shared_rooms = services() let user_is_in_shared_rooms =
.rooms services().rooms.user.get_shared_rooms(vec![sender_user.clone(), user_id]).ok()?.next().is_some();
.user
.get_shared_rooms(vec![sender_user.clone(), user_id])
.ok()?
.next()
.is_some();
if user_is_in_shared_rooms { if user_is_in_shared_rooms {
return Some(user); return Some(user);
} }
None None
}); });
let results = users.by_ref().take(limit).collect(); let results = users.by_ref().take(limit).collect();
let limited = users.next().is_some(); let limited = users.next().is_some();
Ok(search_users::v3::Response { results, limited }) Ok(search_users::v3::Response {
results,
limited,
})
} }

View file

@ -1,9 +1,11 @@
use crate::{services, Result, Ruma}; use std::time::{Duration, SystemTime};
use base64::{engine::general_purpose, Engine as _}; use base64::{engine::general_purpose, Engine as _};
use hmac::{Hmac, Mac}; use hmac::{Hmac, Mac};
use ruma::{api::client::voip::get_turn_server_info, SecondsSinceUnixEpoch}; use ruma::{api::client::voip::get_turn_server_info, SecondsSinceUnixEpoch};
use sha1::Sha1; use sha1::Sha1;
use std::time::{Duration, SystemTime};
use crate::{services, Result, Ruma};
type HmacSha1 = Hmac<Sha1>; type HmacSha1 = Hmac<Sha1>;
@ -11,38 +13,37 @@ type HmacSha1 = Hmac<Sha1>;
/// ///
/// TODO: Returns information about the recommended turn server. /// TODO: Returns information about the recommended turn server.
pub async fn turn_server_route( pub async fn turn_server_route(
body: Ruma<get_turn_server_info::v3::Request>, body: Ruma<get_turn_server_info::v3::Request>,
) -> Result<get_turn_server_info::v3::Response> { ) -> Result<get_turn_server_info::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let turn_secret = services().globals.turn_secret().clone(); let turn_secret = services().globals.turn_secret().clone();
let (username, password) = if !turn_secret.is_empty() { let (username, password) = if !turn_secret.is_empty() {
let expiry = SecondsSinceUnixEpoch::from_system_time( let expiry = SecondsSinceUnixEpoch::from_system_time(
SystemTime::now() + Duration::from_secs(services().globals.turn_ttl()), SystemTime::now() + Duration::from_secs(services().globals.turn_ttl()),
) )
.expect("time is valid"); .expect("time is valid");
let username: String = format!("{}:{}", expiry.get(), sender_user); let username: String = format!("{}:{}", expiry.get(), sender_user);
let mut mac = HmacSha1::new_from_slice(turn_secret.as_bytes()) let mut mac = HmacSha1::new_from_slice(turn_secret.as_bytes()).expect("HMAC can take key of any size");
.expect("HMAC can take key of any size"); mac.update(username.as_bytes());
mac.update(username.as_bytes());
let password: String = general_purpose::STANDARD.encode(mac.finalize().into_bytes()); let password: String = general_purpose::STANDARD.encode(mac.finalize().into_bytes());
(username, password) (username, password)
} else { } else {
( (
services().globals.turn_username().clone(), services().globals.turn_username().clone(),
services().globals.turn_password().clone(), services().globals.turn_password().clone(),
) )
}; };
Ok(get_turn_server_info::v3::Response { Ok(get_turn_server_info::v3::Response {
username, username,
password, password,
uris: services().globals.turn_uris().to_vec(), uris: services().globals.turn_uris().to_vec(),
ttl: Duration::from_secs(services().globals.turn_ttl()), ttl: Duration::from_secs(services().globals.turn_ttl()),
}) })
} }

View file

@ -1,21 +1,21 @@
use std::{collections::BTreeMap, str}; use std::{collections::BTreeMap, str};
use axum::{ use axum::{
async_trait, async_trait,
body::{Full, HttpBody}, body::{Full, HttpBody},
extract::{rejection::TypedHeaderRejectionReason, FromRequest, Path, TypedHeader}, extract::{rejection::TypedHeaderRejectionReason, FromRequest, Path, TypedHeader},
headers::{ headers::{
authorization::{Bearer, Credentials}, authorization::{Bearer, Credentials},
Authorization, Authorization,
}, },
response::{IntoResponse, Response}, response::{IntoResponse, Response},
BoxError, RequestExt, RequestPartsExt, BoxError, RequestExt, RequestPartsExt,
}; };
use bytes::{Buf, BufMut, Bytes, BytesMut}; use bytes::{Buf, BufMut, Bytes, BytesMut};
use http::{Request, StatusCode}; use http::{Request, StatusCode};
use ruma::{ use ruma::{
api::{client::error::ErrorKind, AuthScheme, IncomingRequest, OutgoingResponse}, api::{client::error::ErrorKind, AuthScheme, IncomingRequest, OutgoingResponse},
CanonicalJsonValue, OwnedDeviceId, OwnedServerName, UserId, CanonicalJsonValue, OwnedDeviceId, OwnedServerName, UserId,
}; };
use serde::Deserialize; use serde::Deserialize;
use tracing::{debug, error, warn}; use tracing::{debug, error, warn};
@ -25,400 +25,333 @@ use crate::{services, Error, Result};
#[derive(Deserialize)] #[derive(Deserialize)]
struct QueryParams { struct QueryParams {
access_token: Option<String>, access_token: Option<String>,
user_id: Option<String>, user_id: Option<String>,
} }
#[async_trait] #[async_trait]
impl<T, S, B> FromRequest<S, B> for Ruma<T> impl<T, S, B> FromRequest<S, B> for Ruma<T>
where where
T: IncomingRequest, T: IncomingRequest,
B: HttpBody + Send + 'static, B: HttpBody + Send + 'static,
B::Data: Send, B::Data: Send,
B::Error: Into<BoxError>, B::Error: Into<BoxError>,
{ {
type Rejection = Error; type Rejection = Error;
async fn from_request(req: Request<B>, _state: &S) -> Result<Self, Self::Rejection> { async fn from_request(req: Request<B>, _state: &S) -> Result<Self, Self::Rejection> {
let (mut parts, mut body) = match req.with_limited_body() { let (mut parts, mut body) = match req.with_limited_body() {
Ok(limited_req) => { Ok(limited_req) => {
let (parts, body) = limited_req.into_parts(); let (parts, body) = limited_req.into_parts();
let body = to_bytes(body) let body =
.await to_bytes(body).await.map_err(|_| Error::BadRequest(ErrorKind::MissingToken, "Missing token."))?;
.map_err(|_| Error::BadRequest(ErrorKind::MissingToken, "Missing token."))?; (parts, body)
(parts, body) },
} Err(original_req) => {
Err(original_req) => { let (parts, body) = original_req.into_parts();
let (parts, body) = original_req.into_parts(); let body =
let body = to_bytes(body) to_bytes(body).await.map_err(|_| Error::BadRequest(ErrorKind::MissingToken, "Missing token."))?;
.await (parts, body)
.map_err(|_| Error::BadRequest(ErrorKind::MissingToken, "Missing token."))?; },
(parts, body) };
}
};
let metadata = T::METADATA; let metadata = T::METADATA;
let auth_header: Option<TypedHeader<Authorization<Bearer>>> = parts.extract().await?; let auth_header: Option<TypedHeader<Authorization<Bearer>>> = parts.extract().await?;
let path_params: Path<Vec<String>> = parts.extract().await?; let path_params: Path<Vec<String>> = parts.extract().await?;
let query = parts.uri.query().unwrap_or_default(); let query = parts.uri.query().unwrap_or_default();
let query_params: QueryParams = match serde_html_form::from_str(query) { let query_params: QueryParams = match serde_html_form::from_str(query) {
Ok(params) => params, Ok(params) => params,
Err(e) => { Err(e) => {
error!(%query, "Failed to deserialize query parameters: {}", e); error!(%query, "Failed to deserialize query parameters: {}", e);
return Err(Error::BadRequest( return Err(Error::BadRequest(ErrorKind::Unknown, "Failed to read query parameters"));
ErrorKind::Unknown, },
"Failed to read query parameters", };
));
}
};
let token = match &auth_header { let token = match &auth_header {
Some(TypedHeader(Authorization(bearer))) => Some(bearer.token()), Some(TypedHeader(Authorization(bearer))) => Some(bearer.token()),
None => query_params.access_token.as_deref(), None => query_params.access_token.as_deref(),
}; };
let mut json_body = serde_json::from_slice::<CanonicalJsonValue>(&body).ok(); let mut json_body = serde_json::from_slice::<CanonicalJsonValue>(&body).ok();
let appservices = services().appservice.all().unwrap(); let appservices = services().appservice.all().unwrap();
let appservice_registration = appservices let appservice_registration =
.iter() appservices.iter().find(|(_id, registration)| Some(registration.as_token.as_str()) == token);
.find(|(_id, registration)| Some(registration.as_token.as_str()) == token);
let (sender_user, sender_device, sender_servername, from_appservice) = let (sender_user, sender_device, sender_servername, from_appservice) = if let Some((_id, registration)) =
if let Some((_id, registration)) = appservice_registration { appservice_registration
match metadata.authentication { {
AuthScheme::AccessToken => { match metadata.authentication {
let user_id = query_params.user_id.map_or_else( AuthScheme::AccessToken => {
|| { let user_id = query_params.user_id.map_or_else(
UserId::parse_with_server_name( || {
registration.sender_localpart.as_str(), UserId::parse_with_server_name(
services().globals.server_name(), registration.sender_localpart.as_str(),
) services().globals.server_name(),
.unwrap() )
}, .unwrap()
|s| UserId::parse(s).unwrap(), },
); |s| UserId::parse(s).unwrap(),
);
if !services().users.exists(&user_id).unwrap() { if !services().users.exists(&user_id).unwrap() {
return Err(Error::BadRequest( return Err(Error::BadRequest(ErrorKind::Forbidden, "User does not exist."));
ErrorKind::Forbidden, }
"User does not exist.",
));
}
// TODO: Check if appservice is allowed to be that user // TODO: Check if appservice is allowed to be that user
(Some(user_id), None, None, true) (Some(user_id), None, None, true)
} },
AuthScheme::ServerSignatures => (None, None, None, true), AuthScheme::ServerSignatures => (None, None, None, true),
AuthScheme::None => (None, None, None, true), AuthScheme::None => (None, None, None, true),
} }
} else { } else {
match metadata.authentication { match metadata.authentication {
AuthScheme::AccessToken => { AuthScheme::AccessToken => {
let token = match token { let token = match token {
Some(token) => token, Some(token) => token,
_ => { _ => return Err(Error::BadRequest(ErrorKind::MissingToken, "Missing access token.")),
return Err(Error::BadRequest( };
ErrorKind::MissingToken,
"Missing access token.",
))
}
};
match services().users.find_from_token(token).unwrap() { match services().users.find_from_token(token).unwrap() {
None => { None => {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::UnknownToken { soft_logout: false }, ErrorKind::UnknownToken {
"Unknown access token.", soft_logout: false,
)) },
} "Unknown access token.",
Some((user_id, device_id)) => ( ))
Some(user_id), },
Some(OwnedDeviceId::from(device_id)), Some((user_id, device_id)) => {
None, (Some(user_id), Some(OwnedDeviceId::from(device_id)), None, false)
false, },
), }
} },
} AuthScheme::ServerSignatures => {
AuthScheme::ServerSignatures => { let TypedHeader(Authorization(x_matrix)) =
let TypedHeader(Authorization(x_matrix)) = parts parts.extract::<TypedHeader<Authorization<XMatrix>>>().await.map_err(|e| {
.extract::<TypedHeader<Authorization<XMatrix>>>() warn!("Missing or invalid Authorization header: {}", e);
.await
.map_err(|e| {
warn!("Missing or invalid Authorization header: {}", e);
let msg = match e.reason() { let msg = match e.reason() {
TypedHeaderRejectionReason::Missing => { TypedHeaderRejectionReason::Missing => "Missing Authorization header.",
"Missing Authorization header." TypedHeaderRejectionReason::Error(_) => "Invalid X-Matrix signatures.",
} _ => "Unknown header-related error",
TypedHeaderRejectionReason::Error(_) => { };
"Invalid X-Matrix signatures."
}
_ => "Unknown header-related error",
};
Error::BadRequest(ErrorKind::Forbidden, msg) Error::BadRequest(ErrorKind::Forbidden, msg)
})?; })?;
let origin_signatures = BTreeMap::from_iter([( let origin_signatures =
x_matrix.key.clone(), BTreeMap::from_iter([(x_matrix.key.clone(), CanonicalJsonValue::String(x_matrix.sig))]);
CanonicalJsonValue::String(x_matrix.sig),
)]);
let signatures = BTreeMap::from_iter([( let signatures = BTreeMap::from_iter([(
x_matrix.origin.as_str().to_owned(), x_matrix.origin.as_str().to_owned(),
CanonicalJsonValue::Object(origin_signatures), CanonicalJsonValue::Object(origin_signatures),
)]); )]);
let server_destination = let server_destination = services().globals.server_name().as_str().to_owned();
services().globals.server_name().as_str().to_owned();
if let Some(destination) = x_matrix.destination.as_ref() { if let Some(destination) = x_matrix.destination.as_ref() {
if destination != &server_destination { if destination != &server_destination {
return Err(Error::BadRequest( return Err(Error::BadRequest(ErrorKind::Forbidden, "Invalid authorization."));
ErrorKind::Forbidden, }
"Invalid authorization.", }
));
}
}
let mut request_map = BTreeMap::from_iter([ let mut request_map = BTreeMap::from_iter([
( ("method".to_owned(), CanonicalJsonValue::String(parts.method.to_string())),
"method".to_owned(), ("uri".to_owned(), CanonicalJsonValue::String(parts.uri.to_string())),
CanonicalJsonValue::String(parts.method.to_string()), (
), "origin".to_owned(),
( CanonicalJsonValue::String(x_matrix.origin.as_str().to_owned()),
"uri".to_owned(), ),
CanonicalJsonValue::String(parts.uri.to_string()), ("destination".to_owned(), CanonicalJsonValue::String(server_destination)),
), ("signatures".to_owned(), CanonicalJsonValue::Object(signatures)),
( ]);
"origin".to_owned(),
CanonicalJsonValue::String(x_matrix.origin.as_str().to_owned()),
),
(
"destination".to_owned(),
CanonicalJsonValue::String(server_destination),
),
(
"signatures".to_owned(),
CanonicalJsonValue::Object(signatures),
),
]);
if let Some(json_body) = &json_body { if let Some(json_body) = &json_body {
request_map.insert("content".to_owned(), json_body.clone()); request_map.insert("content".to_owned(), json_body.clone());
}; };
let keys_result = services() let keys_result = services()
.rooms .rooms
.event_handler .event_handler
.fetch_signing_keys_for_server( .fetch_signing_keys_for_server(&x_matrix.origin, vec![x_matrix.key.clone()])
&x_matrix.origin, .await;
vec![x_matrix.key.clone()],
)
.await;
let keys = match keys_result { let keys = match keys_result {
Ok(b) => b, Ok(b) => b,
Err(e) => { Err(e) => {
warn!("Failed to fetch signing keys: {}", e); warn!("Failed to fetch signing keys: {}", e);
return Err(Error::BadRequest( return Err(Error::BadRequest(ErrorKind::Forbidden, "Failed to fetch signing keys."));
ErrorKind::Forbidden, },
"Failed to fetch signing keys.", };
));
}
};
let pub_key_map = let pub_key_map = BTreeMap::from_iter([(x_matrix.origin.as_str().to_owned(), keys)]);
BTreeMap::from_iter([(x_matrix.origin.as_str().to_owned(), keys)]);
match ruma::signatures::verify_json(&pub_key_map, &request_map) { match ruma::signatures::verify_json(&pub_key_map, &request_map) {
Ok(()) => (None, None, Some(x_matrix.origin), false), Ok(()) => (None, None, Some(x_matrix.origin), false),
Err(e) => { Err(e) => {
warn!( warn!(
"Failed to verify json request from {}: {}\n{:?}", "Failed to verify json request from {}: {}\n{:?}",
x_matrix.origin, e, request_map x_matrix.origin, e, request_map
); );
if parts.uri.to_string().contains('@') { if parts.uri.to_string().contains('@') {
warn!( warn!(
"Request uri contained '@' character. Make sure your \ "Request uri contained '@' character. Make sure your reverse proxy gives Conduit \
reverse proxy gives Conduit the raw uri (apache: use \ the raw uri (apache: use nocanon)"
nocanon)" );
); }
}
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::Forbidden, ErrorKind::Forbidden,
"Failed to verify X-Matrix signatures.", "Failed to verify X-Matrix signatures.",
)); ));
} },
} }
} },
AuthScheme::None => match parts.uri.path() { AuthScheme::None => match parts.uri.path() {
// allow_public_room_directory_without_auth // allow_public_room_directory_without_auth
"/_matrix/client/v3/publicRooms" | "/_matrix/client/r0/publicRooms" => { "/_matrix/client/v3/publicRooms" | "/_matrix/client/r0/publicRooms" => {
if !services() if !services().globals.config.allow_public_room_directory_without_auth {
.globals let token = match token {
.config Some(token) => token,
.allow_public_room_directory_without_auth _ => return Err(Error::BadRequest(ErrorKind::MissingToken, "Missing access token.")),
{ };
let token = match token {
Some(token) => token,
_ => {
return Err(Error::BadRequest(
ErrorKind::MissingToken,
"Missing access token.",
))
}
};
match services().users.find_from_token(token).unwrap() { match services().users.find_from_token(token).unwrap() {
None => { None => {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::UnknownToken { soft_logout: false }, ErrorKind::UnknownToken {
"Unknown access token.", soft_logout: false,
)) },
} "Unknown access token.",
Some((user_id, device_id)) => ( ))
Some(user_id), },
Some(OwnedDeviceId::from(device_id)), Some((user_id, device_id)) => {
None, (Some(user_id), Some(OwnedDeviceId::from(device_id)), None, false)
false, },
), }
} } else {
} else { (None, None, None, false)
(None, None, None, false) }
} },
} _ => (None, None, None, false),
_ => (None, None, None, false), },
}, }
} };
};
let mut http_request = http::Request::builder().uri(parts.uri).method(parts.method); let mut http_request = http::Request::builder().uri(parts.uri).method(parts.method);
*http_request.headers_mut().unwrap() = parts.headers; *http_request.headers_mut().unwrap() = parts.headers;
if let Some(CanonicalJsonValue::Object(json_body)) = &mut json_body { if let Some(CanonicalJsonValue::Object(json_body)) = &mut json_body {
let user_id = sender_user.clone().unwrap_or_else(|| { let user_id = sender_user.clone().unwrap_or_else(|| {
UserId::parse_with_server_name("", services().globals.server_name()) UserId::parse_with_server_name("", services().globals.server_name()).expect("we know this is valid")
.expect("we know this is valid") });
});
let uiaa_request = json_body let uiaa_request = json_body
.get("auth") .get("auth")
.and_then(|auth| auth.as_object()) .and_then(|auth| auth.as_object())
.and_then(|auth| auth.get("session")) .and_then(|auth| auth.get("session"))
.and_then(|session| session.as_str()) .and_then(|session| session.as_str())
.and_then(|session| { .and_then(|session| {
services().uiaa.get_uiaa_request( services().uiaa.get_uiaa_request(
&user_id, &user_id,
&sender_device.clone().unwrap_or_else(|| "".into()), &sender_device.clone().unwrap_or_else(|| "".into()),
session, session,
) )
}); });
if let Some(CanonicalJsonValue::Object(initial_request)) = uiaa_request { if let Some(CanonicalJsonValue::Object(initial_request)) = uiaa_request {
for (key, value) in initial_request { for (key, value) in initial_request {
json_body.entry(key).or_insert(value); json_body.entry(key).or_insert(value);
} }
} }
let mut buf = BytesMut::new().writer(); let mut buf = BytesMut::new().writer();
serde_json::to_writer(&mut buf, json_body).expect("value serialization can't fail"); serde_json::to_writer(&mut buf, json_body).expect("value serialization can't fail");
body = buf.into_inner().freeze(); body = buf.into_inner().freeze();
} }
let http_request = http_request.body(&*body).unwrap(); let http_request = http_request.body(&*body).unwrap();
debug!("{:?}", http_request); debug!("{:?}", http_request);
let body = T::try_from_http_request(http_request, &path_params).map_err(|e| { let body = T::try_from_http_request(http_request, &path_params).map_err(|e| {
warn!("try_from_http_request failed: {:?}", e); warn!("try_from_http_request failed: {:?}", e);
debug!("JSON body: {:?}", json_body); debug!("JSON body: {:?}", json_body);
Error::BadRequest(ErrorKind::BadJson, "Failed to deserialize request.") Error::BadRequest(ErrorKind::BadJson, "Failed to deserialize request.")
})?; })?;
Ok(Ruma { Ok(Ruma {
body, body,
sender_user, sender_user,
sender_device, sender_device,
sender_servername, sender_servername,
from_appservice, from_appservice,
json_body, json_body,
}) })
} }
} }
struct XMatrix { struct XMatrix {
origin: OwnedServerName, origin: OwnedServerName,
destination: Option<String>, destination: Option<String>,
key: String, // KeyName? key: String, // KeyName?
sig: String, sig: String,
} }
impl Credentials for XMatrix { impl Credentials for XMatrix {
const SCHEME: &'static str = "X-Matrix"; const SCHEME: &'static str = "X-Matrix";
fn decode(value: &http::HeaderValue) -> Option<Self> { fn decode(value: &http::HeaderValue) -> Option<Self> {
debug_assert!( debug_assert!(
value.as_bytes().starts_with(b"X-Matrix "), value.as_bytes().starts_with(b"X-Matrix "),
"HeaderValue to decode should start with \"X-Matrix ..\", received = {value:?}", "HeaderValue to decode should start with \"X-Matrix ..\", received = {value:?}",
); );
let parameters = str::from_utf8(&value.as_bytes()["X-Matrix ".len()..]) let parameters = str::from_utf8(&value.as_bytes()["X-Matrix ".len()..]).ok()?.trim_start();
.ok()?
.trim_start();
let mut origin = None; let mut origin = None;
let mut destination = None; let mut destination = None;
let mut key = None; let mut key = None;
let mut sig = None; let mut sig = None;
for entry in parameters.split_terminator(',') { for entry in parameters.split_terminator(',') {
let (name, value) = entry.split_once('=')?; let (name, value) = entry.split_once('=')?;
// It's not at all clear why some fields are quoted and others not in the spec, // It's not at all clear why some fields are quoted and others not in the spec,
// let's simply accept either form for every field. // let's simply accept either form for every field.
let value = value let value = value.strip_prefix('"').and_then(|rest| rest.strip_suffix('"')).unwrap_or(value);
.strip_prefix('"')
.and_then(|rest| rest.strip_suffix('"'))
.unwrap_or(value);
// FIXME: Catch multiple fields of the same name // FIXME: Catch multiple fields of the same name
match name { match name {
"origin" => origin = Some(value.try_into().ok()?), "origin" => origin = Some(value.try_into().ok()?),
"key" => key = Some(value.to_owned()), "key" => key = Some(value.to_owned()),
"sig" => sig = Some(value.to_owned()), "sig" => sig = Some(value.to_owned()),
"destination" => destination = Some(value.to_owned()), "destination" => destination = Some(value.to_owned()),
_ => debug!( _ => debug!("Unexpected field `{}` in X-Matrix Authorization header", name),
"Unexpected field `{}` in X-Matrix Authorization header", }
name }
),
}
}
Some(Self { Some(Self {
origin: origin?, origin: origin?,
key: key?, key: key?,
sig: sig?, sig: sig?,
destination, destination,
}) })
} }
fn encode(&self) -> http::HeaderValue { fn encode(&self) -> http::HeaderValue { todo!() }
todo!()
}
} }
impl<T: OutgoingResponse> IntoResponse for RumaResponse<T> { impl<T: OutgoingResponse> IntoResponse for RumaResponse<T> {
fn into_response(self) -> Response { fn into_response(self) -> Response {
match self.0.try_into_http_response::<BytesMut>() { match self.0.try_into_http_response::<BytesMut>() {
Ok(res) => res.map(BytesMut::freeze).map(Full::new).into_response(), Ok(res) => res.map(BytesMut::freeze).map(Full::new).into_response(),
Err(_) => StatusCode::INTERNAL_SERVER_ERROR.into_response(), Err(_) => StatusCode::INTERNAL_SERVER_ERROR.into_response(),
} }
} }
} }
// copied from hyper under the following license: // copied from hyper under the following license:
@ -443,32 +376,32 @@ impl<T: OutgoingResponse> IntoResponse for RumaResponse<T> {
// THE SOFTWARE. // THE SOFTWARE.
pub(crate) async fn to_bytes<T>(body: T) -> Result<Bytes, T::Error> pub(crate) async fn to_bytes<T>(body: T) -> Result<Bytes, T::Error>
where where
T: HttpBody, T: HttpBody,
{ {
futures_util::pin_mut!(body); futures_util::pin_mut!(body);
// If there's only 1 chunk, we can just return Buf::to_bytes() // If there's only 1 chunk, we can just return Buf::to_bytes()
let mut first = if let Some(buf) = body.data().await { let mut first = if let Some(buf) = body.data().await {
buf? buf?
} else { } else {
return Ok(Bytes::new()); return Ok(Bytes::new());
}; };
let second = if let Some(buf) = body.data().await { let second = if let Some(buf) = body.data().await {
buf? buf?
} else { } else {
return Ok(first.copy_to_bytes(first.remaining())); return Ok(first.copy_to_bytes(first.remaining()));
}; };
// With more than 1 buf, we gotta flatten into a Vec first. // With more than 1 buf, we gotta flatten into a Vec first.
let cap = first.remaining() + second.remaining() + body.size_hint().lower() as usize; let cap = first.remaining() + second.remaining() + body.size_hint().lower() as usize;
let mut vec = Vec::with_capacity(cap); let mut vec = Vec::with_capacity(cap);
vec.put(first); vec.put(first);
vec.put(second); vec.put(second);
while let Some(buf) = body.data().await { while let Some(buf) = body.data().await {
vec.put(buf?); vec.put(buf?);
} }
Ok(vec.into()) Ok(vec.into())
} }

View file

@ -1,43 +1,36 @@
use crate::Error;
use ruma::{
api::client::uiaa::UiaaResponse, CanonicalJsonValue, OwnedDeviceId, OwnedServerName,
OwnedUserId,
};
use std::ops::Deref; use std::ops::Deref;
use ruma::{api::client::uiaa::UiaaResponse, CanonicalJsonValue, OwnedDeviceId, OwnedServerName, OwnedUserId};
use crate::Error;
#[cfg(feature = "conduit_bin")] #[cfg(feature = "conduit_bin")]
mod axum; mod axum;
/// Extractor for Ruma request structs /// Extractor for Ruma request structs
pub struct Ruma<T> { pub struct Ruma<T> {
pub body: T, pub body: T,
pub sender_user: Option<OwnedUserId>, pub sender_user: Option<OwnedUserId>,
pub sender_device: Option<OwnedDeviceId>, pub sender_device: Option<OwnedDeviceId>,
pub sender_servername: Option<OwnedServerName>, pub sender_servername: Option<OwnedServerName>,
// This is None when body is not a valid string // This is None when body is not a valid string
pub json_body: Option<CanonicalJsonValue>, pub json_body: Option<CanonicalJsonValue>,
pub from_appservice: bool, pub from_appservice: bool,
} }
impl<T> Deref for Ruma<T> { impl<T> Deref for Ruma<T> {
type Target = T; type Target = T;
fn deref(&self) -> &Self::Target { fn deref(&self) -> &Self::Target { &self.body }
&self.body
}
} }
#[derive(Clone)] #[derive(Clone)]
pub struct RumaResponse<T>(pub T); pub struct RumaResponse<T>(pub T);
impl<T> From<T> for RumaResponse<T> { impl<T> From<T> for RumaResponse<T> {
fn from(t: T) -> Self { fn from(t: T) -> Self { Self(t) }
Self(t)
}
} }
impl From<Error> for RumaResponse<UiaaResponse> { impl From<Error> for RumaResponse<UiaaResponse> {
fn from(t: Error) -> Self { fn from(t: Error) -> Self { t.to_response() }
t.to_response()
}
} }

File diff suppressed because it is too large Load diff

View file

@ -1,9 +1,9 @@
use std::{ use std::{
collections::BTreeMap, collections::BTreeMap,
fmt, fmt,
fmt::Write as _, fmt::Write as _,
net::{IpAddr, Ipv4Addr}, net::{IpAddr, Ipv4Addr},
path::PathBuf, path::PathBuf,
}; };
use either::Either; use either::Either;
@ -21,539 +21,464 @@ mod proxy;
#[derive(Deserialize, Clone, Debug)] #[derive(Deserialize, Clone, Debug)]
#[serde(transparent)] #[serde(transparent)]
pub struct ListeningPort { pub struct ListeningPort {
#[serde(with = "either::serde_untagged")] #[serde(with = "either::serde_untagged")]
pub ports: Either<u16, Vec<u16>>, pub ports: Either<u16, Vec<u16>>,
} }
/// all the config options for conduwuit /// all the config options for conduwuit
#[derive(Clone, Debug, Deserialize)] #[derive(Clone, Debug, Deserialize)]
pub struct Config { pub struct Config {
/// [`IpAddr`] conduwuit will listen on (can be IPv4 or IPv6) /// [`IpAddr`] conduwuit will listen on (can be IPv4 or IPv6)
#[serde(default = "default_address")] #[serde(default = "default_address")]
pub address: IpAddr, pub address: IpAddr,
/// default TCP port(s) conduwuit will listen on /// default TCP port(s) conduwuit will listen on
#[serde(default = "default_port")] #[serde(default = "default_port")]
pub port: ListeningPort, pub port: ListeningPort,
pub tls: Option<TlsConfig>, pub tls: Option<TlsConfig>,
pub unix_socket_path: Option<PathBuf>, pub unix_socket_path: Option<PathBuf>,
#[serde(default = "default_unix_socket_perms")] #[serde(default = "default_unix_socket_perms")]
pub unix_socket_perms: u32, pub unix_socket_perms: u32,
pub server_name: OwnedServerName, pub server_name: OwnedServerName,
#[serde(default = "default_database_backend")] #[serde(default = "default_database_backend")]
pub database_backend: String, pub database_backend: String,
pub database_path: String, pub database_path: String,
#[serde(default = "default_db_cache_capacity_mb")] #[serde(default = "default_db_cache_capacity_mb")]
pub db_cache_capacity_mb: f64, pub db_cache_capacity_mb: f64,
#[serde(default = "default_new_user_displayname_suffix")] #[serde(default = "default_new_user_displayname_suffix")]
pub new_user_displayname_suffix: String, pub new_user_displayname_suffix: String,
#[serde(default = "true_fn")] #[serde(default = "true_fn")]
pub allow_check_for_updates: bool, pub allow_check_for_updates: bool,
#[serde(default = "default_conduit_cache_capacity_modifier")] #[serde(default = "default_conduit_cache_capacity_modifier")]
pub conduit_cache_capacity_modifier: f64, pub conduit_cache_capacity_modifier: f64,
#[serde(default = "default_pdu_cache_capacity")] #[serde(default = "default_pdu_cache_capacity")]
pub pdu_cache_capacity: u32, pub pdu_cache_capacity: u32,
#[serde(default = "default_cleanup_second_interval")] #[serde(default = "default_cleanup_second_interval")]
pub cleanup_second_interval: u32, pub cleanup_second_interval: u32,
#[serde(default = "default_max_request_size")] #[serde(default = "default_max_request_size")]
pub max_request_size: u32, pub max_request_size: u32,
#[serde(default = "default_max_concurrent_requests")] #[serde(default = "default_max_concurrent_requests")]
pub max_concurrent_requests: u16, pub max_concurrent_requests: u16,
#[serde(default = "default_max_fetch_prev_events")] #[serde(default = "default_max_fetch_prev_events")]
pub max_fetch_prev_events: u16, pub max_fetch_prev_events: u16,
#[serde(default)] #[serde(default)]
pub allow_registration: bool, pub allow_registration: bool,
#[serde(default)] #[serde(default)]
pub yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse: bool, pub yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse: bool,
pub registration_token: Option<String>, pub registration_token: Option<String>,
#[serde(default = "true_fn")] #[serde(default = "true_fn")]
pub allow_encryption: bool, pub allow_encryption: bool,
#[serde(default = "true_fn")] #[serde(default = "true_fn")]
pub allow_federation: bool, pub allow_federation: bool,
#[serde(default)] #[serde(default)]
pub allow_public_room_directory_over_federation: bool, pub allow_public_room_directory_over_federation: bool,
#[serde(default)] #[serde(default)]
pub allow_public_room_directory_without_auth: bool, pub allow_public_room_directory_without_auth: bool,
#[serde(default)] #[serde(default)]
pub allow_device_name_federation: bool, pub allow_device_name_federation: bool,
#[serde(default = "true_fn")] #[serde(default = "true_fn")]
pub allow_room_creation: bool, pub allow_room_creation: bool,
#[serde(default = "true_fn")] #[serde(default = "true_fn")]
pub allow_unstable_room_versions: bool, pub allow_unstable_room_versions: bool,
#[serde(default = "default_default_room_version")] #[serde(default = "default_default_room_version")]
pub default_room_version: RoomVersionId, pub default_room_version: RoomVersionId,
pub well_known_client: Option<String>, pub well_known_client: Option<String>,
pub well_known_server: Option<String>, pub well_known_server: Option<String>,
#[serde(default)] #[serde(default)]
pub allow_jaeger: bool, pub allow_jaeger: bool,
#[serde(default)] #[serde(default)]
pub tracing_flame: bool, pub tracing_flame: bool,
#[serde(default)] #[serde(default)]
pub proxy: ProxyConfig, pub proxy: ProxyConfig,
pub jwt_secret: Option<String>, pub jwt_secret: Option<String>,
#[serde(default = "default_trusted_servers")] #[serde(default = "default_trusted_servers")]
pub trusted_servers: Vec<OwnedServerName>, pub trusted_servers: Vec<OwnedServerName>,
#[serde(default = "true_fn")] #[serde(default = "true_fn")]
pub query_trusted_key_servers_first: bool, pub query_trusted_key_servers_first: bool,
#[serde(default = "default_log")] #[serde(default = "default_log")]
pub log: String, pub log: String,
#[serde(default)] #[serde(default)]
pub turn_username: String, pub turn_username: String,
#[serde(default)] #[serde(default)]
pub turn_password: String, pub turn_password: String,
#[serde(default = "Vec::new")] #[serde(default = "Vec::new")]
pub turn_uris: Vec<String>, pub turn_uris: Vec<String>,
#[serde(default)] #[serde(default)]
pub turn_secret: String, pub turn_secret: String,
#[serde(default = "default_turn_ttl")] #[serde(default = "default_turn_ttl")]
pub turn_ttl: u64, pub turn_ttl: u64,
#[serde(default = "default_rocksdb_log_level")] #[serde(default = "default_rocksdb_log_level")]
pub rocksdb_log_level: String, pub rocksdb_log_level: String,
#[serde(default = "default_rocksdb_max_log_file_size")] #[serde(default = "default_rocksdb_max_log_file_size")]
pub rocksdb_max_log_file_size: usize, pub rocksdb_max_log_file_size: usize,
#[serde(default = "default_rocksdb_log_time_to_roll")] #[serde(default = "default_rocksdb_log_time_to_roll")]
pub rocksdb_log_time_to_roll: usize, pub rocksdb_log_time_to_roll: usize,
#[serde(default)] #[serde(default)]
pub rocksdb_optimize_for_spinning_disks: bool, pub rocksdb_optimize_for_spinning_disks: bool,
#[serde(default = "default_rocksdb_parallelism_threads")] #[serde(default = "default_rocksdb_parallelism_threads")]
pub rocksdb_parallelism_threads: usize, pub rocksdb_parallelism_threads: usize,
pub emergency_password: Option<String>, pub emergency_password: Option<String>,
#[serde(default = "default_notification_push_path")] #[serde(default = "default_notification_push_path")]
pub notification_push_path: String, pub notification_push_path: String,
#[serde(default)] #[serde(default)]
pub allow_local_presence: bool, pub allow_local_presence: bool,
#[serde(default)] #[serde(default)]
pub allow_incoming_presence: bool, pub allow_incoming_presence: bool,
#[serde(default)] #[serde(default)]
pub allow_outgoing_presence: bool, pub allow_outgoing_presence: bool,
#[serde(default = "default_presence_idle_timeout_s")] #[serde(default = "default_presence_idle_timeout_s")]
pub presence_idle_timeout_s: u64, pub presence_idle_timeout_s: u64,
#[serde(default = "default_presence_offline_timeout_s")] #[serde(default = "default_presence_offline_timeout_s")]
pub presence_offline_timeout_s: u64, pub presence_offline_timeout_s: u64,
#[serde(default)] #[serde(default)]
pub zstd_compression: bool, pub zstd_compression: bool,
#[serde(default)] #[serde(default)]
pub allow_guest_registration: bool, pub allow_guest_registration: bool,
#[serde(default = "Vec::new")] #[serde(default = "Vec::new")]
pub prevent_media_downloads_from: Vec<OwnedServerName>, pub prevent_media_downloads_from: Vec<OwnedServerName>,
#[serde(default = "default_ip_range_denylist")] #[serde(default = "default_ip_range_denylist")]
pub ip_range_denylist: Vec<String>, pub ip_range_denylist: Vec<String>,
#[serde(default = "Vec::new")] #[serde(default = "Vec::new")]
pub url_preview_domain_contains_allowlist: Vec<String>, pub url_preview_domain_contains_allowlist: Vec<String>,
#[serde(default = "Vec::new")] #[serde(default = "Vec::new")]
pub url_preview_domain_explicit_allowlist: Vec<String>, pub url_preview_domain_explicit_allowlist: Vec<String>,
#[serde(default = "Vec::new")] #[serde(default = "Vec::new")]
pub url_preview_url_contains_allowlist: Vec<String>, pub url_preview_url_contains_allowlist: Vec<String>,
#[serde(default = "default_url_preview_max_spider_size")] #[serde(default = "default_url_preview_max_spider_size")]
pub url_preview_max_spider_size: usize, pub url_preview_max_spider_size: usize,
#[serde(default)] #[serde(default)]
pub url_preview_check_root_domain: bool, pub url_preview_check_root_domain: bool,
#[serde(default = "RegexSet::empty")] #[serde(default = "RegexSet::empty")]
#[serde(with = "serde_regex")] #[serde(with = "serde_regex")]
pub forbidden_room_names: RegexSet, pub forbidden_room_names: RegexSet,
#[serde(default = "RegexSet::empty")] #[serde(default = "RegexSet::empty")]
#[serde(with = "serde_regex")] #[serde(with = "serde_regex")]
pub forbidden_usernames: RegexSet, pub forbidden_usernames: RegexSet,
#[serde(default)] #[serde(default)]
pub block_non_admin_invites: bool, pub block_non_admin_invites: bool,
#[serde(flatten)] #[serde(flatten)]
pub catchall: BTreeMap<String, IgnoredAny>, pub catchall: BTreeMap<String, IgnoredAny>,
} }
#[derive(Clone, Debug, Deserialize)] #[derive(Clone, Debug, Deserialize)]
pub struct TlsConfig { pub struct TlsConfig {
pub certs: String, pub certs: String,
pub key: String, pub key: String,
#[serde(default)] #[serde(default)]
/// Whether to listen and allow for HTTP and HTTPS connections (insecure!) /// Whether to listen and allow for HTTP and HTTPS connections (insecure!)
/// Only works / does something if the `axum_dual_protocol` feature flag was built /// Only works / does something if the `axum_dual_protocol` feature flag was
pub dual_protocol: bool, /// built
pub dual_protocol: bool,
} }
const DEPRECATED_KEYS: &[&str] = &["cache_capacity"]; const DEPRECATED_KEYS: &[&str] = &["cache_capacity"];
impl Config { impl Config {
/// Iterates over all the keys in the config file and warns if there is a deprecated key specified /// Iterates over all the keys in the config file and warns if there is a
pub fn warn_deprecated(&self) { /// deprecated key specified
debug!("Checking for deprecated config keys"); pub fn warn_deprecated(&self) {
let mut was_deprecated = false; debug!("Checking for deprecated config keys");
for key in self let mut was_deprecated = false;
.catchall for key in self.catchall.keys().filter(|key| DEPRECATED_KEYS.iter().any(|s| s == key)) {
.keys() warn!("Config parameter \"{}\" is deprecated, ignoring.", key);
.filter(|key| DEPRECATED_KEYS.iter().any(|s| s == key)) was_deprecated = true;
{ }
warn!("Config parameter \"{}\" is deprecated, ignoring.", key);
was_deprecated = true;
}
if was_deprecated { if was_deprecated {
warn!("Read conduit documentation and check your configuration if any new configuration parameters should be adjusted"); warn!(
} "Read conduit documentation and check your configuration if any new configuration parameters should \
} be adjusted"
);
}
}
/// iterates over all the catchall keys (unknown config options) and warns if there are any. /// iterates over all the catchall keys (unknown config options) and warns
pub fn warn_unknown_key(&self) { /// if there are any.
debug!("Checking for unknown config keys"); pub fn warn_unknown_key(&self) {
for key in self.catchall.keys().filter( debug!("Checking for unknown config keys");
|key| "config".to_owned().ne(key.to_owned()), /* "config" is expected */ for key in
) { self.catchall.keys().filter(|key| "config".to_owned().ne(key.to_owned()) /* "config" is expected */)
warn!( {
"Config parameter \"{}\" is unknown to conduwuit, ignoring.", warn!("Config parameter \"{}\" is unknown to conduwuit, ignoring.", key);
key }
); }
}
}
/// Checks the presence of the `address` and `unix_socket_path` keys in the raw_config, exiting the process if both keys were detected. /// Checks the presence of the `address` and `unix_socket_path` keys in the
pub fn is_dual_listening(&self, raw_config: Figment) -> bool { /// raw_config, exiting the process if both keys were detected.
let check_address = raw_config.find_value("address"); pub fn is_dual_listening(&self, raw_config: Figment) -> bool {
let check_unix_socket = raw_config.find_value("unix_socket_path"); let check_address = raw_config.find_value("address");
let check_unix_socket = raw_config.find_value("unix_socket_path");
// are the check_address and check_unix_socket keys both Ok (specified) at the same time? // are the check_address and check_unix_socket keys both Ok (specified) at the
if check_address.is_ok() && check_unix_socket.is_ok() { // same time?
error!("TOML keys \"address\" and \"unix_socket_path\" were both defined. Please specify only one option."); if check_address.is_ok() && check_unix_socket.is_ok() {
return true; error!("TOML keys \"address\" and \"unix_socket_path\" were both defined. Please specify only one option.");
} return true;
}
false false
} }
} }
impl fmt::Display for Config { impl fmt::Display for Config {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
// Prepare a list of config values to show // Prepare a list of config values to show
let lines = [ let lines = [
("Server name", self.server_name.host()), ("Server name", self.server_name.host()),
("Database backend", &self.database_backend), ("Database backend", &self.database_backend),
("Database path", &self.database_path), ("Database path", &self.database_path),
( ("Database cache capacity (MB)", &self.db_cache_capacity_mb.to_string()),
"Database cache capacity (MB)", ("Cache capacity modifier", &self.conduit_cache_capacity_modifier.to_string()),
&self.db_cache_capacity_mb.to_string(), ("PDU cache capacity", &self.pdu_cache_capacity.to_string()),
), ("Cleanup interval in seconds", &self.cleanup_second_interval.to_string()),
( ("Maximum request size (bytes)", &self.max_request_size.to_string()),
"Cache capacity modifier", ("Maximum concurrent requests", &self.max_concurrent_requests.to_string()),
&self.conduit_cache_capacity_modifier.to_string(), ("Allow registration", &self.allow_registration.to_string()),
), (
("PDU cache capacity", &self.pdu_cache_capacity.to_string()), "Registration token",
( match self.registration_token {
"Cleanup interval in seconds", Some(_) => "set",
&self.cleanup_second_interval.to_string(), None => "not set (open registration!)",
), },
("Maximum request size (bytes)", &self.max_request_size.to_string()), ),
( (
"Maximum concurrent requests", "Allow guest registration (inherently false if allow registration is false)",
&self.max_concurrent_requests.to_string(), &self.allow_guest_registration.to_string(),
), ),
( ("New user display name suffix", &self.new_user_displayname_suffix),
"Allow registration", ("Allow encryption", &self.allow_encryption.to_string()),
&self.allow_registration.to_string(), ("Allow federation", &self.allow_federation.to_string()),
), (
( "Allow incoming federated presence requests (updates)",
"Registration token", &self.allow_incoming_presence.to_string(),
match self.registration_token { ),
Some(_) => "set", (
None => "not set (open registration!)", "Allow outgoing federated presence requests (updates)",
}, &self.allow_outgoing_presence.to_string(),
), ),
( (
"Allow guest registration (inherently false if allow registration is false)", "Allow local presence requests (updates)",
&self.allow_guest_registration.to_string(), &self.allow_local_presence.to_string(),
), ),
( (
"New user display name suffix", "Block non-admin room invites (local and remote, admins can still send and receive invites)",
&self.new_user_displayname_suffix, &self.block_non_admin_invites.to_string(),
), ),
("Allow encryption", &self.allow_encryption.to_string()), ("Allow device name federation", &self.allow_device_name_federation.to_string()),
("Allow federation", &self.allow_federation.to_string()), ("Notification push path", &self.notification_push_path),
( ("Allow room creation", &self.allow_room_creation.to_string()),
"Allow incoming federated presence requests (updates)", (
&self.allow_incoming_presence.to_string(), "Allow public room directory over federation",
), &self.allow_public_room_directory_over_federation.to_string(),
( ),
"Allow outgoing federated presence requests (updates)", (
&self.allow_outgoing_presence.to_string(), "Allow public room directory without authentication",
), &self.allow_public_room_directory_without_auth.to_string(),
( ),
"Allow local presence requests (updates)", (
&self.allow_local_presence.to_string(), "JWT secret",
), match self.jwt_secret {
( Some(_) => "set",
"Block non-admin room invites (local and remote, admins can still send and receive invites)", None => "not set",
&self.block_non_admin_invites.to_string(), },
), ),
( ("Trusted servers", {
"Allow device name federation", let mut lst = vec![];
&self.allow_device_name_federation.to_string(), for server in &self.trusted_servers {
), lst.push(server.host());
("Notification push path", &self.notification_push_path), }
("Allow room creation", &self.allow_room_creation.to_string()), &lst.join(", ")
( }),
"Allow public room directory over federation", (
&self.allow_public_room_directory_over_federation.to_string(), "Query Trusted Key Servers First",
), &self.query_trusted_key_servers_first.to_string(),
( ),
"Allow public room directory without authentication", (
&self.allow_public_room_directory_without_auth.to_string(), "TURN username",
), if self.turn_username.is_empty() {
( "not set"
"JWT secret", } else {
match self.jwt_secret { &self.turn_username
Some(_) => "set", },
None => "not set", ),
}, ("TURN password", {
), if self.turn_password.is_empty() {
("Trusted servers", { "not set"
let mut lst = vec![]; } else {
for server in &self.trusted_servers { "set"
lst.push(server.host()); }
} }),
&lst.join(", ") ("TURN secret", {
}), if self.turn_secret.is_empty() {
( "not set"
"Query Trusted Key Servers First", } else {
&self.query_trusted_key_servers_first.to_string(), "set"
), }
( }),
"TURN username", ("Turn TTL", &self.turn_ttl.to_string()),
if self.turn_username.is_empty() { ("Turn URIs", {
"not set" let mut lst = vec![];
} else { for item in self.turn_uris.iter().cloned().enumerate() {
&self.turn_username let (_, uri): (usize, String) = item;
}, lst.push(uri);
), }
("TURN password", { &lst.join(", ")
if self.turn_password.is_empty() { }),
"not set" ("zstd Response Body Compression", &self.zstd_compression.to_string()),
} else { ("RocksDB database log level", &self.rocksdb_log_level),
"set" ("RocksDB database log time-to-roll", &self.rocksdb_log_time_to_roll.to_string()),
} (
}), "RocksDB database max log file size",
("TURN secret", { &self.rocksdb_max_log_file_size.to_string(),
if self.turn_secret.is_empty() { ),
"not set" (
} else { "RocksDB database optimize for spinning disks",
"set" &self.rocksdb_optimize_for_spinning_disks.to_string(),
} ),
}), ("RocksDB Parallelism Threads", &self.rocksdb_parallelism_threads.to_string()),
("Turn TTL", &self.turn_ttl.to_string()), ("Prevent Media Downloads From", {
("Turn URIs", { let mut lst = vec![];
let mut lst = vec![]; for domain in &self.prevent_media_downloads_from {
for item in self.turn_uris.iter().cloned().enumerate() { lst.push(domain.host());
let (_, uri): (usize, String) = item; }
lst.push(uri); &lst.join(", ")
} }),
&lst.join(", ") ("Outbound Request IP Range Denylist", {
}), let mut lst = vec![];
( for item in self.ip_range_denylist.iter().cloned().enumerate() {
"zstd Response Body Compression", let (_, ip): (usize, String) = item;
&self.zstd_compression.to_string(), lst.push(ip);
), }
("RocksDB database log level", &self.rocksdb_log_level), &lst.join(", ")
( }),
"RocksDB database log time-to-roll", ("Forbidden usernames", {
&self.rocksdb_log_time_to_roll.to_string(), &self.forbidden_usernames.patterns().iter().join(", ")
), }),
( ("Forbidden room names", {
"RocksDB database max log file size", &self.forbidden_room_names.patterns().iter().join(", ")
&self.rocksdb_max_log_file_size.to_string(), }),
), (
( "URL preview domain contains allowlist",
"RocksDB database optimize for spinning disks", &self.url_preview_domain_contains_allowlist.join(", "),
&self.rocksdb_optimize_for_spinning_disks.to_string(), ),
), (
( "URL preview domain explicit allowlist",
"RocksDB Parallelism Threads", &self.url_preview_domain_explicit_allowlist.join(", "),
&self.rocksdb_parallelism_threads.to_string(), ),
), (
("Prevent Media Downloads From", { "URL preview URL contains allowlist",
let mut lst = vec![]; &self.url_preview_url_contains_allowlist.join(", "),
for domain in &self.prevent_media_downloads_from { ),
lst.push(domain.host()); ("URL preview maximum spider size", &self.url_preview_max_spider_size.to_string()),
} ("URL preview check root domain", &self.url_preview_check_root_domain.to_string()),
&lst.join(", ") ];
}),
("Outbound Request IP Range Denylist", {
let mut lst = vec![];
for item in self.ip_range_denylist.iter().cloned().enumerate() {
let (_, ip): (usize, String) = item;
lst.push(ip);
}
&lst.join(", ")
}),
("Forbidden usernames", {
&self.forbidden_usernames.patterns().iter().join(", ")
}),
("Forbidden room names", {
&self.forbidden_room_names.patterns().iter().join(", ")
}),
(
"URL preview domain contains allowlist",
&self.url_preview_domain_contains_allowlist.join(", "),
),
(
"URL preview domain explicit allowlist",
&self.url_preview_domain_explicit_allowlist.join(", "),
),
(
"URL preview URL contains allowlist",
&self.url_preview_url_contains_allowlist.join(", "),
),
(
"URL preview maximum spider size",
&self.url_preview_max_spider_size.to_string(),
),
(
"URL preview check root domain",
&self.url_preview_check_root_domain.to_string(),
),
];
let mut msg: String = "Active config values:\n\n".to_owned(); let mut msg: String = "Active config values:\n\n".to_owned();
for line in lines.into_iter().enumerate() { for line in lines.into_iter().enumerate() {
let _ = writeln!(msg, "{}: {}", line.1 .0, line.1 .1); let _ = writeln!(msg, "{}: {}", line.1 .0, line.1 .1);
} }
write!(f, "{msg}") write!(f, "{msg}")
} }
} }
fn true_fn() -> bool { fn true_fn() -> bool { true }
true
}
fn default_address() -> IpAddr { fn default_address() -> IpAddr { Ipv4Addr::LOCALHOST.into() }
Ipv4Addr::LOCALHOST.into()
}
fn default_port() -> ListeningPort { fn default_port() -> ListeningPort {
ListeningPort { ListeningPort {
ports: Either::Left(8008), ports: Either::Left(8008),
} }
} }
fn default_unix_socket_perms() -> u32 { fn default_unix_socket_perms() -> u32 { 660 }
660
}
fn default_database_backend() -> String { fn default_database_backend() -> String { "rocksdb".to_owned() }
"rocksdb".to_owned()
}
fn default_db_cache_capacity_mb() -> f64 { fn default_db_cache_capacity_mb() -> f64 { 300.0 }
300.0
}
fn default_conduit_cache_capacity_modifier() -> f64 { fn default_conduit_cache_capacity_modifier() -> f64 { 1.0 }
1.0
}
fn default_pdu_cache_capacity() -> u32 { fn default_pdu_cache_capacity() -> u32 { 150_000 }
150_000
}
fn default_cleanup_second_interval() -> u32 { fn default_cleanup_second_interval() -> u32 {
60 // every minute 60 // every minute
} }
fn default_max_request_size() -> u32 { fn default_max_request_size() -> u32 {
20 * 1024 * 1024 // Default to 20 MB 20 * 1024 * 1024 // Default to 20 MB
} }
fn default_max_concurrent_requests() -> u16 { fn default_max_concurrent_requests() -> u16 { 500 }
500
}
fn default_max_fetch_prev_events() -> u16 { fn default_max_fetch_prev_events() -> u16 { 100_u16 }
100_u16
}
fn default_trusted_servers() -> Vec<OwnedServerName> { fn default_trusted_servers() -> Vec<OwnedServerName> { vec![OwnedServerName::try_from("matrix.org").unwrap()] }
vec![OwnedServerName::try_from("matrix.org").unwrap()]
}
fn default_log() -> String { fn default_log() -> String { "warn,state_res=warn".to_owned() }
"warn,state_res=warn".to_owned()
}
fn default_notification_push_path() -> String { fn default_notification_push_path() -> String { "/_matrix/push/v1/notify".to_owned() }
"/_matrix/push/v1/notify".to_owned()
}
fn default_turn_ttl() -> u64 { fn default_turn_ttl() -> u64 { 60 * 60 * 24 }
60 * 60 * 24
}
fn default_presence_idle_timeout_s() -> u64 { fn default_presence_idle_timeout_s() -> u64 { 5 * 60 }
5 * 60
}
fn default_presence_offline_timeout_s() -> u64 { fn default_presence_offline_timeout_s() -> u64 { 30 * 60 }
30 * 60
}
fn default_rocksdb_log_level() -> String { fn default_rocksdb_log_level() -> String { "warn".to_owned() }
"warn".to_owned()
}
fn default_rocksdb_log_time_to_roll() -> usize { fn default_rocksdb_log_time_to_roll() -> usize { 0 }
0
}
fn default_rocksdb_parallelism_threads() -> usize { fn default_rocksdb_parallelism_threads() -> usize { num_cpus::get_physical() / 2 }
num_cpus::get_physical() / 2
}
// I know, it's a great name // I know, it's a great name
pub(crate) fn default_default_room_version() -> RoomVersionId { pub(crate) fn default_default_room_version() -> RoomVersionId { RoomVersionId::V10 }
RoomVersionId::V10
}
fn default_rocksdb_max_log_file_size() -> usize { fn default_rocksdb_max_log_file_size() -> usize {
// 4 megabytes // 4 megabytes
4 * 1024 * 1024 4 * 1024 * 1024
} }
fn default_ip_range_denylist() -> Vec<String> { fn default_ip_range_denylist() -> Vec<String> {
vec![ vec![
"127.0.0.0/8".to_owned(), "127.0.0.0/8".to_owned(),
"10.0.0.0/8".to_owned(), "10.0.0.0/8".to_owned(),
"172.16.0.0/12".to_owned(), "172.16.0.0/12".to_owned(),
"192.168.0.0/16".to_owned(), "192.168.0.0/16".to_owned(),
"100.64.0.0/10".to_owned(), "100.64.0.0/10".to_owned(),
"192.0.0.0/24".to_owned(), "192.0.0.0/24".to_owned(),
"169.254.0.0/16".to_owned(), "169.254.0.0/16".to_owned(),
"192.88.99.0/24".to_owned(), "192.88.99.0/24".to_owned(),
"198.18.0.0/15".to_owned(), "198.18.0.0/15".to_owned(),
"192.0.2.0/24".to_owned(), "192.0.2.0/24".to_owned(),
"198.51.100.0/24".to_owned(), "198.51.100.0/24".to_owned(),
"203.0.113.0/24".to_owned(), "203.0.113.0/24".to_owned(),
"224.0.0.0/4".to_owned(), "224.0.0.0/4".to_owned(),
"::1/128".to_owned(), "::1/128".to_owned(),
"fe80::/10".to_owned(), "fe80::/10".to_owned(),
"fc00::/7".to_owned(), "fc00::/7".to_owned(),
"2001:db8::/32".to_owned(), "2001:db8::/32".to_owned(),
"ff00::/8".to_owned(), "ff00::/8".to_owned(),
"fec0::/10".to_owned(), "fec0::/10".to_owned(),
] ]
} }
fn default_url_preview_max_spider_size() -> usize { fn default_url_preview_max_spider_size() -> usize {
1_000_000 // 1MB 1_000_000 // 1MB
} }
fn default_new_user_displayname_suffix() -> String { fn default_new_user_displayname_suffix() -> String { "🏳️‍⚧️".to_owned() }
"🏳️‍⚧️".to_owned()
}

View file

@ -24,119 +24,124 @@ use crate::Result;
/// ## Include vs. Exclude /// ## Include vs. Exclude
/// If include is an empty list, it is assumed to be `["*"]`. /// If include is an empty list, it is assumed to be `["*"]`.
/// ///
/// If a domain matches both the exclude and include list, the proxy will only be used if it was /// If a domain matches both the exclude and include list, the proxy will only
/// included because of a more specific rule than it was excluded. In the above example, the proxy /// be used if it was included because of a more specific rule than it was
/// would be used for `ordinary.onion`, `matrix.myspecial.onion`, but not `hello.myspecial.onion`. /// excluded. In the above example, the proxy would be used for
/// `ordinary.onion`, `matrix.myspecial.onion`, but not `hello.myspecial.onion`.
#[derive(Clone, Default, Debug, Deserialize)] #[derive(Clone, Default, Debug, Deserialize)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum ProxyConfig { pub enum ProxyConfig {
#[default] #[default]
None, None,
Global { Global {
#[serde(deserialize_with = "crate::utils::deserialize_from_str")] #[serde(deserialize_with = "crate::utils::deserialize_from_str")]
url: Url, url: Url,
}, },
ByDomain(Vec<PartialProxyConfig>), ByDomain(Vec<PartialProxyConfig>),
} }
impl ProxyConfig { impl ProxyConfig {
pub fn to_proxy(&self) -> Result<Option<Proxy>> { pub fn to_proxy(&self) -> Result<Option<Proxy>> {
Ok(match self.clone() { Ok(match self.clone() {
ProxyConfig::None => None, ProxyConfig::None => None,
ProxyConfig::Global { url } => Some(Proxy::all(url)?), ProxyConfig::Global {
ProxyConfig::ByDomain(proxies) => Some(Proxy::custom(move |url| { url,
proxies.iter().find_map(|proxy| proxy.for_url(url)).cloned() // first matching proxy } => Some(Proxy::all(url)?),
})), ProxyConfig::ByDomain(proxies) => Some(Proxy::custom(move |url| {
}) proxies.iter().find_map(|proxy| proxy.for_url(url)).cloned() // first matching
} // proxy
})),
})
}
} }
#[derive(Clone, Debug, Deserialize)] #[derive(Clone, Debug, Deserialize)]
pub struct PartialProxyConfig { pub struct PartialProxyConfig {
#[serde(deserialize_with = "crate::utils::deserialize_from_str")] #[serde(deserialize_with = "crate::utils::deserialize_from_str")]
url: Url, url: Url,
#[serde(default)] #[serde(default)]
include: Vec<WildCardedDomain>, include: Vec<WildCardedDomain>,
#[serde(default)] #[serde(default)]
exclude: Vec<WildCardedDomain>, exclude: Vec<WildCardedDomain>,
} }
impl PartialProxyConfig { impl PartialProxyConfig {
pub fn for_url(&self, url: &Url) -> Option<&Url> { pub fn for_url(&self, url: &Url) -> Option<&Url> {
let domain = url.domain()?; let domain = url.domain()?;
let mut included_because = None; // most specific reason it was included let mut included_because = None; // most specific reason it was included
let mut excluded_because = None; // most specific reason it was excluded let mut excluded_because = None; // most specific reason it was excluded
if self.include.is_empty() { if self.include.is_empty() {
// treat empty include list as `*` // treat empty include list as `*`
included_because = Some(&WildCardedDomain::WildCard); included_because = Some(&WildCardedDomain::WildCard);
} }
for wc_domain in &self.include { for wc_domain in &self.include {
if wc_domain.matches(domain) { if wc_domain.matches(domain) {
match included_because { match included_because {
Some(prev) if !wc_domain.more_specific_than(prev) => (), Some(prev) if !wc_domain.more_specific_than(prev) => (),
_ => included_because = Some(wc_domain), _ => included_because = Some(wc_domain),
} }
} }
} }
for wc_domain in &self.exclude { for wc_domain in &self.exclude {
if wc_domain.matches(domain) { if wc_domain.matches(domain) {
match excluded_because { match excluded_because {
Some(prev) if !wc_domain.more_specific_than(prev) => (), Some(prev) if !wc_domain.more_specific_than(prev) => (),
_ => excluded_because = Some(wc_domain), _ => excluded_because = Some(wc_domain),
} }
} }
} }
match (included_because, excluded_because) { match (included_because, excluded_because) {
(Some(a), Some(b)) if a.more_specific_than(b) => Some(&self.url), // included for a more specific reason than excluded (Some(a), Some(b)) if a.more_specific_than(b) => Some(&self.url), /* included for a more specific reason */
(Some(_), None) => Some(&self.url), // than excluded
_ => None, (Some(_), None) => Some(&self.url),
} _ => None,
} }
}
} }
/// A domain name, that optionally allows a * as its first subdomain. /// A domain name, that optionally allows a * as its first subdomain.
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
enum WildCardedDomain { enum WildCardedDomain {
WildCard, WildCard,
WildCarded(String), WildCarded(String),
Exact(String), Exact(String),
} }
impl WildCardedDomain { impl WildCardedDomain {
fn matches(&self, domain: &str) -> bool { fn matches(&self, domain: &str) -> bool {
match self { match self {
WildCardedDomain::WildCard => true, WildCardedDomain::WildCard => true,
WildCardedDomain::WildCarded(d) => domain.ends_with(d), WildCardedDomain::WildCarded(d) => domain.ends_with(d),
WildCardedDomain::Exact(d) => domain == d, WildCardedDomain::Exact(d) => domain == d,
} }
} }
fn more_specific_than(&self, other: &Self) -> bool {
match (self, other) { fn more_specific_than(&self, other: &Self) -> bool {
(WildCardedDomain::WildCard, WildCardedDomain::WildCard) => false, match (self, other) {
(_, WildCardedDomain::WildCard) => true, (WildCardedDomain::WildCard, WildCardedDomain::WildCard) => false,
(WildCardedDomain::Exact(a), WildCardedDomain::WildCarded(_)) => other.matches(a), (_, WildCardedDomain::WildCard) => true,
(WildCardedDomain::WildCarded(a), WildCardedDomain::WildCarded(b)) => { (WildCardedDomain::Exact(a), WildCardedDomain::WildCarded(_)) => other.matches(a),
a != b && a.ends_with(b) (WildCardedDomain::WildCarded(a), WildCardedDomain::WildCarded(b)) => a != b && a.ends_with(b),
} _ => false,
_ => false, }
} }
}
} }
impl std::str::FromStr for WildCardedDomain { impl std::str::FromStr for WildCardedDomain {
type Err = std::convert::Infallible; type Err = std::convert::Infallible;
fn from_str(s: &str) -> Result<Self, Self::Err> {
// maybe do some domain validation? fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(if s.starts_with("*.") { // maybe do some domain validation?
WildCardedDomain::WildCarded(s[1..].to_owned()) Ok(if s.starts_with("*.") {
} else if s == "*" { WildCardedDomain::WildCarded(s[1..].to_owned())
WildCardedDomain::WildCarded("".to_owned()) } else if s == "*" {
} else { WildCardedDomain::WildCarded("".to_owned())
WildCardedDomain::Exact(s.to_owned()) } else {
}) WildCardedDomain::Exact(s.to_owned())
} })
}
} }
impl<'de> Deserialize<'de> for WildCardedDomain { impl<'de> Deserialize<'de> for WildCardedDomain {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where where
D: serde::de::Deserializer<'de>, D: serde::de::Deserializer<'de>,
{ {
crate::utils::deserialize_from_str(deserializer) crate::utils::deserialize_from_str(deserializer)
} }
} }

View file

@ -1,8 +1,8 @@
use std::{future::Future, pin::Pin, sync::Arc};
use super::Config; use super::Config;
use crate::Result; use crate::Result;
use std::{future::Future, pin::Pin, sync::Arc};
#[cfg(feature = "sqlite")] #[cfg(feature = "sqlite")]
pub mod sqlite; pub mod sqlite;
@ -13,53 +13,44 @@ pub(crate) mod rocksdb;
pub(crate) mod watchers; pub(crate) mod watchers;
pub(crate) trait KeyValueDatabaseEngine: Send + Sync { pub(crate) trait KeyValueDatabaseEngine: Send + Sync {
fn open(config: &Config) -> Result<Self> fn open(config: &Config) -> Result<Self>
where where
Self: Sized; Self: Sized;
fn open_tree(&self, name: &'static str) -> Result<Arc<dyn KvTree>>; fn open_tree(&self, name: &'static str) -> Result<Arc<dyn KvTree>>;
fn flush(&self) -> Result<()>; fn flush(&self) -> Result<()>;
fn cleanup(&self) -> Result<()> { fn cleanup(&self) -> Result<()> { Ok(()) }
Ok(()) fn memory_usage(&self) -> Result<String> {
} Ok("Current database engine does not support memory usage reporting.".to_owned())
fn memory_usage(&self) -> Result<String> { }
Ok("Current database engine does not support memory usage reporting.".to_owned())
}
#[allow(dead_code)] #[allow(dead_code)]
fn clear_caches(&self) {} fn clear_caches(&self) {}
} }
pub(crate) trait KvTree: Send + Sync { pub(crate) trait KvTree: Send + Sync {
fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>>; fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>>;
fn insert(&self, key: &[u8], value: &[u8]) -> Result<()>; fn insert(&self, key: &[u8], value: &[u8]) -> Result<()>;
fn insert_batch(&self, iter: &mut dyn Iterator<Item = (Vec<u8>, Vec<u8>)>) -> Result<()>; fn insert_batch(&self, iter: &mut dyn Iterator<Item = (Vec<u8>, Vec<u8>)>) -> Result<()>;
fn remove(&self, key: &[u8]) -> Result<()>; fn remove(&self, key: &[u8]) -> Result<()>;
fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a>; fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a>;
fn iter_from<'a>( fn iter_from<'a>(&'a self, from: &[u8], backwards: bool) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a>;
&'a self,
from: &[u8],
backwards: bool,
) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a>;
fn increment(&self, key: &[u8]) -> Result<Vec<u8>>; fn increment(&self, key: &[u8]) -> Result<Vec<u8>>;
fn increment_batch(&self, iter: &mut dyn Iterator<Item = Vec<u8>>) -> Result<()>; fn increment_batch(&self, iter: &mut dyn Iterator<Item = Vec<u8>>) -> Result<()>;
fn scan_prefix<'a>( fn scan_prefix<'a>(&'a self, prefix: Vec<u8>) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a>;
&'a self,
prefix: Vec<u8>,
) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a>;
fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>>; fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>>;
fn clear(&self) -> Result<()> { fn clear(&self) -> Result<()> {
for (key, _) in self.iter() { for (key, _) in self.iter() {
self.remove(&key)?; self.remove(&key)?;
} }
Ok(()) Ok(())
} }
} }

View file

@ -1,293 +1,265 @@
use std::{ use std::{
future::Future, future::Future,
pin::Pin, pin::Pin,
sync::{Arc, RwLock}, sync::{Arc, RwLock},
}; };
use rocksdb::LogLevel::{Debug, Error, Fatal, Info, Warn}; use rocksdb::LogLevel::{Debug, Error, Fatal, Info, Warn};
use tracing::{debug, info}; use tracing::{debug, info};
use super::{super::Config, watchers::Watchers, KeyValueDatabaseEngine, KvTree};
use crate::{utils, Result}; use crate::{utils, Result};
use super::{super::Config, watchers::Watchers, KeyValueDatabaseEngine, KvTree};
pub(crate) struct Engine { pub(crate) struct Engine {
rocks: rocksdb::DBWithThreadMode<rocksdb::MultiThreaded>, rocks: rocksdb::DBWithThreadMode<rocksdb::MultiThreaded>,
cache: rocksdb::Cache, cache: rocksdb::Cache,
old_cfs: Vec<String>, old_cfs: Vec<String>,
config: Config, config: Config,
} }
struct RocksDbEngineTree<'a> { struct RocksDbEngineTree<'a> {
db: Arc<Engine>, db: Arc<Engine>,
name: &'a str, name: &'a str,
watchers: Watchers, watchers: Watchers,
write_lock: RwLock<()>, write_lock: RwLock<()>,
} }
fn db_options(rocksdb_cache: &rocksdb::Cache, config: &Config) -> rocksdb::Options { fn db_options(rocksdb_cache: &rocksdb::Cache, config: &Config) -> rocksdb::Options {
// block-based options: https://docs.rs/rocksdb/latest/rocksdb/struct.BlockBasedOptions.html# // block-based options: https://docs.rs/rocksdb/latest/rocksdb/struct.BlockBasedOptions.html#
let mut block_based_options = rocksdb::BlockBasedOptions::default(); let mut block_based_options = rocksdb::BlockBasedOptions::default();
block_based_options.set_block_cache(rocksdb_cache); block_based_options.set_block_cache(rocksdb_cache);
// "Difference of spinning disk" // "Difference of spinning disk"
// https://zhangyuchi.gitbooks.io/rocksdbbook/content/RocksDB-Tuning-Guide.html // https://zhangyuchi.gitbooks.io/rocksdbbook/content/RocksDB-Tuning-Guide.html
block_based_options.set_block_size(64 * 1024); block_based_options.set_block_size(64 * 1024);
block_based_options.set_cache_index_and_filter_blocks(true); block_based_options.set_cache_index_and_filter_blocks(true);
// database options: https://docs.rs/rocksdb/latest/rocksdb/struct.Options.html# // database options: https://docs.rs/rocksdb/latest/rocksdb/struct.Options.html#
let mut db_opts = rocksdb::Options::default(); let mut db_opts = rocksdb::Options::default();
let rocksdb_log_level = match config.rocksdb_log_level.as_ref() { let rocksdb_log_level = match config.rocksdb_log_level.as_ref() {
"debug" => Debug, "debug" => Debug,
"info" => Info, "info" => Info,
"error" => Error, "error" => Error,
"fatal" => Fatal, "fatal" => Fatal,
_ => Warn, _ => Warn,
}; };
let threads = if config.rocksdb_parallelism_threads == 0 { let threads = if config.rocksdb_parallelism_threads == 0 {
num_cpus::get_physical() // max cores if user specified 0 num_cpus::get_physical() // max cores if user specified 0
} else { } else {
config.rocksdb_parallelism_threads config.rocksdb_parallelism_threads
}; };
db_opts.set_log_level(rocksdb_log_level); db_opts.set_log_level(rocksdb_log_level);
db_opts.set_max_log_file_size(config.rocksdb_max_log_file_size); db_opts.set_max_log_file_size(config.rocksdb_max_log_file_size);
db_opts.set_log_file_time_to_roll(config.rocksdb_log_time_to_roll); db_opts.set_log_file_time_to_roll(config.rocksdb_log_time_to_roll);
if config.rocksdb_optimize_for_spinning_disks { if config.rocksdb_optimize_for_spinning_disks {
db_opts.set_skip_stats_update_on_db_open(true); db_opts.set_skip_stats_update_on_db_open(true);
db_opts.set_compaction_readahead_size(2 * 1024 * 1024); // default compaction_readahead_size is 0 which is good for SSDs db_opts.set_compaction_readahead_size(2 * 1024 * 1024); // default compaction_readahead_size is 0 which is good for SSDs
db_opts.set_target_file_size_base(256 * 1024 * 1024); // default target_file_size is 64MB which is good for SSDs db_opts.set_target_file_size_base(256 * 1024 * 1024); // default target_file_size is 64MB which is good for SSDs
db_opts.set_optimize_filters_for_hits(true); // doesn't really seem useful for fast storage db_opts.set_optimize_filters_for_hits(true); // doesn't really seem useful for fast storage
db_opts.set_keep_log_file_num(3); // keep as few LOG files as possible for spinning hard drives. these are not really important db_opts.set_keep_log_file_num(3); // keep as few LOG files as possible for
} else { // spinning hard drives. these are not really
db_opts.set_skip_stats_update_on_db_open(false); // important
db_opts.set_max_bytes_for_level_base(512 * 1024 * 1024); } else {
db_opts.set_use_direct_reads(true); db_opts.set_skip_stats_update_on_db_open(false);
db_opts.set_use_direct_io_for_flush_and_compaction(true); db_opts.set_max_bytes_for_level_base(512 * 1024 * 1024);
db_opts.set_keep_log_file_num(20); db_opts.set_use_direct_reads(true);
} db_opts.set_use_direct_io_for_flush_and_compaction(true);
db_opts.set_keep_log_file_num(20);
}
db_opts.set_block_based_table_factory(&block_based_options); db_opts.set_block_based_table_factory(&block_based_options);
db_opts.set_level_compaction_dynamic_level_bytes(true); db_opts.set_level_compaction_dynamic_level_bytes(true);
db_opts.create_if_missing(true); db_opts.create_if_missing(true);
db_opts.increase_parallelism( db_opts.increase_parallelism(
threads threads.try_into().expect("Failed to convert \"rocksdb_parallelism_threads\" usize into i32"),
.try_into() );
.expect("Failed to convert \"rocksdb_parallelism_threads\" usize into i32"), //db_opts.set_max_open_files(config.rocksdb_max_open_files);
); db_opts.set_compression_type(rocksdb::DBCompressionType::Zstd);
//db_opts.set_max_open_files(config.rocksdb_max_open_files); db_opts.set_compaction_style(rocksdb::DBCompactionStyle::Level);
db_opts.set_compression_type(rocksdb::DBCompressionType::Zstd); db_opts.optimize_level_style_compaction(10 * 1024 * 1024);
db_opts.set_compaction_style(rocksdb::DBCompactionStyle::Level);
db_opts.optimize_level_style_compaction(10 * 1024 * 1024);
// https://github.com/facebook/rocksdb/wiki/Setup-Options-and-Basic-Tuning // https://github.com/facebook/rocksdb/wiki/Setup-Options-and-Basic-Tuning
db_opts.set_max_background_jobs(6); db_opts.set_max_background_jobs(6);
db_opts.set_bytes_per_sync(1_048_576); db_opts.set_bytes_per_sync(1_048_576);
// https://github.com/facebook/rocksdb/wiki/WAL-Recovery-Modes#ktoleratecorruptedtailrecords // https://github.com/facebook/rocksdb/wiki/WAL-Recovery-Modes#ktoleratecorruptedtailrecords
// //
// Unclean shutdowns of a Matrix homeserver are likely to be fine when // Unclean shutdowns of a Matrix homeserver are likely to be fine when
// recovered in this manner as it's likely any lost information will be // recovered in this manner as it's likely any lost information will be
// restored via federation. // restored via federation.
db_opts.set_wal_recovery_mode(rocksdb::DBRecoveryMode::TolerateCorruptedTailRecords); db_opts.set_wal_recovery_mode(rocksdb::DBRecoveryMode::TolerateCorruptedTailRecords);
let prefix_extractor = rocksdb::SliceTransform::create_fixed_prefix(1); let prefix_extractor = rocksdb::SliceTransform::create_fixed_prefix(1);
db_opts.set_prefix_extractor(prefix_extractor); db_opts.set_prefix_extractor(prefix_extractor);
db_opts db_opts
} }
impl KeyValueDatabaseEngine for Arc<Engine> { impl KeyValueDatabaseEngine for Arc<Engine> {
fn open(config: &Config) -> Result<Self> { fn open(config: &Config) -> Result<Self> {
let cache_capacity_bytes = (config.db_cache_capacity_mb * 1024.0 * 1024.0) as usize; let cache_capacity_bytes = (config.db_cache_capacity_mb * 1024.0 * 1024.0) as usize;
let rocksdb_cache = rocksdb::Cache::new_lru_cache(cache_capacity_bytes); let rocksdb_cache = rocksdb::Cache::new_lru_cache(cache_capacity_bytes);
let db_opts = db_options(&rocksdb_cache, config); let db_opts = db_options(&rocksdb_cache, config);
debug!("Listing column families in database"); debug!("Listing column families in database");
let cfs = rocksdb::DBWithThreadMode::<rocksdb::MultiThreaded>::list_cf( let cfs = rocksdb::DBWithThreadMode::<rocksdb::MultiThreaded>::list_cf(&db_opts, &config.database_path)
&db_opts, .unwrap_or_default();
&config.database_path,
)
.unwrap_or_default();
debug!("Opening column family descriptors in database"); debug!("Opening column family descriptors in database");
info!("RocksDB database compaction will take place now, a delay in startup is expected"); info!("RocksDB database compaction will take place now, a delay in startup is expected");
let db = rocksdb::DBWithThreadMode::<rocksdb::MultiThreaded>::open_cf_descriptors( let db = rocksdb::DBWithThreadMode::<rocksdb::MultiThreaded>::open_cf_descriptors(
&db_opts, &db_opts,
&config.database_path, &config.database_path,
cfs.iter().map(|name| { cfs.iter().map(|name| rocksdb::ColumnFamilyDescriptor::new(name, db_options(&rocksdb_cache, config))),
rocksdb::ColumnFamilyDescriptor::new(name, db_options(&rocksdb_cache, config)) )?;
}),
)?;
Ok(Arc::new(Engine { Ok(Arc::new(Engine {
rocks: db, rocks: db,
cache: rocksdb_cache, cache: rocksdb_cache,
old_cfs: cfs, old_cfs: cfs,
config: config.clone(), config: config.clone(),
})) }))
} }
fn open_tree(&self, name: &'static str) -> Result<Arc<dyn KvTree>> { fn open_tree(&self, name: &'static str) -> Result<Arc<dyn KvTree>> {
if !self.old_cfs.contains(&name.to_owned()) { if !self.old_cfs.contains(&name.to_owned()) {
// Create if it didn't exist // Create if it didn't exist
debug!("Creating new column family in database: {}", name); debug!("Creating new column family in database: {}", name);
let _ = self let _ = self.rocks.create_cf(name, &db_options(&self.cache, &self.config));
.rocks }
.create_cf(name, &db_options(&self.cache, &self.config));
}
Ok(Arc::new(RocksDbEngineTree { Ok(Arc::new(RocksDbEngineTree {
name, name,
db: Arc::clone(self), db: Arc::clone(self),
watchers: Watchers::default(), watchers: Watchers::default(),
write_lock: RwLock::new(()), write_lock: RwLock::new(()),
})) }))
} }
fn flush(&self) -> Result<()> { fn flush(&self) -> Result<()> {
// TODO? // TODO?
Ok(()) Ok(())
} }
fn memory_usage(&self) -> Result<String> { fn memory_usage(&self) -> Result<String> {
let stats = let stats = rocksdb::perf::get_memory_usage_stats(Some(&[&self.rocks]), Some(&[&self.cache]))?;
rocksdb::perf::get_memory_usage_stats(Some(&[&self.rocks]), Some(&[&self.cache]))?; Ok(format!(
Ok(format!( "Approximate memory usage of all the mem-tables: {:.3} MB\nApproximate memory usage of un-flushed \
"Approximate memory usage of all the mem-tables: {:.3} MB\n\ mem-tables: {:.3} MB\nApproximate memory usage of all the table readers: {:.3} MB\nApproximate memory \
Approximate memory usage of un-flushed mem-tables: {:.3} MB\n\ usage by cache: {:.3} MB\nApproximate memory usage by cache pinned: {:.3} MB\n",
Approximate memory usage of all the table readers: {:.3} MB\n\ stats.mem_table_total as f64 / 1024.0 / 1024.0,
Approximate memory usage by cache: {:.3} MB\n\ stats.mem_table_unflushed as f64 / 1024.0 / 1024.0,
Approximate memory usage by cache pinned: {:.3} MB\n\ stats.mem_table_readers_total as f64 / 1024.0 / 1024.0,
", stats.cache_total as f64 / 1024.0 / 1024.0,
stats.mem_table_total as f64 / 1024.0 / 1024.0, self.cache.get_pinned_usage() as f64 / 1024.0 / 1024.0,
stats.mem_table_unflushed as f64 / 1024.0 / 1024.0, ))
stats.mem_table_readers_total as f64 / 1024.0 / 1024.0, }
stats.cache_total as f64 / 1024.0 / 1024.0,
self.cache.get_pinned_usage() as f64 / 1024.0 / 1024.0,
))
}
// TODO: figure out if this is needed for rocksdb // TODO: figure out if this is needed for rocksdb
#[allow(dead_code)] #[allow(dead_code)]
fn clear_caches(&self) {} fn clear_caches(&self) {}
} }
impl RocksDbEngineTree<'_> { impl RocksDbEngineTree<'_> {
fn cf(&self) -> Arc<rocksdb::BoundColumnFamily<'_>> { fn cf(&self) -> Arc<rocksdb::BoundColumnFamily<'_>> { self.db.rocks.cf_handle(self.name).unwrap() }
self.db.rocks.cf_handle(self.name).unwrap()
}
} }
impl KvTree for RocksDbEngineTree<'_> { impl KvTree for RocksDbEngineTree<'_> {
fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> { fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> { Ok(self.db.rocks.get_cf(&self.cf(), key)?) }
Ok(self.db.rocks.get_cf(&self.cf(), key)?)
}
fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> { fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> {
let lock = self.write_lock.read().unwrap(); let lock = self.write_lock.read().unwrap();
self.db.rocks.put_cf(&self.cf(), key, value)?; self.db.rocks.put_cf(&self.cf(), key, value)?;
drop(lock); drop(lock);
self.watchers.wake(key); self.watchers.wake(key);
Ok(()) Ok(())
} }
fn insert_batch(&self, iter: &mut dyn Iterator<Item = (Vec<u8>, Vec<u8>)>) -> Result<()> { fn insert_batch(&self, iter: &mut dyn Iterator<Item = (Vec<u8>, Vec<u8>)>) -> Result<()> {
for (key, value) in iter { for (key, value) in iter {
self.db.rocks.put_cf(&self.cf(), key, value)?; self.db.rocks.put_cf(&self.cf(), key, value)?;
} }
Ok(()) Ok(())
} }
fn remove(&self, key: &[u8]) -> Result<()> { fn remove(&self, key: &[u8]) -> Result<()> { Ok(self.db.rocks.delete_cf(&self.cf(), key)?) }
Ok(self.db.rocks.delete_cf(&self.cf(), key)?)
}
fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a> { fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a> {
Box::new( Box::new(
self.db self.db
.rocks .rocks
.iterator_cf(&self.cf(), rocksdb::IteratorMode::Start) .iterator_cf(&self.cf(), rocksdb::IteratorMode::Start)
.map(std::result::Result::unwrap) .map(std::result::Result::unwrap)
.map(|(k, v)| (Vec::from(k), Vec::from(v))), .map(|(k, v)| (Vec::from(k), Vec::from(v))),
) )
} }
fn iter_from<'a>( fn iter_from<'a>(&'a self, from: &[u8], backwards: bool) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a> {
&'a self, Box::new(
from: &[u8], self.db
backwards: bool, .rocks
) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a> { .iterator_cf(
Box::new( &self.cf(),
self.db rocksdb::IteratorMode::From(
.rocks from,
.iterator_cf( if backwards {
&self.cf(), rocksdb::Direction::Reverse
rocksdb::IteratorMode::From( } else {
from, rocksdb::Direction::Forward
if backwards { },
rocksdb::Direction::Reverse ),
} else { )
rocksdb::Direction::Forward .map(std::result::Result::unwrap)
}, .map(|(k, v)| (Vec::from(k), Vec::from(v))),
), )
) }
.map(std::result::Result::unwrap)
.map(|(k, v)| (Vec::from(k), Vec::from(v))),
)
}
fn increment(&self, key: &[u8]) -> Result<Vec<u8>> { fn increment(&self, key: &[u8]) -> Result<Vec<u8>> {
let lock = self.write_lock.write().unwrap(); let lock = self.write_lock.write().unwrap();
let old = self.db.rocks.get_cf(&self.cf(), key)?; let old = self.db.rocks.get_cf(&self.cf(), key)?;
let new = utils::increment(old.as_deref()).unwrap(); let new = utils::increment(old.as_deref()).unwrap();
self.db.rocks.put_cf(&self.cf(), key, &new)?; self.db.rocks.put_cf(&self.cf(), key, &new)?;
drop(lock); drop(lock);
Ok(new) Ok(new)
} }
fn increment_batch(&self, iter: &mut dyn Iterator<Item = Vec<u8>>) -> Result<()> { fn increment_batch(&self, iter: &mut dyn Iterator<Item = Vec<u8>>) -> Result<()> {
let lock = self.write_lock.write().unwrap(); let lock = self.write_lock.write().unwrap();
for key in iter { for key in iter {
let old = self.db.rocks.get_cf(&self.cf(), &key)?; let old = self.db.rocks.get_cf(&self.cf(), &key)?;
let new = utils::increment(old.as_deref()).unwrap(); let new = utils::increment(old.as_deref()).unwrap();
self.db.rocks.put_cf(&self.cf(), key, new)?; self.db.rocks.put_cf(&self.cf(), key, new)?;
} }
drop(lock); drop(lock);
Ok(()) Ok(())
} }
fn scan_prefix<'a>( fn scan_prefix<'a>(&'a self, prefix: Vec<u8>) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a> {
&'a self, Box::new(
prefix: Vec<u8>, self.db
) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a> { .rocks
Box::new( .iterator_cf(&self.cf(), rocksdb::IteratorMode::From(&prefix, rocksdb::Direction::Forward))
self.db .map(std::result::Result::unwrap)
.rocks .map(|(k, v)| (Vec::from(k), Vec::from(v)))
.iterator_cf( .take_while(move |(k, _)| k.starts_with(&prefix)),
&self.cf(), )
rocksdb::IteratorMode::From(&prefix, rocksdb::Direction::Forward), }
)
.map(std::result::Result::unwrap)
.map(|(k, v)| (Vec::from(k), Vec::from(v)))
.take_while(move |(k, _)| k.starts_with(&prefix)),
)
}
fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> { fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
self.watchers.watch(prefix) self.watchers.watch(prefix)
} }
} }

View file

@ -1,340 +1,305 @@
use super::{watchers::Watchers, KeyValueDatabaseEngine, KvTree}; use std::{
use crate::{database::Config, Result}; cell::RefCell,
future::Future,
path::{Path, PathBuf},
pin::Pin,
sync::Arc,
};
use parking_lot::{Mutex, MutexGuard}; use parking_lot::{Mutex, MutexGuard};
use rusqlite::{Connection, DatabaseName::Main, OptionalExtension}; use rusqlite::{Connection, DatabaseName::Main, OptionalExtension};
use std::{
cell::RefCell,
future::Future,
path::{Path, PathBuf},
pin::Pin,
sync::Arc,
};
use thread_local::ThreadLocal; use thread_local::ThreadLocal;
use tracing::debug; use tracing::debug;
use super::{watchers::Watchers, KeyValueDatabaseEngine, KvTree};
use crate::{database::Config, Result};
thread_local! { thread_local! {
static READ_CONNECTION: RefCell<Option<&'static Connection>> = RefCell::new(None); static READ_CONNECTION: RefCell<Option<&'static Connection>> = RefCell::new(None);
static READ_CONNECTION_ITERATOR: RefCell<Option<&'static Connection>> = RefCell::new(None); static READ_CONNECTION_ITERATOR: RefCell<Option<&'static Connection>> = RefCell::new(None);
} }
struct PreparedStatementIterator<'a> { struct PreparedStatementIterator<'a> {
pub iterator: Box<dyn Iterator<Item = TupleOfBytes> + 'a>, pub iterator: Box<dyn Iterator<Item = TupleOfBytes> + 'a>,
pub _statement_ref: NonAliasingBox<rusqlite::Statement<'a>>, pub _statement_ref: NonAliasingBox<rusqlite::Statement<'a>>,
} }
impl Iterator for PreparedStatementIterator<'_> { impl Iterator for PreparedStatementIterator<'_> {
type Item = TupleOfBytes; type Item = TupleOfBytes;
fn next(&mut self) -> Option<Self::Item> { fn next(&mut self) -> Option<Self::Item> { self.iterator.next() }
self.iterator.next()
}
} }
struct NonAliasingBox<T>(*mut T); struct NonAliasingBox<T>(*mut T);
impl<T> Drop for NonAliasingBox<T> { impl<T> Drop for NonAliasingBox<T> {
fn drop(&mut self) { fn drop(&mut self) {
unsafe { unsafe {
let _ = Box::from_raw(self.0); let _ = Box::from_raw(self.0);
}; };
} }
} }
pub struct Engine { pub struct Engine {
writer: Mutex<Connection>, writer: Mutex<Connection>,
read_conn_tls: ThreadLocal<Connection>, read_conn_tls: ThreadLocal<Connection>,
read_iterator_conn_tls: ThreadLocal<Connection>, read_iterator_conn_tls: ThreadLocal<Connection>,
path: PathBuf, path: PathBuf,
cache_size_per_thread: u32, cache_size_per_thread: u32,
} }
impl Engine { impl Engine {
fn prepare_conn(path: &Path, cache_size_kb: u32) -> Result<Connection> { fn prepare_conn(path: &Path, cache_size_kb: u32) -> Result<Connection> {
let conn = Connection::open(path)?; let conn = Connection::open(path)?;
conn.pragma_update(Some(Main), "page_size", 2048)?; conn.pragma_update(Some(Main), "page_size", 2048)?;
conn.pragma_update(Some(Main), "journal_mode", "WAL")?; conn.pragma_update(Some(Main), "journal_mode", "WAL")?;
conn.pragma_update(Some(Main), "synchronous", "NORMAL")?; conn.pragma_update(Some(Main), "synchronous", "NORMAL")?;
conn.pragma_update(Some(Main), "cache_size", -i64::from(cache_size_kb))?; conn.pragma_update(Some(Main), "cache_size", -i64::from(cache_size_kb))?;
conn.pragma_update(Some(Main), "wal_autocheckpoint", 0)?; conn.pragma_update(Some(Main), "wal_autocheckpoint", 0)?;
Ok(conn) Ok(conn)
} }
fn write_lock(&self) -> MutexGuard<'_, Connection> { fn write_lock(&self) -> MutexGuard<'_, Connection> { self.writer.lock() }
self.writer.lock()
}
fn read_lock(&self) -> &Connection { fn read_lock(&self) -> &Connection {
self.read_conn_tls self.read_conn_tls.get_or(|| Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap())
.get_or(|| Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap()) }
}
fn read_lock_iterator(&self) -> &Connection { fn read_lock_iterator(&self) -> &Connection {
self.read_iterator_conn_tls self.read_iterator_conn_tls.get_or(|| Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap())
.get_or(|| Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap()) }
}
pub fn flush_wal(self: &Arc<Self>) -> Result<()> { pub fn flush_wal(self: &Arc<Self>) -> Result<()> {
self.write_lock() self.write_lock().pragma_update(Some(Main), "wal_checkpoint", "RESTART")?;
.pragma_update(Some(Main), "wal_checkpoint", "RESTART")?; Ok(())
Ok(()) }
}
} }
impl KeyValueDatabaseEngine for Arc<Engine> { impl KeyValueDatabaseEngine for Arc<Engine> {
fn open(config: &Config) -> Result<Self> { fn open(config: &Config) -> Result<Self> {
let path = Path::new(&config.database_path).join("conduit.db"); let path = Path::new(&config.database_path).join("conduit.db");
// calculates cache-size per permanent connection // calculates cache-size per permanent connection
// 1. convert MB to KiB // 1. convert MB to KiB
// 2. divide by permanent connections + permanent iter connections + write connection // 2. divide by permanent connections + permanent iter connections + write
// 3. round down to nearest integer // connection
let cache_size_per_thread: u32 = ((config.db_cache_capacity_mb * 1024.0) // 3. round down to nearest integer
/ ((num_cpus::get().max(1) * 2) + 1) as f64) let cache_size_per_thread: u32 =
as u32; ((config.db_cache_capacity_mb * 1024.0) / ((num_cpus::get().max(1) * 2) + 1) as f64) as u32;
let writer = Mutex::new(Engine::prepare_conn(&path, cache_size_per_thread)?); let writer = Mutex::new(Engine::prepare_conn(&path, cache_size_per_thread)?);
let arc = Arc::new(Engine { let arc = Arc::new(Engine {
writer, writer,
read_conn_tls: ThreadLocal::new(), read_conn_tls: ThreadLocal::new(),
read_iterator_conn_tls: ThreadLocal::new(), read_iterator_conn_tls: ThreadLocal::new(),
path, path,
cache_size_per_thread, cache_size_per_thread,
}); });
Ok(arc) Ok(arc)
} }
fn open_tree(&self, name: &str) -> Result<Arc<dyn KvTree>> { fn open_tree(&self, name: &str) -> Result<Arc<dyn KvTree>> {
self.write_lock().execute(&format!("CREATE TABLE IF NOT EXISTS {name} ( \"key\" BLOB PRIMARY KEY, \"value\" BLOB NOT NULL )"), [])?; self.write_lock().execute(
&format!("CREATE TABLE IF NOT EXISTS {name} ( \"key\" BLOB PRIMARY KEY, \"value\" BLOB NOT NULL )"),
[],
)?;
Ok(Arc::new(SqliteTable { Ok(Arc::new(SqliteTable {
engine: Arc::clone(self), engine: Arc::clone(self),
name: name.to_owned(), name: name.to_owned(),
watchers: Watchers::default(), watchers: Watchers::default(),
})) }))
} }
fn flush(&self) -> Result<()> { fn flush(&self) -> Result<()> {
// we enabled PRAGMA synchronous=normal, so this should not be necessary // we enabled PRAGMA synchronous=normal, so this should not be necessary
Ok(()) Ok(())
} }
fn cleanup(&self) -> Result<()> { fn cleanup(&self) -> Result<()> { self.flush_wal() }
self.flush_wal()
}
} }
pub struct SqliteTable { pub struct SqliteTable {
engine: Arc<Engine>, engine: Arc<Engine>,
name: String, name: String,
watchers: Watchers, watchers: Watchers,
} }
type TupleOfBytes = (Vec<u8>, Vec<u8>); type TupleOfBytes = (Vec<u8>, Vec<u8>);
impl SqliteTable { impl SqliteTable {
fn get_with_guard(&self, guard: &Connection, key: &[u8]) -> Result<Option<Vec<u8>>> { fn get_with_guard(&self, guard: &Connection, key: &[u8]) -> Result<Option<Vec<u8>>> {
Ok(guard Ok(guard
.prepare(format!("SELECT value FROM {} WHERE key = ?", self.name).as_str())? .prepare(format!("SELECT value FROM {} WHERE key = ?", self.name).as_str())?
.query_row([key], |row| row.get(0)) .query_row([key], |row| row.get(0))
.optional()?) .optional()?)
} }
fn insert_with_guard(&self, guard: &Connection, key: &[u8], value: &[u8]) -> Result<()> { fn insert_with_guard(&self, guard: &Connection, key: &[u8], value: &[u8]) -> Result<()> {
guard.execute( guard.execute(
format!( format!("INSERT OR REPLACE INTO {} (key, value) VALUES (?, ?)", self.name).as_str(),
"INSERT OR REPLACE INTO {} (key, value) VALUES (?, ?)", [key, value],
self.name )?;
) Ok(())
.as_str(), }
[key, value],
)?;
Ok(())
}
pub fn iter_with_guard<'a>( pub fn iter_with_guard<'a>(&'a self, guard: &'a Connection) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> {
&'a self, let statement = Box::leak(Box::new(
guard: &'a Connection, guard.prepare(&format!("SELECT key, value FROM {} ORDER BY key ASC", &self.name)).unwrap(),
) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> { ));
let statement = Box::leak(Box::new(
guard
.prepare(&format!(
"SELECT key, value FROM {} ORDER BY key ASC",
&self.name
))
.unwrap(),
));
let statement_ref = NonAliasingBox(statement); let statement_ref = NonAliasingBox(statement);
//let name = self.name.clone(); //let name = self.name.clone();
let iterator = Box::new( let iterator = Box::new(
statement statement.query_map([], |row| Ok((row.get_unwrap(0), row.get_unwrap(1)))).unwrap().map(move |r| r.unwrap()),
.query_map([], |row| Ok((row.get_unwrap(0), row.get_unwrap(1)))) );
.unwrap()
.map(move |r| r.unwrap()),
);
Box::new(PreparedStatementIterator { Box::new(PreparedStatementIterator {
iterator, iterator,
_statement_ref: statement_ref, _statement_ref: statement_ref,
}) })
} }
} }
impl KvTree for SqliteTable { impl KvTree for SqliteTable {
fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> { fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> { self.get_with_guard(self.engine.read_lock(), key) }
self.get_with_guard(self.engine.read_lock(), key)
}
fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> { fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> {
let guard = self.engine.write_lock(); let guard = self.engine.write_lock();
self.insert_with_guard(&guard, key, value)?; self.insert_with_guard(&guard, key, value)?;
drop(guard); drop(guard);
self.watchers.wake(key); self.watchers.wake(key);
Ok(()) Ok(())
} }
fn insert_batch<'a>(&self, iter: &mut dyn Iterator<Item = (Vec<u8>, Vec<u8>)>) -> Result<()> { fn insert_batch<'a>(&self, iter: &mut dyn Iterator<Item = (Vec<u8>, Vec<u8>)>) -> Result<()> {
let guard = self.engine.write_lock(); let guard = self.engine.write_lock();
guard.execute("BEGIN", [])?; guard.execute("BEGIN", [])?;
for (key, value) in iter { for (key, value) in iter {
self.insert_with_guard(&guard, &key, &value)?; self.insert_with_guard(&guard, &key, &value)?;
} }
guard.execute("COMMIT", [])?; guard.execute("COMMIT", [])?;
drop(guard); drop(guard);
Ok(()) Ok(())
} }
fn increment_batch<'a>(&self, iter: &mut dyn Iterator<Item = Vec<u8>>) -> Result<()> { fn increment_batch<'a>(&self, iter: &mut dyn Iterator<Item = Vec<u8>>) -> Result<()> {
let guard = self.engine.write_lock(); let guard = self.engine.write_lock();
guard.execute("BEGIN", [])?; guard.execute("BEGIN", [])?;
for key in iter { for key in iter {
let old = self.get_with_guard(&guard, &key)?; let old = self.get_with_guard(&guard, &key)?;
let new = crate::utils::increment(old.as_deref()) let new = crate::utils::increment(old.as_deref()).expect("utils::increment always returns Some");
.expect("utils::increment always returns Some"); self.insert_with_guard(&guard, &key, &new)?;
self.insert_with_guard(&guard, &key, &new)?; }
} guard.execute("COMMIT", [])?;
guard.execute("COMMIT", [])?;
drop(guard); drop(guard);
Ok(()) Ok(())
} }
fn remove(&self, key: &[u8]) -> Result<()> { fn remove(&self, key: &[u8]) -> Result<()> {
let guard = self.engine.write_lock(); let guard = self.engine.write_lock();
guard.execute( guard.execute(format!("DELETE FROM {} WHERE key = ?", self.name).as_str(), [key])?;
format!("DELETE FROM {} WHERE key = ?", self.name).as_str(),
[key],
)?;
Ok(()) Ok(())
} }
fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> { fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> {
let guard = self.engine.read_lock_iterator(); let guard = self.engine.read_lock_iterator();
self.iter_with_guard(guard) self.iter_with_guard(guard)
} }
fn iter_from<'a>( fn iter_from<'a>(&'a self, from: &[u8], backwards: bool) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> {
&'a self, let guard = self.engine.read_lock_iterator();
from: &[u8], let from = from.to_vec(); // TODO change interface?
backwards: bool,
) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> {
let guard = self.engine.read_lock_iterator();
let from = from.to_vec(); // TODO change interface?
//let name = self.name.clone(); //let name = self.name.clone();
if backwards { if backwards {
let statement = Box::leak(Box::new( let statement = Box::leak(Box::new(
guard guard
.prepare(&format!( .prepare(&format!(
"SELECT key, value FROM {} WHERE key <= ? ORDER BY key DESC", "SELECT key, value FROM {} WHERE key <= ? ORDER BY key DESC",
&self.name &self.name
)) ))
.unwrap(), .unwrap(),
)); ));
let statement_ref = NonAliasingBox(statement); let statement_ref = NonAliasingBox(statement);
let iterator = Box::new( let iterator = Box::new(
statement statement
.query_map([from], |row| Ok((row.get_unwrap(0), row.get_unwrap(1)))) .query_map([from], |row| Ok((row.get_unwrap(0), row.get_unwrap(1))))
.unwrap() .unwrap()
.map(move |r| r.unwrap()), .map(move |r| r.unwrap()),
); );
Box::new(PreparedStatementIterator { Box::new(PreparedStatementIterator {
iterator, iterator,
_statement_ref: statement_ref, _statement_ref: statement_ref,
}) })
} else { } else {
let statement = Box::leak(Box::new( let statement = Box::leak(Box::new(
guard guard
.prepare(&format!( .prepare(&format!(
"SELECT key, value FROM {} WHERE key >= ? ORDER BY key ASC", "SELECT key, value FROM {} WHERE key >= ? ORDER BY key ASC",
&self.name &self.name
)) ))
.unwrap(), .unwrap(),
)); ));
let statement_ref = NonAliasingBox(statement); let statement_ref = NonAliasingBox(statement);
let iterator = Box::new( let iterator = Box::new(
statement statement
.query_map([from], |row| Ok((row.get_unwrap(0), row.get_unwrap(1)))) .query_map([from], |row| Ok((row.get_unwrap(0), row.get_unwrap(1))))
.unwrap() .unwrap()
.map(move |r| r.unwrap()), .map(move |r| r.unwrap()),
); );
Box::new(PreparedStatementIterator { Box::new(PreparedStatementIterator {
iterator, iterator,
_statement_ref: statement_ref, _statement_ref: statement_ref,
}) })
} }
} }
fn increment(&self, key: &[u8]) -> Result<Vec<u8>> { fn increment(&self, key: &[u8]) -> Result<Vec<u8>> {
let guard = self.engine.write_lock(); let guard = self.engine.write_lock();
let old = self.get_with_guard(&guard, key)?; let old = self.get_with_guard(&guard, key)?;
let new = let new = crate::utils::increment(old.as_deref()).expect("utils::increment always returns Some");
crate::utils::increment(old.as_deref()).expect("utils::increment always returns Some");
self.insert_with_guard(&guard, key, &new)?; self.insert_with_guard(&guard, key, &new)?;
Ok(new) Ok(new)
} }
fn scan_prefix<'a>(&'a self, prefix: Vec<u8>) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> { fn scan_prefix<'a>(&'a self, prefix: Vec<u8>) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> {
Box::new( Box::new(self.iter_from(&prefix, false).take_while(move |(key, _)| key.starts_with(&prefix)))
self.iter_from(&prefix, false) }
.take_while(move |(key, _)| key.starts_with(&prefix)),
)
}
fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> { fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
self.watchers.watch(prefix) self.watchers.watch(prefix)
} }
fn clear(&self) -> Result<()> { fn clear(&self) -> Result<()> {
debug!("clear: running"); debug!("clear: running");
self.engine self.engine.write_lock().execute(format!("DELETE FROM {}", self.name).as_str(), [])?;
.write_lock() debug!("clear: ran");
.execute(format!("DELETE FROM {}", self.name).as_str(), [])?; Ok(())
debug!("clear: ran"); }
Ok(())
}
} }

View file

@ -1,56 +1,55 @@
use std::{ use std::{
collections::{hash_map, HashMap}, collections::{hash_map, HashMap},
future::Future, future::Future,
pin::Pin, pin::Pin,
sync::RwLock, sync::RwLock,
}; };
use tokio::sync::watch; use tokio::sync::watch;
type Watcher = RwLock<HashMap<Vec<u8>, (watch::Sender<()>, watch::Receiver<()>)>>; type Watcher = RwLock<HashMap<Vec<u8>, (watch::Sender<()>, watch::Receiver<()>)>>;
#[derive(Default)] #[derive(Default)]
pub(super) struct Watchers { pub(super) struct Watchers {
watchers: Watcher, watchers: Watcher,
} }
impl Watchers { impl Watchers {
pub(super) fn watch<'a>( pub(super) fn watch<'a>(&'a self, prefix: &[u8]) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
&'a self, let mut rx = match self.watchers.write().unwrap().entry(prefix.to_vec()) {
prefix: &[u8], hash_map::Entry::Occupied(o) => o.get().1.clone(),
) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> { hash_map::Entry::Vacant(v) => {
let mut rx = match self.watchers.write().unwrap().entry(prefix.to_vec()) { let (tx, rx) = tokio::sync::watch::channel(());
hash_map::Entry::Occupied(o) => o.get().1.clone(), v.insert((tx, rx.clone()));
hash_map::Entry::Vacant(v) => { rx
let (tx, rx) = tokio::sync::watch::channel(()); },
v.insert((tx, rx.clone())); };
rx
}
};
Box::pin(async move { Box::pin(async move {
// Tx is never destroyed // Tx is never destroyed
rx.changed().await.unwrap(); rx.changed().await.unwrap();
}) })
} }
pub(super) fn wake(&self, key: &[u8]) {
let watchers = self.watchers.read().unwrap();
let mut triggered = Vec::new();
for length in 0..=key.len() { pub(super) fn wake(&self, key: &[u8]) {
if watchers.contains_key(&key[..length]) { let watchers = self.watchers.read().unwrap();
triggered.push(&key[..length]); let mut triggered = Vec::new();
}
}
drop(watchers); for length in 0..=key.len() {
if watchers.contains_key(&key[..length]) {
triggered.push(&key[..length]);
}
}
if !triggered.is_empty() { drop(watchers);
let mut watchers = self.watchers.write().unwrap();
for prefix in triggered { if !triggered.is_empty() {
if let Some(tx) = watchers.remove(prefix) { let mut watchers = self.watchers.write().unwrap();
let _ = tx.0.send(()); for prefix in triggered {
} if let Some(tx) = watchers.remove(prefix) {
} let _ = tx.0.send(());
}; }
} }
};
}
} }

View file

@ -1,148 +1,120 @@
use std::collections::HashMap; use std::collections::HashMap;
use ruma::{ use ruma::{
api::client::error::ErrorKind, api::client::error::ErrorKind,
events::{AnyEphemeralRoomEvent, RoomAccountDataEventType}, events::{AnyEphemeralRoomEvent, RoomAccountDataEventType},
serde::Raw, serde::Raw,
RoomId, UserId, RoomId, UserId,
}; };
use tracing::warn; use tracing::warn;
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
impl service::account_data::Data for KeyValueDatabase { impl service::account_data::Data for KeyValueDatabase {
/// Places one event in the account data of the user and removes the previous entry. /// Places one event in the account data of the user and removes the
#[tracing::instrument(skip(self, room_id, user_id, event_type, data))] /// previous entry.
fn update( #[tracing::instrument(skip(self, room_id, user_id, event_type, data))]
&self, fn update(
room_id: Option<&RoomId>, &self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType,
user_id: &UserId, data: &serde_json::Value,
event_type: RoomAccountDataEventType, ) -> Result<()> {
data: &serde_json::Value, let mut prefix = room_id.map(std::string::ToString::to_string).unwrap_or_default().as_bytes().to_vec();
) -> Result<()> { prefix.push(0xFF);
let mut prefix = room_id prefix.extend_from_slice(user_id.as_bytes());
.map(std::string::ToString::to_string) prefix.push(0xFF);
.unwrap_or_default()
.as_bytes()
.to_vec();
prefix.push(0xff);
prefix.extend_from_slice(user_id.as_bytes());
prefix.push(0xff);
let mut roomuserdataid = prefix.clone(); let mut roomuserdataid = prefix.clone();
roomuserdataid.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); roomuserdataid.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
roomuserdataid.push(0xff); roomuserdataid.push(0xFF);
roomuserdataid.extend_from_slice(event_type.to_string().as_bytes()); roomuserdataid.extend_from_slice(event_type.to_string().as_bytes());
let mut key = prefix; let mut key = prefix;
key.extend_from_slice(event_type.to_string().as_bytes()); key.extend_from_slice(event_type.to_string().as_bytes());
if data.get("type").is_none() || data.get("content").is_none() { if data.get("type").is_none() || data.get("content").is_none() {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::InvalidParam, ErrorKind::InvalidParam,
"Account data doesn't have all required fields.", "Account data doesn't have all required fields.",
)); ));
} }
self.roomuserdataid_accountdata.insert( self.roomuserdataid_accountdata.insert(
&roomuserdataid, &roomuserdataid,
&serde_json::to_vec(&data).expect("to_vec always works on json values"), &serde_json::to_vec(&data).expect("to_vec always works on json values"),
)?; )?;
let prev = self.roomusertype_roomuserdataid.get(&key)?; let prev = self.roomusertype_roomuserdataid.get(&key)?;
self.roomusertype_roomuserdataid self.roomusertype_roomuserdataid.insert(&key, &roomuserdataid)?;
.insert(&key, &roomuserdataid)?;
// Remove old entry // Remove old entry
if let Some(prev) = prev { if let Some(prev) = prev {
self.roomuserdataid_accountdata.remove(&prev)?; self.roomuserdataid_accountdata.remove(&prev)?;
} }
Ok(()) Ok(())
} }
/// Searches the account data for a specific kind. /// Searches the account data for a specific kind.
#[tracing::instrument(skip(self, room_id, user_id, kind))] #[tracing::instrument(skip(self, room_id, user_id, kind))]
fn get( fn get(
&self, &self, room_id: Option<&RoomId>, user_id: &UserId, kind: RoomAccountDataEventType,
room_id: Option<&RoomId>, ) -> Result<Option<Box<serde_json::value::RawValue>>> {
user_id: &UserId, let mut key = room_id.map(std::string::ToString::to_string).unwrap_or_default().as_bytes().to_vec();
kind: RoomAccountDataEventType, key.push(0xFF);
) -> Result<Option<Box<serde_json::value::RawValue>>> { key.extend_from_slice(user_id.as_bytes());
let mut key = room_id key.push(0xFF);
.map(std::string::ToString::to_string) key.extend_from_slice(kind.to_string().as_bytes());
.unwrap_or_default()
.as_bytes()
.to_vec();
key.push(0xff);
key.extend_from_slice(user_id.as_bytes());
key.push(0xff);
key.extend_from_slice(kind.to_string().as_bytes());
self.roomusertype_roomuserdataid self.roomusertype_roomuserdataid
.get(&key)? .get(&key)?
.and_then(|roomuserdataid| { .and_then(|roomuserdataid| self.roomuserdataid_accountdata.get(&roomuserdataid).transpose())
self.roomuserdataid_accountdata .transpose()?
.get(&roomuserdataid) .map(|data| serde_json::from_slice(&data).map_err(|_| Error::bad_database("could not deserialize")))
.transpose() .transpose()
}) }
.transpose()?
.map(|data| {
serde_json::from_slice(&data)
.map_err(|_| Error::bad_database("could not deserialize"))
})
.transpose()
}
/// Returns all changes to the account data that happened after `since`. /// Returns all changes to the account data that happened after `since`.
#[tracing::instrument(skip(self, room_id, user_id, since))] #[tracing::instrument(skip(self, room_id, user_id, since))]
fn changes_since( fn changes_since(
&self, &self, room_id: Option<&RoomId>, user_id: &UserId, since: u64,
room_id: Option<&RoomId>, ) -> Result<HashMap<RoomAccountDataEventType, Raw<AnyEphemeralRoomEvent>>> {
user_id: &UserId, let mut userdata = HashMap::new();
since: u64,
) -> Result<HashMap<RoomAccountDataEventType, Raw<AnyEphemeralRoomEvent>>> {
let mut userdata = HashMap::new();
let mut prefix = room_id let mut prefix = room_id.map(std::string::ToString::to_string).unwrap_or_default().as_bytes().to_vec();
.map(std::string::ToString::to_string) prefix.push(0xFF);
.unwrap_or_default() prefix.extend_from_slice(user_id.as_bytes());
.as_bytes() prefix.push(0xFF);
.to_vec();
prefix.push(0xff);
prefix.extend_from_slice(user_id.as_bytes());
prefix.push(0xff);
// Skip the data that's exactly at since, because we sent that last time // Skip the data that's exactly at since, because we sent that last time
let mut first_possible = prefix.clone(); let mut first_possible = prefix.clone();
first_possible.extend_from_slice(&(since + 1).to_be_bytes()); first_possible.extend_from_slice(&(since + 1).to_be_bytes());
for r in self for r in self
.roomuserdataid_accountdata .roomuserdataid_accountdata
.iter_from(&first_possible, false) .iter_from(&first_possible, false)
.take_while(move |(k, _)| k.starts_with(&prefix)) .take_while(move |(k, _)| k.starts_with(&prefix))
.map(|(k, v)| { .map(|(k, v)| {
Ok::<_, Error>(( Ok::<_, Error>((
RoomAccountDataEventType::from( RoomAccountDataEventType::from(
utils::string_from_bytes(k.rsplit(|&b| b == 0xff).next().ok_or_else( utils::string_from_bytes(
|| Error::bad_database("RoomUserData ID in db is invalid."), k.rsplit(|&b| b == 0xFF)
)?) .next()
.map_err(|e| { .ok_or_else(|| Error::bad_database("RoomUserData ID in db is invalid."))?,
warn!("RoomUserData ID in database is invalid: {}", e); )
Error::bad_database("RoomUserData ID in db is invalid.") .map_err(|e| {
})?, warn!("RoomUserData ID in database is invalid: {}", e);
), Error::bad_database("RoomUserData ID in db is invalid.")
serde_json::from_slice::<Raw<AnyEphemeralRoomEvent>>(&v).map_err(|_| { })?,
Error::bad_database("Database contains invalid account data.") ),
})?, serde_json::from_slice::<Raw<AnyEphemeralRoomEvent>>(&v)
)) .map_err(|_| Error::bad_database("Database contains invalid account data."))?,
}) ))
{ }) {
let (kind, data) = r?; let (kind, data) = r?;
userdata.insert(kind, data); userdata.insert(kind, data);
} }
Ok(userdata) Ok(userdata)
} }
} }

View file

@ -3,78 +3,58 @@ use ruma::api::appservice::Registration;
use crate::{database::KeyValueDatabase, service, utils, Error, Result}; use crate::{database::KeyValueDatabase, service, utils, Error, Result};
impl service::appservice::Data for KeyValueDatabase { impl service::appservice::Data for KeyValueDatabase {
/// Registers an appservice and returns the ID to the caller /// Registers an appservice and returns the ID to the caller
fn register_appservice(&self, yaml: Registration) -> Result<String> { fn register_appservice(&self, yaml: Registration) -> Result<String> {
let id = yaml.id.as_str(); let id = yaml.id.as_str();
self.id_appserviceregistrations.insert( self.id_appserviceregistrations.insert(id.as_bytes(), serde_yaml::to_string(&yaml).unwrap().as_bytes())?;
id.as_bytes(), self.cached_registrations.write().unwrap().insert(id.to_owned(), yaml.clone());
serde_yaml::to_string(&yaml).unwrap().as_bytes(),
)?;
self.cached_registrations
.write()
.unwrap()
.insert(id.to_owned(), yaml.clone());
Ok(id.to_owned()) Ok(id.to_owned())
} }
/// Remove an appservice registration /// Remove an appservice registration
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `service_name` - the name you send to register the service previously /// * `service_name` - the name you send to register the service previously
fn unregister_appservice(&self, service_name: &str) -> Result<()> { fn unregister_appservice(&self, service_name: &str) -> Result<()> {
self.id_appserviceregistrations self.id_appserviceregistrations.remove(service_name.as_bytes())?;
.remove(service_name.as_bytes())?; self.cached_registrations.write().unwrap().remove(service_name);
self.cached_registrations Ok(())
.write() }
.unwrap()
.remove(service_name);
Ok(())
}
fn get_registration(&self, id: &str) -> Result<Option<Registration>> { fn get_registration(&self, id: &str) -> Result<Option<Registration>> {
self.cached_registrations self.cached_registrations.read().unwrap().get(id).map_or_else(
.read() || {
.unwrap() self.id_appserviceregistrations
.get(id) .get(id.as_bytes())?
.map_or_else( .map(|bytes| {
|| { serde_yaml::from_slice(&bytes).map_err(|_| {
self.id_appserviceregistrations Error::bad_database("Invalid registration bytes in id_appserviceregistrations.")
.get(id.as_bytes())? })
.map(|bytes| { })
serde_yaml::from_slice(&bytes).map_err(|_| { .transpose()
Error::bad_database( },
"Invalid registration bytes in id_appserviceregistrations.", |r| Ok(Some(r.clone())),
) )
}) }
})
.transpose()
},
|r| Ok(Some(r.clone())),
)
}
fn iter_ids<'a>(&'a self) -> Result<Box<dyn Iterator<Item = Result<String>> + 'a>> { fn iter_ids<'a>(&'a self) -> Result<Box<dyn Iterator<Item = Result<String>> + 'a>> {
Ok(Box::new(self.id_appserviceregistrations.iter().map( Ok(Box::new(self.id_appserviceregistrations.iter().map(|(id, _)| {
|(id, _)| { utils::string_from_bytes(&id)
utils::string_from_bytes(&id).map_err(|_| { .map_err(|_| Error::bad_database("Invalid id bytes in id_appserviceregistrations."))
Error::bad_database("Invalid id bytes in id_appserviceregistrations.") })))
}) }
},
)))
}
fn all(&self) -> Result<Vec<(String, Registration)>> { fn all(&self) -> Result<Vec<(String, Registration)>> {
self.iter_ids()? self.iter_ids()?
.filter_map(std::result::Result::ok) .filter_map(std::result::Result::ok)
.map(move |id| { .map(move |id| {
Ok(( Ok((
id.clone(), id.clone(),
self.get_registration(&id)? self.get_registration(&id)?.expect("iter_ids only returns appservices that exist"),
.expect("iter_ids only returns appservices that exist"), ))
)) })
}) .collect()
.collect() }
}
} }

View file

@ -4,9 +4,9 @@ use async_trait::async_trait;
use futures_util::{stream::FuturesUnordered, StreamExt}; use futures_util::{stream::FuturesUnordered, StreamExt};
use lru_cache::LruCache; use lru_cache::LruCache;
use ruma::{ use ruma::{
api::federation::discovery::{ServerSigningKeys, VerifyKey}, api::federation::discovery::{ServerSigningKeys, VerifyKey},
signatures::Ed25519KeyPair, signatures::Ed25519KeyPair,
DeviceId, MilliSecondsSinceUnixEpoch, OwnedServerSigningKeyId, ServerName, UserId, DeviceId, MilliSecondsSinceUnixEpoch, OwnedServerSigningKeyId, ServerName, UserId,
}; };
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
@ -16,139 +16,118 @@ const LAST_CHECK_FOR_UPDATES_COUNT: &[u8] = b"u";
#[async_trait] #[async_trait]
impl service::globals::Data for KeyValueDatabase { impl service::globals::Data for KeyValueDatabase {
fn next_count(&self) -> Result<u64> { fn next_count(&self) -> Result<u64> {
utils::u64_from_bytes(&self.global.increment(COUNTER)?) utils::u64_from_bytes(&self.global.increment(COUNTER)?)
.map_err(|_| Error::bad_database("Count has invalid bytes.")) .map_err(|_| Error::bad_database("Count has invalid bytes."))
} }
fn current_count(&self) -> Result<u64> { fn current_count(&self) -> Result<u64> {
self.global.get(COUNTER)?.map_or(Ok(0_u64), |bytes| { self.global.get(COUNTER)?.map_or(Ok(0_u64), |bytes| {
utils::u64_from_bytes(&bytes) utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Count has invalid bytes."))
.map_err(|_| Error::bad_database("Count has invalid bytes.")) })
}) }
}
fn last_check_for_updates_id(&self) -> Result<u64> { fn last_check_for_updates_id(&self) -> Result<u64> {
self.global self.global.get(LAST_CHECK_FOR_UPDATES_COUNT)?.map_or(Ok(0_u64), |bytes| {
.get(LAST_CHECK_FOR_UPDATES_COUNT)? utils::u64_from_bytes(&bytes)
.map_or(Ok(0_u64), |bytes| { .map_err(|_| Error::bad_database("last check for updates count has invalid bytes."))
utils::u64_from_bytes(&bytes).map_err(|_| { })
Error::bad_database("last check for updates count has invalid bytes.") }
})
})
}
fn update_check_for_updates_id(&self, id: u64) -> Result<()> { fn update_check_for_updates_id(&self, id: u64) -> Result<()> {
self.global self.global.insert(LAST_CHECK_FOR_UPDATES_COUNT, &id.to_be_bytes())?;
.insert(LAST_CHECK_FOR_UPDATES_COUNT, &id.to_be_bytes())?;
Ok(()) Ok(())
} }
async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> {
let userid_bytes = user_id.as_bytes().to_vec(); let userid_bytes = user_id.as_bytes().to_vec();
let mut userid_prefix = userid_bytes.clone(); let mut userid_prefix = userid_bytes.clone();
userid_prefix.push(0xff); userid_prefix.push(0xFF);
let mut userdeviceid_prefix = userid_prefix.clone(); let mut userdeviceid_prefix = userid_prefix.clone();
userdeviceid_prefix.extend_from_slice(device_id.as_bytes()); userdeviceid_prefix.extend_from_slice(device_id.as_bytes());
userdeviceid_prefix.push(0xff); userdeviceid_prefix.push(0xFF);
let mut futures = FuturesUnordered::new(); let mut futures = FuturesUnordered::new();
// Return when *any* user changed his key // Return when *any* user changed his key
// TODO: only send for user they share a room with // TODO: only send for user they share a room with
futures.push(self.todeviceid_events.watch_prefix(&userdeviceid_prefix)); futures.push(self.todeviceid_events.watch_prefix(&userdeviceid_prefix));
futures.push(self.userroomid_joined.watch_prefix(&userid_prefix)); futures.push(self.userroomid_joined.watch_prefix(&userid_prefix));
futures.push(self.userroomid_invitestate.watch_prefix(&userid_prefix)); futures.push(self.userroomid_invitestate.watch_prefix(&userid_prefix));
futures.push(self.userroomid_leftstate.watch_prefix(&userid_prefix)); futures.push(self.userroomid_leftstate.watch_prefix(&userid_prefix));
futures.push( futures.push(self.userroomid_notificationcount.watch_prefix(&userid_prefix));
self.userroomid_notificationcount futures.push(self.userroomid_highlightcount.watch_prefix(&userid_prefix));
.watch_prefix(&userid_prefix),
);
futures.push(self.userroomid_highlightcount.watch_prefix(&userid_prefix));
// Events for rooms we are in // Events for rooms we are in
for room_id in services() for room_id in services().rooms.state_cache.rooms_joined(user_id).filter_map(std::result::Result::ok) {
.rooms let short_roomid = services()
.state_cache .rooms
.rooms_joined(user_id) .short
.filter_map(std::result::Result::ok) .get_shortroomid(&room_id)
{ .ok()
let short_roomid = services() .flatten()
.rooms .expect("room exists")
.short .to_be_bytes()
.get_shortroomid(&room_id) .to_vec();
.ok()
.flatten()
.expect("room exists")
.to_be_bytes()
.to_vec();
let roomid_bytes = room_id.as_bytes().to_vec(); let roomid_bytes = room_id.as_bytes().to_vec();
let mut roomid_prefix = roomid_bytes.clone(); let mut roomid_prefix = roomid_bytes.clone();
roomid_prefix.push(0xff); roomid_prefix.push(0xFF);
// PDUs // PDUs
futures.push(self.pduid_pdu.watch_prefix(&short_roomid)); futures.push(self.pduid_pdu.watch_prefix(&short_roomid));
// EDUs // EDUs
futures.push(self.roomid_lasttypingupdate.watch_prefix(&roomid_bytes)); futures.push(self.roomid_lasttypingupdate.watch_prefix(&roomid_bytes));
futures.push(self.readreceiptid_readreceipt.watch_prefix(&roomid_prefix)); futures.push(self.readreceiptid_readreceipt.watch_prefix(&roomid_prefix));
// Key changes // Key changes
futures.push(self.keychangeid_userid.watch_prefix(&roomid_prefix)); futures.push(self.keychangeid_userid.watch_prefix(&roomid_prefix));
// Room account data // Room account data
let mut roomuser_prefix = roomid_prefix.clone(); let mut roomuser_prefix = roomid_prefix.clone();
roomuser_prefix.extend_from_slice(&userid_prefix); roomuser_prefix.extend_from_slice(&userid_prefix);
futures.push( futures.push(self.roomusertype_roomuserdataid.watch_prefix(&roomuser_prefix));
self.roomusertype_roomuserdataid }
.watch_prefix(&roomuser_prefix),
);
}
let mut globaluserdata_prefix = vec![0xff]; let mut globaluserdata_prefix = vec![0xFF];
globaluserdata_prefix.extend_from_slice(&userid_prefix); globaluserdata_prefix.extend_from_slice(&userid_prefix);
futures.push( futures.push(self.roomusertype_roomuserdataid.watch_prefix(&globaluserdata_prefix));
self.roomusertype_roomuserdataid
.watch_prefix(&globaluserdata_prefix),
);
// More key changes (used when user is not joined to any rooms) // More key changes (used when user is not joined to any rooms)
futures.push(self.keychangeid_userid.watch_prefix(&userid_prefix)); futures.push(self.keychangeid_userid.watch_prefix(&userid_prefix));
// One time keys // One time keys
futures.push(self.userid_lastonetimekeyupdate.watch_prefix(&userid_bytes)); futures.push(self.userid_lastonetimekeyupdate.watch_prefix(&userid_bytes));
futures.push(Box::pin(services().globals.rotate.watch())); futures.push(Box::pin(services().globals.rotate.watch()));
// Wait until one of them finds something // Wait until one of them finds something
futures.next().await; futures.next().await;
Ok(()) Ok(())
} }
fn cleanup(&self) -> Result<()> { fn cleanup(&self) -> Result<()> { self.db.cleanup() }
self.db.cleanup()
}
fn memory_usage(&self) -> String { fn memory_usage(&self) -> String {
let pdu_cache = self.pdu_cache.lock().unwrap().len(); let pdu_cache = self.pdu_cache.lock().unwrap().len();
let shorteventid_cache = self.shorteventid_cache.lock().unwrap().len(); let shorteventid_cache = self.shorteventid_cache.lock().unwrap().len();
let auth_chain_cache = self.auth_chain_cache.lock().unwrap().len(); let auth_chain_cache = self.auth_chain_cache.lock().unwrap().len();
let eventidshort_cache = self.eventidshort_cache.lock().unwrap().len(); let eventidshort_cache = self.eventidshort_cache.lock().unwrap().len();
let statekeyshort_cache = self.statekeyshort_cache.lock().unwrap().len(); let statekeyshort_cache = self.statekeyshort_cache.lock().unwrap().len();
let our_real_users_cache = self.our_real_users_cache.read().unwrap().len(); let our_real_users_cache = self.our_real_users_cache.read().unwrap().len();
let appservice_in_room_cache = self.appservice_in_room_cache.read().unwrap().len(); let appservice_in_room_cache = self.appservice_in_room_cache.read().unwrap().len();
let lasttimelinecount_cache = self.lasttimelinecount_cache.lock().unwrap().len(); let lasttimelinecount_cache = self.lasttimelinecount_cache.lock().unwrap().len();
let mut response = format!( let mut response = format!(
"\ "\
pdu_cache: {pdu_cache} pdu_cache: {pdu_cache}
shorteventid_cache: {shorteventid_cache} shorteventid_cache: {shorteventid_cache}
auth_chain_cache: {auth_chain_cache} auth_chain_cache: {auth_chain_cache}
@ -157,155 +136,137 @@ statekeyshort_cache: {statekeyshort_cache}
our_real_users_cache: {our_real_users_cache} our_real_users_cache: {our_real_users_cache}
appservice_in_room_cache: {appservice_in_room_cache} appservice_in_room_cache: {appservice_in_room_cache}
lasttimelinecount_cache: {lasttimelinecount_cache}\n" lasttimelinecount_cache: {lasttimelinecount_cache}\n"
); );
if let Ok(db_stats) = self.db.memory_usage() { if let Ok(db_stats) = self.db.memory_usage() {
response += &db_stats; response += &db_stats;
} }
response response
} }
fn clear_caches(&self, amount: u32) { fn clear_caches(&self, amount: u32) {
if amount > 0 { if amount > 0 {
let c = &mut *self.pdu_cache.lock().unwrap(); let c = &mut *self.pdu_cache.lock().unwrap();
*c = LruCache::new(c.capacity()); *c = LruCache::new(c.capacity());
} }
if amount > 1 { if amount > 1 {
let c = &mut *self.shorteventid_cache.lock().unwrap(); let c = &mut *self.shorteventid_cache.lock().unwrap();
*c = LruCache::new(c.capacity()); *c = LruCache::new(c.capacity());
} }
if amount > 2 { if amount > 2 {
let c = &mut *self.auth_chain_cache.lock().unwrap(); let c = &mut *self.auth_chain_cache.lock().unwrap();
*c = LruCache::new(c.capacity()); *c = LruCache::new(c.capacity());
} }
if amount > 3 { if amount > 3 {
let c = &mut *self.eventidshort_cache.lock().unwrap(); let c = &mut *self.eventidshort_cache.lock().unwrap();
*c = LruCache::new(c.capacity()); *c = LruCache::new(c.capacity());
} }
if amount > 4 { if amount > 4 {
let c = &mut *self.statekeyshort_cache.lock().unwrap(); let c = &mut *self.statekeyshort_cache.lock().unwrap();
*c = LruCache::new(c.capacity()); *c = LruCache::new(c.capacity());
} }
if amount > 5 { if amount > 5 {
let c = &mut *self.our_real_users_cache.write().unwrap(); let c = &mut *self.our_real_users_cache.write().unwrap();
*c = HashMap::new(); *c = HashMap::new();
} }
if amount > 6 { if amount > 6 {
let c = &mut *self.appservice_in_room_cache.write().unwrap(); let c = &mut *self.appservice_in_room_cache.write().unwrap();
*c = HashMap::new(); *c = HashMap::new();
} }
if amount > 7 { if amount > 7 {
let c = &mut *self.lasttimelinecount_cache.lock().unwrap(); let c = &mut *self.lasttimelinecount_cache.lock().unwrap();
*c = HashMap::new(); *c = HashMap::new();
} }
} }
fn load_keypair(&self) -> Result<Ed25519KeyPair> { fn load_keypair(&self) -> Result<Ed25519KeyPair> {
let keypair_bytes = self.global.get(b"keypair")?.map_or_else( let keypair_bytes = self.global.get(b"keypair")?.map_or_else(
|| { || {
let keypair = utils::generate_keypair(); let keypair = utils::generate_keypair();
self.global.insert(b"keypair", &keypair)?; self.global.insert(b"keypair", &keypair)?;
Ok::<_, Error>(keypair) Ok::<_, Error>(keypair)
}, },
Ok, Ok,
)?; )?;
let mut parts = keypair_bytes.splitn(2, |&b| b == 0xff); let mut parts = keypair_bytes.splitn(2, |&b| b == 0xFF);
utils::string_from_bytes( utils::string_from_bytes(
// 1. version // 1. version
parts parts.next().expect("splitn always returns at least one element"),
.next() )
.expect("splitn always returns at least one element"), .map_err(|_| Error::bad_database("Invalid version bytes in keypair."))
) .and_then(|version| {
.map_err(|_| Error::bad_database("Invalid version bytes in keypair.")) // 2. key
.and_then(|version| { parts
// 2. key .next()
parts .ok_or_else(|| Error::bad_database("Invalid keypair format in database."))
.next() .map(|key| (version, key))
.ok_or_else(|| Error::bad_database("Invalid keypair format in database.")) })
.map(|key| (version, key)) .and_then(|(version, key)| {
}) Ed25519KeyPair::from_der(key, version)
.and_then(|(version, key)| { .map_err(|_| Error::bad_database("Private or public keys are invalid."))
Ed25519KeyPair::from_der(key, version) })
.map_err(|_| Error::bad_database("Private or public keys are invalid.")) }
})
}
fn remove_keypair(&self) -> Result<()> {
self.global.remove(b"keypair")
}
fn add_signing_key( fn remove_keypair(&self) -> Result<()> { self.global.remove(b"keypair") }
&self,
origin: &ServerName,
new_keys: ServerSigningKeys,
) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> {
// Not atomic, but this is not critical
let signingkeys = self.server_signingkeys.get(origin.as_bytes())?;
let mut keys = signingkeys fn add_signing_key(
.and_then(|keys| serde_json::from_slice(&keys).ok()) &self, origin: &ServerName, new_keys: ServerSigningKeys,
.unwrap_or_else(|| { ) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> {
// Just insert "now", it doesn't matter // Not atomic, but this is not critical
ServerSigningKeys::new(origin.to_owned(), MilliSecondsSinceUnixEpoch::now()) let signingkeys = self.server_signingkeys.get(origin.as_bytes())?;
});
let ServerSigningKeys { let mut keys = signingkeys.and_then(|keys| serde_json::from_slice(&keys).ok()).unwrap_or_else(|| {
verify_keys, // Just insert "now", it doesn't matter
old_verify_keys, ServerSigningKeys::new(origin.to_owned(), MilliSecondsSinceUnixEpoch::now())
.. });
} = new_keys;
keys.verify_keys.extend(verify_keys); let ServerSigningKeys {
keys.old_verify_keys.extend(old_verify_keys); verify_keys,
old_verify_keys,
..
} = new_keys;
self.server_signingkeys.insert( keys.verify_keys.extend(verify_keys);
origin.as_bytes(), keys.old_verify_keys.extend(old_verify_keys);
&serde_json::to_vec(&keys).expect("serversigningkeys can be serialized"),
)?;
let mut tree = keys.verify_keys; self.server_signingkeys.insert(
tree.extend( origin.as_bytes(),
keys.old_verify_keys &serde_json::to_vec(&keys).expect("serversigningkeys can be serialized"),
.into_iter() )?;
.map(|old| (old.0, VerifyKey::new(old.1.key))),
);
Ok(tree) let mut tree = keys.verify_keys;
} tree.extend(keys.old_verify_keys.into_iter().map(|old| (old.0, VerifyKey::new(old.1.key))));
/// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found for the server. Ok(tree)
fn signing_keys_for( }
&self,
origin: &ServerName,
) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> {
let signingkeys = self
.server_signingkeys
.get(origin.as_bytes())?
.and_then(|bytes| serde_json::from_slice(&bytes).ok())
.map(|keys: ServerSigningKeys| {
let mut tree = keys.verify_keys;
tree.extend(
keys.old_verify_keys
.into_iter()
.map(|old| (old.0, VerifyKey::new(old.1.key))),
);
tree
})
.unwrap_or_else(BTreeMap::new);
Ok(signingkeys) /// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found
} /// for the server.
fn signing_keys_for(&self, origin: &ServerName) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> {
let signingkeys = self
.server_signingkeys
.get(origin.as_bytes())?
.and_then(|bytes| serde_json::from_slice(&bytes).ok())
.map(|keys: ServerSigningKeys| {
let mut tree = keys.verify_keys;
tree.extend(keys.old_verify_keys.into_iter().map(|old| (old.0, VerifyKey::new(old.1.key))));
tree
})
.unwrap_or_else(BTreeMap::new);
fn database_version(&self) -> Result<u64> { Ok(signingkeys)
self.global.get(b"version")?.map_or(Ok(0), |version| { }
utils::u64_from_bytes(&version)
.map_err(|_| Error::bad_database("Database version id is invalid."))
})
}
fn bump_database_version(&self, new_version: u64) -> Result<()> { fn database_version(&self) -> Result<u64> {
self.global.insert(b"version", &new_version.to_be_bytes())?; self.global.get(b"version")?.map_or(Ok(0), |version| {
Ok(()) utils::u64_from_bytes(&version).map_err(|_| Error::bad_database("Database version id is invalid."))
} })
}
fn bump_database_version(&self, new_version: u64) -> Result<()> {
self.global.insert(b"version", &new_version.to_be_bytes())?;
Ok(())
}
} }

View file

@ -1,364 +1,292 @@
use std::collections::BTreeMap; use std::collections::BTreeMap;
use ruma::{ use ruma::{
api::client::{ api::client::{
backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup}, backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup},
error::ErrorKind, error::ErrorKind,
}, },
serde::Raw, serde::Raw,
OwnedRoomId, RoomId, UserId, OwnedRoomId, RoomId, UserId,
}; };
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
impl service::key_backups::Data for KeyValueDatabase { impl service::key_backups::Data for KeyValueDatabase {
fn create_backup( fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw<BackupAlgorithm>) -> Result<String> {
&self, let version = services().globals.next_count()?.to_string();
user_id: &UserId,
backup_metadata: &Raw<BackupAlgorithm>,
) -> Result<String> {
let version = services().globals.next_count()?.to_string();
let mut key = user_id.as_bytes().to_vec(); let mut key = user_id.as_bytes().to_vec();
key.push(0xff); key.push(0xFF);
key.extend_from_slice(version.as_bytes()); key.extend_from_slice(version.as_bytes());
self.backupid_algorithm.insert( self.backupid_algorithm.insert(
&key, &key,
&serde_json::to_vec(backup_metadata).expect("BackupAlgorithm::to_vec always works"), &serde_json::to_vec(backup_metadata).expect("BackupAlgorithm::to_vec always works"),
)?; )?;
self.backupid_etag self.backupid_etag.insert(&key, &services().globals.next_count()?.to_be_bytes())?;
.insert(&key, &services().globals.next_count()?.to_be_bytes())?; Ok(version)
Ok(version) }
}
fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> { fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> {
let mut key = user_id.as_bytes().to_vec(); let mut key = user_id.as_bytes().to_vec();
key.push(0xff); key.push(0xFF);
key.extend_from_slice(version.as_bytes()); key.extend_from_slice(version.as_bytes());
self.backupid_algorithm.remove(&key)?; self.backupid_algorithm.remove(&key)?;
self.backupid_etag.remove(&key)?; self.backupid_etag.remove(&key)?;
key.push(0xff); key.push(0xFF);
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) { for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
self.backupkeyid_backup.remove(&outdated_key)?; self.backupkeyid_backup.remove(&outdated_key)?;
} }
Ok(()) Ok(())
} }
fn update_backup( fn update_backup(&self, user_id: &UserId, version: &str, backup_metadata: &Raw<BackupAlgorithm>) -> Result<String> {
&self, let mut key = user_id.as_bytes().to_vec();
user_id: &UserId, key.push(0xFF);
version: &str, key.extend_from_slice(version.as_bytes());
backup_metadata: &Raw<BackupAlgorithm>,
) -> Result<String> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xff);
key.extend_from_slice(version.as_bytes());
if self.backupid_algorithm.get(&key)?.is_none() { if self.backupid_algorithm.get(&key)?.is_none() {
return Err(Error::BadRequest( return Err(Error::BadRequest(ErrorKind::NotFound, "Tried to update nonexistent backup."));
ErrorKind::NotFound, }
"Tried to update nonexistent backup.",
));
}
self.backupid_algorithm self.backupid_algorithm.insert(&key, backup_metadata.json().get().as_bytes())?;
.insert(&key, backup_metadata.json().get().as_bytes())?; self.backupid_etag.insert(&key, &services().globals.next_count()?.to_be_bytes())?;
self.backupid_etag Ok(version.to_owned())
.insert(&key, &services().globals.next_count()?.to_be_bytes())?; }
Ok(version.to_owned())
}
fn get_latest_backup_version(&self, user_id: &UserId) -> Result<Option<String>> { fn get_latest_backup_version(&self, user_id: &UserId) -> Result<Option<String>> {
let mut prefix = user_id.as_bytes().to_vec(); let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xff); prefix.push(0xFF);
let mut last_possible_key = prefix.clone(); let mut last_possible_key = prefix.clone();
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
self.backupid_algorithm self.backupid_algorithm
.iter_from(&last_possible_key, true) .iter_from(&last_possible_key, true)
.take_while(move |(k, _)| k.starts_with(&prefix)) .take_while(move |(k, _)| k.starts_with(&prefix))
.next() .next()
.map(|(key, _)| { .map(|(key, _)| {
utils::string_from_bytes( utils::string_from_bytes(key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element"))
key.rsplit(|&b| b == 0xff) .map_err(|_| Error::bad_database("backupid_algorithm key is invalid."))
.next() })
.expect("rsplit always returns an element"), .transpose()
) }
.map_err(|_| Error::bad_database("backupid_algorithm key is invalid."))
})
.transpose()
}
fn get_latest_backup( fn get_latest_backup(&self, user_id: &UserId) -> Result<Option<(String, Raw<BackupAlgorithm>)>> {
&self, let mut prefix = user_id.as_bytes().to_vec();
user_id: &UserId, prefix.push(0xFF);
) -> Result<Option<(String, Raw<BackupAlgorithm>)>> { let mut last_possible_key = prefix.clone();
let mut prefix = user_id.as_bytes().to_vec(); last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
prefix.push(0xff);
let mut last_possible_key = prefix.clone();
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
self.backupid_algorithm self.backupid_algorithm
.iter_from(&last_possible_key, true) .iter_from(&last_possible_key, true)
.take_while(move |(k, _)| k.starts_with(&prefix)) .take_while(move |(k, _)| k.starts_with(&prefix))
.next() .next()
.map(|(key, value)| { .map(|(key, value)| {
let version = utils::string_from_bytes( let version = utils::string_from_bytes(
key.rsplit(|&b| b == 0xff) key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element"),
.next() )
.expect("rsplit always returns an element"), .map_err(|_| Error::bad_database("backupid_algorithm key is invalid."))?;
)
.map_err(|_| Error::bad_database("backupid_algorithm key is invalid."))?;
Ok(( Ok((
version, version,
serde_json::from_slice(&value).map_err(|_| { serde_json::from_slice(&value)
Error::bad_database("Algorithm in backupid_algorithm is invalid.") .map_err(|_| Error::bad_database("Algorithm in backupid_algorithm is invalid."))?,
})?, ))
)) })
}) .transpose()
.transpose() }
}
fn get_backup(&self, user_id: &UserId, version: &str) -> Result<Option<Raw<BackupAlgorithm>>> { fn get_backup(&self, user_id: &UserId, version: &str) -> Result<Option<Raw<BackupAlgorithm>>> {
let mut key = user_id.as_bytes().to_vec(); let mut key = user_id.as_bytes().to_vec();
key.push(0xff); key.push(0xFF);
key.extend_from_slice(version.as_bytes()); key.extend_from_slice(version.as_bytes());
self.backupid_algorithm self.backupid_algorithm.get(&key)?.map_or(Ok(None), |bytes| {
.get(&key)? serde_json::from_slice(&bytes)
.map_or(Ok(None), |bytes| { .map_err(|_| Error::bad_database("Algorithm in backupid_algorithm is invalid."))
serde_json::from_slice(&bytes) })
.map_err(|_| Error::bad_database("Algorithm in backupid_algorithm is invalid.")) }
})
}
fn add_key( fn add_key(
&self, &self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, key_data: &Raw<KeyBackupData>,
user_id: &UserId, ) -> Result<()> {
version: &str, let mut key = user_id.as_bytes().to_vec();
room_id: &RoomId, key.push(0xFF);
session_id: &str, key.extend_from_slice(version.as_bytes());
key_data: &Raw<KeyBackupData>,
) -> Result<()> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xff);
key.extend_from_slice(version.as_bytes());
if self.backupid_algorithm.get(&key)?.is_none() { if self.backupid_algorithm.get(&key)?.is_none() {
return Err(Error::BadRequest( return Err(Error::BadRequest(ErrorKind::NotFound, "Tried to update nonexistent backup."));
ErrorKind::NotFound, }
"Tried to update nonexistent backup.",
));
}
self.backupid_etag self.backupid_etag.insert(&key, &services().globals.next_count()?.to_be_bytes())?;
.insert(&key, &services().globals.next_count()?.to_be_bytes())?;
key.push(0xff); key.push(0xFF);
key.extend_from_slice(room_id.as_bytes()); key.extend_from_slice(room_id.as_bytes());
key.push(0xff); key.push(0xFF);
key.extend_from_slice(session_id.as_bytes()); key.extend_from_slice(session_id.as_bytes());
self.backupkeyid_backup self.backupkeyid_backup.insert(&key, key_data.json().get().as_bytes())?;
.insert(&key, key_data.json().get().as_bytes())?;
Ok(()) Ok(())
} }
fn count_keys(&self, user_id: &UserId, version: &str) -> Result<usize> { fn count_keys(&self, user_id: &UserId, version: &str) -> Result<usize> {
let mut prefix = user_id.as_bytes().to_vec(); let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xff); prefix.push(0xFF);
prefix.extend_from_slice(version.as_bytes()); prefix.extend_from_slice(version.as_bytes());
Ok(self.backupkeyid_backup.scan_prefix(prefix).count()) Ok(self.backupkeyid_backup.scan_prefix(prefix).count())
} }
fn get_etag(&self, user_id: &UserId, version: &str) -> Result<String> { fn get_etag(&self, user_id: &UserId, version: &str) -> Result<String> {
let mut key = user_id.as_bytes().to_vec(); let mut key = user_id.as_bytes().to_vec();
key.push(0xff); key.push(0xFF);
key.extend_from_slice(version.as_bytes()); key.extend_from_slice(version.as_bytes());
Ok(utils::u64_from_bytes( Ok(utils::u64_from_bytes(
&self &self.backupid_etag.get(&key)?.ok_or_else(|| Error::bad_database("Backup has no etag."))?,
.backupid_etag )
.get(&key)? .map_err(|_| Error::bad_database("etag in backupid_etag invalid."))?
.ok_or_else(|| Error::bad_database("Backup has no etag."))?, .to_string())
) }
.map_err(|_| Error::bad_database("etag in backupid_etag invalid."))?
.to_string())
}
fn get_all( fn get_all(&self, user_id: &UserId, version: &str) -> Result<BTreeMap<OwnedRoomId, RoomKeyBackup>> {
&self, let mut prefix = user_id.as_bytes().to_vec();
user_id: &UserId, prefix.push(0xFF);
version: &str, prefix.extend_from_slice(version.as_bytes());
) -> Result<BTreeMap<OwnedRoomId, RoomKeyBackup>> { prefix.push(0xFF);
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xff);
prefix.extend_from_slice(version.as_bytes());
prefix.push(0xff);
let mut rooms = BTreeMap::<OwnedRoomId, RoomKeyBackup>::new(); let mut rooms = BTreeMap::<OwnedRoomId, RoomKeyBackup>::new();
for result in self for result in self.backupkeyid_backup.scan_prefix(prefix).map(|(key, value)| {
.backupkeyid_backup let mut parts = key.rsplit(|&b| b == 0xFF);
.scan_prefix(prefix)
.map(|(key, value)| {
let mut parts = key.rsplit(|&b| b == 0xff);
let session_id = let session_id = utils::string_from_bytes(
utils::string_from_bytes(parts.next().ok_or_else(|| { parts.next().ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?,
Error::bad_database("backupkeyid_backup key is invalid.") )
})?) .map_err(|_| Error::bad_database("backupkeyid_backup session_id is invalid."))?;
.map_err(|_| {
Error::bad_database("backupkeyid_backup session_id is invalid.")
})?;
let room_id = RoomId::parse( let room_id = RoomId::parse(
utils::string_from_bytes(parts.next().ok_or_else(|| { utils::string_from_bytes(
Error::bad_database("backupkeyid_backup key is invalid.") parts.next().ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?,
})?) )
.map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid."))?, .map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid."))?,
) )
.map_err(|_| { .map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid room id."))?;
Error::bad_database("backupkeyid_backup room_id is invalid room id.")
})?;
let key_data = serde_json::from_slice(&value).map_err(|_| { let key_data = serde_json::from_slice(&value)
Error::bad_database("KeyBackupData in backupkeyid_backup is invalid.") .map_err(|_| Error::bad_database("KeyBackupData in backupkeyid_backup is invalid."))?;
})?;
Ok::<_, Error>((room_id, session_id, key_data)) Ok::<_, Error>((room_id, session_id, key_data))
}) }) {
{ let (room_id, session_id, key_data) = result?;
let (room_id, session_id, key_data) = result?; rooms
rooms .entry(room_id)
.entry(room_id) .or_insert_with(|| RoomKeyBackup {
.or_insert_with(|| RoomKeyBackup { sessions: BTreeMap::new(),
sessions: BTreeMap::new(), })
}) .sessions
.sessions .insert(session_id, key_data);
.insert(session_id, key_data); }
}
Ok(rooms) Ok(rooms)
} }
fn get_room( fn get_room(
&self, &self, user_id: &UserId, version: &str, room_id: &RoomId,
user_id: &UserId, ) -> Result<BTreeMap<String, Raw<KeyBackupData>>> {
version: &str, let mut prefix = user_id.as_bytes().to_vec();
room_id: &RoomId, prefix.push(0xFF);
) -> Result<BTreeMap<String, Raw<KeyBackupData>>> { prefix.extend_from_slice(version.as_bytes());
let mut prefix = user_id.as_bytes().to_vec(); prefix.push(0xFF);
prefix.push(0xff); prefix.extend_from_slice(room_id.as_bytes());
prefix.extend_from_slice(version.as_bytes()); prefix.push(0xFF);
prefix.push(0xff);
prefix.extend_from_slice(room_id.as_bytes());
prefix.push(0xff);
Ok(self Ok(self
.backupkeyid_backup .backupkeyid_backup
.scan_prefix(prefix) .scan_prefix(prefix)
.map(|(key, value)| { .map(|(key, value)| {
let mut parts = key.rsplit(|&b| b == 0xff); let mut parts = key.rsplit(|&b| b == 0xFF);
let session_id = let session_id = utils::string_from_bytes(
utils::string_from_bytes(parts.next().ok_or_else(|| { parts.next().ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?,
Error::bad_database("backupkeyid_backup key is invalid.") )
})?) .map_err(|_| Error::bad_database("backupkeyid_backup session_id is invalid."))?;
.map_err(|_| {
Error::bad_database("backupkeyid_backup session_id is invalid.")
})?;
let key_data = serde_json::from_slice(&value).map_err(|_| { let key_data = serde_json::from_slice(&value)
Error::bad_database("KeyBackupData in backupkeyid_backup is invalid.") .map_err(|_| Error::bad_database("KeyBackupData in backupkeyid_backup is invalid."))?;
})?;
Ok::<_, Error>((session_id, key_data)) Ok::<_, Error>((session_id, key_data))
}) })
.filter_map(std::result::Result::ok) .filter_map(std::result::Result::ok)
.collect()) .collect())
} }
fn get_session( fn get_session(
&self, &self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str,
user_id: &UserId, ) -> Result<Option<Raw<KeyBackupData>>> {
version: &str, let mut key = user_id.as_bytes().to_vec();
room_id: &RoomId, key.push(0xFF);
session_id: &str, key.extend_from_slice(version.as_bytes());
) -> Result<Option<Raw<KeyBackupData>>> { key.push(0xFF);
let mut key = user_id.as_bytes().to_vec(); key.extend_from_slice(room_id.as_bytes());
key.push(0xff); key.push(0xFF);
key.extend_from_slice(version.as_bytes()); key.extend_from_slice(session_id.as_bytes());
key.push(0xff);
key.extend_from_slice(room_id.as_bytes());
key.push(0xff);
key.extend_from_slice(session_id.as_bytes());
self.backupkeyid_backup self.backupkeyid_backup
.get(&key)? .get(&key)?
.map(|value| { .map(|value| {
serde_json::from_slice(&value).map_err(|_| { serde_json::from_slice(&value)
Error::bad_database("KeyBackupData in backupkeyid_backup is invalid.") .map_err(|_| Error::bad_database("KeyBackupData in backupkeyid_backup is invalid."))
}) })
}) .transpose()
.transpose() }
}
fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()> { fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()> {
let mut key = user_id.as_bytes().to_vec(); let mut key = user_id.as_bytes().to_vec();
key.push(0xff); key.push(0xFF);
key.extend_from_slice(version.as_bytes()); key.extend_from_slice(version.as_bytes());
key.push(0xff); key.push(0xFF);
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) { for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
self.backupkeyid_backup.remove(&outdated_key)?; self.backupkeyid_backup.remove(&outdated_key)?;
} }
Ok(()) Ok(())
} }
fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()> { fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()> {
let mut key = user_id.as_bytes().to_vec(); let mut key = user_id.as_bytes().to_vec();
key.push(0xff); key.push(0xFF);
key.extend_from_slice(version.as_bytes()); key.extend_from_slice(version.as_bytes());
key.push(0xff); key.push(0xFF);
key.extend_from_slice(room_id.as_bytes()); key.extend_from_slice(room_id.as_bytes());
key.push(0xff); key.push(0xFF);
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) { for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
self.backupkeyid_backup.remove(&outdated_key)?; self.backupkeyid_backup.remove(&outdated_key)?;
} }
Ok(()) Ok(())
} }
fn delete_room_key( fn delete_room_key(&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str) -> Result<()> {
&self, let mut key = user_id.as_bytes().to_vec();
user_id: &UserId, key.push(0xFF);
version: &str, key.extend_from_slice(version.as_bytes());
room_id: &RoomId, key.push(0xFF);
session_id: &str, key.extend_from_slice(room_id.as_bytes());
) -> Result<()> { key.push(0xFF);
let mut key = user_id.as_bytes().to_vec(); key.extend_from_slice(session_id.as_bytes());
key.push(0xff);
key.extend_from_slice(version.as_bytes());
key.push(0xff);
key.extend_from_slice(room_id.as_bytes());
key.push(0xff);
key.extend_from_slice(session_id.as_bytes());
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) { for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
self.backupkeyid_backup.remove(&outdated_key)?; self.backupkeyid_backup.remove(&outdated_key)?;
} }
Ok(()) Ok(())
} }
} }

View file

@ -2,245 +2,182 @@ use ruma::api::client::error::ErrorKind;
use tracing::debug; use tracing::debug;
use crate::{ use crate::{
database::KeyValueDatabase, database::KeyValueDatabase,
service::{self, media::UrlPreviewData}, service::{self, media::UrlPreviewData},
utils, Error, Result, utils, Error, Result,
}; };
impl service::media::Data for KeyValueDatabase { impl service::media::Data for KeyValueDatabase {
fn create_file_metadata( fn create_file_metadata(
&self, &self, mxc: String, width: u32, height: u32, content_disposition: Option<&str>, content_type: Option<&str>,
mxc: String, ) -> Result<Vec<u8>> {
width: u32, let mut key = mxc.as_bytes().to_vec();
height: u32, key.push(0xFF);
content_disposition: Option<&str>, key.extend_from_slice(&width.to_be_bytes());
content_type: Option<&str>, key.extend_from_slice(&height.to_be_bytes());
) -> Result<Vec<u8>> { key.push(0xFF);
let mut key = mxc.as_bytes().to_vec(); key.extend_from_slice(content_disposition.as_ref().map(|f| f.as_bytes()).unwrap_or_default());
key.push(0xff); key.push(0xFF);
key.extend_from_slice(&width.to_be_bytes()); key.extend_from_slice(content_type.as_ref().map(|c| c.as_bytes()).unwrap_or_default());
key.extend_from_slice(&height.to_be_bytes());
key.push(0xff);
key.extend_from_slice(
content_disposition
.as_ref()
.map(|f| f.as_bytes())
.unwrap_or_default(),
);
key.push(0xff);
key.extend_from_slice(
content_type
.as_ref()
.map(|c| c.as_bytes())
.unwrap_or_default(),
);
self.mediaid_file.insert(&key, &[])?; self.mediaid_file.insert(&key, &[])?;
Ok(key) Ok(key)
} }
fn delete_file_mxc(&self, mxc: String) -> Result<()> { fn delete_file_mxc(&self, mxc: String) -> Result<()> {
debug!("MXC URI: {:?}", mxc); debug!("MXC URI: {:?}", mxc);
let mut prefix = mxc.as_bytes().to_vec(); let mut prefix = mxc.as_bytes().to_vec();
prefix.push(0xff); prefix.push(0xFF);
debug!("MXC db prefix: {:?}", prefix); debug!("MXC db prefix: {:?}", prefix);
for (key, _) in self.mediaid_file.scan_prefix(prefix) { for (key, _) in self.mediaid_file.scan_prefix(prefix) {
debug!("Deleting key: {:?}", key); debug!("Deleting key: {:?}", key);
self.mediaid_file.remove(&key)?; self.mediaid_file.remove(&key)?;
} }
Ok(()) Ok(())
} }
/// Searches for all files with the given MXC /// Searches for all files with the given MXC
fn search_mxc_metadata_prefix(&self, mxc: String) -> Result<Vec<Vec<u8>>> { fn search_mxc_metadata_prefix(&self, mxc: String) -> Result<Vec<Vec<u8>>> {
debug!("MXC URI: {:?}", mxc); debug!("MXC URI: {:?}", mxc);
let mut prefix = mxc.as_bytes().to_vec(); let mut prefix = mxc.as_bytes().to_vec();
prefix.push(0xff); prefix.push(0xFF);
let mut keys: Vec<Vec<u8>> = vec![]; let mut keys: Vec<Vec<u8>> = vec![];
for (key, _) in self.mediaid_file.scan_prefix(prefix) { for (key, _) in self.mediaid_file.scan_prefix(prefix) {
keys.push(key); keys.push(key);
} }
if keys.is_empty() { if keys.is_empty() {
return Err(Error::bad_database( return Err(Error::bad_database(
"Failed to find any keys in database with the provided MXC.", "Failed to find any keys in database with the provided MXC.",
)); ));
} }
debug!("Got the following keys: {:?}", keys); debug!("Got the following keys: {:?}", keys);
Ok(keys) Ok(keys)
} }
fn search_file_metadata( fn search_file_metadata(
&self, &self, mxc: String, width: u32, height: u32,
mxc: String, ) -> Result<(Option<String>, Option<String>, Vec<u8>)> {
width: u32, let mut prefix = mxc.as_bytes().to_vec();
height: u32, prefix.push(0xFF);
) -> Result<(Option<String>, Option<String>, Vec<u8>)> { prefix.extend_from_slice(&width.to_be_bytes());
let mut prefix = mxc.as_bytes().to_vec(); prefix.extend_from_slice(&height.to_be_bytes());
prefix.push(0xff); prefix.push(0xFF);
prefix.extend_from_slice(&width.to_be_bytes());
prefix.extend_from_slice(&height.to_be_bytes());
prefix.push(0xff);
let (key, _) = self let (key, _) = self
.mediaid_file .mediaid_file
.scan_prefix(prefix) .scan_prefix(prefix)
.next() .next()
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Media not found"))?; .ok_or(Error::BadRequest(ErrorKind::NotFound, "Media not found"))?;
let mut parts = key.rsplit(|&b| b == 0xff); let mut parts = key.rsplit(|&b| b == 0xFF);
let content_type = parts let content_type = parts
.next() .next()
.map(|bytes| { .map(|bytes| {
utils::string_from_bytes(bytes).map_err(|_| { utils::string_from_bytes(bytes)
Error::bad_database("Content type in mediaid_file is invalid unicode.") .map_err(|_| Error::bad_database("Content type in mediaid_file is invalid unicode."))
}) })
}) .transpose()?;
.transpose()?;
let content_disposition_bytes = parts let content_disposition_bytes =
.next() parts.next().ok_or_else(|| Error::bad_database("Media ID in db is invalid."))?;
.ok_or_else(|| Error::bad_database("Media ID in db is invalid."))?;
let content_disposition = if content_disposition_bytes.is_empty() { let content_disposition = if content_disposition_bytes.is_empty() {
None None
} else { } else {
Some( Some(
utils::string_from_bytes(content_disposition_bytes).map_err(|_| { utils::string_from_bytes(content_disposition_bytes)
Error::bad_database("Content Disposition in mediaid_file is invalid unicode.") .map_err(|_| Error::bad_database("Content Disposition in mediaid_file is invalid unicode."))?,
})?, )
) };
}; Ok((content_disposition, content_type, key))
Ok((content_disposition, content_type, key)) }
}
/// Gets all the media keys in our database (this includes all the metadata associated with it such as width, height, content-type, etc) /// Gets all the media keys in our database (this includes all the metadata
fn get_all_media_keys(&self) -> Result<Vec<Vec<u8>>> { /// associated with it such as width, height, content-type, etc)
let mut keys: Vec<Vec<u8>> = vec![]; fn get_all_media_keys(&self) -> Result<Vec<Vec<u8>>> {
let mut keys: Vec<Vec<u8>> = vec![];
for (key, _) in self.mediaid_file.iter() { for (key, _) in self.mediaid_file.iter() {
keys.push(key); keys.push(key);
} }
Ok(keys) Ok(keys)
} }
fn remove_url_preview(&self, url: &str) -> Result<()> { fn remove_url_preview(&self, url: &str) -> Result<()> { self.url_previews.remove(url.as_bytes()) }
self.url_previews.remove(url.as_bytes())
}
fn set_url_preview( fn set_url_preview(&self, url: &str, data: &UrlPreviewData, timestamp: std::time::Duration) -> Result<()> {
&self, let mut value = Vec::<u8>::new();
url: &str, value.extend_from_slice(&timestamp.as_secs().to_be_bytes());
data: &UrlPreviewData, value.push(0xFF);
timestamp: std::time::Duration, value.extend_from_slice(data.title.as_ref().map(std::string::String::as_bytes).unwrap_or_default());
) -> Result<()> { value.push(0xFF);
let mut value = Vec::<u8>::new(); value.extend_from_slice(data.description.as_ref().map(std::string::String::as_bytes).unwrap_or_default());
value.extend_from_slice(&timestamp.as_secs().to_be_bytes()); value.push(0xFF);
value.push(0xff); value.extend_from_slice(data.image.as_ref().map(std::string::String::as_bytes).unwrap_or_default());
value.extend_from_slice( value.push(0xFF);
data.title value.extend_from_slice(&data.image_size.unwrap_or(0).to_be_bytes());
.as_ref() value.push(0xFF);
.map(std::string::String::as_bytes) value.extend_from_slice(&data.image_width.unwrap_or(0).to_be_bytes());
.unwrap_or_default(), value.push(0xFF);
); value.extend_from_slice(&data.image_height.unwrap_or(0).to_be_bytes());
value.push(0xff);
value.extend_from_slice(
data.description
.as_ref()
.map(std::string::String::as_bytes)
.unwrap_or_default(),
);
value.push(0xff);
value.extend_from_slice(
data.image
.as_ref()
.map(std::string::String::as_bytes)
.unwrap_or_default(),
);
value.push(0xff);
value.extend_from_slice(&data.image_size.unwrap_or(0).to_be_bytes());
value.push(0xff);
value.extend_from_slice(&data.image_width.unwrap_or(0).to_be_bytes());
value.push(0xff);
value.extend_from_slice(&data.image_height.unwrap_or(0).to_be_bytes());
self.url_previews.insert(url.as_bytes(), &value) self.url_previews.insert(url.as_bytes(), &value)
} }
fn get_url_preview(&self, url: &str) -> Option<UrlPreviewData> { fn get_url_preview(&self, url: &str) -> Option<UrlPreviewData> {
let values = self.url_previews.get(url.as_bytes()).ok()??; let values = self.url_previews.get(url.as_bytes()).ok()??;
let mut values = values.split(|&b| b == 0xff); let mut values = values.split(|&b| b == 0xFF);
let _ts = match values let _ts = match values.next().map(|b| u64::from_be_bytes(b.try_into().expect("valid BE array"))) {
.next() Some(0) => None,
.map(|b| u64::from_be_bytes(b.try_into().expect("valid BE array"))) x => x,
{ };
Some(0) => None, let title = match values.next().and_then(|b| String::from_utf8(b.to_vec()).ok()) {
x => x, Some(s) if s.is_empty() => None,
}; x => x,
let title = match values };
.next() let description = match values.next().and_then(|b| String::from_utf8(b.to_vec()).ok()) {
.and_then(|b| String::from_utf8(b.to_vec()).ok()) Some(s) if s.is_empty() => None,
{ x => x,
Some(s) if s.is_empty() => None, };
x => x, let image = match values.next().and_then(|b| String::from_utf8(b.to_vec()).ok()) {
}; Some(s) if s.is_empty() => None,
let description = match values x => x,
.next() };
.and_then(|b| String::from_utf8(b.to_vec()).ok()) let image_size = match values.next().map(|b| usize::from_be_bytes(b.try_into().expect("valid BE array"))) {
{ Some(0) => None,
Some(s) if s.is_empty() => None, x => x,
x => x, };
}; let image_width = match values.next().map(|b| u32::from_be_bytes(b.try_into().expect("valid BE array"))) {
let image = match values Some(0) => None,
.next() x => x,
.and_then(|b| String::from_utf8(b.to_vec()).ok()) };
{ let image_height = match values.next().map(|b| u32::from_be_bytes(b.try_into().expect("valid BE array"))) {
Some(s) if s.is_empty() => None, Some(0) => None,
x => x, x => x,
}; };
let image_size = match values
.next()
.map(|b| usize::from_be_bytes(b.try_into().expect("valid BE array")))
{
Some(0) => None,
x => x,
};
let image_width = match values
.next()
.map(|b| u32::from_be_bytes(b.try_into().expect("valid BE array")))
{
Some(0) => None,
x => x,
};
let image_height = match values
.next()
.map(|b| u32::from_be_bytes(b.try_into().expect("valid BE array")))
{
Some(0) => None,
x => x,
};
Some(UrlPreviewData { Some(UrlPreviewData {
title, title,
description, description,
image, image,
image_size, image_size,
image_width, image_width,
image_height, image_height,
}) })
} }
} }

View file

@ -1,79 +1,63 @@
use ruma::{ use ruma::{
api::client::push::{set_pusher, Pusher}, api::client::push::{set_pusher, Pusher},
UserId, UserId,
}; };
use crate::{database::KeyValueDatabase, service, utils, Error, Result}; use crate::{database::KeyValueDatabase, service, utils, Error, Result};
impl service::pusher::Data for KeyValueDatabase { impl service::pusher::Data for KeyValueDatabase {
fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::PusherAction) -> Result<()> { fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::PusherAction) -> Result<()> {
match &pusher { match &pusher {
set_pusher::v3::PusherAction::Post(data) => { set_pusher::v3::PusherAction::Post(data) => {
let mut key = sender.as_bytes().to_vec(); let mut key = sender.as_bytes().to_vec();
key.push(0xff); key.push(0xFF);
key.extend_from_slice(data.pusher.ids.pushkey.as_bytes()); key.extend_from_slice(data.pusher.ids.pushkey.as_bytes());
self.senderkey_pusher.insert( self.senderkey_pusher
&key, .insert(&key, &serde_json::to_vec(&pusher).expect("Pusher is valid JSON value"))?;
&serde_json::to_vec(&pusher).expect("Pusher is valid JSON value"), Ok(())
)?; },
Ok(()) set_pusher::v3::PusherAction::Delete(ids) => {
} let mut key = sender.as_bytes().to_vec();
set_pusher::v3::PusherAction::Delete(ids) => { key.push(0xFF);
let mut key = sender.as_bytes().to_vec(); key.extend_from_slice(ids.pushkey.as_bytes());
key.push(0xff); self.senderkey_pusher.remove(&key).map(|_| ()).map_err(Into::into)
key.extend_from_slice(ids.pushkey.as_bytes()); },
self.senderkey_pusher }
.remove(&key) }
.map(|_| ())
.map_err(Into::into)
}
}
}
fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result<Option<Pusher>> { fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result<Option<Pusher>> {
let mut senderkey = sender.as_bytes().to_vec(); let mut senderkey = sender.as_bytes().to_vec();
senderkey.push(0xff); senderkey.push(0xFF);
senderkey.extend_from_slice(pushkey.as_bytes()); senderkey.extend_from_slice(pushkey.as_bytes());
self.senderkey_pusher self.senderkey_pusher
.get(&senderkey)? .get(&senderkey)?
.map(|push| { .map(|push| serde_json::from_slice(&push).map_err(|_| Error::bad_database("Invalid Pusher in db.")))
serde_json::from_slice(&push) .transpose()
.map_err(|_| Error::bad_database("Invalid Pusher in db.")) }
})
.transpose()
}
fn get_pushers(&self, sender: &UserId) -> Result<Vec<Pusher>> { fn get_pushers(&self, sender: &UserId) -> Result<Vec<Pusher>> {
let mut prefix = sender.as_bytes().to_vec(); let mut prefix = sender.as_bytes().to_vec();
prefix.push(0xff); prefix.push(0xFF);
self.senderkey_pusher self.senderkey_pusher
.scan_prefix(prefix) .scan_prefix(prefix)
.map(|(_, push)| { .map(|(_, push)| serde_json::from_slice(&push).map_err(|_| Error::bad_database("Invalid Pusher in db.")))
serde_json::from_slice(&push) .collect()
.map_err(|_| Error::bad_database("Invalid Pusher in db.")) }
})
.collect()
}
fn get_pushkeys<'a>( fn get_pushkeys<'a>(&'a self, sender: &UserId) -> Box<dyn Iterator<Item = Result<String>> + 'a> {
&'a self, let mut prefix = sender.as_bytes().to_vec();
sender: &UserId, prefix.push(0xFF);
) -> Box<dyn Iterator<Item = Result<String>> + 'a> {
let mut prefix = sender.as_bytes().to_vec();
prefix.push(0xff);
Box::new(self.senderkey_pusher.scan_prefix(prefix).map(|(k, _)| { Box::new(self.senderkey_pusher.scan_prefix(prefix).map(|(k, _)| {
let mut parts = k.splitn(2, |&b| b == 0xff); let mut parts = k.splitn(2, |&b| b == 0xFF);
let _senderkey = parts.next(); let _senderkey = parts.next();
let push_key = parts let push_key = parts.next().ok_or_else(|| Error::bad_database("Invalid senderkey_pusher in db"))?;
.next() let push_key_string = utils::string_from_bytes(push_key)
.ok_or_else(|| Error::bad_database("Invalid senderkey_pusher in db"))?; .map_err(|_| Error::bad_database("Invalid pusher bytes in senderkey_pusher"))?;
let push_key_string = utils::string_from_bytes(push_key)
.map_err(|_| Error::bad_database("Invalid pusher bytes in senderkey_pusher"))?;
Ok(push_key_string) Ok(push_key_string)
})) }))
} }
} }

View file

@ -3,82 +3,68 @@ use ruma::{api::client::error::ErrorKind, OwnedRoomAliasId, OwnedRoomId, RoomAli
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
impl service::rooms::alias::Data for KeyValueDatabase { impl service::rooms::alias::Data for KeyValueDatabase {
fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId) -> Result<()> { fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId) -> Result<()> {
self.alias_roomid self.alias_roomid.insert(alias.alias().as_bytes(), room_id.as_bytes())?;
.insert(alias.alias().as_bytes(), room_id.as_bytes())?; let mut aliasid = room_id.as_bytes().to_vec();
let mut aliasid = room_id.as_bytes().to_vec(); aliasid.push(0xFF);
aliasid.push(0xff); aliasid.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
aliasid.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); self.aliasid_alias.insert(&aliasid, alias.as_bytes())?;
self.aliasid_alias.insert(&aliasid, alias.as_bytes())?; Ok(())
Ok(()) }
}
fn remove_alias(&self, alias: &RoomAliasId) -> Result<()> { fn remove_alias(&self, alias: &RoomAliasId) -> Result<()> {
if let Some(room_id) = self.alias_roomid.get(alias.alias().as_bytes())? { if let Some(room_id) = self.alias_roomid.get(alias.alias().as_bytes())? {
let mut prefix = room_id; let mut prefix = room_id;
prefix.push(0xff); prefix.push(0xFF);
for (key, _) in self.aliasid_alias.scan_prefix(prefix) { for (key, _) in self.aliasid_alias.scan_prefix(prefix) {
self.aliasid_alias.remove(&key)?; self.aliasid_alias.remove(&key)?;
} }
self.alias_roomid.remove(alias.alias().as_bytes())?; self.alias_roomid.remove(alias.alias().as_bytes())?;
} else { } else {
return Err(Error::BadRequest( return Err(Error::BadRequest(ErrorKind::NotFound, "Alias does not exist."));
ErrorKind::NotFound, }
"Alias does not exist.", Ok(())
)); }
}
Ok(())
}
fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result<Option<OwnedRoomId>> { fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result<Option<OwnedRoomId>> {
self.alias_roomid self.alias_roomid
.get(alias.alias().as_bytes())? .get(alias.alias().as_bytes())?
.map(|bytes| { .map(|bytes| {
RoomId::parse(utils::string_from_bytes(&bytes).map_err(|_| { RoomId::parse(
Error::bad_database("Room ID in alias_roomid is invalid unicode.") utils::string_from_bytes(&bytes)
})?) .map_err(|_| Error::bad_database("Room ID in alias_roomid is invalid unicode."))?,
.map_err(|_| Error::bad_database("Room ID in alias_roomid is invalid.")) )
}) .map_err(|_| Error::bad_database("Room ID in alias_roomid is invalid."))
.transpose() })
} .transpose()
}
fn local_aliases_for_room<'a>( fn local_aliases_for_room<'a>(
&'a self, &'a self, room_id: &RoomId,
room_id: &RoomId, ) -> Box<dyn Iterator<Item = Result<OwnedRoomAliasId>> + 'a> {
) -> Box<dyn Iterator<Item = Result<OwnedRoomAliasId>> + 'a> { let mut prefix = room_id.as_bytes().to_vec();
let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xFF);
prefix.push(0xff);
Box::new(self.aliasid_alias.scan_prefix(prefix).map(|(_, bytes)| { Box::new(self.aliasid_alias.scan_prefix(prefix).map(|(_, bytes)| {
utils::string_from_bytes(&bytes) utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Invalid alias bytes in aliasid_alias."))? .map_err(|_| Error::bad_database("Invalid alias bytes in aliasid_alias."))?
.try_into() .try_into()
.map_err(|_| Error::bad_database("Invalid alias in aliasid_alias.")) .map_err(|_| Error::bad_database("Invalid alias in aliasid_alias."))
})) }))
} }
fn all_local_aliases<'a>( fn all_local_aliases<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(OwnedRoomId, String)>> + 'a> {
&'a self, Box::new(self.alias_roomid.iter().map(|(room_alias_bytes, room_id_bytes)| {
) -> Box<dyn Iterator<Item = Result<(OwnedRoomId, String)>> + 'a> { let room_alias_localpart = utils::string_from_bytes(&room_alias_bytes)
Box::new( .map_err(|_| Error::bad_database("Invalid alias bytes in aliasid_alias."))?;
self.alias_roomid
.iter()
.map(|(room_alias_bytes, room_id_bytes)| {
let room_alias_localpart = utils::string_from_bytes(&room_alias_bytes)
.map_err(|_| {
Error::bad_database("Invalid alias bytes in aliasid_alias.")
})?;
let room_id = utils::string_from_bytes(&room_id_bytes) let room_id = utils::string_from_bytes(&room_id_bytes)
.map_err(|_| { .map_err(|_| Error::bad_database("Invalid room_id bytes in aliasid_alias."))?
Error::bad_database("Invalid room_id bytes in aliasid_alias.") .try_into()
})? .map_err(|_| Error::bad_database("Invalid room_id in aliasid_alias."))?;
.try_into()
.map_err(|_| Error::bad_database("Invalid room_id in aliasid_alias."))?;
Ok((room_id, room_alias_localpart)) Ok((room_id, room_alias_localpart))
}), }))
) }
}
} }

View file

@ -3,59 +3,47 @@ use std::{collections::HashSet, mem::size_of, sync::Arc};
use crate::{database::KeyValueDatabase, service, utils, Result}; use crate::{database::KeyValueDatabase, service, utils, Result};
impl service::rooms::auth_chain::Data for KeyValueDatabase { impl service::rooms::auth_chain::Data for KeyValueDatabase {
fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Option<Arc<HashSet<u64>>>> { fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Option<Arc<HashSet<u64>>>> {
// Check RAM cache // Check RAM cache
if let Some(result) = self.auth_chain_cache.lock().unwrap().get_mut(key) { if let Some(result) = self.auth_chain_cache.lock().unwrap().get_mut(key) {
return Ok(Some(Arc::clone(result))); return Ok(Some(Arc::clone(result)));
} }
// We only save auth chains for single events in the db // We only save auth chains for single events in the db
if key.len() == 1 { if key.len() == 1 {
// Check DB cache // Check DB cache
let chain = self let chain = self.shorteventid_authchain.get(&key[0].to_be_bytes())?.map(|chain| {
.shorteventid_authchain chain
.get(&key[0].to_be_bytes())? .chunks_exact(size_of::<u64>())
.map(|chain| { .map(|chunk| utils::u64_from_bytes(chunk).expect("byte length is correct"))
chain .collect()
.chunks_exact(size_of::<u64>()) });
.map(|chunk| utils::u64_from_bytes(chunk).expect("byte length is correct"))
.collect()
});
if let Some(chain) = chain { if let Some(chain) = chain {
let chain = Arc::new(chain); let chain = Arc::new(chain);
// Cache in RAM // Cache in RAM
self.auth_chain_cache self.auth_chain_cache.lock().unwrap().insert(vec![key[0]], Arc::clone(&chain));
.lock()
.unwrap()
.insert(vec![key[0]], Arc::clone(&chain));
return Ok(Some(chain)); return Ok(Some(chain));
} }
} }
Ok(None) Ok(None)
} }
fn cache_auth_chain(&self, key: Vec<u64>, auth_chain: Arc<HashSet<u64>>) -> Result<()> { fn cache_auth_chain(&self, key: Vec<u64>, auth_chain: Arc<HashSet<u64>>) -> Result<()> {
// Only persist single events in db // Only persist single events in db
if key.len() == 1 { if key.len() == 1 {
self.shorteventid_authchain.insert( self.shorteventid_authchain.insert(
&key[0].to_be_bytes(), &key[0].to_be_bytes(),
&auth_chain &auth_chain.iter().flat_map(|s| s.to_be_bytes().to_vec()).collect::<Vec<u8>>(),
.iter() )?;
.flat_map(|s| s.to_be_bytes().to_vec()) }
.collect::<Vec<u8>>(),
)?;
}
// Cache in RAM // Cache in RAM
self.auth_chain_cache self.auth_chain_cache.lock().unwrap().insert(key, auth_chain);
.lock()
.unwrap()
.insert(key, auth_chain);
Ok(()) Ok(())
} }
} }

View file

@ -3,26 +3,21 @@ use ruma::{OwnedRoomId, RoomId};
use crate::{database::KeyValueDatabase, service, utils, Error, Result}; use crate::{database::KeyValueDatabase, service, utils, Error, Result};
impl service::rooms::directory::Data for KeyValueDatabase { impl service::rooms::directory::Data for KeyValueDatabase {
fn set_public(&self, room_id: &RoomId) -> Result<()> { fn set_public(&self, room_id: &RoomId) -> Result<()> { self.publicroomids.insert(room_id.as_bytes(), &[]) }
self.publicroomids.insert(room_id.as_bytes(), &[])
}
fn set_not_public(&self, room_id: &RoomId) -> Result<()> { fn set_not_public(&self, room_id: &RoomId) -> Result<()> { self.publicroomids.remove(room_id.as_bytes()) }
self.publicroomids.remove(room_id.as_bytes())
}
fn is_public_room(&self, room_id: &RoomId) -> Result<bool> { fn is_public_room(&self, room_id: &RoomId) -> Result<bool> {
Ok(self.publicroomids.get(room_id.as_bytes())?.is_some()) Ok(self.publicroomids.get(room_id.as_bytes())?.is_some())
} }
fn public_rooms<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> { fn public_rooms<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
Box::new(self.publicroomids.iter().map(|(bytes, _)| { Box::new(self.publicroomids.iter().map(|(bytes, _)| {
RoomId::parse( RoomId::parse(
utils::string_from_bytes(&bytes).map_err(|_| { utils::string_from_bytes(&bytes)
Error::bad_database("Room ID in publicroomids is invalid unicode.") .map_err(|_| Error::bad_database("Room ID in publicroomids is invalid unicode."))?,
})?, )
) .map_err(|_| Error::bad_database("Room ID in publicroomids is invalid."))
.map_err(|_| Error::bad_database("Room ID in publicroomids is invalid.")) }))
})) }
}
} }

View file

@ -1,178 +1,155 @@
use std::time::Duration; use std::time::Duration;
use ruma::{ use ruma::{events::presence::PresenceEvent, presence::PresenceState, OwnedUserId, RoomId, UInt, UserId};
events::presence::PresenceEvent, presence::PresenceState, OwnedUserId, RoomId, UInt, UserId,
};
use tracing::error; use tracing::error;
use crate::{ use crate::{
database::KeyValueDatabase, database::KeyValueDatabase,
service::{self, rooms::edus::presence::Presence}, service::{self, rooms::edus::presence::Presence},
services, services,
utils::{self, user_id_from_bytes}, utils::{self, user_id_from_bytes},
Error, Result, Error, Result,
}; };
impl service::rooms::edus::presence::Data for KeyValueDatabase { impl service::rooms::edus::presence::Data for KeyValueDatabase {
fn get_presence(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<PresenceEvent>> { fn get_presence(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<PresenceEvent>> {
let key = presence_key(room_id, user_id); let key = presence_key(room_id, user_id);
self.roomuserid_presence self.roomuserid_presence
.get(&key)? .get(&key)?
.map(|presence_bytes| -> Result<PresenceEvent> { .map(|presence_bytes| -> Result<PresenceEvent> {
Presence::from_json_bytes(&presence_bytes)?.to_presence_event(user_id) Presence::from_json_bytes(&presence_bytes)?.to_presence_event(user_id)
}) })
.transpose() .transpose()
} }
fn ping_presence(&self, user_id: &UserId, new_state: PresenceState) -> Result<()> { fn ping_presence(&self, user_id: &UserId, new_state: PresenceState) -> Result<()> {
let now = utils::millis_since_unix_epoch(); let now = utils::millis_since_unix_epoch();
let mut state_changed = false; let mut state_changed = false;
for room_id in services().rooms.state_cache.rooms_joined(user_id) { for room_id in services().rooms.state_cache.rooms_joined(user_id) {
let key = presence_key(&room_id?, user_id); let key = presence_key(&room_id?, user_id);
let presence_bytes = self.roomuserid_presence.get(&key)?; let presence_bytes = self.roomuserid_presence.get(&key)?;
if let Some(presence_bytes) = presence_bytes { if let Some(presence_bytes) = presence_bytes {
let presence = Presence::from_json_bytes(&presence_bytes)?; let presence = Presence::from_json_bytes(&presence_bytes)?;
if presence.state != new_state { if presence.state != new_state {
state_changed = true; state_changed = true;
break; break;
} }
} }
} }
let count = if state_changed { let count = if state_changed {
services().globals.next_count()? services().globals.next_count()?
} else { } else {
services().globals.current_count()? services().globals.current_count()?
}; };
for room_id in services().rooms.state_cache.rooms_joined(user_id) { for room_id in services().rooms.state_cache.rooms_joined(user_id) {
let key = presence_key(&room_id?, user_id); let key = presence_key(&room_id?, user_id);
let presence_bytes = self.roomuserid_presence.get(&key)?; let presence_bytes = self.roomuserid_presence.get(&key)?;
let new_presence = match presence_bytes { let new_presence = match presence_bytes {
Some(presence_bytes) => { Some(presence_bytes) => {
let mut presence = Presence::from_json_bytes(&presence_bytes)?; let mut presence = Presence::from_json_bytes(&presence_bytes)?;
presence.state = new_state.clone(); presence.state = new_state.clone();
presence.currently_active = presence.state == PresenceState::Online; presence.currently_active = presence.state == PresenceState::Online;
presence.last_active_ts = now; presence.last_active_ts = now;
presence.last_count = count; presence.last_count = count;
presence presence
} },
None => Presence::new( None => Presence::new(new_state.clone(), new_state == PresenceState::Online, now, count, None),
new_state.clone(), };
new_state == PresenceState::Online,
now,
count,
None,
),
};
self.roomuserid_presence self.roomuserid_presence.insert(&key, &new_presence.to_json_bytes()?)?;
.insert(&key, &new_presence.to_json_bytes()?)?; }
}
let timeout = match new_state { let timeout = match new_state {
PresenceState::Online => services().globals.config.presence_idle_timeout_s, PresenceState::Online => services().globals.config.presence_idle_timeout_s,
_ => services().globals.config.presence_offline_timeout_s, _ => services().globals.config.presence_offline_timeout_s,
}; };
self.presence_timer_sender self.presence_timer_sender.send((user_id.to_owned(), Duration::from_secs(timeout))).map_err(|e| {
.send((user_id.to_owned(), Duration::from_secs(timeout))) error!("Failed to add presence timer: {}", e);
.map_err(|e| { Error::bad_database("Failed to add presence timer")
error!("Failed to add presence timer: {}", e); })
Error::bad_database("Failed to add presence timer") }
})
}
fn set_presence( fn set_presence(
&self, &self, room_id: &RoomId, user_id: &UserId, presence_state: PresenceState, currently_active: Option<bool>,
room_id: &RoomId, last_active_ago: Option<UInt>, status_msg: Option<String>,
user_id: &UserId, ) -> Result<()> {
presence_state: PresenceState, let now = utils::millis_since_unix_epoch();
currently_active: Option<bool>, let last_active_ts = match last_active_ago {
last_active_ago: Option<UInt>, Some(last_active_ago) => now.saturating_sub(last_active_ago.into()),
status_msg: Option<String>, None => now,
) -> Result<()> { };
let now = utils::millis_since_unix_epoch();
let last_active_ts = match last_active_ago {
Some(last_active_ago) => now.saturating_sub(last_active_ago.into()),
None => now,
};
let key = presence_key(room_id, user_id); let key = presence_key(room_id, user_id);
let presence = Presence::new( let presence = Presence::new(
presence_state, presence_state,
currently_active.unwrap_or(false), currently_active.unwrap_or(false),
last_active_ts, last_active_ts,
services().globals.next_count()?, services().globals.next_count()?,
status_msg, status_msg,
); );
let timeout = match presence.state { let timeout = match presence.state {
PresenceState::Online => services().globals.config.presence_idle_timeout_s, PresenceState::Online => services().globals.config.presence_idle_timeout_s,
_ => services().globals.config.presence_offline_timeout_s, _ => services().globals.config.presence_offline_timeout_s,
}; };
self.presence_timer_sender self.presence_timer_sender.send((user_id.to_owned(), Duration::from_secs(timeout))).map_err(|e| {
.send((user_id.to_owned(), Duration::from_secs(timeout))) error!("Failed to add presence timer: {}", e);
.map_err(|e| { Error::bad_database("Failed to add presence timer")
error!("Failed to add presence timer: {}", e); })?;
Error::bad_database("Failed to add presence timer")
})?;
self.roomuserid_presence self.roomuserid_presence.insert(&key, &presence.to_json_bytes()?)?;
.insert(&key, &presence.to_json_bytes()?)?;
Ok(()) Ok(())
} }
fn remove_presence(&self, user_id: &UserId) -> Result<()> { fn remove_presence(&self, user_id: &UserId) -> Result<()> {
for room_id in services().rooms.state_cache.rooms_joined(user_id) { for room_id in services().rooms.state_cache.rooms_joined(user_id) {
let key = presence_key(&room_id?, user_id); let key = presence_key(&room_id?, user_id);
self.roomuserid_presence.remove(&key)?; self.roomuserid_presence.remove(&key)?;
} }
Ok(()) Ok(())
} }
fn presence_since<'a>( fn presence_since<'a>(
&'a self, &'a self, room_id: &RoomId, since: u64,
room_id: &RoomId, ) -> Box<dyn Iterator<Item = (OwnedUserId, u64, PresenceEvent)> + 'a> {
since: u64, let prefix = [room_id.as_bytes(), &[0xFF]].concat();
) -> Box<dyn Iterator<Item = (OwnedUserId, u64, PresenceEvent)> + 'a> {
let prefix = [room_id.as_bytes(), &[0xff]].concat();
Box::new( Box::new(
self.roomuserid_presence self.roomuserid_presence
.scan_prefix(prefix) .scan_prefix(prefix)
.flat_map( .flat_map(|(key, presence_bytes)| -> Result<(OwnedUserId, u64, PresenceEvent)> {
|(key, presence_bytes)| -> Result<(OwnedUserId, u64, PresenceEvent)> { let user_id = user_id_from_bytes(
let user_id = user_id_from_bytes( key.rsplit(|byte| *byte == 0xFF)
key.rsplit(|byte| *byte == 0xff).next().ok_or_else(|| { .next()
Error::bad_database("No UserID bytes in presence key") .ok_or_else(|| Error::bad_database("No UserID bytes in presence key"))?,
})?, )?;
)?;
let presence = Presence::from_json_bytes(&presence_bytes)?; let presence = Presence::from_json_bytes(&presence_bytes)?;
let presence_event = presence.to_presence_event(&user_id)?; let presence_event = presence.to_presence_event(&user_id)?;
Ok((user_id, presence.last_count, presence_event)) Ok((user_id, presence.last_count, presence_event))
}, })
) .filter(move |(_, count, _)| *count > since),
.filter(move |(_, count, _)| *count > since), )
) }
}
} }
#[inline] #[inline]
fn presence_key(room_id: &RoomId, user_id: &UserId) -> Vec<u8> { fn presence_key(room_id: &RoomId, user_id: &UserId) -> Vec<u8> {
[room_id.as_bytes(), &[0xff], user_id.as_bytes()].concat() [room_id.as_bytes(), &[0xFF], user_id.as_bytes()].concat()
} }

View file

@ -1,150 +1,113 @@
use std::mem; use std::mem;
use ruma::{ use ruma::{events::receipt::ReceiptEvent, serde::Raw, CanonicalJsonObject, OwnedUserId, RoomId, UserId};
events::receipt::ReceiptEvent, serde::Raw, CanonicalJsonObject, OwnedUserId, RoomId, UserId,
};
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
impl service::rooms::edus::read_receipt::Data for KeyValueDatabase { impl service::rooms::edus::read_receipt::Data for KeyValueDatabase {
fn readreceipt_update( fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: ReceiptEvent) -> Result<()> {
&self, let mut prefix = room_id.as_bytes().to_vec();
user_id: &UserId, prefix.push(0xFF);
room_id: &RoomId,
event: ReceiptEvent,
) -> Result<()> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xff);
let mut last_possible_key = prefix.clone(); let mut last_possible_key = prefix.clone();
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
// Remove old entry // Remove old entry
if let Some((old, _)) = self if let Some((old, _)) = self
.readreceiptid_readreceipt .readreceiptid_readreceipt
.iter_from(&last_possible_key, true) .iter_from(&last_possible_key, true)
.take_while(|(key, _)| key.starts_with(&prefix)) .take_while(|(key, _)| key.starts_with(&prefix))
.find(|(key, _)| { .find(|(key, _)| {
key.rsplit(|&b| b == 0xff) key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element") == user_id.as_bytes()
.next() }) {
.expect("rsplit always returns an element") // This is the old room_latest
== user_id.as_bytes() self.readreceiptid_readreceipt.remove(&old)?;
}) }
{
// This is the old room_latest
self.readreceiptid_readreceipt.remove(&old)?;
}
let mut room_latest_id = prefix; let mut room_latest_id = prefix;
room_latest_id.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); room_latest_id.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
room_latest_id.push(0xff); room_latest_id.push(0xFF);
room_latest_id.extend_from_slice(user_id.as_bytes()); room_latest_id.extend_from_slice(user_id.as_bytes());
self.readreceiptid_readreceipt.insert( self.readreceiptid_readreceipt.insert(
&room_latest_id, &room_latest_id,
&serde_json::to_vec(&event).expect("EduEvent::to_string always works"), &serde_json::to_vec(&event).expect("EduEvent::to_string always works"),
)?; )?;
Ok(()) Ok(())
} }
fn readreceipts_since<'a>( fn readreceipts_since<'a>(
&'a self, &'a self, room_id: &RoomId, since: u64,
room_id: &RoomId, ) -> Box<dyn Iterator<Item = Result<(OwnedUserId, u64, Raw<ruma::events::AnySyncEphemeralRoomEvent>)>> + 'a> {
since: u64, let mut prefix = room_id.as_bytes().to_vec();
) -> Box< prefix.push(0xFF);
dyn Iterator< let prefix2 = prefix.clone();
Item = Result<(
OwnedUserId,
u64,
Raw<ruma::events::AnySyncEphemeralRoomEvent>,
)>,
> + 'a,
> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xff);
let prefix2 = prefix.clone();
let mut first_possible_edu = prefix.clone(); let mut first_possible_edu = prefix.clone();
first_possible_edu.extend_from_slice(&(since + 1).to_be_bytes()); // +1 so we don't send the event at since first_possible_edu.extend_from_slice(&(since + 1).to_be_bytes()); // +1 so we don't send the event at since
Box::new( Box::new(
self.readreceiptid_readreceipt self.readreceiptid_readreceipt
.iter_from(&first_possible_edu, false) .iter_from(&first_possible_edu, false)
.take_while(move |(k, _)| k.starts_with(&prefix2)) .take_while(move |(k, _)| k.starts_with(&prefix2))
.map(move |(k, v)| { .map(move |(k, v)| {
let count = utils::u64_from_bytes( let count = utils::u64_from_bytes(&k[prefix.len()..prefix.len() + mem::size_of::<u64>()])
&k[prefix.len()..prefix.len() + mem::size_of::<u64>()], .map_err(|_| Error::bad_database("Invalid readreceiptid count in db."))?;
) let user_id = UserId::parse(
.map_err(|_| Error::bad_database("Invalid readreceiptid count in db."))?; utils::string_from_bytes(&k[prefix.len() + mem::size_of::<u64>() + 1..])
let user_id = UserId::parse( .map_err(|_| Error::bad_database("Invalid readreceiptid userid bytes in db."))?,
utils::string_from_bytes(&k[prefix.len() + mem::size_of::<u64>() + 1..]) )
.map_err(|_| { .map_err(|_| Error::bad_database("Invalid readreceiptid userid in db."))?;
Error::bad_database("Invalid readreceiptid userid bytes in db.")
})?,
)
.map_err(|_| Error::bad_database("Invalid readreceiptid userid in db."))?;
let mut json = let mut json = serde_json::from_slice::<CanonicalJsonObject>(&v)
serde_json::from_slice::<CanonicalJsonObject>(&v).map_err(|_| { .map_err(|_| Error::bad_database("Read receipt in roomlatestid_roomlatest is invalid json."))?;
Error::bad_database( json.remove("room_id");
"Read receipt in roomlatestid_roomlatest is invalid json.",
)
})?;
json.remove("room_id");
Ok(( Ok((
user_id, user_id,
count, count,
Raw::from_json( Raw::from_json(serde_json::value::to_raw_value(&json).expect("json is valid raw value")),
serde_json::value::to_raw_value(&json) ))
.expect("json is valid raw value"), }),
), )
)) }
}),
)
}
fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> { fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> {
let mut key = room_id.as_bytes().to_vec(); let mut key = room_id.as_bytes().to_vec();
key.push(0xff); key.push(0xFF);
key.extend_from_slice(user_id.as_bytes()); key.extend_from_slice(user_id.as_bytes());
self.roomuserid_privateread self.roomuserid_privateread.insert(&key, &count.to_be_bytes())?;
.insert(&key, &count.to_be_bytes())?;
self.roomuserid_lastprivatereadupdate self.roomuserid_lastprivatereadupdate.insert(&key, &services().globals.next_count()?.to_be_bytes())
.insert(&key, &services().globals.next_count()?.to_be_bytes()) }
}
fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> { fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {
let mut key = room_id.as_bytes().to_vec(); let mut key = room_id.as_bytes().to_vec();
key.push(0xff); key.push(0xFF);
key.extend_from_slice(user_id.as_bytes()); key.extend_from_slice(user_id.as_bytes());
self.roomuserid_privateread self.roomuserid_privateread.get(&key)?.map_or(Ok(None), |v| {
.get(&key)? Ok(Some(
.map_or(Ok(None), |v| { utils::u64_from_bytes(&v).map_err(|_| Error::bad_database("Invalid private read marker bytes"))?,
Ok(Some(utils::u64_from_bytes(&v).map_err(|_| { ))
Error::bad_database("Invalid private read marker bytes") })
})?)) }
})
}
fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> { fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
let mut key = room_id.as_bytes().to_vec(); let mut key = room_id.as_bytes().to_vec();
key.push(0xff); key.push(0xFF);
key.extend_from_slice(user_id.as_bytes()); key.extend_from_slice(user_id.as_bytes());
Ok(self Ok(self
.roomuserid_lastprivatereadupdate .roomuserid_lastprivatereadupdate
.get(&key)? .get(&key)?
.map(|bytes| { .map(|bytes| {
utils::u64_from_bytes(&bytes).map_err(|_| { utils::u64_from_bytes(&bytes)
Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid.") .map_err(|_| Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid."))
}) })
}) .transpose()?
.transpose()? .unwrap_or(0))
.unwrap_or(0)) }
}
} }

View file

@ -5,123 +5,111 @@ use ruma::{OwnedUserId, RoomId, UserId};
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
impl service::rooms::edus::typing::Data for KeyValueDatabase { impl service::rooms::edus::typing::Data for KeyValueDatabase {
fn typing_add(&self, user_id: &UserId, room_id: &RoomId, timeout: u64) -> Result<()> { fn typing_add(&self, user_id: &UserId, room_id: &RoomId, timeout: u64) -> Result<()> {
let mut prefix = room_id.as_bytes().to_vec(); let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xff); prefix.push(0xFF);
let count = services().globals.next_count()?.to_be_bytes(); let count = services().globals.next_count()?.to_be_bytes();
let mut room_typing_id = prefix; let mut room_typing_id = prefix;
room_typing_id.extend_from_slice(&timeout.to_be_bytes()); room_typing_id.extend_from_slice(&timeout.to_be_bytes());
room_typing_id.push(0xff); room_typing_id.push(0xFF);
room_typing_id.extend_from_slice(&count); room_typing_id.extend_from_slice(&count);
self.typingid_userid self.typingid_userid.insert(&room_typing_id, user_id.as_bytes())?;
.insert(&room_typing_id, user_id.as_bytes())?;
self.roomid_lasttypingupdate self.roomid_lasttypingupdate.insert(room_id.as_bytes(), &count)?;
.insert(room_id.as_bytes(), &count)?;
Ok(()) Ok(())
} }
fn typing_remove(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { fn typing_remove(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
let mut prefix = room_id.as_bytes().to_vec(); let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xff); prefix.push(0xFF);
let user_id = user_id.to_string(); let user_id = user_id.to_string();
let mut found_outdated = false; let mut found_outdated = false;
// Maybe there are multiple ones from calling roomtyping_add multiple times // Maybe there are multiple ones from calling roomtyping_add multiple times
for outdated_edu in self for outdated_edu in self.typingid_userid.scan_prefix(prefix).filter(|(_, v)| &**v == user_id.as_bytes()) {
.typingid_userid self.typingid_userid.remove(&outdated_edu.0)?;
.scan_prefix(prefix) found_outdated = true;
.filter(|(_, v)| &**v == user_id.as_bytes()) }
{
self.typingid_userid.remove(&outdated_edu.0)?;
found_outdated = true;
}
if found_outdated { if found_outdated {
self.roomid_lasttypingupdate.insert( self.roomid_lasttypingupdate.insert(room_id.as_bytes(), &services().globals.next_count()?.to_be_bytes())?;
room_id.as_bytes(), }
&services().globals.next_count()?.to_be_bytes(),
)?;
}
Ok(()) Ok(())
} }
fn typings_maintain(&self, room_id: &RoomId) -> Result<()> { fn typings_maintain(&self, room_id: &RoomId) -> Result<()> {
let mut prefix = room_id.as_bytes().to_vec(); let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xff); prefix.push(0xFF);
let current_timestamp = utils::millis_since_unix_epoch(); let current_timestamp = utils::millis_since_unix_epoch();
let mut found_outdated = false; let mut found_outdated = false;
// Find all outdated edus before inserting a new one // Find all outdated edus before inserting a new one
for outdated_edu in self for outdated_edu in self
.typingid_userid .typingid_userid
.scan_prefix(prefix) .scan_prefix(prefix)
.map(|(key, _)| { .map(|(key, _)| {
Ok::<_, Error>(( Ok::<_, Error>((
key.clone(), key.clone(),
utils::u64_from_bytes( utils::u64_from_bytes(
&key.splitn(2, |&b| b == 0xff).nth(1).ok_or_else(|| { &key.splitn(2, |&b| b == 0xFF)
Error::bad_database("RoomTyping has invalid timestamp or delimiters.") .nth(1)
})?[0..mem::size_of::<u64>()], .ok_or_else(|| Error::bad_database("RoomTyping has invalid timestamp or delimiters."))?[0..mem::size_of::<u64>()],
) )
.map_err(|_| Error::bad_database("RoomTyping has invalid timestamp bytes."))?, .map_err(|_| Error::bad_database("RoomTyping has invalid timestamp bytes."))?,
)) ))
}) })
.filter_map(std::result::Result::ok) .filter_map(std::result::Result::ok)
.take_while(|&(_, timestamp)| timestamp < current_timestamp) .take_while(|&(_, timestamp)| timestamp < current_timestamp)
{ {
// This is an outdated edu (time > timestamp) // This is an outdated edu (time > timestamp)
self.typingid_userid.remove(&outdated_edu.0)?; self.typingid_userid.remove(&outdated_edu.0)?;
found_outdated = true; found_outdated = true;
} }
if found_outdated { if found_outdated {
self.roomid_lasttypingupdate.insert( self.roomid_lasttypingupdate.insert(room_id.as_bytes(), &services().globals.next_count()?.to_be_bytes())?;
room_id.as_bytes(), }
&services().globals.next_count()?.to_be_bytes(),
)?;
}
Ok(()) Ok(())
} }
fn last_typing_update(&self, room_id: &RoomId) -> Result<u64> { fn last_typing_update(&self, room_id: &RoomId) -> Result<u64> {
Ok(self Ok(self
.roomid_lasttypingupdate .roomid_lasttypingupdate
.get(room_id.as_bytes())? .get(room_id.as_bytes())?
.map(|bytes| { .map(|bytes| {
utils::u64_from_bytes(&bytes).map_err(|_| { utils::u64_from_bytes(&bytes)
Error::bad_database("Count in roomid_lastroomactiveupdate is invalid.") .map_err(|_| Error::bad_database("Count in roomid_lastroomactiveupdate is invalid."))
}) })
}) .transpose()?
.transpose()? .unwrap_or(0))
.unwrap_or(0)) }
}
fn typings_all(&self, room_id: &RoomId) -> Result<HashSet<OwnedUserId>> { fn typings_all(&self, room_id: &RoomId) -> Result<HashSet<OwnedUserId>> {
let mut prefix = room_id.as_bytes().to_vec(); let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xff); prefix.push(0xFF);
let mut user_ids = HashSet::new(); let mut user_ids = HashSet::new();
for (_, user_id) in self.typingid_userid.scan_prefix(prefix) { for (_, user_id) in self.typingid_userid.scan_prefix(prefix) {
let user_id = UserId::parse(utils::string_from_bytes(&user_id).map_err(|_| { let user_id = UserId::parse(
Error::bad_database("User ID in typingid_userid is invalid unicode.") utils::string_from_bytes(&user_id)
})?) .map_err(|_| Error::bad_database("User ID in typingid_userid is invalid unicode."))?,
.map_err(|_| Error::bad_database("User ID in typingid_userid is invalid."))?; )
.map_err(|_| Error::bad_database("User ID in typingid_userid is invalid."))?;
user_ids.insert(user_id); user_ids.insert(user_id);
} }
Ok(user_ids) Ok(user_ids)
} }
} }

View file

@ -3,63 +3,51 @@ use ruma::{DeviceId, RoomId, UserId};
use crate::{database::KeyValueDatabase, service, Result}; use crate::{database::KeyValueDatabase, service, Result};
impl service::rooms::lazy_loading::Data for KeyValueDatabase { impl service::rooms::lazy_loading::Data for KeyValueDatabase {
fn lazy_load_was_sent_before( fn lazy_load_was_sent_before(
&self, &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, ll_user: &UserId,
user_id: &UserId, ) -> Result<bool> {
device_id: &DeviceId, let mut key = user_id.as_bytes().to_vec();
room_id: &RoomId, key.push(0xFF);
ll_user: &UserId, key.extend_from_slice(device_id.as_bytes());
) -> Result<bool> { key.push(0xFF);
let mut key = user_id.as_bytes().to_vec(); key.extend_from_slice(room_id.as_bytes());
key.push(0xff); key.push(0xFF);
key.extend_from_slice(device_id.as_bytes()); key.extend_from_slice(ll_user.as_bytes());
key.push(0xff); Ok(self.lazyloadedids.get(&key)?.is_some())
key.extend_from_slice(room_id.as_bytes()); }
key.push(0xff);
key.extend_from_slice(ll_user.as_bytes());
Ok(self.lazyloadedids.get(&key)?.is_some())
}
fn lazy_load_confirm_delivery( fn lazy_load_confirm_delivery(
&self, &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId,
user_id: &UserId, confirmed_user_ids: &mut dyn Iterator<Item = &UserId>,
device_id: &DeviceId, ) -> Result<()> {
room_id: &RoomId, let mut prefix = user_id.as_bytes().to_vec();
confirmed_user_ids: &mut dyn Iterator<Item = &UserId>, prefix.push(0xFF);
) -> Result<()> { prefix.extend_from_slice(device_id.as_bytes());
let mut prefix = user_id.as_bytes().to_vec(); prefix.push(0xFF);
prefix.push(0xff); prefix.extend_from_slice(room_id.as_bytes());
prefix.extend_from_slice(device_id.as_bytes()); prefix.push(0xFF);
prefix.push(0xff);
prefix.extend_from_slice(room_id.as_bytes());
prefix.push(0xff);
for ll_id in confirmed_user_ids { for ll_id in confirmed_user_ids {
let mut key = prefix.clone(); let mut key = prefix.clone();
key.extend_from_slice(ll_id.as_bytes()); key.extend_from_slice(ll_id.as_bytes());
self.lazyloadedids.insert(&key, &[])?; self.lazyloadedids.insert(&key, &[])?;
} }
Ok(()) Ok(())
} }
fn lazy_load_reset( fn lazy_load_reset(&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId) -> Result<()> {
&self, let mut prefix = user_id.as_bytes().to_vec();
user_id: &UserId, prefix.push(0xFF);
device_id: &DeviceId, prefix.extend_from_slice(device_id.as_bytes());
room_id: &RoomId, prefix.push(0xFF);
) -> Result<()> { prefix.extend_from_slice(room_id.as_bytes());
let mut prefix = user_id.as_bytes().to_vec(); prefix.push(0xFF);
prefix.push(0xff);
prefix.extend_from_slice(device_id.as_bytes());
prefix.push(0xff);
prefix.extend_from_slice(room_id.as_bytes());
prefix.push(0xff);
for (key, _) in self.lazyloadedids.scan_prefix(prefix) { for (key, _) in self.lazyloadedids.scan_prefix(prefix) {
self.lazyloadedids.remove(&key)?; self.lazyloadedids.remove(&key)?;
} }
Ok(()) Ok(())
} }
} }

View file

@ -4,76 +4,68 @@ use tracing::error;
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
impl service::rooms::metadata::Data for KeyValueDatabase { impl service::rooms::metadata::Data for KeyValueDatabase {
fn exists(&self, room_id: &RoomId) -> Result<bool> { fn exists(&self, room_id: &RoomId) -> Result<bool> {
let prefix = match services().rooms.short.get_shortroomid(room_id)? { let prefix = match services().rooms.short.get_shortroomid(room_id)? {
Some(b) => b.to_be_bytes().to_vec(), Some(b) => b.to_be_bytes().to_vec(),
None => return Ok(false), None => return Ok(false),
}; };
// Look for PDUs in that room. // Look for PDUs in that room.
Ok(self Ok(self.pduid_pdu.iter_from(&prefix, false).next().filter(|(k, _)| k.starts_with(&prefix)).is_some())
.pduid_pdu }
.iter_from(&prefix, false)
.next()
.filter(|(k, _)| k.starts_with(&prefix))
.is_some())
}
fn iter_ids<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> { fn iter_ids<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
Box::new(self.roomid_shortroomid.iter().map(|(bytes, _)| { Box::new(self.roomid_shortroomid.iter().map(|(bytes, _)| {
RoomId::parse( RoomId::parse(
utils::string_from_bytes(&bytes).map_err(|_| { utils::string_from_bytes(&bytes)
Error::bad_database("Room ID in publicroomids is invalid unicode.") .map_err(|_| Error::bad_database("Room ID in publicroomids is invalid unicode."))?,
})?, )
) .map_err(|_| Error::bad_database("Room ID in roomid_shortroomid is invalid."))
.map_err(|_| Error::bad_database("Room ID in roomid_shortroomid is invalid.")) }))
})) }
}
fn is_disabled(&self, room_id: &RoomId) -> Result<bool> { fn is_disabled(&self, room_id: &RoomId) -> Result<bool> {
Ok(self.disabledroomids.get(room_id.as_bytes())?.is_some()) Ok(self.disabledroomids.get(room_id.as_bytes())?.is_some())
} }
fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()> { fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()> {
if disabled { if disabled {
self.disabledroomids.insert(room_id.as_bytes(), &[])?; self.disabledroomids.insert(room_id.as_bytes(), &[])?;
} else { } else {
self.disabledroomids.remove(room_id.as_bytes())?; self.disabledroomids.remove(room_id.as_bytes())?;
} }
Ok(()) Ok(())
} }
fn is_banned(&self, room_id: &RoomId) -> Result<bool> { fn is_banned(&self, room_id: &RoomId) -> Result<bool> { Ok(self.bannedroomids.get(room_id.as_bytes())?.is_some()) }
Ok(self.bannedroomids.get(room_id.as_bytes())?.is_some())
}
fn ban_room(&self, room_id: &RoomId, banned: bool) -> Result<()> { fn ban_room(&self, room_id: &RoomId, banned: bool) -> Result<()> {
if banned { if banned {
self.bannedroomids.insert(room_id.as_bytes(), &[])?; self.bannedroomids.insert(room_id.as_bytes(), &[])?;
} else { } else {
self.bannedroomids.remove(room_id.as_bytes())?; self.bannedroomids.remove(room_id.as_bytes())?;
} }
Ok(()) Ok(())
} }
fn list_banned_rooms<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> { fn list_banned_rooms<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
Box::new(self.bannedroomids.iter().map( Box::new(self.bannedroomids.iter().map(
|(room_id_bytes, _ /* non-banned rooms should not be in this table */)| { |(room_id_bytes, _ /* non-banned rooms should not be in this table */)| {
let room_id = utils::string_from_bytes(&room_id_bytes) let room_id = utils::string_from_bytes(&room_id_bytes)
.map_err(|e| { .map_err(|e| {
error!("Invalid room_id bytes in bannedroomids: {e}"); error!("Invalid room_id bytes in bannedroomids: {e}");
Error::bad_database("Invalid room_id in bannedroomids.") Error::bad_database("Invalid room_id in bannedroomids.")
})? })?
.try_into() .try_into()
.map_err(|e| { .map_err(|e| {
error!("Invalid room_id in bannedroomids: {e}"); error!("Invalid room_id in bannedroomids: {e}");
Error::bad_database("Invalid room_id in bannedroomids") Error::bad_database("Invalid room_id in bannedroomids")
})?; })?;
Ok(room_id) Ok(room_id)
}, },
)) ))
} }
} }

View file

@ -3,26 +3,22 @@ use ruma::{CanonicalJsonObject, EventId};
use crate::{database::KeyValueDatabase, service, Error, PduEvent, Result}; use crate::{database::KeyValueDatabase, service, Error, PduEvent, Result};
impl service::rooms::outlier::Data for KeyValueDatabase { impl service::rooms::outlier::Data for KeyValueDatabase {
fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> { fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
self.eventid_outlierpdu self.eventid_outlierpdu.get(event_id.as_bytes())?.map_or(Ok(None), |pdu| {
.get(event_id.as_bytes())? serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))
.map_or(Ok(None), |pdu| { })
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) }
})
}
fn get_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>> { fn get_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>> {
self.eventid_outlierpdu self.eventid_outlierpdu.get(event_id.as_bytes())?.map_or(Ok(None), |pdu| {
.get(event_id.as_bytes())? serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))
.map_or(Ok(None), |pdu| { })
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) }
})
}
fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()> { fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()> {
self.eventid_outlierpdu.insert( self.eventid_outlierpdu.insert(
event_id.as_bytes(), event_id.as_bytes(),
&serde_json::to_vec(&pdu).expect("CanonicalJsonObject is valid"), &serde_json::to_vec(&pdu).expect("CanonicalJsonObject is valid"),
) )
} }
} }

View file

@ -3,85 +3,78 @@ use std::{mem, sync::Arc};
use ruma::{EventId, RoomId, UserId}; use ruma::{EventId, RoomId, UserId};
use crate::{ use crate::{
database::KeyValueDatabase, database::KeyValueDatabase,
service::{self, rooms::timeline::PduCount}, service::{self, rooms::timeline::PduCount},
services, utils, Error, PduEvent, Result, services, utils, Error, PduEvent, Result,
}; };
impl service::rooms::pdu_metadata::Data for KeyValueDatabase { impl service::rooms::pdu_metadata::Data for KeyValueDatabase {
fn add_relation(&self, from: u64, to: u64) -> Result<()> { fn add_relation(&self, from: u64, to: u64) -> Result<()> {
let mut key = to.to_be_bytes().to_vec(); let mut key = to.to_be_bytes().to_vec();
key.extend_from_slice(&from.to_be_bytes()); key.extend_from_slice(&from.to_be_bytes());
self.tofrom_relation.insert(&key, &[])?; self.tofrom_relation.insert(&key, &[])?;
Ok(()) Ok(())
} }
fn relations_until<'a>( fn relations_until<'a>(
&'a self, &'a self, user_id: &'a UserId, shortroomid: u64, target: u64, until: PduCount,
user_id: &'a UserId, ) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>> {
shortroomid: u64, let prefix = target.to_be_bytes().to_vec();
target: u64, let mut current = prefix.clone();
until: PduCount,
) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>> {
let prefix = target.to_be_bytes().to_vec();
let mut current = prefix.clone();
let count_raw = match until { let count_raw = match until {
PduCount::Normal(x) => x - 1, PduCount::Normal(x) => x - 1,
PduCount::Backfilled(x) => { PduCount::Backfilled(x) => {
current.extend_from_slice(&0_u64.to_be_bytes()); current.extend_from_slice(&0_u64.to_be_bytes());
u64::MAX - x - 1 u64::MAX - x - 1
} },
}; };
current.extend_from_slice(&count_raw.to_be_bytes()); current.extend_from_slice(&count_raw.to_be_bytes());
Ok(Box::new( Ok(Box::new(
self.tofrom_relation self.tofrom_relation.iter_from(&current, true).take_while(move |(k, _)| k.starts_with(&prefix)).map(
.iter_from(&current, true) move |(tofrom, _data)| {
.take_while(move |(k, _)| k.starts_with(&prefix)) let from = utils::u64_from_bytes(&tofrom[(mem::size_of::<u64>())..])
.map(move |(tofrom, _data)| { .map_err(|_| Error::bad_database("Invalid count in tofrom_relation."))?;
let from = utils::u64_from_bytes(&tofrom[(mem::size_of::<u64>())..])
.map_err(|_| Error::bad_database("Invalid count in tofrom_relation."))?;
let mut pduid = shortroomid.to_be_bytes().to_vec(); let mut pduid = shortroomid.to_be_bytes().to_vec();
pduid.extend_from_slice(&from.to_be_bytes()); pduid.extend_from_slice(&from.to_be_bytes());
let mut pdu = services() let mut pdu = services()
.rooms .rooms
.timeline .timeline
.get_pdu_from_id(&pduid)? .get_pdu_from_id(&pduid)?
.ok_or_else(|| Error::bad_database("Pdu in tofrom_relation is invalid."))?; .ok_or_else(|| Error::bad_database("Pdu in tofrom_relation is invalid."))?;
if pdu.sender != user_id { if pdu.sender != user_id {
pdu.remove_transaction_id()?; pdu.remove_transaction_id()?;
} }
Ok((PduCount::Normal(from), pdu)) Ok((PduCount::Normal(from), pdu))
}), },
)) ),
} ))
}
fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc<EventId>]) -> Result<()> { fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc<EventId>]) -> Result<()> {
for prev in event_ids { for prev in event_ids {
let mut key = room_id.as_bytes().to_vec(); let mut key = room_id.as_bytes().to_vec();
key.extend_from_slice(prev.as_bytes()); key.extend_from_slice(prev.as_bytes());
self.referencedevents.insert(&key, &[])?; self.referencedevents.insert(&key, &[])?;
} }
Ok(()) Ok(())
} }
fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result<bool> { fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result<bool> {
let mut key = room_id.as_bytes().to_vec(); let mut key = room_id.as_bytes().to_vec();
key.extend_from_slice(event_id.as_bytes()); key.extend_from_slice(event_id.as_bytes());
Ok(self.referencedevents.get(&key)?.is_some()) Ok(self.referencedevents.get(&key)?.is_some())
} }
fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()> { fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()> {
self.softfailedeventids.insert(event_id.as_bytes(), &[]) self.softfailedeventids.insert(event_id.as_bytes(), &[])
} }
fn is_event_soft_failed(&self, event_id: &EventId) -> Result<bool> { fn is_event_soft_failed(&self, event_id: &EventId) -> Result<bool> {
self.softfailedeventids self.softfailedeventids.get(event_id.as_bytes()).map(|o| o.is_some())
.get(event_id.as_bytes()) }
.map(|o| o.is_some())
}
} }

View file

@ -5,61 +5,55 @@ use crate::{database::KeyValueDatabase, service, services, utils, Result};
type SearchPdusResult<'a> = Result<Option<(Box<dyn Iterator<Item = Vec<u8>> + 'a>, Vec<String>)>>; type SearchPdusResult<'a> = Result<Option<(Box<dyn Iterator<Item = Vec<u8>> + 'a>, Vec<String>)>>;
impl service::rooms::search::Data for KeyValueDatabase { impl service::rooms::search::Data for KeyValueDatabase {
fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> { fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> {
let mut batch = message_body let mut batch = message_body
.split_terminator(|c: char| !c.is_alphanumeric()) .split_terminator(|c: char| !c.is_alphanumeric())
.filter(|s| !s.is_empty()) .filter(|s| !s.is_empty())
.filter(|word| word.len() <= 50) .filter(|word| word.len() <= 50)
.map(str::to_lowercase) .map(str::to_lowercase)
.map(|word| { .map(|word| {
let mut key = shortroomid.to_be_bytes().to_vec(); let mut key = shortroomid.to_be_bytes().to_vec();
key.extend_from_slice(word.as_bytes()); key.extend_from_slice(word.as_bytes());
key.push(0xff); key.push(0xFF);
key.extend_from_slice(pdu_id); // TODO: currently we save the room id a second time here key.extend_from_slice(pdu_id); // TODO: currently we save the room id a second time here
(key, Vec::new()) (key, Vec::new())
}); });
self.tokenids.insert_batch(&mut batch) self.tokenids.insert_batch(&mut batch)
} }
fn search_pdus<'a>(&'a self, room_id: &RoomId, search_string: &str) -> SearchPdusResult<'a> { fn search_pdus<'a>(&'a self, room_id: &RoomId, search_string: &str) -> SearchPdusResult<'a> {
let prefix = services() let prefix = services().rooms.short.get_shortroomid(room_id)?.expect("room exists").to_be_bytes().to_vec();
.rooms
.short
.get_shortroomid(room_id)?
.expect("room exists")
.to_be_bytes()
.to_vec();
let words: Vec<_> = search_string let words: Vec<_> = search_string
.split_terminator(|c: char| !c.is_alphanumeric()) .split_terminator(|c: char| !c.is_alphanumeric())
.filter(|s| !s.is_empty()) .filter(|s| !s.is_empty())
.map(str::to_lowercase) .map(str::to_lowercase)
.collect(); .collect();
let iterators = words.clone().into_iter().map(move |word| { let iterators = words.clone().into_iter().map(move |word| {
let mut prefix2 = prefix.clone(); let mut prefix2 = prefix.clone();
prefix2.extend_from_slice(word.as_bytes()); prefix2.extend_from_slice(word.as_bytes());
prefix2.push(0xff); prefix2.push(0xFF);
let prefix3 = prefix2.clone(); let prefix3 = prefix2.clone();
let mut last_possible_id = prefix2.clone(); let mut last_possible_id = prefix2.clone();
last_possible_id.extend_from_slice(&u64::MAX.to_be_bytes()); last_possible_id.extend_from_slice(&u64::MAX.to_be_bytes());
self.tokenids self.tokenids
.iter_from(&last_possible_id, true) // Newest pdus first .iter_from(&last_possible_id, true) // Newest pdus first
.take_while(move |(k, _)| k.starts_with(&prefix2)) .take_while(move |(k, _)| k.starts_with(&prefix2))
.map(move |(key, _)| key[prefix3.len()..].to_vec()) .map(move |(key, _)| key[prefix3.len()..].to_vec())
}); });
let common_elements = match utils::common_elements(iterators, |a, b| { let common_elements = match utils::common_elements(iterators, |a, b| {
// We compare b with a because we reversed the iterator earlier // We compare b with a because we reversed the iterator earlier
b.cmp(a) b.cmp(a)
}) { }) {
Some(it) => it, Some(it) => it,
None => return Ok(None), None => return Ok(None),
}; };
Ok(Some((Box::new(common_elements), words))) Ok(Some((Box::new(common_elements), words)))
} }
} }

View file

@ -6,214 +6,165 @@ use tracing::warn;
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
impl service::rooms::short::Data for KeyValueDatabase { impl service::rooms::short::Data for KeyValueDatabase {
fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result<u64> { fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result<u64> {
if let Some(short) = self.eventidshort_cache.lock().unwrap().get_mut(event_id) { if let Some(short) = self.eventidshort_cache.lock().unwrap().get_mut(event_id) {
return Ok(*short); return Ok(*short);
} }
let short = match self.eventid_shorteventid.get(event_id.as_bytes())? { let short = match self.eventid_shorteventid.get(event_id.as_bytes())? {
Some(shorteventid) => utils::u64_from_bytes(&shorteventid) Some(shorteventid) => {
.map_err(|_| Error::bad_database("Invalid shorteventid in db."))?, utils::u64_from_bytes(&shorteventid).map_err(|_| Error::bad_database("Invalid shorteventid in db."))?
None => { },
let shorteventid = services().globals.next_count()?; None => {
self.eventid_shorteventid let shorteventid = services().globals.next_count()?;
.insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?; self.eventid_shorteventid.insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?;
self.shorteventid_eventid self.shorteventid_eventid.insert(&shorteventid.to_be_bytes(), event_id.as_bytes())?;
.insert(&shorteventid.to_be_bytes(), event_id.as_bytes())?; shorteventid
shorteventid },
} };
};
self.eventidshort_cache self.eventidshort_cache.lock().unwrap().insert(event_id.to_owned(), short);
.lock()
.unwrap()
.insert(event_id.to_owned(), short);
Ok(short) Ok(short)
} }
fn get_shortstatekey( fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result<Option<u64>> {
&self, if let Some(short) =
event_type: &StateEventType, self.statekeyshort_cache.lock().unwrap().get_mut(&(event_type.clone(), state_key.to_owned()))
state_key: &str, {
) -> Result<Option<u64>> { return Ok(Some(*short));
if let Some(short) = self }
.statekeyshort_cache
.lock()
.unwrap()
.get_mut(&(event_type.clone(), state_key.to_owned()))
{
return Ok(Some(*short));
}
let mut statekey_vec = event_type.to_string().as_bytes().to_vec(); let mut statekey_vec = event_type.to_string().as_bytes().to_vec();
statekey_vec.push(0xff); statekey_vec.push(0xFF);
statekey_vec.extend_from_slice(state_key.as_bytes()); statekey_vec.extend_from_slice(state_key.as_bytes());
let short = self let short = self
.statekey_shortstatekey .statekey_shortstatekey
.get(&statekey_vec)? .get(&statekey_vec)?
.map(|shortstatekey| { .map(|shortstatekey| {
utils::u64_from_bytes(&shortstatekey) utils::u64_from_bytes(&shortstatekey).map_err(|_| Error::bad_database("Invalid shortstatekey in db."))
.map_err(|_| Error::bad_database("Invalid shortstatekey in db.")) })
}) .transpose()?;
.transpose()?;
if let Some(s) = short { if let Some(s) = short {
self.statekeyshort_cache self.statekeyshort_cache.lock().unwrap().insert((event_type.clone(), state_key.to_owned()), s);
.lock() }
.unwrap()
.insert((event_type.clone(), state_key.to_owned()), s);
}
Ok(short) Ok(short)
} }
fn get_or_create_shortstatekey( fn get_or_create_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result<u64> {
&self, if let Some(short) =
event_type: &StateEventType, self.statekeyshort_cache.lock().unwrap().get_mut(&(event_type.clone(), state_key.to_owned()))
state_key: &str, {
) -> Result<u64> { return Ok(*short);
if let Some(short) = self }
.statekeyshort_cache
.lock()
.unwrap()
.get_mut(&(event_type.clone(), state_key.to_owned()))
{
return Ok(*short);
}
let mut statekey_vec = event_type.to_string().as_bytes().to_vec(); let mut statekey_vec = event_type.to_string().as_bytes().to_vec();
statekey_vec.push(0xff); statekey_vec.push(0xFF);
statekey_vec.extend_from_slice(state_key.as_bytes()); statekey_vec.extend_from_slice(state_key.as_bytes());
let short = match self.statekey_shortstatekey.get(&statekey_vec)? { let short = match self.statekey_shortstatekey.get(&statekey_vec)? {
Some(shortstatekey) => utils::u64_from_bytes(&shortstatekey) Some(shortstatekey) => utils::u64_from_bytes(&shortstatekey)
.map_err(|_| Error::bad_database("Invalid shortstatekey in db."))?, .map_err(|_| Error::bad_database("Invalid shortstatekey in db."))?,
None => { None => {
let shortstatekey = services().globals.next_count()?; let shortstatekey = services().globals.next_count()?;
self.statekey_shortstatekey self.statekey_shortstatekey.insert(&statekey_vec, &shortstatekey.to_be_bytes())?;
.insert(&statekey_vec, &shortstatekey.to_be_bytes())?; self.shortstatekey_statekey.insert(&shortstatekey.to_be_bytes(), &statekey_vec)?;
self.shortstatekey_statekey shortstatekey
.insert(&shortstatekey.to_be_bytes(), &statekey_vec)?; },
shortstatekey };
}
};
self.statekeyshort_cache self.statekeyshort_cache.lock().unwrap().insert((event_type.clone(), state_key.to_owned()), short);
.lock()
.unwrap()
.insert((event_type.clone(), state_key.to_owned()), short);
Ok(short) Ok(short)
} }
fn get_eventid_from_short(&self, shorteventid: u64) -> Result<Arc<EventId>> { fn get_eventid_from_short(&self, shorteventid: u64) -> Result<Arc<EventId>> {
if let Some(id) = self if let Some(id) = self.shorteventid_cache.lock().unwrap().get_mut(&shorteventid) {
.shorteventid_cache return Ok(Arc::clone(id));
.lock() }
.unwrap()
.get_mut(&shorteventid)
{
return Ok(Arc::clone(id));
}
let bytes = self let bytes = self
.shorteventid_eventid .shorteventid_eventid
.get(&shorteventid.to_be_bytes())? .get(&shorteventid.to_be_bytes())?
.ok_or_else(|| Error::bad_database("Shorteventid does not exist"))?; .ok_or_else(|| Error::bad_database("Shorteventid does not exist"))?;
let event_id = EventId::parse_arc(utils::string_from_bytes(&bytes).map_err(|_| { let event_id = EventId::parse_arc(
Error::bad_database("EventID in shorteventid_eventid is invalid unicode.") utils::string_from_bytes(&bytes)
})?) .map_err(|_| Error::bad_database("EventID in shorteventid_eventid is invalid unicode."))?,
.map_err(|_| Error::bad_database("EventId in shorteventid_eventid is invalid."))?; )
.map_err(|_| Error::bad_database("EventId in shorteventid_eventid is invalid."))?;
self.shorteventid_cache self.shorteventid_cache.lock().unwrap().insert(shorteventid, Arc::clone(&event_id));
.lock()
.unwrap()
.insert(shorteventid, Arc::clone(&event_id));
Ok(event_id) Ok(event_id)
} }
fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> { fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> {
if let Some(id) = self if let Some(id) = self.shortstatekey_cache.lock().unwrap().get_mut(&shortstatekey) {
.shortstatekey_cache return Ok(id.clone());
.lock() }
.unwrap()
.get_mut(&shortstatekey)
{
return Ok(id.clone());
}
let bytes = self let bytes = self
.shortstatekey_statekey .shortstatekey_statekey
.get(&shortstatekey.to_be_bytes())? .get(&shortstatekey.to_be_bytes())?
.ok_or_else(|| Error::bad_database("Shortstatekey does not exist"))?; .ok_or_else(|| Error::bad_database("Shortstatekey does not exist"))?;
let mut parts = bytes.splitn(2, |&b| b == 0xff); let mut parts = bytes.splitn(2, |&b| b == 0xFF);
let eventtype_bytes = parts.next().expect("split always returns one entry"); let eventtype_bytes = parts.next().expect("split always returns one entry");
let statekey_bytes = parts let statekey_bytes =
.next() parts.next().ok_or_else(|| Error::bad_database("Invalid statekey in shortstatekey_statekey."))?;
.ok_or_else(|| Error::bad_database("Invalid statekey in shortstatekey_statekey."))?;
let event_type = let event_type = StateEventType::from(utils::string_from_bytes(eventtype_bytes).map_err(|e| {
StateEventType::from(utils::string_from_bytes(eventtype_bytes).map_err(|e| { warn!("Event type in shortstatekey_statekey is invalid: {}", e);
warn!("Event type in shortstatekey_statekey is invalid: {}", e); Error::bad_database("Event type in shortstatekey_statekey is invalid.")
Error::bad_database("Event type in shortstatekey_statekey is invalid.") })?);
})?);
let state_key = utils::string_from_bytes(statekey_bytes).map_err(|_| { let state_key = utils::string_from_bytes(statekey_bytes)
Error::bad_database("Statekey in shortstatekey_statekey is invalid unicode.") .map_err(|_| Error::bad_database("Statekey in shortstatekey_statekey is invalid unicode."))?;
})?;
let result = (event_type, state_key); let result = (event_type, state_key);
self.shortstatekey_cache self.shortstatekey_cache.lock().unwrap().insert(shortstatekey, result.clone());
.lock()
.unwrap()
.insert(shortstatekey, result.clone());
Ok(result) Ok(result)
} }
/// Returns (shortstatehash, already_existed) /// Returns (shortstatehash, already_existed)
fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)> { fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)> {
Ok(match self.statehash_shortstatehash.get(state_hash)? { Ok(match self.statehash_shortstatehash.get(state_hash)? {
Some(shortstatehash) => ( Some(shortstatehash) => (
utils::u64_from_bytes(&shortstatehash) utils::u64_from_bytes(&shortstatehash)
.map_err(|_| Error::bad_database("Invalid shortstatehash in db."))?, .map_err(|_| Error::bad_database("Invalid shortstatehash in db."))?,
true, true,
), ),
None => { None => {
let shortstatehash = services().globals.next_count()?; let shortstatehash = services().globals.next_count()?;
self.statehash_shortstatehash self.statehash_shortstatehash.insert(state_hash, &shortstatehash.to_be_bytes())?;
.insert(state_hash, &shortstatehash.to_be_bytes())?; (shortstatehash, false)
(shortstatehash, false) },
} })
}) }
}
fn get_shortroomid(&self, room_id: &RoomId) -> Result<Option<u64>> { fn get_shortroomid(&self, room_id: &RoomId) -> Result<Option<u64>> {
self.roomid_shortroomid self.roomid_shortroomid
.get(room_id.as_bytes())? .get(room_id.as_bytes())?
.map(|bytes| { .map(|bytes| utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid shortroomid in db.")))
utils::u64_from_bytes(&bytes) .transpose()
.map_err(|_| Error::bad_database("Invalid shortroomid in db.")) }
})
.transpose()
}
fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result<u64> { fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result<u64> {
Ok(match self.roomid_shortroomid.get(room_id.as_bytes())? { Ok(match self.roomid_shortroomid.get(room_id.as_bytes())? {
Some(short) => utils::u64_from_bytes(&short) Some(short) => {
.map_err(|_| Error::bad_database("Invalid shortroomid in db."))?, utils::u64_from_bytes(&short).map_err(|_| Error::bad_database("Invalid shortroomid in db."))?
None => { },
let short = services().globals.next_count()?; None => {
self.roomid_shortroomid let short = services().globals.next_count()?;
.insert(room_id.as_bytes(), &short.to_be_bytes())?; self.roomid_shortroomid.insert(room_id.as_bytes(), &short.to_be_bytes())?;
short short
} },
}) })
} }
} }

View file

@ -1,73 +1,69 @@
use ruma::{EventId, OwnedEventId, RoomId}; use std::{collections::HashSet, sync::Arc};
use std::collections::HashSet;
use std::sync::Arc; use ruma::{EventId, OwnedEventId, RoomId};
use tokio::sync::MutexGuard; use tokio::sync::MutexGuard;
use crate::{database::KeyValueDatabase, service, utils, Error, Result}; use crate::{database::KeyValueDatabase, service, utils, Error, Result};
impl service::rooms::state::Data for KeyValueDatabase { impl service::rooms::state::Data for KeyValueDatabase {
fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result<Option<u64>> { fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result<Option<u64>> {
self.roomid_shortstatehash self.roomid_shortstatehash.get(room_id.as_bytes())?.map_or(Ok(None), |bytes| {
.get(room_id.as_bytes())? Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| {
.map_or(Ok(None), |bytes| { Error::bad_database("Invalid shortstatehash in roomid_shortstatehash")
Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| { })?))
Error::bad_database("Invalid shortstatehash in roomid_shortstatehash") })
})?)) }
})
}
fn set_room_state( fn set_room_state(
&self, &self,
room_id: &RoomId, room_id: &RoomId,
new_shortstatehash: u64, new_shortstatehash: u64,
_mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex _mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex
) -> Result<()> { ) -> Result<()> {
self.roomid_shortstatehash self.roomid_shortstatehash.insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes())?;
.insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes())?; Ok(())
Ok(()) }
}
fn set_event_state(&self, shorteventid: u64, shortstatehash: u64) -> Result<()> { fn set_event_state(&self, shorteventid: u64, shortstatehash: u64) -> Result<()> {
self.shorteventid_shortstatehash self.shorteventid_shortstatehash.insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?;
.insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?; Ok(())
Ok(()) }
}
fn get_forward_extremities(&self, room_id: &RoomId) -> Result<HashSet<Arc<EventId>>> { fn get_forward_extremities(&self, room_id: &RoomId) -> Result<HashSet<Arc<EventId>>> {
let mut prefix = room_id.as_bytes().to_vec(); let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xff); prefix.push(0xFF);
self.roomid_pduleaves self.roomid_pduleaves
.scan_prefix(prefix) .scan_prefix(prefix)
.map(|(_, bytes)| { .map(|(_, bytes)| {
EventId::parse_arc(utils::string_from_bytes(&bytes).map_err(|_| { EventId::parse_arc(
Error::bad_database("EventID in roomid_pduleaves is invalid unicode.") utils::string_from_bytes(&bytes)
})?) .map_err(|_| Error::bad_database("EventID in roomid_pduleaves is invalid unicode."))?,
.map_err(|_| Error::bad_database("EventId in roomid_pduleaves is invalid.")) )
}) .map_err(|_| Error::bad_database("EventId in roomid_pduleaves is invalid."))
.collect() })
} .collect()
}
fn set_forward_extremities( fn set_forward_extremities(
&self, &self,
room_id: &RoomId, room_id: &RoomId,
event_ids: Vec<OwnedEventId>, event_ids: Vec<OwnedEventId>,
_mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex _mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex
) -> Result<()> { ) -> Result<()> {
let mut prefix = room_id.as_bytes().to_vec(); let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xff); prefix.push(0xFF);
for (key, _) in self.roomid_pduleaves.scan_prefix(prefix.clone()) { for (key, _) in self.roomid_pduleaves.scan_prefix(prefix.clone()) {
self.roomid_pduleaves.remove(&key)?; self.roomid_pduleaves.remove(&key)?;
} }
for event_id in event_ids { for event_id in event_ids {
let mut key = prefix.clone(); let mut key = prefix.clone();
key.extend_from_slice(event_id.as_bytes()); key.extend_from_slice(event_id.as_bytes());
self.roomid_pduleaves.insert(&key, event_id.as_bytes())?; self.roomid_pduleaves.insert(&key, event_id.as_bytes())?;
} }
Ok(()) Ok(())
} }
} }

View file

@ -1,186 +1,144 @@
use std::{collections::HashMap, sync::Arc}; use std::{collections::HashMap, sync::Arc};
use crate::{database::KeyValueDatabase, service, services, utils, Error, PduEvent, Result};
use async_trait::async_trait; use async_trait::async_trait;
use ruma::{events::StateEventType, EventId, RoomId}; use ruma::{events::StateEventType, EventId, RoomId};
use crate::{database::KeyValueDatabase, service, services, utils, Error, PduEvent, Result};
#[async_trait] #[async_trait]
impl service::rooms::state_accessor::Data for KeyValueDatabase { impl service::rooms::state_accessor::Data for KeyValueDatabase {
async fn state_full_ids(&self, shortstatehash: u64) -> Result<HashMap<u64, Arc<EventId>>> { async fn state_full_ids(&self, shortstatehash: u64) -> Result<HashMap<u64, Arc<EventId>>> {
let full_state = services() let full_state = services()
.rooms .rooms
.state_compressor .state_compressor
.load_shortstatehash_info(shortstatehash)? .load_shortstatehash_info(shortstatehash)?
.pop() .pop()
.expect("there is always one layer") .expect("there is always one layer")
.1; .1;
let mut result = HashMap::new(); let mut result = HashMap::new();
let mut i = 0; let mut i = 0;
for compressed in full_state.iter() { for compressed in full_state.iter() {
let parsed = services() let parsed = services().rooms.state_compressor.parse_compressed_state_event(compressed)?;
.rooms result.insert(parsed.0, parsed.1);
.state_compressor
.parse_compressed_state_event(compressed)?;
result.insert(parsed.0, parsed.1);
i += 1; i += 1;
if i % 100 == 0 { if i % 100 == 0 {
tokio::task::yield_now().await; tokio::task::yield_now().await;
} }
} }
Ok(result) Ok(result)
} }
async fn state_full( async fn state_full(&self, shortstatehash: u64) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> {
&self, let full_state = services()
shortstatehash: u64, .rooms
) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> { .state_compressor
let full_state = services() .load_shortstatehash_info(shortstatehash)?
.rooms .pop()
.state_compressor .expect("there is always one layer")
.load_shortstatehash_info(shortstatehash)? .1;
.pop()
.expect("there is always one layer")
.1;
let mut result = HashMap::new(); let mut result = HashMap::new();
let mut i = 0; let mut i = 0;
for compressed in full_state.iter() { for compressed in full_state.iter() {
let (_, eventid) = services() let (_, eventid) = services().rooms.state_compressor.parse_compressed_state_event(compressed)?;
.rooms if let Some(pdu) = services().rooms.timeline.get_pdu(&eventid)? {
.state_compressor result.insert(
.parse_compressed_state_event(compressed)?; (
if let Some(pdu) = services().rooms.timeline.get_pdu(&eventid)? { pdu.kind.to_string().into(),
result.insert( pdu.state_key
( .as_ref()
pdu.kind.to_string().into(), .ok_or_else(|| Error::bad_database("State event has no state key."))?
pdu.state_key .clone(),
.as_ref() ),
.ok_or_else(|| Error::bad_database("State event has no state key."))? pdu,
.clone(), );
), }
pdu,
);
}
i += 1; i += 1;
if i % 100 == 0 { if i % 100 == 0 {
tokio::task::yield_now().await; tokio::task::yield_now().await;
} }
} }
Ok(result) Ok(result)
} }
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). /// Returns a single PDU from `room_id` with key (`event_type`,
fn state_get_id( /// `state_key`).
&self, fn state_get_id(
shortstatehash: u64, &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str,
event_type: &StateEventType, ) -> Result<Option<Arc<EventId>>> {
state_key: &str, let shortstatekey = match services().rooms.short.get_shortstatekey(event_type, state_key)? {
) -> Result<Option<Arc<EventId>>> { Some(s) => s,
let shortstatekey = match services() None => return Ok(None),
.rooms };
.short let full_state = services()
.get_shortstatekey(event_type, state_key)? .rooms
{ .state_compressor
Some(s) => s, .load_shortstatehash_info(shortstatehash)?
None => return Ok(None), .pop()
}; .expect("there is always one layer")
let full_state = services() .1;
.rooms Ok(
.state_compressor full_state.iter().find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes())).and_then(|compressed| {
.load_shortstatehash_info(shortstatehash)? services().rooms.state_compressor.parse_compressed_state_event(compressed).ok().map(|(_, id)| id)
.pop() }),
.expect("there is always one layer") )
.1; }
Ok(full_state
.iter()
.find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes()))
.and_then(|compressed| {
services()
.rooms
.state_compressor
.parse_compressed_state_event(compressed)
.ok()
.map(|(_, id)| id)
}))
}
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). /// Returns a single PDU from `room_id` with key (`event_type`,
fn state_get( /// `state_key`).
&self, fn state_get(
shortstatehash: u64, &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str,
event_type: &StateEventType, ) -> Result<Option<Arc<PduEvent>>> {
state_key: &str, self.state_get_id(shortstatehash, event_type, state_key)?
) -> Result<Option<Arc<PduEvent>>> { .map_or(Ok(None), |event_id| services().rooms.timeline.get_pdu(&event_id))
self.state_get_id(shortstatehash, event_type, state_key)? }
.map_or(Ok(None), |event_id| {
services().rooms.timeline.get_pdu(&event_id)
})
}
/// Returns the state hash for this pdu. /// Returns the state hash for this pdu.
fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<Option<u64>> { fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<Option<u64>> {
self.eventid_shorteventid self.eventid_shorteventid.get(event_id.as_bytes())?.map_or(Ok(None), |shorteventid| {
.get(event_id.as_bytes())? self.shorteventid_shortstatehash
.map_or(Ok(None), |shorteventid| { .get(&shorteventid)?
self.shorteventid_shortstatehash .map(|bytes| {
.get(&shorteventid)? utils::u64_from_bytes(&bytes)
.map(|bytes| { .map_err(|_| Error::bad_database("Invalid shortstatehash bytes in shorteventid_shortstatehash"))
utils::u64_from_bytes(&bytes).map_err(|_| { })
Error::bad_database( .transpose()
"Invalid shortstatehash bytes in shorteventid_shortstatehash", })
) }
})
})
.transpose()
})
}
/// Returns the full room state. /// Returns the full room state.
async fn room_state_full( async fn room_state_full(&self, room_id: &RoomId) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> {
&self, if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? {
room_id: &RoomId, self.state_full(current_shortstatehash).await
) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> { } else {
if let Some(current_shortstatehash) = Ok(HashMap::new())
services().rooms.state.get_room_shortstatehash(room_id)? }
{ }
self.state_full(current_shortstatehash).await
} else {
Ok(HashMap::new())
}
}
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). /// Returns a single PDU from `room_id` with key (`event_type`,
fn room_state_get_id( /// `state_key`).
&self, fn room_state_get_id(
room_id: &RoomId, &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
event_type: &StateEventType, ) -> Result<Option<Arc<EventId>>> {
state_key: &str, if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? {
) -> Result<Option<Arc<EventId>>> { self.state_get_id(current_shortstatehash, event_type, state_key)
if let Some(current_shortstatehash) = } else {
services().rooms.state.get_room_shortstatehash(room_id)? Ok(None)
{ }
self.state_get_id(current_shortstatehash, event_type, state_key) }
} else {
Ok(None)
}
}
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). /// Returns a single PDU from `room_id` with key (`event_type`,
fn room_state_get( /// `state_key`).
&self, fn room_state_get(
room_id: &RoomId, &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
event_type: &StateEventType, ) -> Result<Option<Arc<PduEvent>>> {
state_key: &str, if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? {
) -> Result<Option<Arc<PduEvent>>> { self.state_get(current_shortstatehash, event_type, state_key)
if let Some(current_shortstatehash) = } else {
services().rooms.state.get_room_shortstatehash(room_id)? Ok(None)
{ }
self.state_get(current_shortstatehash, event_type, state_key) }
} else {
Ok(None)
}
}
} }

File diff suppressed because it is too large Load diff

View file

@ -1,61 +1,63 @@
use std::{collections::HashSet, mem::size_of, sync::Arc}; use std::{collections::HashSet, mem::size_of, sync::Arc};
use crate::{ use crate::{
database::KeyValueDatabase, database::KeyValueDatabase,
service::{self, rooms::state_compressor::data::StateDiff}, service::{self, rooms::state_compressor::data::StateDiff},
utils, Error, Result, utils, Error, Result,
}; };
impl service::rooms::state_compressor::Data for KeyValueDatabase { impl service::rooms::state_compressor::Data for KeyValueDatabase {
fn get_statediff(&self, shortstatehash: u64) -> Result<StateDiff> { fn get_statediff(&self, shortstatehash: u64) -> Result<StateDiff> {
let value = self let value = self
.shortstatehash_statediff .shortstatehash_statediff
.get(&shortstatehash.to_be_bytes())? .get(&shortstatehash.to_be_bytes())?
.ok_or_else(|| Error::bad_database("State hash does not exist"))?; .ok_or_else(|| Error::bad_database("State hash does not exist"))?;
let parent = let parent = utils::u64_from_bytes(&value[0..size_of::<u64>()]).expect("bytes have right length");
utils::u64_from_bytes(&value[0..size_of::<u64>()]).expect("bytes have right length"); let parent = if parent != 0 {
let parent = if parent != 0 { Some(parent) } else { None }; Some(parent)
} else {
None
};
let mut add_mode = true; let mut add_mode = true;
let mut added = HashSet::new(); let mut added = HashSet::new();
let mut removed = HashSet::new(); let mut removed = HashSet::new();
let mut i = size_of::<u64>(); let mut i = size_of::<u64>();
while let Some(v) = value.get(i..i + 2 * size_of::<u64>()) { while let Some(v) = value.get(i..i + 2 * size_of::<u64>()) {
if add_mode && v.starts_with(&0_u64.to_be_bytes()) { if add_mode && v.starts_with(&0_u64.to_be_bytes()) {
add_mode = false; add_mode = false;
i += size_of::<u64>(); i += size_of::<u64>();
continue; continue;
} }
if add_mode { if add_mode {
added.insert(v.try_into().expect("we checked the size above")); added.insert(v.try_into().expect("we checked the size above"));
} else { } else {
removed.insert(v.try_into().expect("we checked the size above")); removed.insert(v.try_into().expect("we checked the size above"));
} }
i += 2 * size_of::<u64>(); i += 2 * size_of::<u64>();
} }
Ok(StateDiff { Ok(StateDiff {
parent, parent,
added: Arc::new(added), added: Arc::new(added),
removed: Arc::new(removed), removed: Arc::new(removed),
}) })
} }
fn save_statediff(&self, shortstatehash: u64, diff: StateDiff) -> Result<()> { fn save_statediff(&self, shortstatehash: u64, diff: StateDiff) -> Result<()> {
let mut value = diff.parent.unwrap_or(0).to_be_bytes().to_vec(); let mut value = diff.parent.unwrap_or(0).to_be_bytes().to_vec();
for new in diff.added.iter() { for new in diff.added.iter() {
value.extend_from_slice(&new[..]); value.extend_from_slice(&new[..]);
} }
if !diff.removed.is_empty() { if !diff.removed.is_empty() {
value.extend_from_slice(&0_u64.to_be_bytes()); value.extend_from_slice(&0_u64.to_be_bytes());
for removed in diff.removed.iter() { for removed in diff.removed.iter() {
value.extend_from_slice(&removed[..]); value.extend_from_slice(&removed[..]);
} }
} }
self.shortstatehash_statediff self.shortstatehash_statediff.insert(&shortstatehash.to_be_bytes(), &value)
.insert(&shortstatehash.to_be_bytes(), &value) }
}
} }

View file

@ -7,74 +7,58 @@ use crate::{database::KeyValueDatabase, service, services, utils, Error, PduEven
type PduEventIterResult<'a> = Result<Box<dyn Iterator<Item = Result<(u64, PduEvent)>> + 'a>>; type PduEventIterResult<'a> = Result<Box<dyn Iterator<Item = Result<(u64, PduEvent)>> + 'a>>;
impl service::rooms::threads::Data for KeyValueDatabase { impl service::rooms::threads::Data for KeyValueDatabase {
fn threads_until<'a>( fn threads_until<'a>(
&'a self, &'a self, user_id: &'a UserId, room_id: &'a RoomId, until: u64, _include: &'a IncludeThreads,
user_id: &'a UserId, ) -> PduEventIterResult<'a> {
room_id: &'a RoomId, let prefix = services().rooms.short.get_shortroomid(room_id)?.expect("room exists").to_be_bytes().to_vec();
until: u64,
_include: &'a IncludeThreads,
) -> PduEventIterResult<'a> {
let prefix = services()
.rooms
.short
.get_shortroomid(room_id)?
.expect("room exists")
.to_be_bytes()
.to_vec();
let mut current = prefix.clone(); let mut current = prefix.clone();
current.extend_from_slice(&(until - 1).to_be_bytes()); current.extend_from_slice(&(until - 1).to_be_bytes());
Ok(Box::new( Ok(Box::new(
self.threadid_userids self.threadid_userids.iter_from(&current, true).take_while(move |(k, _)| k.starts_with(&prefix)).map(
.iter_from(&current, true) move |(pduid, _users)| {
.take_while(move |(k, _)| k.starts_with(&prefix)) let count = utils::u64_from_bytes(&pduid[(mem::size_of::<u64>())..])
.map(move |(pduid, _users)| { .map_err(|_| Error::bad_database("Invalid pduid in threadid_userids."))?;
let count = utils::u64_from_bytes(&pduid[(mem::size_of::<u64>())..]) let mut pdu = services()
.map_err(|_| Error::bad_database("Invalid pduid in threadid_userids."))?; .rooms
let mut pdu = services() .timeline
.rooms .get_pdu_from_id(&pduid)?
.timeline .ok_or_else(|| Error::bad_database("Invalid pduid reference in threadid_userids"))?;
.get_pdu_from_id(&pduid)? if pdu.sender != user_id {
.ok_or_else(|| { pdu.remove_transaction_id()?;
Error::bad_database("Invalid pduid reference in threadid_userids") }
})?; Ok((count, pdu))
if pdu.sender != user_id { },
pdu.remove_transaction_id()?; ),
} ))
Ok((count, pdu)) }
}),
))
}
fn update_participants(&self, root_id: &[u8], participants: &[OwnedUserId]) -> Result<()> { fn update_participants(&self, root_id: &[u8], participants: &[OwnedUserId]) -> Result<()> {
let users = participants let users = participants.iter().map(|user| user.as_bytes()).collect::<Vec<_>>().join(&[0xFF][..]);
.iter()
.map(|user| user.as_bytes())
.collect::<Vec<_>>()
.join(&[0xff][..]);
self.threadid_userids.insert(root_id, &users)?; self.threadid_userids.insert(root_id, &users)?;
Ok(()) Ok(())
} }
fn get_participants(&self, root_id: &[u8]) -> Result<Option<Vec<OwnedUserId>>> { fn get_participants(&self, root_id: &[u8]) -> Result<Option<Vec<OwnedUserId>>> {
if let Some(users) = self.threadid_userids.get(root_id)? { if let Some(users) = self.threadid_userids.get(root_id)? {
Ok(Some( Ok(Some(
users users
.split(|b| *b == 0xff) .split(|b| *b == 0xFF)
.map(|bytes| { .map(|bytes| {
UserId::parse(utils::string_from_bytes(bytes).map_err(|_| { UserId::parse(
Error::bad_database("Invalid UserId bytes in threadid_userids.") utils::string_from_bytes(bytes)
})?) .map_err(|_| Error::bad_database("Invalid UserId bytes in threadid_userids."))?,
.map_err(|_| Error::bad_database("Invalid UserId in threadid_userids.")) )
}) .map_err(|_| Error::bad_database("Invalid UserId in threadid_userids."))
.filter_map(std::result::Result::ok) })
.collect(), .filter_map(std::result::Result::ok)
)) .collect(),
} else { ))
Ok(None) } else {
} Ok(None)
} }
}
} }

View file

@ -1,364 +1,286 @@
use std::{collections::hash_map, mem::size_of, sync::Arc}; use std::{collections::hash_map, mem::size_of, sync::Arc};
use ruma::{ use ruma::{api::client::error::ErrorKind, CanonicalJsonObject, EventId, OwnedUserId, RoomId, UserId};
api::client::error::ErrorKind, CanonicalJsonObject, EventId, OwnedUserId, RoomId, UserId, use service::rooms::timeline::PduCount;
};
use tracing::error; use tracing::error;
use crate::{database::KeyValueDatabase, service, services, utils, Error, PduEvent, Result}; use crate::{database::KeyValueDatabase, service, services, utils, Error, PduEvent, Result};
use service::rooms::timeline::PduCount;
impl service::rooms::timeline::Data for KeyValueDatabase { impl service::rooms::timeline::Data for KeyValueDatabase {
fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result<PduCount> { fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result<PduCount> {
match self match self.lasttimelinecount_cache.lock().unwrap().entry(room_id.to_owned()) {
.lasttimelinecount_cache hash_map::Entry::Vacant(v) => {
.lock() if let Some(last_count) = self.pdus_until(sender_user, room_id, PduCount::max())?.find_map(|r| {
.unwrap() // Filter out buggy events
.entry(room_id.to_owned()) if r.is_err() {
{ error!("Bad pdu in pdus_since: {:?}", r);
hash_map::Entry::Vacant(v) => { }
if let Some(last_count) = self r.ok()
.pdus_until(sender_user, room_id, PduCount::max())? }) {
.find_map(|r| { Ok(*v.insert(last_count.0))
// Filter out buggy events } else {
if r.is_err() { Ok(PduCount::Normal(0))
error!("Bad pdu in pdus_since: {:?}", r); }
} },
r.ok() hash_map::Entry::Occupied(o) => Ok(*o.get()),
}) }
{ }
Ok(*v.insert(last_count.0))
} else {
Ok(PduCount::Normal(0))
}
}
hash_map::Entry::Occupied(o) => Ok(*o.get()),
}
}
/// Returns the `count` of this pdu's id. /// Returns the `count` of this pdu's id.
fn get_pdu_count(&self, event_id: &EventId) -> Result<Option<PduCount>> { fn get_pdu_count(&self, event_id: &EventId) -> Result<Option<PduCount>> {
self.eventid_pduid self.eventid_pduid.get(event_id.as_bytes())?.map(|pdu_id| pdu_count(&pdu_id)).transpose()
.get(event_id.as_bytes())? }
.map(|pdu_id| pdu_count(&pdu_id))
.transpose()
}
/// Returns the json of a pdu. /// Returns the json of a pdu.
fn get_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> { fn get_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
self.get_non_outlier_pdu_json(event_id)?.map_or_else( self.get_non_outlier_pdu_json(event_id)?.map_or_else(
|| { || {
self.eventid_outlierpdu self.eventid_outlierpdu
.get(event_id.as_bytes())? .get(event_id.as_bytes())?
.map(|pdu| { .map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")))
serde_json::from_slice(&pdu) .transpose()
.map_err(|_| Error::bad_database("Invalid PDU in db.")) },
}) |x| Ok(Some(x)),
.transpose() )
}, }
|x| Ok(Some(x)),
)
}
/// Returns the json of a pdu. /// Returns the json of a pdu.
fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> { fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
self.eventid_pduid self.eventid_pduid
.get(event_id.as_bytes())? .get(event_id.as_bytes())?
.map(|pduid| { .map(|pduid| {
self.pduid_pdu self.pduid_pdu.get(&pduid)?.ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid."))
.get(&pduid)? })
.ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid.")) .transpose()?
}) .map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")))
.transpose()? .transpose()
.map(|pdu| { }
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))
})
.transpose()
}
/// Returns the pdu's id. /// Returns the pdu's id.
fn get_pdu_id(&self, event_id: &EventId) -> Result<Option<Vec<u8>>> { fn get_pdu_id(&self, event_id: &EventId) -> Result<Option<Vec<u8>>> { self.eventid_pduid.get(event_id.as_bytes()) }
self.eventid_pduid.get(event_id.as_bytes())
}
/// Returns the pdu. /// Returns the pdu.
fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>> { fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>> {
self.eventid_pduid self.eventid_pduid
.get(event_id.as_bytes())? .get(event_id.as_bytes())?
.map(|pduid| { .map(|pduid| {
self.pduid_pdu self.pduid_pdu.get(&pduid)?.ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid."))
.get(&pduid)? })
.ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid.")) .transpose()?
}) .map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")))
.transpose()? .transpose()
.map(|pdu| { }
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))
})
.transpose()
}
/// Returns the pdu. /// Returns the pdu.
/// ///
/// Checks the `eventid_outlierpdu` Tree if not found in the timeline. /// Checks the `eventid_outlierpdu` Tree if not found in the timeline.
fn get_pdu(&self, event_id: &EventId) -> Result<Option<Arc<PduEvent>>> { fn get_pdu(&self, event_id: &EventId) -> Result<Option<Arc<PduEvent>>> {
if let Some(p) = self.pdu_cache.lock().unwrap().get_mut(event_id) { if let Some(p) = self.pdu_cache.lock().unwrap().get_mut(event_id) {
return Ok(Some(Arc::clone(p))); return Ok(Some(Arc::clone(p)));
} }
if let Some(pdu) = self if let Some(pdu) = self
.get_non_outlier_pdu(event_id)? .get_non_outlier_pdu(event_id)?
.map_or_else( .map_or_else(
|| { || {
self.eventid_outlierpdu self.eventid_outlierpdu
.get(event_id.as_bytes())? .get(event_id.as_bytes())?
.map(|pdu| { .map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")))
serde_json::from_slice(&pdu) .transpose()
.map_err(|_| Error::bad_database("Invalid PDU in db.")) },
}) |x| Ok(Some(x)),
.transpose() )?
}, .map(Arc::new)
|x| Ok(Some(x)), {
)? self.pdu_cache.lock().unwrap().insert(event_id.to_owned(), Arc::clone(&pdu));
.map(Arc::new) Ok(Some(pdu))
{ } else {
self.pdu_cache Ok(None)
.lock() }
.unwrap() }
.insert(event_id.to_owned(), Arc::clone(&pdu));
Ok(Some(pdu))
} else {
Ok(None)
}
}
/// Returns the pdu. /// Returns the pdu.
/// ///
/// This does __NOT__ check the outliers `Tree`. /// This does __NOT__ check the outliers `Tree`.
fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result<Option<PduEvent>> { fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result<Option<PduEvent>> {
self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| { self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| {
Ok(Some( Ok(Some(
serde_json::from_slice(&pdu) serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))?,
.map_err(|_| Error::bad_database("Invalid PDU in db."))?, ))
)) })
}) }
}
/// Returns the pdu as a `BTreeMap<String, CanonicalJsonValue>`. /// Returns the pdu as a `BTreeMap<String, CanonicalJsonValue>`.
fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result<Option<CanonicalJsonObject>> { fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result<Option<CanonicalJsonObject>> {
self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| { self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| {
Ok(Some( Ok(Some(
serde_json::from_slice(&pdu) serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))?,
.map_err(|_| Error::bad_database("Invalid PDU in db."))?, ))
)) })
}) }
}
fn append_pdu( fn append_pdu(&self, pdu_id: &[u8], pdu: &PduEvent, json: &CanonicalJsonObject, count: u64) -> Result<()> {
&self, self.pduid_pdu.insert(
pdu_id: &[u8], pdu_id,
pdu: &PduEvent, &serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"),
json: &CanonicalJsonObject, )?;
count: u64,
) -> Result<()> {
self.pduid_pdu.insert(
pdu_id,
&serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"),
)?;
self.lasttimelinecount_cache self.lasttimelinecount_cache.lock().unwrap().insert(pdu.room_id.clone(), PduCount::Normal(count));
.lock()
.unwrap()
.insert(pdu.room_id.clone(), PduCount::Normal(count));
self.eventid_pduid.insert(pdu.event_id.as_bytes(), pdu_id)?; self.eventid_pduid.insert(pdu.event_id.as_bytes(), pdu_id)?;
self.eventid_outlierpdu.remove(pdu.event_id.as_bytes())?; self.eventid_outlierpdu.remove(pdu.event_id.as_bytes())?;
Ok(()) Ok(())
} }
fn prepend_backfill_pdu( fn prepend_backfill_pdu(&self, pdu_id: &[u8], event_id: &EventId, json: &CanonicalJsonObject) -> Result<()> {
&self, self.pduid_pdu.insert(
pdu_id: &[u8], pdu_id,
event_id: &EventId, &serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"),
json: &CanonicalJsonObject, )?;
) -> Result<()> {
self.pduid_pdu.insert(
pdu_id,
&serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"),
)?;
self.eventid_pduid.insert(event_id.as_bytes(), pdu_id)?; self.eventid_pduid.insert(event_id.as_bytes(), pdu_id)?;
self.eventid_outlierpdu.remove(event_id.as_bytes())?; self.eventid_outlierpdu.remove(event_id.as_bytes())?;
Ok(()) Ok(())
} }
/// Removes a pdu and creates a new one with the same id. /// Removes a pdu and creates a new one with the same id.
fn replace_pdu( fn replace_pdu(&self, pdu_id: &[u8], pdu_json: &CanonicalJsonObject, pdu: &PduEvent) -> Result<()> {
&self, if self.pduid_pdu.get(pdu_id)?.is_some() {
pdu_id: &[u8], self.pduid_pdu.insert(
pdu_json: &CanonicalJsonObject, pdu_id,
pdu: &PduEvent, &serde_json::to_vec(pdu_json).expect("CanonicalJsonObject is always a valid"),
) -> Result<()> { )?;
if self.pduid_pdu.get(pdu_id)?.is_some() { } else {
self.pduid_pdu.insert( return Err(Error::BadRequest(ErrorKind::NotFound, "PDU does not exist."));
pdu_id, }
&serde_json::to_vec(pdu_json).expect("CanonicalJsonObject is always a valid"),
)?;
} else {
return Err(Error::BadRequest(
ErrorKind::NotFound,
"PDU does not exist.",
));
}
self.pdu_cache self.pdu_cache.lock().unwrap().remove(&(*pdu.event_id).to_owned());
.lock()
.unwrap()
.remove(&(*pdu.event_id).to_owned());
Ok(()) Ok(())
} }
/// Returns an iterator over all events and their tokens in a room that happened before the /// Returns an iterator over all events and their tokens in a room that
/// event with id `until` in reverse-chronological order. /// happened before the event with id `until` in reverse-chronological
fn pdus_until<'a>( /// order.
&'a self, fn pdus_until<'a>(
user_id: &UserId, &'a self, user_id: &UserId, room_id: &RoomId, until: PduCount,
room_id: &RoomId, ) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>> {
until: PduCount, let (prefix, current) = count_to_id(room_id, until, 1, true)?;
) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>> {
let (prefix, current) = count_to_id(room_id, until, 1, true)?;
let user_id = user_id.to_owned(); let user_id = user_id.to_owned();
Ok(Box::new( Ok(Box::new(
self.pduid_pdu self.pduid_pdu.iter_from(&current, true).take_while(move |(k, _)| k.starts_with(&prefix)).map(
.iter_from(&current, true) move |(pdu_id, v)| {
.take_while(move |(k, _)| k.starts_with(&prefix)) let mut pdu = serde_json::from_slice::<PduEvent>(&v)
.map(move |(pdu_id, v)| { .map_err(|_| Error::bad_database("PDU in db is invalid."))?;
let mut pdu = serde_json::from_slice::<PduEvent>(&v) if pdu.sender != user_id {
.map_err(|_| Error::bad_database("PDU in db is invalid."))?; pdu.remove_transaction_id()?;
if pdu.sender != user_id { }
pdu.remove_transaction_id()?; pdu.add_age()?;
} let count = pdu_count(&pdu_id)?;
pdu.add_age()?; Ok((count, pdu))
let count = pdu_count(&pdu_id)?; },
Ok((count, pdu)) ),
}), ))
)) }
}
fn pdus_after<'a>( fn pdus_after<'a>(
&'a self, &'a self, user_id: &UserId, room_id: &RoomId, from: PduCount,
user_id: &UserId, ) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>> {
room_id: &RoomId, let (prefix, current) = count_to_id(room_id, from, 1, false)?;
from: PduCount,
) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>> {
let (prefix, current) = count_to_id(room_id, from, 1, false)?;
let user_id = user_id.to_owned(); let user_id = user_id.to_owned();
Ok(Box::new( Ok(Box::new(
self.pduid_pdu self.pduid_pdu.iter_from(&current, false).take_while(move |(k, _)| k.starts_with(&prefix)).map(
.iter_from(&current, false) move |(pdu_id, v)| {
.take_while(move |(k, _)| k.starts_with(&prefix)) let mut pdu = serde_json::from_slice::<PduEvent>(&v)
.map(move |(pdu_id, v)| { .map_err(|_| Error::bad_database("PDU in db is invalid."))?;
let mut pdu = serde_json::from_slice::<PduEvent>(&v) if pdu.sender != user_id {
.map_err(|_| Error::bad_database("PDU in db is invalid."))?; pdu.remove_transaction_id()?;
if pdu.sender != user_id { }
pdu.remove_transaction_id()?; pdu.add_age()?;
} let count = pdu_count(&pdu_id)?;
pdu.add_age()?; Ok((count, pdu))
let count = pdu_count(&pdu_id)?; },
Ok((count, pdu)) ),
}), ))
)) }
}
fn increment_notification_counts( fn increment_notification_counts(
&self, &self, room_id: &RoomId, notifies: Vec<OwnedUserId>, highlights: Vec<OwnedUserId>,
room_id: &RoomId, ) -> Result<()> {
notifies: Vec<OwnedUserId>, let mut notifies_batch = Vec::new();
highlights: Vec<OwnedUserId>, let mut highlights_batch = Vec::new();
) -> Result<()> { for user in notifies {
let mut notifies_batch = Vec::new(); let mut userroom_id = user.as_bytes().to_vec();
let mut highlights_batch = Vec::new(); userroom_id.push(0xFF);
for user in notifies { userroom_id.extend_from_slice(room_id.as_bytes());
let mut userroom_id = user.as_bytes().to_vec(); notifies_batch.push(userroom_id);
userroom_id.push(0xff); }
userroom_id.extend_from_slice(room_id.as_bytes()); for user in highlights {
notifies_batch.push(userroom_id); let mut userroom_id = user.as_bytes().to_vec();
} userroom_id.push(0xFF);
for user in highlights { userroom_id.extend_from_slice(room_id.as_bytes());
let mut userroom_id = user.as_bytes().to_vec(); highlights_batch.push(userroom_id);
userroom_id.push(0xff); }
userroom_id.extend_from_slice(room_id.as_bytes());
highlights_batch.push(userroom_id);
}
self.userroomid_notificationcount self.userroomid_notificationcount.increment_batch(&mut notifies_batch.into_iter())?;
.increment_batch(&mut notifies_batch.into_iter())?; self.userroomid_highlightcount.increment_batch(&mut highlights_batch.into_iter())?;
self.userroomid_highlightcount Ok(())
.increment_batch(&mut highlights_batch.into_iter())?; }
Ok(())
}
} }
/// Returns the `count` of this pdu's id. /// Returns the `count` of this pdu's id.
fn pdu_count(pdu_id: &[u8]) -> Result<PduCount> { fn pdu_count(pdu_id: &[u8]) -> Result<PduCount> {
let last_u64 = utils::u64_from_bytes(&pdu_id[pdu_id.len() - size_of::<u64>()..]) let last_u64 = utils::u64_from_bytes(&pdu_id[pdu_id.len() - size_of::<u64>()..])
.map_err(|_| Error::bad_database("PDU has invalid count bytes."))?; .map_err(|_| Error::bad_database("PDU has invalid count bytes."))?;
let second_last_u64 = utils::u64_from_bytes( let second_last_u64 =
&pdu_id[pdu_id.len() - 2 * size_of::<u64>()..pdu_id.len() - size_of::<u64>()], utils::u64_from_bytes(&pdu_id[pdu_id.len() - 2 * size_of::<u64>()..pdu_id.len() - size_of::<u64>()]);
);
if matches!(second_last_u64, Ok(0)) { if matches!(second_last_u64, Ok(0)) {
Ok(PduCount::Backfilled(u64::MAX - last_u64)) Ok(PduCount::Backfilled(u64::MAX - last_u64))
} else { } else {
Ok(PduCount::Normal(last_u64)) Ok(PduCount::Normal(last_u64))
} }
} }
fn count_to_id( fn count_to_id(room_id: &RoomId, count: PduCount, offset: u64, subtract: bool) -> Result<(Vec<u8>, Vec<u8>)> {
room_id: &RoomId, let prefix = services()
count: PduCount, .rooms
offset: u64, .short
subtract: bool, .get_shortroomid(room_id)?
) -> Result<(Vec<u8>, Vec<u8>)> { .ok_or_else(|| Error::bad_database("Looked for bad shortroomid in timeline"))?
let prefix = services() .to_be_bytes()
.rooms .to_vec();
.short let mut pdu_id = prefix.clone();
.get_shortroomid(room_id)? // +1 so we don't send the base event
.ok_or_else(|| Error::bad_database("Looked for bad shortroomid in timeline"))? let count_raw = match count {
.to_be_bytes() PduCount::Normal(x) => {
.to_vec(); if subtract {
let mut pdu_id = prefix.clone(); x - offset
// +1 so we don't send the base event } else {
let count_raw = match count { x + offset
PduCount::Normal(x) => { }
if subtract { },
x - offset PduCount::Backfilled(x) => {
} else { pdu_id.extend_from_slice(&0_u64.to_be_bytes());
x + offset let num = u64::MAX - x;
} if subtract {
} if num > 0 {
PduCount::Backfilled(x) => { num - offset
pdu_id.extend_from_slice(&0_u64.to_be_bytes()); } else {
let num = u64::MAX - x; num
if subtract { }
if num > 0 { } else {
num - offset num + offset
} else { }
num },
} };
} else { pdu_id.extend_from_slice(&count_raw.to_be_bytes());
num + offset
}
}
};
pdu_id.extend_from_slice(&count_raw.to_be_bytes());
Ok((prefix, pdu_id)) Ok((prefix, pdu_id))
} }

View file

@ -3,147 +3,122 @@ use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId};
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
impl service::rooms::user::Data for KeyValueDatabase { impl service::rooms::user::Data for KeyValueDatabase {
fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
let mut userroom_id = user_id.as_bytes().to_vec(); let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xff); userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes()); userroom_id.extend_from_slice(room_id.as_bytes());
let mut roomuser_id = room_id.as_bytes().to_vec(); let mut roomuser_id = room_id.as_bytes().to_vec();
roomuser_id.push(0xff); roomuser_id.push(0xFF);
roomuser_id.extend_from_slice(user_id.as_bytes()); roomuser_id.extend_from_slice(user_id.as_bytes());
self.userroomid_notificationcount self.userroomid_notificationcount.insert(&userroom_id, &0_u64.to_be_bytes())?;
.insert(&userroom_id, &0_u64.to_be_bytes())?; self.userroomid_highlightcount.insert(&userroom_id, &0_u64.to_be_bytes())?;
self.userroomid_highlightcount
.insert(&userroom_id, &0_u64.to_be_bytes())?;
self.roomuserid_lastnotificationread.insert( self.roomuserid_lastnotificationread.insert(&roomuser_id, &services().globals.next_count()?.to_be_bytes())?;
&roomuser_id,
&services().globals.next_count()?.to_be_bytes(),
)?;
Ok(()) Ok(())
} }
fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> { fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
let mut userroom_id = user_id.as_bytes().to_vec(); let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xff); userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes()); userroom_id.extend_from_slice(room_id.as_bytes());
self.userroomid_notificationcount self.userroomid_notificationcount
.get(&userroom_id)? .get(&userroom_id)?
.map(|bytes| { .map(|bytes| {
utils::u64_from_bytes(&bytes) utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid notification count in db."))
.map_err(|_| Error::bad_database("Invalid notification count in db.")) })
}) .unwrap_or(Ok(0))
.unwrap_or(Ok(0)) }
}
fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> { fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
let mut userroom_id = user_id.as_bytes().to_vec(); let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xff); userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes()); userroom_id.extend_from_slice(room_id.as_bytes());
self.userroomid_highlightcount self.userroomid_highlightcount
.get(&userroom_id)? .get(&userroom_id)?
.map(|bytes| { .map(|bytes| {
utils::u64_from_bytes(&bytes) utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid highlight count in db."))
.map_err(|_| Error::bad_database("Invalid highlight count in db.")) })
}) .unwrap_or(Ok(0))
.unwrap_or(Ok(0)) }
}
fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> { fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
let mut key = room_id.as_bytes().to_vec(); let mut key = room_id.as_bytes().to_vec();
key.push(0xff); key.push(0xFF);
key.extend_from_slice(user_id.as_bytes()); key.extend_from_slice(user_id.as_bytes());
Ok(self Ok(self
.roomuserid_lastnotificationread .roomuserid_lastnotificationread
.get(&key)? .get(&key)?
.map(|bytes| { .map(|bytes| {
utils::u64_from_bytes(&bytes).map_err(|_| { utils::u64_from_bytes(&bytes)
Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid.") .map_err(|_| Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid."))
}) })
}) .transpose()?
.transpose()? .unwrap_or(0))
.unwrap_or(0)) }
}
fn associate_token_shortstatehash( fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64, shortstatehash: u64) -> Result<()> {
&self, let shortroomid = services().rooms.short.get_shortroomid(room_id)?.expect("room exists");
room_id: &RoomId,
token: u64,
shortstatehash: u64,
) -> Result<()> {
let shortroomid = services()
.rooms
.short
.get_shortroomid(room_id)?
.expect("room exists");
let mut key = shortroomid.to_be_bytes().to_vec(); let mut key = shortroomid.to_be_bytes().to_vec();
key.extend_from_slice(&token.to_be_bytes()); key.extend_from_slice(&token.to_be_bytes());
self.roomsynctoken_shortstatehash self.roomsynctoken_shortstatehash.insert(&key, &shortstatehash.to_be_bytes())
.insert(&key, &shortstatehash.to_be_bytes()) }
}
fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<Option<u64>> { fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<Option<u64>> {
let shortroomid = services() let shortroomid = services().rooms.short.get_shortroomid(room_id)?.expect("room exists");
.rooms
.short
.get_shortroomid(room_id)?
.expect("room exists");
let mut key = shortroomid.to_be_bytes().to_vec(); let mut key = shortroomid.to_be_bytes().to_vec();
key.extend_from_slice(&token.to_be_bytes()); key.extend_from_slice(&token.to_be_bytes());
self.roomsynctoken_shortstatehash self.roomsynctoken_shortstatehash
.get(&key)? .get(&key)?
.map(|bytes| { .map(|bytes| {
utils::u64_from_bytes(&bytes).map_err(|_| { utils::u64_from_bytes(&bytes)
Error::bad_database("Invalid shortstatehash in roomsynctoken_shortstatehash") .map_err(|_| Error::bad_database("Invalid shortstatehash in roomsynctoken_shortstatehash"))
}) })
}) .transpose()
.transpose() }
}
fn get_shared_rooms<'a>( fn get_shared_rooms<'a>(
&'a self, &'a self, users: Vec<OwnedUserId>,
users: Vec<OwnedUserId>, ) -> Result<Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a>> {
) -> Result<Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a>> { let iterators = users.into_iter().map(move |user_id| {
let iterators = users.into_iter().map(move |user_id| { let mut prefix = user_id.as_bytes().to_vec();
let mut prefix = user_id.as_bytes().to_vec(); prefix.push(0xFF);
prefix.push(0xff);
self.userroomid_joined self.userroomid_joined
.scan_prefix(prefix) .scan_prefix(prefix)
.map(|(key, _)| { .map(|(key, _)| {
let roomid_index = key let roomid_index = key
.iter() .iter()
.enumerate() .enumerate()
.find(|(_, &b)| b == 0xff) .find(|(_, &b)| b == 0xFF)
.ok_or_else(|| Error::bad_database("Invalid userroomid_joined in db."))? .ok_or_else(|| Error::bad_database("Invalid userroomid_joined in db."))?
.0 .0 + 1; // +1 because the room id starts AFTER the separator
+ 1; // +1 because the room id starts AFTER the separator
let room_id = key[roomid_index..].to_vec(); let room_id = key[roomid_index..].to_vec();
Ok::<_, Error>(room_id) Ok::<_, Error>(room_id)
}) })
.filter_map(std::result::Result::ok) .filter_map(std::result::Result::ok)
}); });
// We use the default compare function because keys are sorted correctly (not reversed) // We use the default compare function because keys are sorted correctly (not
Ok(Box::new( // reversed)
utils::common_elements(iterators, Ord::cmp) Ok(Box::new(
.expect("users is not empty") utils::common_elements(iterators, Ord::cmp).expect("users is not empty").map(|bytes| {
.map(|bytes| { RoomId::parse(
RoomId::parse(utils::string_from_bytes(&bytes).map_err(|_| { utils::string_from_bytes(&bytes)
Error::bad_database("Invalid RoomId bytes in userroomid_joined") .map_err(|_| Error::bad_database("Invalid RoomId bytes in userroomid_joined"))?,
})?) )
.map_err(|_| Error::bad_database("Invalid RoomId in userroomid_joined.")) .map_err(|_| Error::bad_database("Invalid RoomId in userroomid_joined."))
}), }),
)) ))
} }
} }

View file

@ -1,205 +1,181 @@
use ruma::{ServerName, UserId}; use ruma::{ServerName, UserId};
use crate::{ use crate::{
database::KeyValueDatabase, database::KeyValueDatabase,
service::{ service::{
self, self,
sending::{OutgoingKind, SendingEventType}, sending::{OutgoingKind, SendingEventType},
}, },
services, utils, Error, Result, services, utils, Error, Result,
}; };
impl service::sending::Data for KeyValueDatabase { impl service::sending::Data for KeyValueDatabase {
fn active_requests<'a>( fn active_requests<'a>(
&'a self, &'a self,
) -> Box<dyn Iterator<Item = Result<(Vec<u8>, OutgoingKind, SendingEventType)>> + 'a> { ) -> Box<dyn Iterator<Item = Result<(Vec<u8>, OutgoingKind, SendingEventType)>> + 'a> {
Box::new( Box::new(
self.servercurrentevent_data self.servercurrentevent_data
.iter() .iter()
.map(|(key, v)| parse_servercurrentevent(&key, v).map(|(k, e)| (key, k, e))), .map(|(key, v)| parse_servercurrentevent(&key, v).map(|(k, e)| (key, k, e))),
) )
} }
fn active_requests_for<'a>( fn active_requests_for<'a>(
&'a self, &'a self, outgoing_kind: &OutgoingKind,
outgoing_kind: &OutgoingKind, ) -> Box<dyn Iterator<Item = Result<(Vec<u8>, SendingEventType)>> + 'a> {
) -> Box<dyn Iterator<Item = Result<(Vec<u8>, SendingEventType)>> + 'a> { let prefix = outgoing_kind.get_prefix();
let prefix = outgoing_kind.get_prefix(); Box::new(
Box::new( self.servercurrentevent_data
self.servercurrentevent_data .scan_prefix(prefix)
.scan_prefix(prefix) .map(|(key, v)| parse_servercurrentevent(&key, v).map(|(_, e)| (key, e))),
.map(|(key, v)| parse_servercurrentevent(&key, v).map(|(_, e)| (key, e))), )
) }
}
fn delete_active_request(&self, key: Vec<u8>) -> Result<()> { fn delete_active_request(&self, key: Vec<u8>) -> Result<()> { self.servercurrentevent_data.remove(&key) }
self.servercurrentevent_data.remove(&key)
}
fn delete_all_active_requests_for(&self, outgoing_kind: &OutgoingKind) -> Result<()> { fn delete_all_active_requests_for(&self, outgoing_kind: &OutgoingKind) -> Result<()> {
let prefix = outgoing_kind.get_prefix(); let prefix = outgoing_kind.get_prefix();
for (key, _) in self.servercurrentevent_data.scan_prefix(prefix) { for (key, _) in self.servercurrentevent_data.scan_prefix(prefix) {
self.servercurrentevent_data.remove(&key)?; self.servercurrentevent_data.remove(&key)?;
} }
Ok(()) Ok(())
} }
fn delete_all_requests_for(&self, outgoing_kind: &OutgoingKind) -> Result<()> { fn delete_all_requests_for(&self, outgoing_kind: &OutgoingKind) -> Result<()> {
let prefix = outgoing_kind.get_prefix(); let prefix = outgoing_kind.get_prefix();
for (key, _) in self.servercurrentevent_data.scan_prefix(prefix.clone()) { for (key, _) in self.servercurrentevent_data.scan_prefix(prefix.clone()) {
self.servercurrentevent_data.remove(&key).unwrap(); self.servercurrentevent_data.remove(&key).unwrap();
} }
for (key, _) in self.servernameevent_data.scan_prefix(prefix) { for (key, _) in self.servernameevent_data.scan_prefix(prefix) {
self.servernameevent_data.remove(&key).unwrap(); self.servernameevent_data.remove(&key).unwrap();
} }
Ok(()) Ok(())
} }
fn queue_requests( fn queue_requests(&self, requests: &[(&OutgoingKind, SendingEventType)]) -> Result<Vec<Vec<u8>>> {
&self, let mut batch = Vec::new();
requests: &[(&OutgoingKind, SendingEventType)], let mut keys = Vec::new();
) -> Result<Vec<Vec<u8>>> { for (outgoing_kind, event) in requests {
let mut batch = Vec::new(); let mut key = outgoing_kind.get_prefix();
let mut keys = Vec::new(); if let SendingEventType::Pdu(value) = &event {
for (outgoing_kind, event) in requests { key.extend_from_slice(value);
let mut key = outgoing_kind.get_prefix(); } else {
if let SendingEventType::Pdu(value) = &event { key.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
key.extend_from_slice(value); }
} else { let value = if let SendingEventType::Edu(value) = &event {
key.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); &**value
} } else {
let value = if let SendingEventType::Edu(value) = &event { &[]
&**value };
} else { batch.push((key.clone(), value.to_owned()));
&[] keys.push(key);
}; }
batch.push((key.clone(), value.to_owned())); self.servernameevent_data.insert_batch(&mut batch.into_iter())?;
keys.push(key); Ok(keys)
} }
self.servernameevent_data
.insert_batch(&mut batch.into_iter())?;
Ok(keys)
}
fn queued_requests<'a>( fn queued_requests<'a>(
&'a self, &'a self, outgoing_kind: &OutgoingKind,
outgoing_kind: &OutgoingKind, ) -> Box<dyn Iterator<Item = Result<(SendingEventType, Vec<u8>)>> + 'a> {
) -> Box<dyn Iterator<Item = Result<(SendingEventType, Vec<u8>)>> + 'a> { let prefix = outgoing_kind.get_prefix();
let prefix = outgoing_kind.get_prefix(); return Box::new(
return Box::new( self.servernameevent_data
self.servernameevent_data .scan_prefix(prefix)
.scan_prefix(prefix) .map(|(k, v)| parse_servercurrentevent(&k, v).map(|(_, ev)| (ev, k))),
.map(|(k, v)| parse_servercurrentevent(&k, v).map(|(_, ev)| (ev, k))), );
); }
}
fn mark_as_active(&self, events: &[(SendingEventType, Vec<u8>)]) -> Result<()> { fn mark_as_active(&self, events: &[(SendingEventType, Vec<u8>)]) -> Result<()> {
for (e, key) in events { for (e, key) in events {
let value = if let SendingEventType::Edu(value) = &e { let value = if let SendingEventType::Edu(value) = &e {
&**value &**value
} else { } else {
&[] &[]
}; };
self.servercurrentevent_data.insert(key, value)?; self.servercurrentevent_data.insert(key, value)?;
self.servernameevent_data.remove(key)?; self.servernameevent_data.remove(key)?;
} }
Ok(()) Ok(())
} }
fn set_latest_educount(&self, server_name: &ServerName, last_count: u64) -> Result<()> { fn set_latest_educount(&self, server_name: &ServerName, last_count: u64) -> Result<()> {
self.servername_educount self.servername_educount.insert(server_name.as_bytes(), &last_count.to_be_bytes())
.insert(server_name.as_bytes(), &last_count.to_be_bytes()) }
}
fn get_latest_educount(&self, server_name: &ServerName) -> Result<u64> { fn get_latest_educount(&self, server_name: &ServerName) -> Result<u64> {
self.servername_educount self.servername_educount.get(server_name.as_bytes())?.map_or(Ok(0), |bytes| {
.get(server_name.as_bytes())? utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid u64 in servername_educount."))
.map_or(Ok(0), |bytes| { })
utils::u64_from_bytes(&bytes) }
.map_err(|_| Error::bad_database("Invalid u64 in servername_educount."))
})
}
} }
#[tracing::instrument(skip(key))] #[tracing::instrument(skip(key))]
fn parse_servercurrentevent( fn parse_servercurrentevent(key: &[u8], value: Vec<u8>) -> Result<(OutgoingKind, SendingEventType)> {
key: &[u8], // Appservices start with a plus
value: Vec<u8>, Ok::<_, Error>(if key.starts_with(b"+") {
) -> Result<(OutgoingKind, SendingEventType)> { let mut parts = key[1..].splitn(2, |&b| b == 0xFF);
// Appservices start with a plus
Ok::<_, Error>(if key.starts_with(b"+") {
let mut parts = key[1..].splitn(2, |&b| b == 0xff);
let server = parts.next().expect("splitn always returns one element"); let server = parts.next().expect("splitn always returns one element");
let event = parts let event = parts.next().ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
.next()
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
let server = utils::string_from_bytes(server).map_err(|_| { let server = utils::string_from_bytes(server)
Error::bad_database("Invalid server bytes in server_currenttransaction") .map_err(|_| Error::bad_database("Invalid server bytes in server_currenttransaction"))?;
})?;
( (
OutgoingKind::Appservice(server), OutgoingKind::Appservice(server),
if value.is_empty() { if value.is_empty() {
SendingEventType::Pdu(event.to_vec()) SendingEventType::Pdu(event.to_vec())
} else { } else {
SendingEventType::Edu(value) SendingEventType::Edu(value)
}, },
) )
} else if key.starts_with(b"$") { } else if key.starts_with(b"$") {
let mut parts = key[1..].splitn(3, |&b| b == 0xff); let mut parts = key[1..].splitn(3, |&b| b == 0xFF);
let user = parts.next().expect("splitn always returns one element"); let user = parts.next().expect("splitn always returns one element");
let user_string = utils::string_from_bytes(user) let user_string = utils::string_from_bytes(user)
.map_err(|_| Error::bad_database("Invalid user string in servercurrentevent"))?; .map_err(|_| Error::bad_database("Invalid user string in servercurrentevent"))?;
let user_id = UserId::parse(user_string) let user_id =
.map_err(|_| Error::bad_database("Invalid user id in servercurrentevent"))?; UserId::parse(user_string).map_err(|_| Error::bad_database("Invalid user id in servercurrentevent"))?;
let pushkey = parts let pushkey = parts.next().ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
.next() let pushkey_string = utils::string_from_bytes(pushkey)
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?; .map_err(|_| Error::bad_database("Invalid pushkey in servercurrentevent"))?;
let pushkey_string = utils::string_from_bytes(pushkey)
.map_err(|_| Error::bad_database("Invalid pushkey in servercurrentevent"))?;
let event = parts let event = parts.next().ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
.next()
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
( (
OutgoingKind::Push(user_id, pushkey_string), OutgoingKind::Push(user_id, pushkey_string),
if value.is_empty() { if value.is_empty() {
SendingEventType::Pdu(event.to_vec()) SendingEventType::Pdu(event.to_vec())
} else { } else {
// I'm pretty sure this should never be called // I'm pretty sure this should never be called
SendingEventType::Edu(value) SendingEventType::Edu(value)
}, },
) )
} else { } else {
let mut parts = key.splitn(2, |&b| b == 0xff); let mut parts = key.splitn(2, |&b| b == 0xFF);
let server = parts.next().expect("splitn always returns one element"); let server = parts.next().expect("splitn always returns one element");
let event = parts let event = parts.next().ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
.next()
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
let server = utils::string_from_bytes(server).map_err(|_| { let server = utils::string_from_bytes(server)
Error::bad_database("Invalid server bytes in server_currenttransaction") .map_err(|_| Error::bad_database("Invalid server bytes in server_currenttransaction"))?;
})?;
( (
OutgoingKind::Normal(ServerName::parse(server).map_err(|_| { OutgoingKind::Normal(
Error::bad_database("Invalid server string in server_currenttransaction") ServerName::parse(server)
})?), .map_err(|_| Error::bad_database("Invalid server string in server_currenttransaction"))?,
if value.is_empty() { ),
SendingEventType::Pdu(event.to_vec()) if value.is_empty() {
} else { SendingEventType::Pdu(event.to_vec())
SendingEventType::Edu(value) } else {
}, SendingEventType::Edu(value)
) },
}) )
})
} }

View file

@ -3,37 +3,30 @@ use ruma::{DeviceId, TransactionId, UserId};
use crate::{database::KeyValueDatabase, service, Result}; use crate::{database::KeyValueDatabase, service, Result};
impl service::transaction_ids::Data for KeyValueDatabase { impl service::transaction_ids::Data for KeyValueDatabase {
fn add_txnid( fn add_txnid(
&self, &self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, data: &[u8],
user_id: &UserId, ) -> Result<()> {
device_id: Option<&DeviceId>, let mut key = user_id.as_bytes().to_vec();
txn_id: &TransactionId, key.push(0xFF);
data: &[u8], key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default());
) -> Result<()> { key.push(0xFF);
let mut key = user_id.as_bytes().to_vec(); key.extend_from_slice(txn_id.as_bytes());
key.push(0xff);
key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default());
key.push(0xff);
key.extend_from_slice(txn_id.as_bytes());
self.userdevicetxnid_response.insert(&key, data)?; self.userdevicetxnid_response.insert(&key, data)?;
Ok(()) Ok(())
} }
fn existing_txnid( fn existing_txnid(
&self, &self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId,
user_id: &UserId, ) -> Result<Option<Vec<u8>>> {
device_id: Option<&DeviceId>, let mut key = user_id.as_bytes().to_vec();
txn_id: &TransactionId, key.push(0xFF);
) -> Result<Option<Vec<u8>>> { key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default());
let mut key = user_id.as_bytes().to_vec(); key.push(0xFF);
key.push(0xff); key.extend_from_slice(txn_id.as_bytes());
key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default());
key.push(0xff);
key.extend_from_slice(txn_id.as_bytes());
// If there's no entry, this is a new transaction // If there's no entry, this is a new transaction
self.userdevicetxnid_response.get(&key) self.userdevicetxnid_response.get(&key)
} }
} }

View file

@ -1,89 +1,64 @@
use ruma::{ use ruma::{
api::client::{error::ErrorKind, uiaa::UiaaInfo}, api::client::{error::ErrorKind, uiaa::UiaaInfo},
CanonicalJsonValue, DeviceId, UserId, CanonicalJsonValue, DeviceId, UserId,
}; };
use crate::{database::KeyValueDatabase, service, Error, Result}; use crate::{database::KeyValueDatabase, service, Error, Result};
impl service::uiaa::Data for KeyValueDatabase { impl service::uiaa::Data for KeyValueDatabase {
fn set_uiaa_request( fn set_uiaa_request(
&self, &self, user_id: &UserId, device_id: &DeviceId, session: &str, request: &CanonicalJsonValue,
user_id: &UserId, ) -> Result<()> {
device_id: &DeviceId, self.userdevicesessionid_uiaarequest.write().unwrap().insert(
session: &str, (user_id.to_owned(), device_id.to_owned(), session.to_owned()),
request: &CanonicalJsonValue, request.to_owned(),
) -> Result<()> { );
self.userdevicesessionid_uiaarequest
.write()
.unwrap()
.insert(
(user_id.to_owned(), device_id.to_owned(), session.to_owned()),
request.to_owned(),
);
Ok(()) Ok(())
} }
fn get_uiaa_request( fn get_uiaa_request(&self, user_id: &UserId, device_id: &DeviceId, session: &str) -> Option<CanonicalJsonValue> {
&self, self.userdevicesessionid_uiaarequest
user_id: &UserId, .read()
device_id: &DeviceId, .unwrap()
session: &str, .get(&(user_id.to_owned(), device_id.to_owned(), session.to_owned()))
) -> Option<CanonicalJsonValue> { .map(std::borrow::ToOwned::to_owned)
self.userdevicesessionid_uiaarequest }
.read()
.unwrap()
.get(&(user_id.to_owned(), device_id.to_owned(), session.to_owned()))
.map(std::borrow::ToOwned::to_owned)
}
fn update_uiaa_session( fn update_uiaa_session(
&self, &self, user_id: &UserId, device_id: &DeviceId, session: &str, uiaainfo: Option<&UiaaInfo>,
user_id: &UserId, ) -> Result<()> {
device_id: &DeviceId, let mut userdevicesessionid = user_id.as_bytes().to_vec();
session: &str, userdevicesessionid.push(0xFF);
uiaainfo: Option<&UiaaInfo>, userdevicesessionid.extend_from_slice(device_id.as_bytes());
) -> Result<()> { userdevicesessionid.push(0xFF);
let mut userdevicesessionid = user_id.as_bytes().to_vec(); userdevicesessionid.extend_from_slice(session.as_bytes());
userdevicesessionid.push(0xff);
userdevicesessionid.extend_from_slice(device_id.as_bytes());
userdevicesessionid.push(0xff);
userdevicesessionid.extend_from_slice(session.as_bytes());
if let Some(uiaainfo) = uiaainfo { if let Some(uiaainfo) = uiaainfo {
self.userdevicesessionid_uiaainfo.insert( self.userdevicesessionid_uiaainfo.insert(
&userdevicesessionid, &userdevicesessionid,
&serde_json::to_vec(&uiaainfo).expect("UiaaInfo::to_vec always works"), &serde_json::to_vec(&uiaainfo).expect("UiaaInfo::to_vec always works"),
)?; )?;
} else { } else {
self.userdevicesessionid_uiaainfo self.userdevicesessionid_uiaainfo.remove(&userdevicesessionid)?;
.remove(&userdevicesessionid)?; }
}
Ok(()) Ok(())
} }
fn get_uiaa_session( fn get_uiaa_session(&self, user_id: &UserId, device_id: &DeviceId, session: &str) -> Result<UiaaInfo> {
&self, let mut userdevicesessionid = user_id.as_bytes().to_vec();
user_id: &UserId, userdevicesessionid.push(0xFF);
device_id: &DeviceId, userdevicesessionid.extend_from_slice(device_id.as_bytes());
session: &str, userdevicesessionid.push(0xFF);
) -> Result<UiaaInfo> { userdevicesessionid.extend_from_slice(session.as_bytes());
let mut userdevicesessionid = user_id.as_bytes().to_vec();
userdevicesessionid.push(0xff);
userdevicesessionid.extend_from_slice(device_id.as_bytes());
userdevicesessionid.push(0xff);
userdevicesessionid.extend_from_slice(session.as_bytes());
serde_json::from_slice( serde_json::from_slice(
&self &self
.userdevicesessionid_uiaainfo .userdevicesessionid_uiaainfo
.get(&userdevicesessionid)? .get(&userdevicesessionid)?
.ok_or(Error::BadRequest( .ok_or(Error::BadRequest(ErrorKind::Forbidden, "UIAA session does not exist."))?,
ErrorKind::Forbidden, )
"UIAA session does not exist.", .map_err(|_| Error::bad_database("UiaaInfo in userdeviceid_uiaainfo is invalid."))
))?, }
)
.map_err(|_| Error::bad_database("UiaaInfo in userdeviceid_uiaainfo is invalid."))
}
} }

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -15,8 +15,5 @@ pub use utils::error::{Error, Result};
pub static SERVICES: RwLock<Option<&'static Services<'static>>> = RwLock::new(None); pub static SERVICES: RwLock<Option<&'static Services<'static>>> = RwLock::new(None);
pub fn services() -> &'static Services<'static> { pub fn services() -> &'static Services<'static> {
SERVICES SERVICES.read().unwrap().expect("SERVICES should be initialized when this is called")
.read()
.unwrap()
.expect("SERVICES should be initialized when this is called")
} }

File diff suppressed because it is too large Load diff

View file

@ -1,35 +1,28 @@
use std::collections::HashMap; use std::collections::HashMap;
use crate::Result;
use ruma::{ use ruma::{
events::{AnyEphemeralRoomEvent, RoomAccountDataEventType}, events::{AnyEphemeralRoomEvent, RoomAccountDataEventType},
serde::Raw, serde::Raw,
RoomId, UserId, RoomId, UserId,
}; };
use crate::Result;
pub trait Data: Send + Sync { pub trait Data: Send + Sync {
/// Places one event in the account data of the user and removes the previous entry. /// Places one event in the account data of the user and removes the
fn update( /// previous entry.
&self, fn update(
room_id: Option<&RoomId>, &self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType,
user_id: &UserId, data: &serde_json::Value,
event_type: RoomAccountDataEventType, ) -> Result<()>;
data: &serde_json::Value,
) -> Result<()>;
/// Searches the account data for a specific kind. /// Searches the account data for a specific kind.
fn get( fn get(
&self, &self, room_id: Option<&RoomId>, user_id: &UserId, kind: RoomAccountDataEventType,
room_id: Option<&RoomId>, ) -> Result<Option<Box<serde_json::value::RawValue>>>;
user_id: &UserId,
kind: RoomAccountDataEventType,
) -> Result<Option<Box<serde_json::value::RawValue>>>;
/// Returns all changes to the account data that happened after `since`. /// Returns all changes to the account data that happened after `since`.
fn changes_since( fn changes_since(
&self, &self, room_id: Option<&RoomId>, user_id: &UserId, since: u64,
room_id: Option<&RoomId>, ) -> Result<HashMap<RoomAccountDataEventType, Raw<AnyEphemeralRoomEvent>>>;
user_id: &UserId,
since: u64,
) -> Result<HashMap<RoomAccountDataEventType, Raw<AnyEphemeralRoomEvent>>>;
} }

View file

@ -1,53 +1,44 @@
mod data; mod data;
pub(crate) use data::Data;
use ruma::{
events::{AnyEphemeralRoomEvent, RoomAccountDataEventType},
serde::Raw,
RoomId, UserId,
};
use std::collections::HashMap; use std::collections::HashMap;
pub(crate) use data::Data;
use ruma::{
events::{AnyEphemeralRoomEvent, RoomAccountDataEventType},
serde::Raw,
RoomId, UserId,
};
use crate::Result; use crate::Result;
pub struct Service { pub struct Service {
pub db: &'static dyn Data, pub db: &'static dyn Data,
} }
impl Service { impl Service {
/// Places one event in the account data of the user and removes the previous entry. /// Places one event in the account data of the user and removes the
#[tracing::instrument(skip(self, room_id, user_id, event_type, data))] /// previous entry.
pub fn update( #[tracing::instrument(skip(self, room_id, user_id, event_type, data))]
&self, pub fn update(
room_id: Option<&RoomId>, &self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType,
user_id: &UserId, data: &serde_json::Value,
event_type: RoomAccountDataEventType, ) -> Result<()> {
data: &serde_json::Value, self.db.update(room_id, user_id, event_type, data)
) -> Result<()> { }
self.db.update(room_id, user_id, event_type, data)
}
/// Searches the account data for a specific kind. /// Searches the account data for a specific kind.
#[tracing::instrument(skip(self, room_id, user_id, event_type))] #[tracing::instrument(skip(self, room_id, user_id, event_type))]
pub fn get( pub fn get(
&self, &self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType,
room_id: Option<&RoomId>, ) -> Result<Option<Box<serde_json::value::RawValue>>> {
user_id: &UserId, self.db.get(room_id, user_id, event_type)
event_type: RoomAccountDataEventType, }
) -> Result<Option<Box<serde_json::value::RawValue>>> {
self.db.get(room_id, user_id, event_type)
}
/// Returns all changes to the account data that happened after `since`. /// Returns all changes to the account data that happened after `since`.
#[tracing::instrument(skip(self, room_id, user_id, since))] #[tracing::instrument(skip(self, room_id, user_id, since))]
pub fn changes_since( pub fn changes_since(
&self, &self, room_id: Option<&RoomId>, user_id: &UserId, since: u64,
room_id: Option<&RoomId>, ) -> Result<HashMap<RoomAccountDataEventType, Raw<AnyEphemeralRoomEvent>>> {
user_id: &UserId, self.db.changes_since(room_id, user_id, since)
since: u64, }
) -> Result<HashMap<RoomAccountDataEventType, Raw<AnyEphemeralRoomEvent>>> {
self.db.changes_since(room_id, user_id, since)
}
} }

File diff suppressed because it is too large Load diff

View file

@ -3,19 +3,19 @@ use ruma::api::appservice::Registration;
use crate::Result; use crate::Result;
pub trait Data: Send + Sync { pub trait Data: Send + Sync {
/// Registers an appservice and returns the ID to the caller /// Registers an appservice and returns the ID to the caller
fn register_appservice(&self, yaml: Registration) -> Result<String>; fn register_appservice(&self, yaml: Registration) -> Result<String>;
/// Remove an appservice registration /// Remove an appservice registration
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `service_name` - the name you send to register the service previously /// * `service_name` - the name you send to register the service previously
fn unregister_appservice(&self, service_name: &str) -> Result<()>; fn unregister_appservice(&self, service_name: &str) -> Result<()>;
fn get_registration(&self, id: &str) -> Result<Option<Registration>>; fn get_registration(&self, id: &str) -> Result<Option<Registration>>;
fn iter_ids<'a>(&'a self) -> Result<Box<dyn Iterator<Item = Result<String>> + 'a>>; fn iter_ids<'a>(&'a self) -> Result<Box<dyn Iterator<Item = Result<String>> + 'a>>;
fn all(&self) -> Result<Vec<(String, Registration)>>; fn all(&self) -> Result<Vec<(String, Registration)>>;
} }

View file

@ -6,33 +6,25 @@ use ruma::api::appservice::Registration;
use crate::Result; use crate::Result;
pub struct Service { pub struct Service {
pub db: &'static dyn Data, pub db: &'static dyn Data,
} }
impl Service { impl Service {
/// Registers an appservice and returns the ID to the caller /// Registers an appservice and returns the ID to the caller
pub fn register_appservice(&self, yaml: Registration) -> Result<String> { pub fn register_appservice(&self, yaml: Registration) -> Result<String> { self.db.register_appservice(yaml) }
self.db.register_appservice(yaml)
}
/// Remove an appservice registration /// Remove an appservice registration
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `service_name` - the name you send to register the service previously /// * `service_name` - the name you send to register the service previously
pub fn unregister_appservice(&self, service_name: &str) -> Result<()> { pub fn unregister_appservice(&self, service_name: &str) -> Result<()> {
self.db.unregister_appservice(service_name) self.db.unregister_appservice(service_name)
} }
pub fn get_registration(&self, id: &str) -> Result<Option<Registration>> { pub fn get_registration(&self, id: &str) -> Result<Option<Registration>> { self.db.get_registration(id) }
self.db.get_registration(id)
}
pub fn iter_ids(&self) -> Result<impl Iterator<Item = Result<String>> + '_> { pub fn iter_ids(&self) -> Result<impl Iterator<Item = Result<String>> + '_> { self.db.iter_ids() }
self.db.iter_ids()
}
pub fn all(&self) -> Result<Vec<(String, Registration)>> { pub fn all(&self) -> Result<Vec<(String, Registration)>> { self.db.all() }
self.db.all()
}
} }

View file

@ -2,36 +2,32 @@ use std::collections::BTreeMap;
use async_trait::async_trait; use async_trait::async_trait;
use ruma::{ use ruma::{
api::federation::discovery::{ServerSigningKeys, VerifyKey}, api::federation::discovery::{ServerSigningKeys, VerifyKey},
signatures::Ed25519KeyPair, signatures::Ed25519KeyPair,
DeviceId, OwnedServerSigningKeyId, ServerName, UserId, DeviceId, OwnedServerSigningKeyId, ServerName, UserId,
}; };
use crate::Result; use crate::Result;
#[async_trait] #[async_trait]
pub trait Data: Send + Sync { pub trait Data: Send + Sync {
fn next_count(&self) -> Result<u64>; fn next_count(&self) -> Result<u64>;
fn current_count(&self) -> Result<u64>; fn current_count(&self) -> Result<u64>;
fn last_check_for_updates_id(&self) -> Result<u64>; fn last_check_for_updates_id(&self) -> Result<u64>;
fn update_check_for_updates_id(&self, id: u64) -> Result<()>; fn update_check_for_updates_id(&self, id: u64) -> Result<()>;
async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()>; async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()>;
fn cleanup(&self) -> Result<()>; fn cleanup(&self) -> Result<()>;
fn memory_usage(&self) -> String; fn memory_usage(&self) -> String;
fn clear_caches(&self, amount: u32); fn clear_caches(&self, amount: u32);
fn load_keypair(&self) -> Result<Ed25519KeyPair>; fn load_keypair(&self) -> Result<Ed25519KeyPair>;
fn remove_keypair(&self) -> Result<()>; fn remove_keypair(&self) -> Result<()>;
fn add_signing_key( fn add_signing_key(
&self, &self, origin: &ServerName, new_keys: ServerSigningKeys,
origin: &ServerName, ) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>>;
new_keys: ServerSigningKeys,
) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>>;
/// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found for the server. /// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found
fn signing_keys_for( /// for the server.
&self, fn signing_keys_for(&self, origin: &ServerName) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>>;
origin: &ServerName, fn database_version(&self) -> Result<u64>;
) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>>; fn bump_database_version(&self, new_version: u64) -> Result<()>;
fn database_version(&self) -> Result<u64>;
fn bump_database_version(&self, new_version: u64) -> Result<()>;
} }

File diff suppressed because it is too large Load diff

View file

@ -1,78 +1,47 @@
use std::collections::BTreeMap; use std::collections::BTreeMap;
use crate::Result;
use ruma::{ use ruma::{
api::client::backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup}, api::client::backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup},
serde::Raw, serde::Raw,
OwnedRoomId, RoomId, UserId, OwnedRoomId, RoomId, UserId,
}; };
use crate::Result;
pub trait Data: Send + Sync { pub trait Data: Send + Sync {
fn create_backup( fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw<BackupAlgorithm>) -> Result<String>;
&self,
user_id: &UserId,
backup_metadata: &Raw<BackupAlgorithm>,
) -> Result<String>;
fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()>; fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()>;
fn update_backup( fn update_backup(&self, user_id: &UserId, version: &str, backup_metadata: &Raw<BackupAlgorithm>) -> Result<String>;
&self,
user_id: &UserId,
version: &str,
backup_metadata: &Raw<BackupAlgorithm>,
) -> Result<String>;
fn get_latest_backup_version(&self, user_id: &UserId) -> Result<Option<String>>; fn get_latest_backup_version(&self, user_id: &UserId) -> Result<Option<String>>;
fn get_latest_backup(&self, user_id: &UserId) fn get_latest_backup(&self, user_id: &UserId) -> Result<Option<(String, Raw<BackupAlgorithm>)>>;
-> Result<Option<(String, Raw<BackupAlgorithm>)>>;
fn get_backup(&self, user_id: &UserId, version: &str) -> Result<Option<Raw<BackupAlgorithm>>>; fn get_backup(&self, user_id: &UserId, version: &str) -> Result<Option<Raw<BackupAlgorithm>>>;
fn add_key( fn add_key(
&self, &self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, key_data: &Raw<KeyBackupData>,
user_id: &UserId, ) -> Result<()>;
version: &str,
room_id: &RoomId,
session_id: &str,
key_data: &Raw<KeyBackupData>,
) -> Result<()>;
fn count_keys(&self, user_id: &UserId, version: &str) -> Result<usize>; fn count_keys(&self, user_id: &UserId, version: &str) -> Result<usize>;
fn get_etag(&self, user_id: &UserId, version: &str) -> Result<String>; fn get_etag(&self, user_id: &UserId, version: &str) -> Result<String>;
fn get_all( fn get_all(&self, user_id: &UserId, version: &str) -> Result<BTreeMap<OwnedRoomId, RoomKeyBackup>>;
&self,
user_id: &UserId,
version: &str,
) -> Result<BTreeMap<OwnedRoomId, RoomKeyBackup>>;
fn get_room( fn get_room(
&self, &self, user_id: &UserId, version: &str, room_id: &RoomId,
user_id: &UserId, ) -> Result<BTreeMap<String, Raw<KeyBackupData>>>;
version: &str,
room_id: &RoomId,
) -> Result<BTreeMap<String, Raw<KeyBackupData>>>;
fn get_session( fn get_session(
&self, &self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str,
user_id: &UserId, ) -> Result<Option<Raw<KeyBackupData>>>;
version: &str,
room_id: &RoomId,
session_id: &str,
) -> Result<Option<Raw<KeyBackupData>>>;
fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()>; fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()>;
fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()>; fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()>;
fn delete_room_key( fn delete_room_key(&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str) -> Result<()>;
&self,
user_id: &UserId,
version: &str,
room_id: &RoomId,
session_id: &str,
) -> Result<()>;
} }

View file

@ -1,127 +1,81 @@
mod data; mod data;
pub(crate) use data::Data;
use crate::Result;
use ruma::{
api::client::backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup},
serde::Raw,
OwnedRoomId, RoomId, UserId,
};
use std::collections::BTreeMap; use std::collections::BTreeMap;
pub(crate) use data::Data;
use ruma::{
api::client::backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup},
serde::Raw,
OwnedRoomId, RoomId, UserId,
};
use crate::Result;
pub struct Service { pub struct Service {
pub db: &'static dyn Data, pub db: &'static dyn Data,
} }
impl Service { impl Service {
pub fn create_backup( pub fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw<BackupAlgorithm>) -> Result<String> {
&self, self.db.create_backup(user_id, backup_metadata)
user_id: &UserId, }
backup_metadata: &Raw<BackupAlgorithm>,
) -> Result<String> {
self.db.create_backup(user_id, backup_metadata)
}
pub fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> { pub fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> {
self.db.delete_backup(user_id, version) self.db.delete_backup(user_id, version)
} }
pub fn update_backup( pub fn update_backup(
&self, &self, user_id: &UserId, version: &str, backup_metadata: &Raw<BackupAlgorithm>,
user_id: &UserId, ) -> Result<String> {
version: &str, self.db.update_backup(user_id, version, backup_metadata)
backup_metadata: &Raw<BackupAlgorithm>, }
) -> Result<String> {
self.db.update_backup(user_id, version, backup_metadata)
}
pub fn get_latest_backup_version(&self, user_id: &UserId) -> Result<Option<String>> { pub fn get_latest_backup_version(&self, user_id: &UserId) -> Result<Option<String>> {
self.db.get_latest_backup_version(user_id) self.db.get_latest_backup_version(user_id)
} }
pub fn get_latest_backup( pub fn get_latest_backup(&self, user_id: &UserId) -> Result<Option<(String, Raw<BackupAlgorithm>)>> {
&self, self.db.get_latest_backup(user_id)
user_id: &UserId, }
) -> Result<Option<(String, Raw<BackupAlgorithm>)>> {
self.db.get_latest_backup(user_id)
}
pub fn get_backup( pub fn get_backup(&self, user_id: &UserId, version: &str) -> Result<Option<Raw<BackupAlgorithm>>> {
&self, self.db.get_backup(user_id, version)
user_id: &UserId, }
version: &str,
) -> Result<Option<Raw<BackupAlgorithm>>> {
self.db.get_backup(user_id, version)
}
pub fn add_key( pub fn add_key(
&self, &self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, key_data: &Raw<KeyBackupData>,
user_id: &UserId, ) -> Result<()> {
version: &str, self.db.add_key(user_id, version, room_id, session_id, key_data)
room_id: &RoomId, }
session_id: &str,
key_data: &Raw<KeyBackupData>,
) -> Result<()> {
self.db
.add_key(user_id, version, room_id, session_id, key_data)
}
pub fn count_keys(&self, user_id: &UserId, version: &str) -> Result<usize> { pub fn count_keys(&self, user_id: &UserId, version: &str) -> Result<usize> { self.db.count_keys(user_id, version) }
self.db.count_keys(user_id, version)
}
pub fn get_etag(&self, user_id: &UserId, version: &str) -> Result<String> { pub fn get_etag(&self, user_id: &UserId, version: &str) -> Result<String> { self.db.get_etag(user_id, version) }
self.db.get_etag(user_id, version)
}
pub fn get_all( pub fn get_all(&self, user_id: &UserId, version: &str) -> Result<BTreeMap<OwnedRoomId, RoomKeyBackup>> {
&self, self.db.get_all(user_id, version)
user_id: &UserId, }
version: &str,
) -> Result<BTreeMap<OwnedRoomId, RoomKeyBackup>> {
self.db.get_all(user_id, version)
}
pub fn get_room( pub fn get_room(
&self, &self, user_id: &UserId, version: &str, room_id: &RoomId,
user_id: &UserId, ) -> Result<BTreeMap<String, Raw<KeyBackupData>>> {
version: &str, self.db.get_room(user_id, version, room_id)
room_id: &RoomId, }
) -> Result<BTreeMap<String, Raw<KeyBackupData>>> {
self.db.get_room(user_id, version, room_id)
}
pub fn get_session( pub fn get_session(
&self, &self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str,
user_id: &UserId, ) -> Result<Option<Raw<KeyBackupData>>> {
version: &str, self.db.get_session(user_id, version, room_id, session_id)
room_id: &RoomId, }
session_id: &str,
) -> Result<Option<Raw<KeyBackupData>>> {
self.db.get_session(user_id, version, room_id, session_id)
}
pub fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()> { pub fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()> {
self.db.delete_all_keys(user_id, version) self.db.delete_all_keys(user_id, version)
} }
pub fn delete_room_keys( pub fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()> {
&self, self.db.delete_room_keys(user_id, version, room_id)
user_id: &UserId, }
version: &str,
room_id: &RoomId,
) -> Result<()> {
self.db.delete_room_keys(user_id, version, room_id)
}
pub fn delete_room_key( pub fn delete_room_key(&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str) -> Result<()> {
&self, self.db.delete_room_key(user_id, version, room_id, session_id)
user_id: &UserId, }
version: &str,
room_id: &RoomId,
session_id: &str,
) -> Result<()> {
self.db
.delete_room_key(user_id, version, room_id, session_id)
}
} }

View file

@ -1,37 +1,24 @@
use crate::Result; use crate::Result;
pub trait Data: Send + Sync { pub trait Data: Send + Sync {
fn create_file_metadata( fn create_file_metadata(
&self, &self, mxc: String, width: u32, height: u32, content_disposition: Option<&str>, content_type: Option<&str>,
mxc: String, ) -> Result<Vec<u8>>;
width: u32,
height: u32,
content_disposition: Option<&str>,
content_type: Option<&str>,
) -> Result<Vec<u8>>;
fn delete_file_mxc(&self, mxc: String) -> Result<()>; fn delete_file_mxc(&self, mxc: String) -> Result<()>;
/// Returns content_disposition, content_type and the metadata key. /// Returns content_disposition, content_type and the metadata key.
fn search_file_metadata( fn search_file_metadata(
&self, &self, mxc: String, width: u32, height: u32,
mxc: String, ) -> Result<(Option<String>, Option<String>, Vec<u8>)>;
width: u32,
height: u32,
) -> Result<(Option<String>, Option<String>, Vec<u8>)>;
fn search_mxc_metadata_prefix(&self, mxc: String) -> Result<Vec<Vec<u8>>>; fn search_mxc_metadata_prefix(&self, mxc: String) -> Result<Vec<Vec<u8>>>;
fn get_all_media_keys(&self) -> Result<Vec<Vec<u8>>>; fn get_all_media_keys(&self) -> Result<Vec<Vec<u8>>>;
fn remove_url_preview(&self, url: &str) -> Result<()>; fn remove_url_preview(&self, url: &str) -> Result<()>;
fn set_url_preview( fn set_url_preview(&self, url: &str, data: &super::UrlPreviewData, timestamp: std::time::Duration) -> Result<()>;
&self,
url: &str,
data: &super::UrlPreviewData,
timestamp: std::time::Duration,
) -> Result<()>;
fn get_url_preview(&self, url: &str) -> Option<super::UrlPreviewData>; fn get_url_preview(&self, url: &str) -> Option<super::UrlPreviewData>;
} }

View file

@ -1,579 +1,480 @@
mod data; mod data;
use std::{ use std::{
collections::HashMap, collections::HashMap,
io::Cursor, io::Cursor,
sync::{Arc, RwLock}, sync::{Arc, RwLock},
time::SystemTime, time::SystemTime,
}; };
pub(crate) use data::Data; pub(crate) use data::Data;
use image::imageops::FilterType;
use ruma::OwnedMxcUri; use ruma::OwnedMxcUri;
use serde::Serialize; use serde::Serialize;
use tokio::{
fs::{self, File},
io::{AsyncReadExt, AsyncWriteExt, BufReader},
sync::Mutex,
};
use tracing::{debug, error}; use tracing::{debug, error};
use crate::{services, utils, Error, Result}; use crate::{services, utils, Error, Result};
use image::imageops::FilterType;
use tokio::{
fs::{self, File},
io::{AsyncReadExt, AsyncWriteExt, BufReader},
sync::Mutex,
};
#[derive(Debug)] #[derive(Debug)]
pub struct FileMeta { pub struct FileMeta {
pub content_disposition: Option<String>, pub content_disposition: Option<String>,
pub content_type: Option<String>, pub content_type: Option<String>,
pub file: Vec<u8>, pub file: Vec<u8>,
} }
#[derive(Serialize, Default)] #[derive(Serialize, Default)]
pub struct UrlPreviewData { pub struct UrlPreviewData {
#[serde( #[serde(skip_serializing_if = "Option::is_none", rename(serialize = "og:title"))]
skip_serializing_if = "Option::is_none", pub title: Option<String>,
rename(serialize = "og:title") #[serde(skip_serializing_if = "Option::is_none", rename(serialize = "og:description"))]
)] pub description: Option<String>,
pub title: Option<String>, #[serde(skip_serializing_if = "Option::is_none", rename(serialize = "og:image"))]
#[serde( pub image: Option<String>,
skip_serializing_if = "Option::is_none", #[serde(skip_serializing_if = "Option::is_none", rename(serialize = "matrix:image:size"))]
rename(serialize = "og:description") pub image_size: Option<usize>,
)] #[serde(skip_serializing_if = "Option::is_none", rename(serialize = "og:image:width"))]
pub description: Option<String>, pub image_width: Option<u32>,
#[serde( #[serde(skip_serializing_if = "Option::is_none", rename(serialize = "og:image:height"))]
skip_serializing_if = "Option::is_none", pub image_height: Option<u32>,
rename(serialize = "og:image")
)]
pub image: Option<String>,
#[serde(
skip_serializing_if = "Option::is_none",
rename(serialize = "matrix:image:size")
)]
pub image_size: Option<usize>,
#[serde(
skip_serializing_if = "Option::is_none",
rename(serialize = "og:image:width")
)]
pub image_width: Option<u32>,
#[serde(
skip_serializing_if = "Option::is_none",
rename(serialize = "og:image:height")
)]
pub image_height: Option<u32>,
} }
pub struct Service { pub struct Service {
pub db: &'static dyn Data, pub db: &'static dyn Data,
pub url_preview_mutex: RwLock<HashMap<String, Arc<Mutex<()>>>>, pub url_preview_mutex: RwLock<HashMap<String, Arc<Mutex<()>>>>,
} }
impl Service { impl Service {
/// Uploads a file. /// Uploads a file.
pub async fn create( pub async fn create(
&self, &self, mxc: String, content_disposition: Option<&str>, content_type: Option<&str>, file: &[u8],
mxc: String, ) -> Result<()> {
content_disposition: Option<&str>, // Width, Height = 0 if it's not a thumbnail
content_type: Option<&str>, let key = self.db.create_file_metadata(mxc, 0, 0, content_disposition, content_type)?;
file: &[u8],
) -> Result<()> {
// Width, Height = 0 if it's not a thumbnail
let key = self
.db
.create_file_metadata(mxc, 0, 0, content_disposition, content_type)?;
let path = if cfg!(feature = "sha256_media") { let path = if cfg!(feature = "sha256_media") {
services().globals.get_media_file_new(&key) services().globals.get_media_file_new(&key)
} else { } else {
#[allow(deprecated)] #[allow(deprecated)]
services().globals.get_media_file(&key) services().globals.get_media_file(&key)
}; };
let mut f = File::create(path).await?; let mut f = File::create(path).await?;
f.write_all(file).await?; f.write_all(file).await?;
Ok(()) Ok(())
} }
/// Deletes a file in the database and from the media directory via an MXC /// Deletes a file in the database and from the media directory via an MXC
pub async fn delete(&self, mxc: String) -> Result<()> { pub async fn delete(&self, mxc: String) -> Result<()> {
if let Ok(keys) = self.db.search_mxc_metadata_prefix(mxc.clone()) { if let Ok(keys) = self.db.search_mxc_metadata_prefix(mxc.clone()) {
for key in keys { for key in keys {
let file_path = if cfg!(feature = "sha256_media") { let file_path = if cfg!(feature = "sha256_media") {
services().globals.get_media_file_new(&key) services().globals.get_media_file_new(&key)
} else { } else {
#[allow(deprecated)] #[allow(deprecated)]
services().globals.get_media_file(&key) services().globals.get_media_file(&key)
}; };
debug!("Got local file path: {:?}", file_path); debug!("Got local file path: {:?}", file_path);
debug!( debug!("Deleting local file {:?} from filesystem, original MXC: {}", file_path, mxc);
"Deleting local file {:?} from filesystem, original MXC: {}", tokio::fs::remove_file(file_path).await?;
file_path, mxc
);
tokio::fs::remove_file(file_path).await?;
debug!("Deleting MXC {mxc} from database"); debug!("Deleting MXC {mxc} from database");
self.db.delete_file_mxc(mxc.clone())?; self.db.delete_file_mxc(mxc.clone())?;
} }
Ok(()) Ok(())
} else { } else {
error!("Failed to find any media keys for MXC \"{mxc}\" in our database (MXC does not exist)"); error!("Failed to find any media keys for MXC \"{mxc}\" in our database (MXC does not exist)");
Err(Error::bad_database("Failed to find any media keys for the provided MXC in our database (MXC does not exist)")) Err(Error::bad_database(
} "Failed to find any media keys for the provided MXC in our database (MXC does not exist)",
} ))
}
}
/// Uploads or replaces a file thumbnail. /// Uploads or replaces a file thumbnail.
pub async fn upload_thumbnail( pub async fn upload_thumbnail(
&self, &self, mxc: String, content_disposition: Option<&str>, content_type: Option<&str>, width: u32, height: u32,
mxc: String, file: &[u8],
content_disposition: Option<&str>, ) -> Result<()> {
content_type: Option<&str>, let key = self.db.create_file_metadata(mxc, width, height, content_disposition, content_type)?;
width: u32,
height: u32,
file: &[u8],
) -> Result<()> {
let key =
self.db
.create_file_metadata(mxc, width, height, content_disposition, content_type)?;
let path = if cfg!(feature = "sha256_media") { let path = if cfg!(feature = "sha256_media") {
services().globals.get_media_file_new(&key) services().globals.get_media_file_new(&key)
} else { } else {
#[allow(deprecated)] #[allow(deprecated)]
services().globals.get_media_file(&key) services().globals.get_media_file(&key)
}; };
let mut f = File::create(path).await?; let mut f = File::create(path).await?;
f.write_all(file).await?; f.write_all(file).await?;
Ok(()) Ok(())
} }
/// Downloads a file. /// Downloads a file.
pub async fn get(&self, mxc: String) -> Result<Option<FileMeta>> { pub async fn get(&self, mxc: String) -> Result<Option<FileMeta>> {
if let Ok((content_disposition, content_type, key)) = if let Ok((content_disposition, content_type, key)) = self.db.search_file_metadata(mxc, 0, 0) {
self.db.search_file_metadata(mxc, 0, 0) let path = if cfg!(feature = "sha256_media") {
{ services().globals.get_media_file_new(&key)
let path = if cfg!(feature = "sha256_media") { } else {
services().globals.get_media_file_new(&key) #[allow(deprecated)]
} else { services().globals.get_media_file(&key)
#[allow(deprecated)] };
services().globals.get_media_file(&key)
};
let mut file = Vec::new(); let mut file = Vec::new();
BufReader::new(File::open(path).await?) BufReader::new(File::open(path).await?).read_to_end(&mut file).await?;
.read_to_end(&mut file)
.await?;
Ok(Some(FileMeta { Ok(Some(FileMeta {
content_disposition, content_disposition,
content_type, content_type,
file, file,
})) }))
} else { } else {
Ok(None) Ok(None)
} }
} }
/// Deletes all remote only media files in the given at or after time/duration. Returns a u32 /// Deletes all remote only media files in the given at or after
/// with the amount of media files deleted. /// time/duration. Returns a u32 with the amount of media files deleted.
pub async fn delete_all_remote_media_at_after_time(&self, time: String) -> Result<u32> { pub async fn delete_all_remote_media_at_after_time(&self, time: String) -> Result<u32> {
if let Ok(all_keys) = self.db.get_all_media_keys() { if let Ok(all_keys) = self.db.get_all_media_keys() {
let user_duration: SystemTime = match cyborgtime::parse_duration(&time) { let user_duration: SystemTime = match cyborgtime::parse_duration(&time) {
Ok(duration) => { Ok(duration) => {
debug!("Parsed duration: {:?}", duration); debug!("Parsed duration: {:?}", duration);
debug!("System time now: {:?}", SystemTime::now()); debug!("System time now: {:?}", SystemTime::now());
SystemTime::now() - duration SystemTime::now() - duration
} },
Err(e) => { Err(e) => {
error!("Failed to parse user-specified time duration: {}", e); error!("Failed to parse user-specified time duration: {}", e);
return Err(Error::bad_database( return Err(Error::bad_database("Failed to parse user-specified time duration."));
"Failed to parse user-specified time duration.", },
)); };
}
};
let mut remote_mxcs: Vec<String> = vec![]; let mut remote_mxcs: Vec<String> = vec![];
for key in all_keys { for key in all_keys {
debug!("Full MXC key from database: {:?}", key); debug!("Full MXC key from database: {:?}", key);
// we need to get the MXC URL from the first part of the key (the first 0xff / 255 push) // we need to get the MXC URL from the first part of the key (the first 0xff /
// this code does look kinda crazy but blame conduit for using magic keys // 255 push) this code does look kinda crazy but blame conduit for using magic
let mut parts = key.split(|&b| b == 0xff); // keys
let mxc = parts let mut parts = key.split(|&b| b == 0xFF);
.next() let mxc = parts
.map(|bytes| { .next()
utils::string_from_bytes(bytes).map_err(|e| { .map(|bytes| {
error!("Failed to parse MXC unicode bytes from our database: {}", e); utils::string_from_bytes(bytes).map_err(|e| {
Error::bad_database( error!("Failed to parse MXC unicode bytes from our database: {}", e);
"Failed to parse MXC unicode bytes from our database", Error::bad_database("Failed to parse MXC unicode bytes from our database")
) })
}) })
}) .transpose()?;
.transpose()?;
let mxc_s = match mxc { let mxc_s = match mxc {
Some(mxc) => mxc, Some(mxc) => mxc,
None => { None => {
return Err(Error::bad_database( return Err(Error::bad_database(
"Parsed MXC URL unicode bytes from database but still is None", "Parsed MXC URL unicode bytes from database but still is None",
)); ));
} },
}; };
debug!("Parsed MXC key to URL: {}", mxc_s); debug!("Parsed MXC key to URL: {}", mxc_s);
let mxc = OwnedMxcUri::from(mxc_s); let mxc = OwnedMxcUri::from(mxc_s);
if mxc.server_name() == Ok(services().globals.server_name()) { if mxc.server_name() == Ok(services().globals.server_name()) {
debug!("Ignoring local media MXC: {}", mxc); debug!("Ignoring local media MXC: {}", mxc);
// ignore our own MXC URLs as this would be local media. // ignore our own MXC URLs as this would be local media.
continue; continue;
} }
let path = if cfg!(feature = "sha256_media") { let path = if cfg!(feature = "sha256_media") {
services().globals.get_media_file_new(&key) services().globals.get_media_file_new(&key)
} else { } else {
#[allow(deprecated)] #[allow(deprecated)]
services().globals.get_media_file(&key) services().globals.get_media_file(&key)
}; };
debug!("MXC path: {:?}", path); debug!("MXC path: {:?}", path);
let file_metadata = fs::metadata(path.clone()).await?; let file_metadata = fs::metadata(path.clone()).await?;
debug!("File metadata: {:?}", file_metadata); debug!("File metadata: {:?}", file_metadata);
let file_created_at = file_metadata.created()?; let file_created_at = file_metadata.created()?;
debug!("File created at: {:?}", file_created_at); debug!("File created at: {:?}", file_created_at);
if file_created_at >= user_duration { if file_created_at >= user_duration {
debug!("File is within user duration, pushing to list of file paths and keys to delete."); debug!("File is within user duration, pushing to list of file paths and keys to delete.");
remote_mxcs.push(mxc.to_string()); remote_mxcs.push(mxc.to_string());
} }
} }
debug!("Finished going through all our media in database for eligible keys to delete, checking if these are empty"); debug!(
"Finished going through all our media in database for eligible keys to delete, checking if these are \
empty"
);
if remote_mxcs.is_empty() { if remote_mxcs.is_empty() {
return Err(Error::bad_database( return Err(Error::bad_database("Did not found any eligible MXCs to delete."));
"Did not found any eligible MXCs to delete.", }
));
}
debug!("Deleting media now in the past \"{:?}\".", user_duration); debug!("Deleting media now in the past \"{:?}\".", user_duration);
let mut deletion_count = 0; let mut deletion_count = 0;
for mxc in remote_mxcs { for mxc in remote_mxcs {
debug!("Deleting MXC {mxc} from database and filesystem"); debug!("Deleting MXC {mxc} from database and filesystem");
self.delete(mxc).await?; self.delete(mxc).await?;
deletion_count += 1; deletion_count += 1;
} }
Ok(deletion_count) Ok(deletion_count)
} else { } else {
Err(Error::bad_database( Err(Error::bad_database(
"Failed to get all our media keys (filesystem or database issue?).", "Failed to get all our media keys (filesystem or database issue?).",
)) ))
} }
} }
/// Returns width, height of the thumbnail and whether it should be cropped. Returns None when /// Returns width, height of the thumbnail and whether it should be cropped.
/// the server should send the original file. /// Returns None when the server should send the original file.
pub fn thumbnail_properties(&self, width: u32, height: u32) -> Option<(u32, u32, bool)> { pub fn thumbnail_properties(&self, width: u32, height: u32) -> Option<(u32, u32, bool)> {
match (width, height) { match (width, height) {
(0..=32, 0..=32) => Some((32, 32, true)), (0..=32, 0..=32) => Some((32, 32, true)),
(0..=96, 0..=96) => Some((96, 96, true)), (0..=96, 0..=96) => Some((96, 96, true)),
(0..=320, 0..=240) => Some((320, 240, false)), (0..=320, 0..=240) => Some((320, 240, false)),
(0..=640, 0..=480) => Some((640, 480, false)), (0..=640, 0..=480) => Some((640, 480, false)),
(0..=800, 0..=600) => Some((800, 600, false)), (0..=800, 0..=600) => Some((800, 600, false)),
_ => None, _ => None,
} }
} }
/// Downloads a file's thumbnail. /// Downloads a file's thumbnail.
/// ///
/// Here's an example on how it works: /// Here's an example on how it works:
/// ///
/// - Client requests an image with width=567, height=567 /// - Client requests an image with width=567, height=567
/// - Server rounds that up to (800, 600), so it doesn't have to save too many thumbnails /// - Server rounds that up to (800, 600), so it doesn't have to save too
/// - Server rounds that up again to (958, 600) to fix the aspect ratio (only for width,height>96) /// many thumbnails
/// - Server creates the thumbnail and sends it to the user /// - Server rounds that up again to (958, 600) to fix the aspect ratio
/// /// (only for width,height>96)
/// For width,height <= 96 the server uses another thumbnailing algorithm which crops the image afterwards. /// - Server creates the thumbnail and sends it to the user
pub async fn get_thumbnail( ///
&self, /// For width,height <= 96 the server uses another thumbnailing algorithm
mxc: String, /// which crops the image afterwards.
width: u32, pub async fn get_thumbnail(&self, mxc: String, width: u32, height: u32) -> Result<Option<FileMeta>> {
height: u32, let (width, height, crop) = self.thumbnail_properties(width, height).unwrap_or((0, 0, false)); // 0, 0 because that's the original file
) -> Result<Option<FileMeta>> {
let (width, height, crop) = self
.thumbnail_properties(width, height)
.unwrap_or((0, 0, false)); // 0, 0 because that's the original file
if let Ok((content_disposition, content_type, key)) = if let Ok((content_disposition, content_type, key)) = self.db.search_file_metadata(mxc.clone(), width, height) {
self.db.search_file_metadata(mxc.clone(), width, height) // Using saved thumbnail
{ let path = if cfg!(feature = "sha256_media") {
// Using saved thumbnail services().globals.get_media_file_new(&key)
let path = if cfg!(feature = "sha256_media") { } else {
services().globals.get_media_file_new(&key) #[allow(deprecated)]
} else { services().globals.get_media_file(&key)
#[allow(deprecated)] };
services().globals.get_media_file(&key)
};
let mut file = Vec::new(); let mut file = Vec::new();
File::open(path).await?.read_to_end(&mut file).await?; File::open(path).await?.read_to_end(&mut file).await?;
Ok(Some(FileMeta { Ok(Some(FileMeta {
content_disposition, content_disposition,
content_type, content_type,
file: file.clone(), file: file.clone(),
})) }))
} else if let Ok((content_disposition, content_type, key)) = } else if let Ok((content_disposition, content_type, key)) = self.db.search_file_metadata(mxc.clone(), 0, 0) {
self.db.search_file_metadata(mxc.clone(), 0, 0) // Generate a thumbnail
{ let path = if cfg!(feature = "sha256_media") {
// Generate a thumbnail services().globals.get_media_file_new(&key)
let path = if cfg!(feature = "sha256_media") { } else {
services().globals.get_media_file_new(&key) #[allow(deprecated)]
} else { services().globals.get_media_file(&key)
#[allow(deprecated)] };
services().globals.get_media_file(&key)
};
let mut file = Vec::new(); let mut file = Vec::new();
File::open(path).await?.read_to_end(&mut file).await?; File::open(path).await?.read_to_end(&mut file).await?;
if let Ok(image) = image::load_from_memory(&file) { if let Ok(image) = image::load_from_memory(&file) {
let original_width = image.width(); let original_width = image.width();
let original_height = image.height(); let original_height = image.height();
if width > original_width || height > original_height { if width > original_width || height > original_height {
return Ok(Some(FileMeta { return Ok(Some(FileMeta {
content_disposition, content_disposition,
content_type, content_type,
file: file.clone(), file: file.clone(),
})); }));
} }
let thumbnail = if crop { let thumbnail = if crop {
image.resize_to_fill(width, height, FilterType::CatmullRom) image.resize_to_fill(width, height, FilterType::CatmullRom)
} else { } else {
let (exact_width, exact_height) = { let (exact_width, exact_height) = {
// Copied from image::dynimage::resize_dimensions // Copied from image::dynimage::resize_dimensions
let ratio = u64::from(original_width) * u64::from(height); let ratio = u64::from(original_width) * u64::from(height);
let nratio = u64::from(width) * u64::from(original_height); let nratio = u64::from(width) * u64::from(original_height);
let use_width = nratio <= ratio; let use_width = nratio <= ratio;
let intermediate = if use_width { let intermediate = if use_width {
u64::from(original_height) * u64::from(width) u64::from(original_height) * u64::from(width) / u64::from(original_width)
/ u64::from(original_width) } else {
} else { u64::from(original_width) * u64::from(height) / u64::from(original_height)
u64::from(original_width) * u64::from(height) };
/ u64::from(original_height) if use_width {
}; if intermediate <= u64::from(::std::u32::MAX) {
if use_width { (width, intermediate as u32)
if intermediate <= u64::from(::std::u32::MAX) { } else {
(width, intermediate as u32) (
} else { (u64::from(width) * u64::from(::std::u32::MAX) / intermediate) as u32,
( ::std::u32::MAX,
(u64::from(width) * u64::from(::std::u32::MAX) / intermediate) )
as u32, }
::std::u32::MAX, } else if intermediate <= u64::from(::std::u32::MAX) {
) (intermediate as u32, height)
} } else {
} else if intermediate <= u64::from(::std::u32::MAX) { (
(intermediate as u32, height) ::std::u32::MAX,
} else { (u64::from(height) * u64::from(::std::u32::MAX) / intermediate) as u32,
( )
::std::u32::MAX, }
(u64::from(height) * u64::from(::std::u32::MAX) / intermediate) };
as u32,
)
}
};
image.thumbnail_exact(exact_width, exact_height) image.thumbnail_exact(exact_width, exact_height)
}; };
let mut thumbnail_bytes = Vec::new(); let mut thumbnail_bytes = Vec::new();
thumbnail.write_to( thumbnail.write_to(&mut Cursor::new(&mut thumbnail_bytes), image::ImageOutputFormat::Png)?;
&mut Cursor::new(&mut thumbnail_bytes),
image::ImageOutputFormat::Png,
)?;
// Save thumbnail in database so we don't have to generate it again next time // Save thumbnail in database so we don't have to generate it again next time
let thumbnail_key = self.db.create_file_metadata( let thumbnail_key = self.db.create_file_metadata(
mxc, mxc,
width, width,
height, height,
content_disposition.as_deref(), content_disposition.as_deref(),
content_type.as_deref(), content_type.as_deref(),
)?; )?;
let path = if cfg!(feature = "sha256_media") { let path = if cfg!(feature = "sha256_media") {
services().globals.get_media_file_new(&thumbnail_key) services().globals.get_media_file_new(&thumbnail_key)
} else { } else {
#[allow(deprecated)] #[allow(deprecated)]
services().globals.get_media_file(&thumbnail_key) services().globals.get_media_file(&thumbnail_key)
}; };
let mut f = File::create(path).await?; let mut f = File::create(path).await?;
f.write_all(&thumbnail_bytes).await?; f.write_all(&thumbnail_bytes).await?;
Ok(Some(FileMeta { Ok(Some(FileMeta {
content_disposition, content_disposition,
content_type, content_type,
file: thumbnail_bytes.clone(), file: thumbnail_bytes.clone(),
})) }))
} else { } else {
// Couldn't parse file to generate thumbnail, send original // Couldn't parse file to generate thumbnail, send original
Ok(Some(FileMeta { Ok(Some(FileMeta {
content_disposition, content_disposition,
content_type, content_type,
file: file.clone(), file: file.clone(),
})) }))
} }
} else { } else {
Ok(None) Ok(None)
} }
} }
pub async fn get_url_preview(&self, url: &str) -> Option<UrlPreviewData> { pub async fn get_url_preview(&self, url: &str) -> Option<UrlPreviewData> { self.db.get_url_preview(url) }
self.db.get_url_preview(url)
}
pub async fn remove_url_preview(&self, url: &str) -> Result<()> { pub async fn remove_url_preview(&self, url: &str) -> Result<()> {
// TODO: also remove the downloaded image // TODO: also remove the downloaded image
self.db.remove_url_preview(url) self.db.remove_url_preview(url)
} }
pub async fn set_url_preview(&self, url: &str, data: &UrlPreviewData) -> Result<()> { pub async fn set_url_preview(&self, url: &str, data: &UrlPreviewData) -> Result<()> {
let now = SystemTime::now() let now = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).expect("valid system time");
.duration_since(SystemTime::UNIX_EPOCH) self.db.set_url_preview(url, data, now)
.expect("valid system time"); }
self.db.set_url_preview(url, data, now)
}
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::path::PathBuf; use std::path::PathBuf;
use sha2::Digest; use base64::{engine::general_purpose, Engine as _};
use sha2::Digest;
use base64::{engine::general_purpose, Engine as _}; use super::*;
use super::*; struct MockedKVDatabase;
struct MockedKVDatabase; impl Data for MockedKVDatabase {
fn create_file_metadata(
&self, mxc: String, width: u32, height: u32, content_disposition: Option<&str>, content_type: Option<&str>,
) -> Result<Vec<u8>> {
// copied from src/database/key_value/media.rs
let mut key = mxc.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(&width.to_be_bytes());
key.extend_from_slice(&height.to_be_bytes());
key.push(0xFF);
key.extend_from_slice(content_disposition.as_ref().map(|f| f.as_bytes()).unwrap_or_default());
key.push(0xFF);
key.extend_from_slice(content_type.as_ref().map(|c| c.as_bytes()).unwrap_or_default());
impl Data for MockedKVDatabase { Ok(key)
fn create_file_metadata( }
&self,
mxc: String,
width: u32,
height: u32,
content_disposition: Option<&str>,
content_type: Option<&str>,
) -> Result<Vec<u8>> {
// copied from src/database/key_value/media.rs
let mut key = mxc.as_bytes().to_vec();
key.push(0xff);
key.extend_from_slice(&width.to_be_bytes());
key.extend_from_slice(&height.to_be_bytes());
key.push(0xff);
key.extend_from_slice(
content_disposition
.as_ref()
.map(|f| f.as_bytes())
.unwrap_or_default(),
);
key.push(0xff);
key.extend_from_slice(
content_type
.as_ref()
.map(|c| c.as_bytes())
.unwrap_or_default(),
);
Ok(key) fn delete_file_mxc(&self, _mxc: String) -> Result<()> { todo!() }
}
fn delete_file_mxc(&self, _mxc: String) -> Result<()> { fn search_mxc_metadata_prefix(&self, _mxc: String) -> Result<Vec<Vec<u8>>> { todo!() }
todo!()
}
fn search_mxc_metadata_prefix(&self, _mxc: String) -> Result<Vec<Vec<u8>>> { fn get_all_media_keys(&self) -> Result<Vec<Vec<u8>>> { todo!() }
todo!()
}
fn get_all_media_keys(&self) -> Result<Vec<Vec<u8>>> { fn search_file_metadata(
todo!() &self, _mxc: String, _width: u32, _height: u32,
} ) -> Result<(Option<String>, Option<String>, Vec<u8>)> {
todo!()
}
fn search_file_metadata( fn remove_url_preview(&self, _url: &str) -> Result<()> { todo!() }
&self,
_mxc: String,
_width: u32,
_height: u32,
) -> Result<(Option<String>, Option<String>, Vec<u8>)> {
todo!()
}
fn remove_url_preview(&self, _url: &str) -> Result<()> { fn set_url_preview(&self, _url: &str, _data: &UrlPreviewData, _timestamp: std::time::Duration) -> Result<()> {
todo!() todo!()
} }
fn set_url_preview( fn get_url_preview(&self, _url: &str) -> Option<UrlPreviewData> { todo!() }
&self, }
_url: &str,
_data: &UrlPreviewData,
_timestamp: std::time::Duration,
) -> Result<()> {
todo!()
}
fn get_url_preview(&self, _url: &str) -> Option<UrlPreviewData> { #[tokio::test]
todo!() async fn long_file_names_works() {
} static DB: MockedKVDatabase = MockedKVDatabase;
} let media = Service {
db: &DB,
url_preview_mutex: RwLock::new(HashMap::new()),
};
#[tokio::test] let mxc = "mxc://example.com/ascERGshawAWawugaAcauga".to_owned();
async fn long_file_names_works() { let width = 100;
static DB: MockedKVDatabase = MockedKVDatabase; let height = 100;
let media = Service { let content_disposition = "attachment; filename=\"this is a very long file name with spaces and special \
db: &DB, characters like äöüß and even emoji like 🦀.png\"";
url_preview_mutex: RwLock::new(HashMap::new()), let content_type = "image/png";
}; let key =
media.db.create_file_metadata(mxc, width, height, Some(content_disposition), Some(content_type)).unwrap();
let mxc = "mxc://example.com/ascERGshawAWawugaAcauga".to_owned(); let mut r = PathBuf::new();
let width = 100; r.push("/tmp");
let height = 100; r.push("media");
let content_disposition = "attachment; filename=\"this is a very long file name with spaces and special characters like äöüß and even emoji like 🦀.png\""; // r.push(base64::encode_config(key, base64::URL_SAFE_NO_PAD));
let content_type = "image/png"; // use the sha256 hash of the key as the file name instead of the key itself
let key = media // this is because the base64 encoded key can be longer than 255 characters.
.db r.push(general_purpose::URL_SAFE_NO_PAD.encode(sha2::Sha256::digest(key)));
.create_file_metadata( // Check that the file path is not longer than 255 characters
mxc, // (255 is the maximum length of a file path on most file systems)
width, assert!(
height, r.to_str().unwrap().len() <= 255,
Some(content_disposition), "File path is too long: {}",
Some(content_type), r.to_str().unwrap().len()
) );
.unwrap(); }
let mut r = PathBuf::new();
r.push("/tmp");
r.push("media");
// r.push(base64::encode_config(key, base64::URL_SAFE_NO_PAD));
// use the sha256 hash of the key as the file name instead of the key itself
// this is because the base64 encoded key can be longer than 255 characters.
r.push(general_purpose::URL_SAFE_NO_PAD.encode(sha2::Sha256::digest(key)));
// Check that the file path is not longer than 255 characters
// (255 is the maximum length of a file path on most file systems)
assert!(
r.to_str().unwrap().len() <= 255,
"File path is too long: {}",
r.to_str().unwrap().len()
);
}
} }

View file

@ -1,6 +1,6 @@
use std::{ use std::{
collections::{BTreeMap, HashMap}, collections::{BTreeMap, HashMap},
sync::{Arc, Mutex, RwLock}, sync::{Arc, Mutex, RwLock},
}; };
use lru_cache::LruCache; use lru_cache::LruCache;
@ -22,210 +22,186 @@ pub(crate) mod uiaa;
pub(crate) mod users; pub(crate) mod users;
pub struct Services<'a> { pub struct Services<'a> {
pub appservice: appservice::Service, pub appservice: appservice::Service,
pub pusher: pusher::Service, pub pusher: pusher::Service,
pub rooms: rooms::Service, pub rooms: rooms::Service,
pub transaction_ids: transaction_ids::Service, pub transaction_ids: transaction_ids::Service,
pub uiaa: uiaa::Service, pub uiaa: uiaa::Service,
pub users: users::Service, pub users: users::Service,
pub account_data: account_data::Service, pub account_data: account_data::Service,
pub admin: Arc<admin::Service>, pub admin: Arc<admin::Service>,
pub globals: globals::Service<'a>, pub globals: globals::Service<'a>,
pub key_backups: key_backups::Service, pub key_backups: key_backups::Service,
pub media: media::Service, pub media: media::Service,
pub sending: Arc<sending::Service>, pub sending: Arc<sending::Service>,
} }
impl Services<'_> { impl Services<'_> {
pub fn build< pub fn build<
D: appservice::Data D: appservice::Data
+ pusher::Data + pusher::Data
+ rooms::Data + rooms::Data
+ transaction_ids::Data + transaction_ids::Data
+ uiaa::Data + uiaa::Data
+ users::Data + users::Data
+ account_data::Data + account_data::Data
+ globals::Data + globals::Data
+ key_backups::Data + key_backups::Data
+ media::Data + media::Data
+ sending::Data + sending::Data
+ 'static, + 'static,
>( >(
db: &'static D, db: &'static D, config: Config,
config: Config, ) -> Result<Self> {
) -> Result<Self> { Ok(Self {
Ok(Self { appservice: appservice::Service {
appservice: appservice::Service { db }, db,
pusher: pusher::Service { db }, },
rooms: rooms::Service { pusher: pusher::Service {
alias: rooms::alias::Service { db }, db,
auth_chain: rooms::auth_chain::Service { db }, },
directory: rooms::directory::Service { db }, rooms: rooms::Service {
edus: rooms::edus::Service { alias: rooms::alias::Service {
presence: rooms::edus::presence::Service { db }, db,
read_receipt: rooms::edus::read_receipt::Service { db }, },
typing: rooms::edus::typing::Service { db }, auth_chain: rooms::auth_chain::Service {
}, db,
event_handler: rooms::event_handler::Service, },
lazy_loading: rooms::lazy_loading::Service { directory: rooms::directory::Service {
db, db,
lazy_load_waiting: Mutex::new(HashMap::new()), },
}, edus: rooms::edus::Service {
metadata: rooms::metadata::Service { db }, presence: rooms::edus::presence::Service {
outlier: rooms::outlier::Service { db }, db,
pdu_metadata: rooms::pdu_metadata::Service { db }, },
search: rooms::search::Service { db }, read_receipt: rooms::edus::read_receipt::Service {
short: rooms::short::Service { db }, db,
state: rooms::state::Service { db }, },
state_accessor: rooms::state_accessor::Service { typing: rooms::edus::typing::Service {
db, db,
server_visibility_cache: Mutex::new(LruCache::new( },
(100.0 * config.conduit_cache_capacity_modifier) as usize, },
)), event_handler: rooms::event_handler::Service,
user_visibility_cache: Mutex::new(LruCache::new( lazy_loading: rooms::lazy_loading::Service {
(100.0 * config.conduit_cache_capacity_modifier) as usize, db,
)), lazy_load_waiting: Mutex::new(HashMap::new()),
}, },
state_cache: rooms::state_cache::Service { db }, metadata: rooms::metadata::Service {
state_compressor: rooms::state_compressor::Service { db,
db, },
stateinfo_cache: Mutex::new(LruCache::new( outlier: rooms::outlier::Service {
(100.0 * config.conduit_cache_capacity_modifier) as usize, db,
)), },
}, pdu_metadata: rooms::pdu_metadata::Service {
timeline: rooms::timeline::Service { db,
db, },
lasttimelinecount_cache: Mutex::new(HashMap::new()), search: rooms::search::Service {
}, db,
threads: rooms::threads::Service { db }, },
spaces: rooms::spaces::Service { short: rooms::short::Service {
roomid_spacechunk_cache: Mutex::new(LruCache::new( db,
(100.0 * config.conduit_cache_capacity_modifier) as usize, },
)), state: rooms::state::Service {
}, db,
user: rooms::user::Service { db }, },
}, state_accessor: rooms::state_accessor::Service {
transaction_ids: transaction_ids::Service { db }, db,
uiaa: uiaa::Service { db }, server_visibility_cache: Mutex::new(LruCache::new(
users: users::Service { (100.0 * config.conduit_cache_capacity_modifier) as usize,
db, )),
connections: Mutex::new(BTreeMap::new()), user_visibility_cache: Mutex::new(LruCache::new(
}, (100.0 * config.conduit_cache_capacity_modifier) as usize,
account_data: account_data::Service { db }, )),
admin: admin::Service::build(), },
key_backups: key_backups::Service { db }, state_cache: rooms::state_cache::Service {
media: media::Service { db,
db, },
url_preview_mutex: RwLock::new(HashMap::new()), state_compressor: rooms::state_compressor::Service {
}, db,
sending: sending::Service::build(db, &config), stateinfo_cache: Mutex::new(LruCache::new(
(100.0 * config.conduit_cache_capacity_modifier) as usize,
)),
},
timeline: rooms::timeline::Service {
db,
lasttimelinecount_cache: Mutex::new(HashMap::new()),
},
threads: rooms::threads::Service {
db,
},
spaces: rooms::spaces::Service {
roomid_spacechunk_cache: Mutex::new(LruCache::new(
(100.0 * config.conduit_cache_capacity_modifier) as usize,
)),
},
user: rooms::user::Service {
db,
},
},
transaction_ids: transaction_ids::Service {
db,
},
uiaa: uiaa::Service {
db,
},
users: users::Service {
db,
connections: Mutex::new(BTreeMap::new()),
},
account_data: account_data::Service {
db,
},
admin: admin::Service::build(),
key_backups: key_backups::Service {
db,
},
media: media::Service {
db,
url_preview_mutex: RwLock::new(HashMap::new()),
},
sending: sending::Service::build(db, &config),
globals: globals::Service::load(db, config)?, globals: globals::Service::load(db, config)?,
}) })
} }
fn memory_usage(&self) -> String {
let lazy_load_waiting = self
.rooms
.lazy_loading
.lazy_load_waiting
.lock()
.unwrap()
.len();
let server_visibility_cache = self
.rooms
.state_accessor
.server_visibility_cache
.lock()
.unwrap()
.len();
let user_visibility_cache = self
.rooms
.state_accessor
.user_visibility_cache
.lock()
.unwrap()
.len();
let stateinfo_cache = self
.rooms
.state_compressor
.stateinfo_cache
.lock()
.unwrap()
.len();
let lasttimelinecount_cache = self
.rooms
.timeline
.lasttimelinecount_cache
.lock()
.unwrap()
.len();
let roomid_spacechunk_cache = self
.rooms
.spaces
.roomid_spacechunk_cache
.lock()
.unwrap()
.len();
format!( fn memory_usage(&self) -> String {
"\ let lazy_load_waiting = self.rooms.lazy_loading.lazy_load_waiting.lock().unwrap().len();
let server_visibility_cache = self.rooms.state_accessor.server_visibility_cache.lock().unwrap().len();
let user_visibility_cache = self.rooms.state_accessor.user_visibility_cache.lock().unwrap().len();
let stateinfo_cache = self.rooms.state_compressor.stateinfo_cache.lock().unwrap().len();
let lasttimelinecount_cache = self.rooms.timeline.lasttimelinecount_cache.lock().unwrap().len();
let roomid_spacechunk_cache = self.rooms.spaces.roomid_spacechunk_cache.lock().unwrap().len();
format!(
"\
lazy_load_waiting: {lazy_load_waiting} lazy_load_waiting: {lazy_load_waiting}
server_visibility_cache: {server_visibility_cache} server_visibility_cache: {server_visibility_cache}
user_visibility_cache: {user_visibility_cache} user_visibility_cache: {user_visibility_cache}
stateinfo_cache: {stateinfo_cache} stateinfo_cache: {stateinfo_cache}
lasttimelinecount_cache: {lasttimelinecount_cache} lasttimelinecount_cache: {lasttimelinecount_cache}
roomid_spacechunk_cache: {roomid_spacechunk_cache}\ roomid_spacechunk_cache: {roomid_spacechunk_cache}"
" )
) }
}
fn clear_caches(&self, amount: u32) { fn clear_caches(&self, amount: u32) {
if amount > 0 { if amount > 0 {
self.rooms self.rooms.lazy_loading.lazy_load_waiting.lock().unwrap().clear();
.lazy_loading }
.lazy_load_waiting if amount > 1 {
.lock() self.rooms.state_accessor.server_visibility_cache.lock().unwrap().clear();
.unwrap() }
.clear(); if amount > 2 {
} self.rooms.state_accessor.user_visibility_cache.lock().unwrap().clear();
if amount > 1 { }
self.rooms if amount > 3 {
.state_accessor self.rooms.state_compressor.stateinfo_cache.lock().unwrap().clear();
.server_visibility_cache }
.lock() if amount > 4 {
.unwrap() self.rooms.timeline.lasttimelinecount_cache.lock().unwrap().clear();
.clear(); }
} if amount > 5 {
if amount > 2 { self.rooms.spaces.roomid_spacechunk_cache.lock().unwrap().clear();
self.rooms }
.state_accessor }
.user_visibility_cache
.lock()
.unwrap()
.clear();
}
if amount > 3 {
self.rooms
.state_compressor
.stateinfo_cache
.lock()
.unwrap()
.clear();
}
if amount > 4 {
self.rooms
.timeline
.lasttimelinecount_cache
.lock()
.unwrap()
.clear();
}
if amount > 5 {
self.rooms
.spaces
.roomid_spacechunk_cache
.lock()
.unwrap()
.clear();
}
}
} }

View file

@ -1,410 +1,372 @@
use crate::Error; use std::{cmp::Ordering, collections::BTreeMap, sync::Arc};
use ruma::{ use ruma::{
canonical_json::redact_content_in_place, canonical_json::redact_content_in_place,
events::{ events::{
room::member::RoomMemberEventContent, space::child::HierarchySpaceChildEvent, room::member::RoomMemberEventContent, space::child::HierarchySpaceChildEvent, AnyEphemeralRoomEvent,
AnyEphemeralRoomEvent, AnyMessageLikeEvent, AnyStateEvent, AnyStrippedStateEvent, AnyMessageLikeEvent, AnyStateEvent, AnyStrippedStateEvent, AnySyncStateEvent, AnySyncTimelineEvent,
AnySyncStateEvent, AnySyncTimelineEvent, AnyTimelineEvent, StateEvent, TimelineEventType, AnyTimelineEvent, StateEvent, TimelineEventType,
}, },
serde::Raw, serde::Raw,
state_res, CanonicalJsonObject, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, state_res, CanonicalJsonObject, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedRoomId,
OwnedEventId, OwnedRoomId, OwnedUserId, RoomId, RoomVersionId, UInt, UserId, OwnedUserId, RoomId, RoomVersionId, UInt, UserId,
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::{ use serde_json::{
json, json,
value::{to_raw_value, RawValue as RawJsonValue}, value::{to_raw_value, RawValue as RawJsonValue},
}; };
use std::{cmp::Ordering, collections::BTreeMap, sync::Arc};
use tracing::warn; use tracing::warn;
use crate::Error;
/// Content hashes of a PDU. /// Content hashes of a PDU.
#[derive(Clone, Debug, Deserialize, Serialize)] #[derive(Clone, Debug, Deserialize, Serialize)]
pub struct EventHash { pub struct EventHash {
/// The SHA-256 hash. /// The SHA-256 hash.
pub sha256: String, pub sha256: String,
} }
#[derive(Clone, Deserialize, Serialize, Debug)] #[derive(Clone, Deserialize, Serialize, Debug)]
pub struct PduEvent { pub struct PduEvent {
pub event_id: Arc<EventId>, pub event_id: Arc<EventId>,
pub room_id: OwnedRoomId, pub room_id: OwnedRoomId,
pub sender: OwnedUserId, pub sender: OwnedUserId,
pub origin_server_ts: UInt, pub origin_server_ts: UInt,
#[serde(rename = "type")] #[serde(rename = "type")]
pub kind: TimelineEventType, pub kind: TimelineEventType,
pub content: Box<RawJsonValue>, pub content: Box<RawJsonValue>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub state_key: Option<String>, pub state_key: Option<String>,
pub prev_events: Vec<Arc<EventId>>, pub prev_events: Vec<Arc<EventId>>,
pub depth: UInt, pub depth: UInt,
pub auth_events: Vec<Arc<EventId>>, pub auth_events: Vec<Arc<EventId>>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub redacts: Option<Arc<EventId>>, pub redacts: Option<Arc<EventId>>,
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
pub unsigned: Option<Box<RawJsonValue>>, pub unsigned: Option<Box<RawJsonValue>>,
pub hashes: EventHash, pub hashes: EventHash,
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
pub signatures: Option<Box<RawJsonValue>>, // BTreeMap<Box<ServerName>, BTreeMap<ServerSigningKeyId, String>> pub signatures: Option<Box<RawJsonValue>>, // BTreeMap<Box<ServerName>, BTreeMap<ServerSigningKeyId, String>>
} }
impl PduEvent { impl PduEvent {
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub fn redact( pub fn redact(&mut self, room_version_id: RoomVersionId, reason: &PduEvent) -> crate::Result<()> {
&mut self, self.unsigned = None;
room_version_id: RoomVersionId,
reason: &PduEvent,
) -> crate::Result<()> {
self.unsigned = None;
let mut content = serde_json::from_str(self.content.get()) let mut content = serde_json::from_str(self.content.get())
.map_err(|_| Error::bad_database("PDU in db has invalid content."))?; .map_err(|_| Error::bad_database("PDU in db has invalid content."))?;
redact_content_in_place(&mut content, &room_version_id, self.kind.to_string()) redact_content_in_place(&mut content, &room_version_id, self.kind.to_string())
.map_err(|e| Error::RedactionError(self.sender.server_name().to_owned(), e))?; .map_err(|e| Error::RedactionError(self.sender.server_name().to_owned(), e))?;
self.unsigned = Some(to_raw_value(&json!({ self.unsigned = Some(
"redacted_because": serde_json::to_value(reason).expect("to_value(PduEvent) always works") to_raw_value(&json!({
})).expect("to string always works")); "redacted_because": serde_json::to_value(reason).expect("to_value(PduEvent) always works")
}))
.expect("to string always works"),
);
self.content = to_raw_value(&content).expect("to string always works"); self.content = to_raw_value(&content).expect("to string always works");
Ok(()) Ok(())
} }
pub fn remove_transaction_id(&mut self) -> crate::Result<()> { pub fn remove_transaction_id(&mut self) -> crate::Result<()> {
if let Some(unsigned) = &self.unsigned { if let Some(unsigned) = &self.unsigned {
let mut unsigned: BTreeMap<String, Box<RawJsonValue>> = let mut unsigned: BTreeMap<String, Box<RawJsonValue>> = serde_json::from_str(unsigned.get())
serde_json::from_str(unsigned.get()) .map_err(|_| Error::bad_database("Invalid unsigned in pdu event"))?;
.map_err(|_| Error::bad_database("Invalid unsigned in pdu event"))?; unsigned.remove("transaction_id");
unsigned.remove("transaction_id"); self.unsigned = Some(to_raw_value(&unsigned).expect("unsigned is valid"));
self.unsigned = Some(to_raw_value(&unsigned).expect("unsigned is valid")); }
}
Ok(()) Ok(())
} }
pub fn add_age(&mut self) -> crate::Result<()> { pub fn add_age(&mut self) -> crate::Result<()> {
let mut unsigned: BTreeMap<String, Box<RawJsonValue>> = self let mut unsigned: BTreeMap<String, Box<RawJsonValue>> = self
.unsigned .unsigned
.as_ref() .as_ref()
.map_or_else(|| Ok(BTreeMap::new()), |u| serde_json::from_str(u.get())) .map_or_else(|| Ok(BTreeMap::new()), |u| serde_json::from_str(u.get()))
.map_err(|_| Error::bad_database("Invalid unsigned in pdu event"))?; .map_err(|_| Error::bad_database("Invalid unsigned in pdu event"))?;
unsigned.insert("age".to_owned(), to_raw_value(&1).unwrap()); unsigned.insert("age".to_owned(), to_raw_value(&1).unwrap());
self.unsigned = Some(to_raw_value(&unsigned).expect("unsigned is valid")); self.unsigned = Some(to_raw_value(&unsigned).expect("unsigned is valid"));
Ok(()) Ok(())
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub fn to_sync_room_event(&self) -> Raw<AnySyncTimelineEvent> { pub fn to_sync_room_event(&self) -> Raw<AnySyncTimelineEvent> {
let mut json = json!({ let mut json = json!({
"content": self.content, "content": self.content,
"type": self.kind, "type": self.kind,
"event_id": self.event_id, "event_id": self.event_id,
"sender": self.sender, "sender": self.sender,
"origin_server_ts": self.origin_server_ts, "origin_server_ts": self.origin_server_ts,
}); });
if let Some(unsigned) = &self.unsigned { if let Some(unsigned) = &self.unsigned {
json["unsigned"] = json!(unsigned); json["unsigned"] = json!(unsigned);
} }
if let Some(state_key) = &self.state_key { if let Some(state_key) = &self.state_key {
json["state_key"] = json!(state_key); json["state_key"] = json!(state_key);
} }
if let Some(redacts) = &self.redacts { if let Some(redacts) = &self.redacts {
json["redacts"] = json!(redacts); json["redacts"] = json!(redacts);
} }
serde_json::from_value(json).expect("Raw::from_value always works") serde_json::from_value(json).expect("Raw::from_value always works")
} }
/// This only works for events that are also AnyRoomEvents. /// This only works for events that are also AnyRoomEvents.
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub fn to_any_event(&self) -> Raw<AnyEphemeralRoomEvent> { pub fn to_any_event(&self) -> Raw<AnyEphemeralRoomEvent> {
let mut json = json!({ let mut json = json!({
"content": self.content, "content": self.content,
"type": self.kind, "type": self.kind,
"event_id": self.event_id, "event_id": self.event_id,
"sender": self.sender, "sender": self.sender,
"origin_server_ts": self.origin_server_ts, "origin_server_ts": self.origin_server_ts,
"room_id": self.room_id, "room_id": self.room_id,
}); });
if let Some(unsigned) = &self.unsigned { if let Some(unsigned) = &self.unsigned {
json["unsigned"] = json!(unsigned); json["unsigned"] = json!(unsigned);
} }
if let Some(state_key) = &self.state_key { if let Some(state_key) = &self.state_key {
json["state_key"] = json!(state_key); json["state_key"] = json!(state_key);
} }
if let Some(redacts) = &self.redacts { if let Some(redacts) = &self.redacts {
json["redacts"] = json!(redacts); json["redacts"] = json!(redacts);
} }
serde_json::from_value(json).expect("Raw::from_value always works") serde_json::from_value(json).expect("Raw::from_value always works")
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub fn to_room_event(&self) -> Raw<AnyTimelineEvent> { pub fn to_room_event(&self) -> Raw<AnyTimelineEvent> {
let mut json = json!({ let mut json = json!({
"content": self.content, "content": self.content,
"type": self.kind, "type": self.kind,
"event_id": self.event_id, "event_id": self.event_id,
"sender": self.sender, "sender": self.sender,
"origin_server_ts": self.origin_server_ts, "origin_server_ts": self.origin_server_ts,
"room_id": self.room_id, "room_id": self.room_id,
}); });
if let Some(unsigned) = &self.unsigned { if let Some(unsigned) = &self.unsigned {
json["unsigned"] = json!(unsigned); json["unsigned"] = json!(unsigned);
} }
if let Some(state_key) = &self.state_key { if let Some(state_key) = &self.state_key {
json["state_key"] = json!(state_key); json["state_key"] = json!(state_key);
} }
if let Some(redacts) = &self.redacts { if let Some(redacts) = &self.redacts {
json["redacts"] = json!(redacts); json["redacts"] = json!(redacts);
} }
serde_json::from_value(json).expect("Raw::from_value always works") serde_json::from_value(json).expect("Raw::from_value always works")
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub fn to_message_like_event(&self) -> Raw<AnyMessageLikeEvent> { pub fn to_message_like_event(&self) -> Raw<AnyMessageLikeEvent> {
let mut json = json!({ let mut json = json!({
"content": self.content, "content": self.content,
"type": self.kind, "type": self.kind,
"event_id": self.event_id, "event_id": self.event_id,
"sender": self.sender, "sender": self.sender,
"origin_server_ts": self.origin_server_ts, "origin_server_ts": self.origin_server_ts,
"room_id": self.room_id, "room_id": self.room_id,
}); });
if let Some(unsigned) = &self.unsigned { if let Some(unsigned) = &self.unsigned {
json["unsigned"] = json!(unsigned); json["unsigned"] = json!(unsigned);
} }
if let Some(state_key) = &self.state_key { if let Some(state_key) = &self.state_key {
json["state_key"] = json!(state_key); json["state_key"] = json!(state_key);
} }
if let Some(redacts) = &self.redacts { if let Some(redacts) = &self.redacts {
json["redacts"] = json!(redacts); json["redacts"] = json!(redacts);
} }
serde_json::from_value(json).expect("Raw::from_value always works") serde_json::from_value(json).expect("Raw::from_value always works")
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub fn to_state_event(&self) -> Raw<AnyStateEvent> { pub fn to_state_event(&self) -> Raw<AnyStateEvent> {
let mut json = json!({ let mut json = json!({
"content": self.content, "content": self.content,
"type": self.kind, "type": self.kind,
"event_id": self.event_id, "event_id": self.event_id,
"sender": self.sender, "sender": self.sender,
"origin_server_ts": self.origin_server_ts, "origin_server_ts": self.origin_server_ts,
"room_id": self.room_id, "room_id": self.room_id,
"state_key": self.state_key, "state_key": self.state_key,
}); });
if let Some(unsigned) = &self.unsigned { if let Some(unsigned) = &self.unsigned {
json["unsigned"] = json!(unsigned); json["unsigned"] = json!(unsigned);
} }
serde_json::from_value(json).expect("Raw::from_value always works") serde_json::from_value(json).expect("Raw::from_value always works")
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub fn to_sync_state_event(&self) -> Raw<AnySyncStateEvent> { pub fn to_sync_state_event(&self) -> Raw<AnySyncStateEvent> {
let mut json = json!({ let mut json = json!({
"content": self.content, "content": self.content,
"type": self.kind, "type": self.kind,
"event_id": self.event_id, "event_id": self.event_id,
"sender": self.sender, "sender": self.sender,
"origin_server_ts": self.origin_server_ts, "origin_server_ts": self.origin_server_ts,
"state_key": self.state_key, "state_key": self.state_key,
}); });
if let Some(unsigned) = &self.unsigned { if let Some(unsigned) = &self.unsigned {
json["unsigned"] = json!(unsigned); json["unsigned"] = json!(unsigned);
} }
serde_json::from_value(json).expect("Raw::from_value always works") serde_json::from_value(json).expect("Raw::from_value always works")
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub fn to_stripped_state_event(&self) -> Raw<AnyStrippedStateEvent> { pub fn to_stripped_state_event(&self) -> Raw<AnyStrippedStateEvent> {
let json = json!({ let json = json!({
"content": self.content, "content": self.content,
"type": self.kind, "type": self.kind,
"sender": self.sender, "sender": self.sender,
"state_key": self.state_key, "state_key": self.state_key,
}); });
serde_json::from_value(json).expect("Raw::from_value always works") serde_json::from_value(json).expect("Raw::from_value always works")
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub fn to_stripped_spacechild_state_event(&self) -> Raw<HierarchySpaceChildEvent> { pub fn to_stripped_spacechild_state_event(&self) -> Raw<HierarchySpaceChildEvent> {
let json = json!({ let json = json!({
"content": self.content, "content": self.content,
"type": self.kind, "type": self.kind,
"sender": self.sender, "sender": self.sender,
"state_key": self.state_key, "state_key": self.state_key,
"origin_server_ts": self.origin_server_ts, "origin_server_ts": self.origin_server_ts,
}); });
serde_json::from_value(json).expect("Raw::from_value always works") serde_json::from_value(json).expect("Raw::from_value always works")
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub fn to_member_event(&self) -> Raw<StateEvent<RoomMemberEventContent>> { pub fn to_member_event(&self) -> Raw<StateEvent<RoomMemberEventContent>> {
let mut json = json!({ let mut json = json!({
"content": self.content, "content": self.content,
"type": self.kind, "type": self.kind,
"event_id": self.event_id, "event_id": self.event_id,
"sender": self.sender, "sender": self.sender,
"origin_server_ts": self.origin_server_ts, "origin_server_ts": self.origin_server_ts,
"redacts": self.redacts, "redacts": self.redacts,
"room_id": self.room_id, "room_id": self.room_id,
"state_key": self.state_key, "state_key": self.state_key,
}); });
if let Some(unsigned) = &self.unsigned { if let Some(unsigned) = &self.unsigned {
json["unsigned"] = json!(unsigned); json["unsigned"] = json!(unsigned);
} }
serde_json::from_value(json).expect("Raw::from_value always works") serde_json::from_value(json).expect("Raw::from_value always works")
} }
/// This does not return a full `Pdu` it is only to satisfy ruma's types. /// This does not return a full `Pdu` it is only to satisfy ruma's types.
#[tracing::instrument] #[tracing::instrument]
pub fn convert_to_outgoing_federation_event( pub fn convert_to_outgoing_federation_event(mut pdu_json: CanonicalJsonObject) -> Box<RawJsonValue> {
mut pdu_json: CanonicalJsonObject, if let Some(unsigned) = pdu_json.get_mut("unsigned").and_then(|val| val.as_object_mut()) {
) -> Box<RawJsonValue> { unsigned.remove("transaction_id");
if let Some(unsigned) = pdu_json }
.get_mut("unsigned")
.and_then(|val| val.as_object_mut())
{
unsigned.remove("transaction_id");
}
pdu_json.remove("event_id"); pdu_json.remove("event_id");
// TODO: another option would be to convert it to a canonical string to validate size // TODO: another option would be to convert it to a canonical string to validate
// and return a Result<Raw<...>> // size and return a Result<Raw<...>>
// serde_json::from_str::<Raw<_>>( // serde_json::from_str::<Raw<_>>(
// ruma::serde::to_canonical_json_string(pdu_json).expect("CanonicalJson is valid serde_json::Value"), // ruma::serde::to_canonical_json_string(pdu_json).expect("CanonicalJson is
// ) // valid serde_json::Value"), )
// .expect("Raw::from_value always works") // .expect("Raw::from_value always works")
to_raw_value(&pdu_json).expect("CanonicalJson is valid serde_json::Value") to_raw_value(&pdu_json).expect("CanonicalJson is valid serde_json::Value")
} }
pub fn from_id_val( pub fn from_id_val(event_id: &EventId, mut json: CanonicalJsonObject) -> Result<Self, serde_json::Error> {
event_id: &EventId, json.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.as_str().to_owned()));
mut json: CanonicalJsonObject,
) -> Result<Self, serde_json::Error> {
json.insert(
"event_id".to_owned(),
CanonicalJsonValue::String(event_id.as_str().to_owned()),
);
serde_json::from_value(serde_json::to_value(json).expect("valid JSON")) serde_json::from_value(serde_json::to_value(json).expect("valid JSON"))
} }
} }
impl state_res::Event for PduEvent { impl state_res::Event for PduEvent {
type Id = Arc<EventId>; type Id = Arc<EventId>;
fn event_id(&self) -> &Self::Id { fn event_id(&self) -> &Self::Id { &self.event_id }
&self.event_id
}
fn room_id(&self) -> &RoomId { fn room_id(&self) -> &RoomId { &self.room_id }
&self.room_id
}
fn sender(&self) -> &UserId { fn sender(&self) -> &UserId { &self.sender }
&self.sender
}
fn event_type(&self) -> &TimelineEventType { fn event_type(&self) -> &TimelineEventType { &self.kind }
&self.kind
}
fn content(&self) -> &RawJsonValue { fn content(&self) -> &RawJsonValue { &self.content }
&self.content
}
fn origin_server_ts(&self) -> MilliSecondsSinceUnixEpoch { fn origin_server_ts(&self) -> MilliSecondsSinceUnixEpoch { MilliSecondsSinceUnixEpoch(self.origin_server_ts) }
MilliSecondsSinceUnixEpoch(self.origin_server_ts)
}
fn state_key(&self) -> Option<&str> { fn state_key(&self) -> Option<&str> { self.state_key.as_deref() }
self.state_key.as_deref()
}
fn prev_events(&self) -> Box<dyn DoubleEndedIterator<Item = &Self::Id> + '_> { fn prev_events(&self) -> Box<dyn DoubleEndedIterator<Item = &Self::Id> + '_> { Box::new(self.prev_events.iter()) }
Box::new(self.prev_events.iter())
}
fn auth_events(&self) -> Box<dyn DoubleEndedIterator<Item = &Self::Id> + '_> { fn auth_events(&self) -> Box<dyn DoubleEndedIterator<Item = &Self::Id> + '_> { Box::new(self.auth_events.iter()) }
Box::new(self.auth_events.iter())
}
fn redacts(&self) -> Option<&Self::Id> { fn redacts(&self) -> Option<&Self::Id> { self.redacts.as_ref() }
self.redacts.as_ref()
}
} }
// These impl's allow us to dedup state snapshots when resolving state // These impl's allow us to dedup state snapshots when resolving state
// for incoming events (federation/send/{txn}). // for incoming events (federation/send/{txn}).
impl Eq for PduEvent {} impl Eq for PduEvent {}
impl PartialEq for PduEvent { impl PartialEq for PduEvent {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool { self.event_id == other.event_id }
self.event_id == other.event_id
}
} }
impl PartialOrd for PduEvent { impl PartialOrd for PduEvent {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> { fn partial_cmp(&self, other: &Self) -> Option<Ordering> { Some(self.cmp(other)) }
Some(self.cmp(other))
}
} }
impl Ord for PduEvent { impl Ord for PduEvent {
fn cmp(&self, other: &Self) -> Ordering { fn cmp(&self, other: &Self) -> Ordering { self.event_id.cmp(&other.event_id) }
self.event_id.cmp(&other.event_id)
}
} }
/// Generates a correct eventId for the incoming pdu. /// Generates a correct eventId for the incoming pdu.
/// ///
/// Returns a tuple of the new `EventId` and the PDU as a `BTreeMap<String, CanonicalJsonValue>`. /// Returns a tuple of the new `EventId` and the PDU as a `BTreeMap<String,
/// CanonicalJsonValue>`.
pub(crate) fn gen_event_id_canonical_json( pub(crate) fn gen_event_id_canonical_json(
pdu: &RawJsonValue, pdu: &RawJsonValue, room_version_id: &RoomVersionId,
room_version_id: &RoomVersionId,
) -> crate::Result<(OwnedEventId, CanonicalJsonObject)> { ) -> crate::Result<(OwnedEventId, CanonicalJsonObject)> {
let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| {
warn!("Error parsing incoming event {:?}: {:?}", pdu, e); warn!("Error parsing incoming event {:?}: {:?}", pdu, e);
Error::BadServerResponse("Invalid PDU in server response") Error::BadServerResponse("Invalid PDU in server response")
})?; })?;
let event_id = format!( let event_id = format!(
"${}", "${}",
// Anything higher than version3 behaves the same // Anything higher than version3 behaves the same
ruma::signatures::reference_hash(&value, room_version_id) ruma::signatures::reference_hash(&value, room_version_id).expect("ruma can calculate reference hashes")
.expect("ruma can calculate reference hashes") )
) .try_into()
.try_into() .expect("ruma's reference hashes are valid event ids");
.expect("ruma's reference hashes are valid event ids");
Ok((event_id, value)) Ok((event_id, value))
} }
/// Build the start of a PDU in order to add it to the Database. /// Build the start of a PDU in order to add it to the Database.
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
pub struct PduBuilder { pub struct PduBuilder {
#[serde(rename = "type")] #[serde(rename = "type")]
pub event_type: TimelineEventType, pub event_type: TimelineEventType,
pub content: Box<RawJsonValue>, pub content: Box<RawJsonValue>,
pub unsigned: Option<BTreeMap<String, serde_json::Value>>, pub unsigned: Option<BTreeMap<String, serde_json::Value>>,
pub state_key: Option<String>, pub state_key: Option<String>,
pub redacts: Option<Arc<EventId>>, pub redacts: Option<Arc<EventId>>,
} }

View file

@ -1,16 +1,16 @@
use crate::Result;
use ruma::{ use ruma::{
api::client::push::{set_pusher, Pusher}, api::client::push::{set_pusher, Pusher},
UserId, UserId,
}; };
use crate::Result;
pub trait Data: Send + Sync { pub trait Data: Send + Sync {
fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::PusherAction) -> Result<()>; fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::PusherAction) -> Result<()>;
fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result<Option<Pusher>>; fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result<Option<Pusher>>;
fn get_pushers(&self, sender: &UserId) -> Result<Vec<Pusher>>; fn get_pushers(&self, sender: &UserId) -> Result<Vec<Pusher>>;
fn get_pushkeys<'a>(&'a self, sender: &UserId) fn get_pushkeys<'a>(&'a self, sender: &UserId) -> Box<dyn Iterator<Item = Result<String>> + 'a>;
-> Box<dyn Iterator<Item = Result<String>> + 'a>;
} }

View file

@ -1,292 +1,236 @@
mod data; mod data;
pub use data::Data;
use ruma::{events::AnySyncTimelineEvent, push::PushConditionPowerLevelsCtx};
use crate::{services, Error, PduEvent, Result};
use bytes::BytesMut;
use ruma::{
api::{
client::push::{set_pusher, Pusher, PusherKind},
push_gateway::send_event_notification::{
self,
v1::{Device, Notification, NotificationCounts, NotificationPriority},
},
IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken,
},
events::{room::power_levels::RoomPowerLevelsEventContent, StateEventType, TimelineEventType},
push::{Action, PushConditionRoomCtx, PushFormat, Ruleset, Tweak},
serde::Raw,
uint, RoomId, UInt, UserId,
};
use std::{fmt::Debug, mem}; use std::{fmt::Debug, mem};
use bytes::BytesMut;
pub use data::Data;
use ruma::{
api::{
client::push::{set_pusher, Pusher, PusherKind},
push_gateway::send_event_notification::{
self,
v1::{Device, Notification, NotificationCounts, NotificationPriority},
},
IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken,
},
events::{
room::power_levels::RoomPowerLevelsEventContent, AnySyncTimelineEvent, StateEventType, TimelineEventType,
},
push::{Action, PushConditionPowerLevelsCtx, PushConditionRoomCtx, PushFormat, Ruleset, Tweak},
serde::Raw,
uint, RoomId, UInt, UserId,
};
use tracing::{info, warn}; use tracing::{info, warn};
use crate::{services, Error, PduEvent, Result};
pub struct Service { pub struct Service {
pub db: &'static dyn Data, pub db: &'static dyn Data,
} }
impl Service { impl Service {
pub fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::PusherAction) -> Result<()> { pub fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::PusherAction) -> Result<()> {
self.db.set_pusher(sender, pusher) self.db.set_pusher(sender, pusher)
} }
pub fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result<Option<Pusher>> { pub fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result<Option<Pusher>> {
self.db.get_pusher(sender, pushkey) self.db.get_pusher(sender, pushkey)
} }
pub fn get_pushers(&self, sender: &UserId) -> Result<Vec<Pusher>> { pub fn get_pushers(&self, sender: &UserId) -> Result<Vec<Pusher>> { self.db.get_pushers(sender) }
self.db.get_pushers(sender)
}
pub fn get_pushkeys(&self, sender: &UserId) -> Box<dyn Iterator<Item = Result<String>>> { pub fn get_pushkeys(&self, sender: &UserId) -> Box<dyn Iterator<Item = Result<String>>> {
self.db.get_pushkeys(sender) self.db.get_pushkeys(sender)
} }
#[tracing::instrument(skip(self, destination, request))] #[tracing::instrument(skip(self, destination, request))]
pub async fn send_request<T>( pub async fn send_request<T>(&self, destination: &str, request: T) -> Result<T::IncomingResponse>
&self, where
destination: &str, T: OutgoingRequest + Debug,
request: T, {
) -> Result<T::IncomingResponse> let destination = destination.replace(services().globals.notification_push_path(), "");
where
T: OutgoingRequest + Debug,
{
let destination = destination.replace(services().globals.notification_push_path(), "");
let http_request = request let http_request = request
.try_into_http_request::<BytesMut>( .try_into_http_request::<BytesMut>(&destination, SendAccessToken::IfRequired(""), &[MatrixVersion::V1_0])
&destination, .map_err(|e| {
SendAccessToken::IfRequired(""), warn!("Failed to find destination {}: {}", destination, e);
&[MatrixVersion::V1_0], Error::BadServerResponse("Invalid destination")
) })?
.map_err(|e| { .map(bytes::BytesMut::freeze);
warn!("Failed to find destination {}: {}", destination, e);
Error::BadServerResponse("Invalid destination")
})?
.map(bytes::BytesMut::freeze);
let reqwest_request = reqwest::Request::try_from(http_request)?; let reqwest_request = reqwest::Request::try_from(http_request)?;
// TODO: we could keep this very short and let expo backoff do it's thing... // TODO: we could keep this very short and let expo backoff do it's thing...
//*reqwest_request.timeout_mut() = Some(Duration::from_secs(5)); //*reqwest_request.timeout_mut() = Some(Duration::from_secs(5));
let url = reqwest_request.url().clone(); let url = reqwest_request.url().clone();
let response = services() let response = services().globals.default_client().execute(reqwest_request).await;
.globals
.default_client()
.execute(reqwest_request)
.await;
match response { match response {
Ok(mut response) => { Ok(mut response) => {
// reqwest::Response -> http::Response conversion // reqwest::Response -> http::Response conversion
let status = response.status(); let status = response.status();
let mut http_response_builder = http::Response::builder() let mut http_response_builder = http::Response::builder().status(status).version(response.version());
.status(status) mem::swap(
.version(response.version()); response.headers_mut(),
mem::swap( http_response_builder.headers_mut().expect("http::response::Builder is usable"),
response.headers_mut(), );
http_response_builder
.headers_mut()
.expect("http::response::Builder is usable"),
);
let body = response.bytes().await.unwrap_or_else(|e| { let body = response.bytes().await.unwrap_or_else(|e| {
warn!("server error {}", e); warn!("server error {}", e);
Vec::new().into() Vec::new().into()
}); // TODO: handle timeout }); // TODO: handle timeout
if !status.is_success() { if !status.is_success() {
info!( info!(
"Push gateway returned bad response {} {}\n{}\n{:?}", "Push gateway returned bad response {} {}\n{}\n{:?}",
destination, destination,
status, status,
url, url,
crate::utils::string_from_bytes(&body) crate::utils::string_from_bytes(&body)
); );
} }
let response = T::IncomingResponse::try_from_http_response( let response = T::IncomingResponse::try_from_http_response(
http_response_builder http_response_builder.body(body).expect("reqwest body is valid http body"),
.body(body) );
.expect("reqwest body is valid http body"), response.map_err(|_| {
); info!("Push gateway returned invalid response bytes {}\n{}", destination, url);
response.map_err(|_| { Error::BadServerResponse("Push gateway returned bad response.")
info!( })
"Push gateway returned invalid response bytes {}\n{}", },
destination, url Err(e) => {
); warn!("Could not send request to pusher {}: {}", destination, e);
Error::BadServerResponse("Push gateway returned bad response.") Err(e.into())
}) },
} }
Err(e) => { }
warn!("Could not send request to pusher {}: {}", destination, e);
Err(e.into())
}
}
}
#[tracing::instrument(skip(self, user, unread, pusher, ruleset, pdu))] #[tracing::instrument(skip(self, user, unread, pusher, ruleset, pdu))]
pub async fn send_push_notice( pub async fn send_push_notice(
&self, &self, user: &UserId, unread: UInt, pusher: &Pusher, ruleset: Ruleset, pdu: &PduEvent,
user: &UserId, ) -> Result<()> {
unread: UInt, let mut notify = None;
pusher: &Pusher, let mut tweaks = Vec::new();
ruleset: Ruleset,
pdu: &PduEvent,
) -> Result<()> {
let mut notify = None;
let mut tweaks = Vec::new();
let power_levels: RoomPowerLevelsEventContent = services() let power_levels: RoomPowerLevelsEventContent = services()
.rooms .rooms
.state_accessor .state_accessor
.room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")? .room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")?
.map(|ev| { .map(|ev| {
serde_json::from_str(ev.content.get()) serde_json::from_str(ev.content.get())
.map_err(|_| Error::bad_database("invalid m.room.power_levels event")) .map_err(|_| Error::bad_database("invalid m.room.power_levels event"))
}) })
.transpose()? .transpose()?
.unwrap_or_default(); .unwrap_or_default();
for action in self.get_actions( for action in self.get_actions(user, &ruleset, &power_levels, &pdu.to_sync_room_event(), &pdu.room_id)? {
user, let n = match action {
&ruleset, Action::Notify => true,
&power_levels, Action::SetTweak(tweak) => {
&pdu.to_sync_room_event(), tweaks.push(tweak.clone());
&pdu.room_id, continue;
)? { },
let n = match action { _ => false,
Action::Notify => true, };
Action::SetTweak(tweak) => {
tweaks.push(tweak.clone());
continue;
}
_ => false,
};
if notify.is_some() { if notify.is_some() {
return Err(Error::bad_database( return Err(Error::bad_database(
r#"Malformed pushrule contains more than one of these actions: ["dont_notify", "notify", "coalesce"]"#, r#"Malformed pushrule contains more than one of these actions: ["dont_notify", "notify", "coalesce"]"#,
)); ));
} }
notify = Some(n); notify = Some(n);
} }
if notify == Some(true) { if notify == Some(true) {
self.send_notice(unread, pusher, tweaks, pdu).await?; self.send_notice(unread, pusher, tweaks, pdu).await?;
} }
// Else the event triggered no actions // Else the event triggered no actions
Ok(()) Ok(())
} }
#[tracing::instrument(skip(self, user, ruleset, pdu))] #[tracing::instrument(skip(self, user, ruleset, pdu))]
pub fn get_actions<'a>( pub fn get_actions<'a>(
&self, &self, user: &UserId, ruleset: &'a Ruleset, power_levels: &RoomPowerLevelsEventContent,
user: &UserId, pdu: &Raw<AnySyncTimelineEvent>, room_id: &RoomId,
ruleset: &'a Ruleset, ) -> Result<&'a [Action]> {
power_levels: &RoomPowerLevelsEventContent, let power_levels = PushConditionPowerLevelsCtx {
pdu: &Raw<AnySyncTimelineEvent>, users: power_levels.users.clone(),
room_id: &RoomId, users_default: power_levels.users_default,
) -> Result<&'a [Action]> { notifications: power_levels.notifications.clone(),
let power_levels = PushConditionPowerLevelsCtx { };
users: power_levels.users.clone(),
users_default: power_levels.users_default,
notifications: power_levels.notifications.clone(),
};
let ctx = PushConditionRoomCtx { let ctx = PushConditionRoomCtx {
room_id: room_id.to_owned(), room_id: room_id.to_owned(),
member_count: UInt::from( member_count: UInt::from(services().rooms.state_cache.room_joined_count(room_id)?.unwrap_or(1) as u32),
services() user_id: user.to_owned(),
.rooms user_display_name: services().users.displayname(user)?.unwrap_or_else(|| user.localpart().to_owned()),
.state_cache power_levels: Some(power_levels),
.room_joined_count(room_id)? };
.unwrap_or(1) as u32,
),
user_id: user.to_owned(),
user_display_name: services()
.users
.displayname(user)?
.unwrap_or_else(|| user.localpart().to_owned()),
power_levels: Some(power_levels),
};
Ok(ruleset.get_actions(pdu, &ctx)) Ok(ruleset.get_actions(pdu, &ctx))
} }
#[tracing::instrument(skip(self, unread, pusher, tweaks, event))] #[tracing::instrument(skip(self, unread, pusher, tweaks, event))]
async fn send_notice( async fn send_notice(&self, unread: UInt, pusher: &Pusher, tweaks: Vec<Tweak>, event: &PduEvent) -> Result<()> {
&self, // TODO: email
unread: UInt, match &pusher.kind {
pusher: &Pusher, PusherKind::Http(http) => {
tweaks: Vec<Tweak>, // TODO:
event: &PduEvent, // Two problems with this
) -> Result<()> { // 1. if "event_id_only" is the only format kind it seems we should never add
// TODO: email // more info
match &pusher.kind { // 2. can pusher/devices have conflicting formats
PusherKind::Http(http) => { let event_id_only = http.format == Some(PushFormat::EventIdOnly);
// TODO:
// Two problems with this
// 1. if "event_id_only" is the only format kind it seems we should never add more info
// 2. can pusher/devices have conflicting formats
let event_id_only = http.format == Some(PushFormat::EventIdOnly);
let mut device = Device::new(pusher.ids.app_id.clone(), pusher.ids.pushkey.clone()); let mut device = Device::new(pusher.ids.app_id.clone(), pusher.ids.pushkey.clone());
device.data.default_payload = http.default_payload.clone(); device.data.default_payload = http.default_payload.clone();
device.data.format = http.format.clone(); device.data.format = http.format.clone();
// Tweaks are only added if the format is NOT event_id_only // Tweaks are only added if the format is NOT event_id_only
if !event_id_only { if !event_id_only {
device.tweaks = tweaks.clone(); device.tweaks = tweaks.clone();
} }
let d = vec![device]; let d = vec![device];
let mut notifi = Notification::new(d); let mut notifi = Notification::new(d);
notifi.prio = NotificationPriority::Low; notifi.prio = NotificationPriority::Low;
notifi.event_id = Some((*event.event_id).to_owned()); notifi.event_id = Some((*event.event_id).to_owned());
notifi.room_id = Some((*event.room_id).to_owned()); notifi.room_id = Some((*event.room_id).to_owned());
// TODO: missed calls // TODO: missed calls
notifi.counts = NotificationCounts::new(unread, uint!(0)); notifi.counts = NotificationCounts::new(unread, uint!(0));
if event.kind == TimelineEventType::RoomEncrypted if event.kind == TimelineEventType::RoomEncrypted
|| tweaks || tweaks.iter().any(|t| matches!(t, Tweak::Highlight(true) | Tweak::Sound(_)))
.iter() {
.any(|t| matches!(t, Tweak::Highlight(true) | Tweak::Sound(_))) notifi.prio = NotificationPriority::High;
{ }
notifi.prio = NotificationPriority::High;
}
if event_id_only { if event_id_only {
self.send_request(&http.url, send_event_notification::v1::Request::new(notifi)) self.send_request(&http.url, send_event_notification::v1::Request::new(notifi)).await?;
.await?; } else {
} else { notifi.sender = Some(event.sender.clone());
notifi.sender = Some(event.sender.clone()); notifi.event_type = Some(event.kind.clone());
notifi.event_type = Some(event.kind.clone()); notifi.content = serde_json::value::to_raw_value(&event.content).ok();
notifi.content = serde_json::value::to_raw_value(&event.content).ok();
if event.kind == TimelineEventType::RoomMember { if event.kind == TimelineEventType::RoomMember {
notifi.user_is_target = notifi.user_is_target = event.state_key.as_deref() == Some(event.sender.as_str());
event.state_key.as_deref() == Some(event.sender.as_str()); }
}
notifi.sender_display_name = services().users.displayname(&event.sender)?; notifi.sender_display_name = services().users.displayname(&event.sender)?;
notifi.room_name = services().rooms.state_accessor.get_name(&event.room_id)?; notifi.room_name = services().rooms.state_accessor.get_name(&event.room_id)?;
self.send_request(&http.url, send_event_notification::v1::Request::new(notifi)) self.send_request(&http.url, send_event_notification::v1::Request::new(notifi)).await?;
.await?; }
}
Ok(()) Ok(())
} },
// TODO: Handle email // TODO: Handle email
//PusherKind::Email(_) => Ok(()), //PusherKind::Email(_) => Ok(()),
_ => Ok(()), _ => Ok(()),
} }
} }
} }

View file

@ -1,24 +1,22 @@
use crate::Result;
use ruma::{OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId}; use ruma::{OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId};
use crate::Result;
pub trait Data: Send + Sync { pub trait Data: Send + Sync {
/// Creates or updates the alias to the given room id. /// Creates or updates the alias to the given room id.
fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId) -> Result<()>; fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId) -> Result<()>;
/// Forgets about an alias. Returns an error if the alias did not exist. /// Forgets about an alias. Returns an error if the alias did not exist.
fn remove_alias(&self, alias: &RoomAliasId) -> Result<()>; fn remove_alias(&self, alias: &RoomAliasId) -> Result<()>;
/// Looks up the roomid for the given alias. /// Looks up the roomid for the given alias.
fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result<Option<OwnedRoomId>>; fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result<Option<OwnedRoomId>>;
/// Returns all local aliases that point to the given room /// Returns all local aliases that point to the given room
fn local_aliases_for_room<'a>( fn local_aliases_for_room<'a>(
&'a self, &'a self, room_id: &RoomId,
room_id: &RoomId, ) -> Box<dyn Iterator<Item = Result<OwnedRoomAliasId>> + 'a>;
) -> Box<dyn Iterator<Item = Result<OwnedRoomAliasId>> + 'a>;
/// Returns all local aliases on the server /// Returns all local aliases on the server
fn all_local_aliases<'a>( fn all_local_aliases<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(OwnedRoomId, String)>> + 'a>;
&'a self,
) -> Box<dyn Iterator<Item = Result<(OwnedRoomId, String)>> + 'a>;
} }

View file

@ -1,42 +1,35 @@
mod data; mod data;
pub use data::Data; pub use data::Data;
use crate::Result;
use ruma::{OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId}; use ruma::{OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId};
use crate::Result;
pub struct Service { pub struct Service {
pub db: &'static dyn Data, pub db: &'static dyn Data,
} }
impl Service { impl Service {
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId) -> Result<()> { pub fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId) -> Result<()> { self.db.set_alias(alias, room_id) }
self.db.set_alias(alias, room_id)
}
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub fn remove_alias(&self, alias: &RoomAliasId) -> Result<()> { pub fn remove_alias(&self, alias: &RoomAliasId) -> Result<()> { self.db.remove_alias(alias) }
self.db.remove_alias(alias)
}
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result<Option<OwnedRoomId>> { pub fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result<Option<OwnedRoomId>> {
self.db.resolve_local_alias(alias) self.db.resolve_local_alias(alias)
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub fn local_aliases_for_room<'a>( pub fn local_aliases_for_room<'a>(
&'a self, &'a self, room_id: &RoomId,
room_id: &RoomId, ) -> Box<dyn Iterator<Item = Result<OwnedRoomAliasId>> + 'a> {
) -> Box<dyn Iterator<Item = Result<OwnedRoomAliasId>> + 'a> { self.db.local_aliases_for_room(room_id)
self.db.local_aliases_for_room(room_id) }
}
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub fn all_local_aliases<'a>( pub fn all_local_aliases<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(OwnedRoomId, String)>> + 'a> {
&'a self, self.db.all_local_aliases()
) -> Box<dyn Iterator<Item = Result<(OwnedRoomId, String)>> + 'a> { }
self.db.all_local_aliases()
}
} }

View file

@ -1,11 +1,8 @@
use crate::Result;
use std::{collections::HashSet, sync::Arc}; use std::{collections::HashSet, sync::Arc};
use crate::Result;
pub trait Data: Send + Sync { pub trait Data: Send + Sync {
fn get_cached_eventid_authchain( fn get_cached_eventid_authchain(&self, shorteventid: &[u64]) -> Result<Option<Arc<HashSet<u64>>>>;
&self, fn cache_auth_chain(&self, shorteventid: Vec<u64>, auth_chain: Arc<HashSet<u64>>) -> Result<()>;
shorteventid: &[u64],
) -> Result<Option<Arc<HashSet<u64>>>>;
fn cache_auth_chain(&self, shorteventid: Vec<u64>, auth_chain: Arc<HashSet<u64>>)
-> Result<()>;
} }

View file

@ -1,7 +1,7 @@
mod data; mod data;
use std::{ use std::{
collections::{BTreeSet, HashSet}, collections::{BTreeSet, HashSet},
sync::Arc, sync::Arc,
}; };
pub use data::Data; pub use data::Data;
@ -11,151 +11,130 @@ use tracing::{debug, error, warn};
use crate::{services, Error, Result}; use crate::{services, Error, Result};
pub struct Service { pub struct Service {
pub db: &'static dyn Data, pub db: &'static dyn Data,
} }
impl Service { impl Service {
pub fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Option<Arc<HashSet<u64>>>> { pub fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Option<Arc<HashSet<u64>>>> {
self.db.get_cached_eventid_authchain(key) self.db.get_cached_eventid_authchain(key)
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub fn cache_auth_chain(&self, key: Vec<u64>, auth_chain: Arc<HashSet<u64>>) -> Result<()> { pub fn cache_auth_chain(&self, key: Vec<u64>, auth_chain: Arc<HashSet<u64>>) -> Result<()> {
self.db.cache_auth_chain(key, auth_chain) self.db.cache_auth_chain(key, auth_chain)
} }
#[tracing::instrument(skip(self, starting_events))] #[tracing::instrument(skip(self, starting_events))]
pub async fn get_auth_chain<'a>( pub async fn get_auth_chain<'a>(
&self, &self, room_id: &RoomId, starting_events: Vec<Arc<EventId>>,
room_id: &RoomId, ) -> Result<impl Iterator<Item = Arc<EventId>> + 'a> {
starting_events: Vec<Arc<EventId>>, const NUM_BUCKETS: usize = 50;
) -> Result<impl Iterator<Item = Arc<EventId>> + 'a> {
const NUM_BUCKETS: usize = 50;
let mut buckets = vec![BTreeSet::new(); NUM_BUCKETS]; let mut buckets = vec![BTreeSet::new(); NUM_BUCKETS];
let mut i = 0; let mut i = 0;
for id in starting_events { for id in starting_events {
let short = services().rooms.short.get_or_create_shorteventid(&id)?; let short = services().rooms.short.get_or_create_shorteventid(&id)?;
let bucket_id = (short % NUM_BUCKETS as u64) as usize; let bucket_id = (short % NUM_BUCKETS as u64) as usize;
buckets[bucket_id].insert((short, id.clone())); buckets[bucket_id].insert((short, id.clone()));
i += 1; i += 1;
if i % 100 == 0 { if i % 100 == 0 {
tokio::task::yield_now().await; tokio::task::yield_now().await;
} }
} }
let mut full_auth_chain = HashSet::new(); let mut full_auth_chain = HashSet::new();
let mut hits = 0; let mut hits = 0;
let mut misses = 0; let mut misses = 0;
for chunk in buckets { for chunk in buckets {
if chunk.is_empty() { if chunk.is_empty() {
continue; continue;
} }
let chunk_key: Vec<u64> = chunk.iter().map(|(short, _)| short).copied().collect(); let chunk_key: Vec<u64> = chunk.iter().map(|(short, _)| short).copied().collect();
if let Some(cached) = services() if let Some(cached) = services().rooms.auth_chain.get_cached_eventid_authchain(&chunk_key)? {
.rooms hits += 1;
.auth_chain full_auth_chain.extend(cached.iter().copied());
.get_cached_eventid_authchain(&chunk_key)? continue;
{ }
hits += 1; misses += 1;
full_auth_chain.extend(cached.iter().copied());
continue;
}
misses += 1;
let mut chunk_cache = HashSet::new(); let mut chunk_cache = HashSet::new();
let mut hits2 = 0; let mut hits2 = 0;
let mut misses2 = 0; let mut misses2 = 0;
let mut i = 0; let mut i = 0;
for (sevent_id, event_id) in chunk { for (sevent_id, event_id) in chunk {
if let Some(cached) = services() if let Some(cached) = services().rooms.auth_chain.get_cached_eventid_authchain(&[sevent_id])? {
.rooms hits2 += 1;
.auth_chain chunk_cache.extend(cached.iter().copied());
.get_cached_eventid_authchain(&[sevent_id])? } else {
{ misses2 += 1;
hits2 += 1; let auth_chain = Arc::new(self.get_auth_chain_inner(room_id, &event_id)?);
chunk_cache.extend(cached.iter().copied()); services().rooms.auth_chain.cache_auth_chain(vec![sevent_id], Arc::clone(&auth_chain))?;
} else { debug!(
misses2 += 1; event_id = ?event_id,
let auth_chain = Arc::new(self.get_auth_chain_inner(room_id, &event_id)?); chain_length = ?auth_chain.len(),
services() "Cache missed event"
.rooms );
.auth_chain chunk_cache.extend(auth_chain.iter());
.cache_auth_chain(vec![sevent_id], Arc::clone(&auth_chain))?;
debug!(
event_id = ?event_id,
chain_length = ?auth_chain.len(),
"Cache missed event"
);
chunk_cache.extend(auth_chain.iter());
i += 1; i += 1;
if i % 100 == 0 { if i % 100 == 0 {
tokio::task::yield_now().await; tokio::task::yield_now().await;
} }
}; };
} }
debug!( debug!(
chunk_cache_length = ?chunk_cache.len(), chunk_cache_length = ?chunk_cache.len(),
hits = ?hits2, hits = ?hits2,
misses = ?misses2, misses = ?misses2,
"Chunk missed", "Chunk missed",
); );
let chunk_cache = Arc::new(chunk_cache); let chunk_cache = Arc::new(chunk_cache);
services() services().rooms.auth_chain.cache_auth_chain(chunk_key, Arc::clone(&chunk_cache))?;
.rooms full_auth_chain.extend(chunk_cache.iter());
.auth_chain }
.cache_auth_chain(chunk_key, Arc::clone(&chunk_cache))?;
full_auth_chain.extend(chunk_cache.iter());
}
debug!( debug!(
chain_length = ?full_auth_chain.len(), chain_length = ?full_auth_chain.len(),
hits = ?hits, hits = ?hits,
misses = ?misses, misses = ?misses,
"Auth chain stats", "Auth chain stats",
); );
Ok(full_auth_chain Ok(full_auth_chain.into_iter().filter_map(move |sid| services().rooms.short.get_eventid_from_short(sid).ok()))
.into_iter() }
.filter_map(move |sid| services().rooms.short.get_eventid_from_short(sid).ok()))
}
#[tracing::instrument(skip(self, event_id))] #[tracing::instrument(skip(self, event_id))]
fn get_auth_chain_inner(&self, room_id: &RoomId, event_id: &EventId) -> Result<HashSet<u64>> { fn get_auth_chain_inner(&self, room_id: &RoomId, event_id: &EventId) -> Result<HashSet<u64>> {
let mut todo = vec![Arc::from(event_id)]; let mut todo = vec![Arc::from(event_id)];
let mut found = HashSet::new(); let mut found = HashSet::new();
while let Some(event_id) = todo.pop() { while let Some(event_id) = todo.pop() {
match services().rooms.timeline.get_pdu(&event_id) { match services().rooms.timeline.get_pdu(&event_id) {
Ok(Some(pdu)) => { Ok(Some(pdu)) => {
if pdu.room_id != room_id { if pdu.room_id != room_id {
return Err(Error::BadRequest(ErrorKind::Forbidden, "Evil event in db")); return Err(Error::BadRequest(ErrorKind::Forbidden, "Evil event in db"));
} }
for auth_event in &pdu.auth_events { for auth_event in &pdu.auth_events {
let sauthevent = services() let sauthevent = services().rooms.short.get_or_create_shorteventid(auth_event)?;
.rooms
.short
.get_or_create_shorteventid(auth_event)?;
if !found.contains(&sauthevent) { if !found.contains(&sauthevent) {
found.insert(sauthevent); found.insert(sauthevent);
todo.push(auth_event.clone()); todo.push(auth_event.clone());
} }
} }
} },
Ok(None) => { Ok(None) => {
warn!(?event_id, "Could not find pdu mentioned in auth events"); warn!(?event_id, "Could not find pdu mentioned in auth events");
} },
Err(error) => { Err(error) => {
error!(?event_id, ?error, "Could not load event in auth chain"); error!(?event_id, ?error, "Could not load event in auth chain");
} },
} }
} }
Ok(found) Ok(found)
} }
} }

View file

@ -1,16 +1,17 @@
use crate::Result;
use ruma::{OwnedRoomId, RoomId}; use ruma::{OwnedRoomId, RoomId};
use crate::Result;
pub trait Data: Send + Sync { pub trait Data: Send + Sync {
/// Adds the room to the public room directory /// Adds the room to the public room directory
fn set_public(&self, room_id: &RoomId) -> Result<()>; fn set_public(&self, room_id: &RoomId) -> Result<()>;
/// Removes the room from the public room directory. /// Removes the room from the public room directory.
fn set_not_public(&self, room_id: &RoomId) -> Result<()>; fn set_not_public(&self, room_id: &RoomId) -> Result<()>;
/// Returns true if the room is in the public room directory. /// Returns true if the room is in the public room directory.
fn is_public_room(&self, room_id: &RoomId) -> Result<bool>; fn is_public_room(&self, room_id: &RoomId) -> Result<bool>;
/// Returns the unsorted public room directory /// Returns the unsorted public room directory
fn public_rooms<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a>; fn public_rooms<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a>;
} }

View file

@ -6,27 +6,19 @@ use ruma::{OwnedRoomId, RoomId};
use crate::Result; use crate::Result;
pub struct Service { pub struct Service {
pub db: &'static dyn Data, pub db: &'static dyn Data,
} }
impl Service { impl Service {
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub fn set_public(&self, room_id: &RoomId) -> Result<()> { pub fn set_public(&self, room_id: &RoomId) -> Result<()> { self.db.set_public(room_id) }
self.db.set_public(room_id)
}
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub fn set_not_public(&self, room_id: &RoomId) -> Result<()> { pub fn set_not_public(&self, room_id: &RoomId) -> Result<()> { self.db.set_not_public(room_id) }
self.db.set_not_public(room_id)
}
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub fn is_public_room(&self, room_id: &RoomId) -> Result<bool> { pub fn is_public_room(&self, room_id: &RoomId) -> Result<bool> { self.db.is_public_room(room_id) }
self.db.is_public_room(room_id)
}
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub fn public_rooms(&self) -> impl Iterator<Item = Result<OwnedRoomId>> + '_ { pub fn public_rooms(&self) -> impl Iterator<Item = Result<OwnedRoomId>> + '_ { self.db.public_rooms() }
self.db.public_rooms()
}
} }

View file

@ -5,7 +5,7 @@ pub mod typing;
pub trait Data: presence::Data + read_receipt::Data + typing::Data + 'static {} pub trait Data: presence::Data + read_receipt::Data + typing::Data + 'static {}
pub struct Service { pub struct Service {
pub presence: presence::Service, pub presence: presence::Service,
pub read_receipt: read_receipt::Service, pub read_receipt: read_receipt::Service,
pub typing: typing::Service, pub typing: typing::Service,
} }

View file

@ -1,33 +1,27 @@
use ruma::{events::presence::PresenceEvent, presence::PresenceState, OwnedUserId, RoomId, UInt, UserId};
use crate::Result; use crate::Result;
use ruma::{
events::presence::PresenceEvent, presence::PresenceState, OwnedUserId, RoomId, UInt, UserId,
};
pub trait Data: Send + Sync { pub trait Data: Send + Sync {
/// Returns the latest presence event for the given user in the given room. /// Returns the latest presence event for the given user in the given room.
fn get_presence(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<PresenceEvent>>; fn get_presence(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<PresenceEvent>>;
/// Pings the presence of the given user in the given room, setting the specified state. /// Pings the presence of the given user in the given room, setting the
fn ping_presence(&self, user_id: &UserId, new_state: PresenceState) -> Result<()>; /// specified state.
fn ping_presence(&self, user_id: &UserId, new_state: PresenceState) -> Result<()>;
/// Adds a presence event which will be saved until a new event replaces it. /// Adds a presence event which will be saved until a new event replaces it.
fn set_presence( fn set_presence(
&self, &self, room_id: &RoomId, user_id: &UserId, presence_state: PresenceState, currently_active: Option<bool>,
room_id: &RoomId, last_active_ago: Option<UInt>, status_msg: Option<String>,
user_id: &UserId, ) -> Result<()>;
presence_state: PresenceState,
currently_active: Option<bool>,
last_active_ago: Option<UInt>,
status_msg: Option<String>,
) -> Result<()>;
/// Removes the presence record for the given user from the database. /// Removes the presence record for the given user from the database.
fn remove_presence(&self, user_id: &UserId) -> Result<()>; fn remove_presence(&self, user_id: &UserId) -> Result<()>;
/// Returns the most recent presence updates that happened after the event with id `since`. /// Returns the most recent presence updates that happened after the event
fn presence_since<'a>( /// with id `since`.
&'a self, fn presence_since<'a>(
room_id: &RoomId, &'a self, room_id: &RoomId, since: u64,
since: u64, ) -> Box<dyn Iterator<Item = (OwnedUserId, u64, PresenceEvent)> + 'a>;
) -> Box<dyn Iterator<Item = (OwnedUserId, u64, PresenceEvent)> + 'a>;
} }

Some files were not shown because too many files have changed in this diff Show more