Hot-Reloading Refactor

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2024-05-09 15:59:08 -07:00 committed by June 🍓🦴
commit 6c1434c165
212 changed files with 5679 additions and 4206 deletions

115
src/service/Cargo.toml Normal file
View file

@ -0,0 +1,115 @@
[package]
name = "conduit_service"
version.workspace = true
edition.workspace = true
[lib]
path = "mod.rs"
crate-type = [
"rlib",
# "dylib",
]
[features]
default = [
"rocksdb",
"io_uring",
"jemalloc",
"gzip_compression",
"zstd_compression",
"brotli_compression",
"release_max_log_level",
]
dev_release_log_level = []
release_max_log_level = [
"tracing/max_level_trace",
"tracing/release_max_level_info",
"log/max_level_trace",
"log/release_max_level_info",
]
sqlite = [
"dep:rusqlite",
"dep:parking_lot",
"dep:thread_local",
]
rocksdb = [
"dep:rust-rocksdb",
]
jemalloc = [
"dep:tikv-jemalloc-sys",
"dep:tikv-jemalloc-ctl",
"dep:tikv-jemallocator",
"rust-rocksdb/jemalloc",
]
io_uring = [
"rust-rocksdb/io-uring",
]
zstd_compression = [
"rust-rocksdb/zstd",
]
gzip_compression = [
"reqwest/gzip",
]
brotli_compression = [
"reqwest/brotli",
]
sha256_media = [
"dep:sha2",
]
[dependencies]
argon2.workspace = true
async-trait.workspace = true
base64.workspace = true
bytes.workspace = true
clap.workspace = true
conduit-core.workspace = true
conduit-database.workspace = true
cyborgtime.workspace = true
futures-util.workspace = true
hickory-resolver.workspace = true
hmac.workspace = true
http.workspace = true
image.workspace = true
ipaddress.workspace = true
itertools.workspace = true
jsonwebtoken.workspace = true
log.workspace = true
loole.workspace = true
lru-cache.workspace = true
parking_lot.optional = true
parking_lot.workspace = true
rand.workspace = true
regex.workspace = true
reqwest.workspace = true
ruma-identifiers-validation.workspace = true
ruma.workspace = true
rusqlite.optional = true
rusqlite.workspace = true
rust-rocksdb.optional = true
rust-rocksdb.workspace = true
serde_json.workspace = true
serde.workspace = true
serde_yaml.workspace = true
sha-1.workspace = true
sha2.optional = true
sha2.workspace = true
thread_local.optional = true
thread_local.workspace = true
tikv-jemallocator.optional = true
tikv-jemallocator.workspace = true
tikv-jemalloc-ctl.optional = true
tikv-jemalloc-ctl.workspace = true
tikv-jemalloc-sys.optional = true
tikv-jemalloc-sys.workspace = true
tokio.workspace = true
tracing-subscriber.workspace = true
tracing.workspace = true
url.workspace = true
webpage.workspace = true
zstd.optional = true
zstd.workspace = true
[lints]
workspace = true

View file

@ -8,7 +8,7 @@ use ruma::{
use crate::Result;
pub(crate) trait Data: Send + Sync {
pub trait Data: Send + Sync {
/// Places one event in the account data of the user and removes the
/// previous entry.
fn update(

View file

@ -1,8 +1,8 @@
mod data;
use std::collections::HashMap;
use std::{collections::HashMap, sync::Arc};
pub(crate) use data::Data;
pub use data::Data;
use ruma::{
events::{AnyEphemeralRoomEvent, RoomAccountDataEventType},
serde::Raw,
@ -11,15 +11,15 @@ use ruma::{
use crate::Result;
pub(crate) struct Service {
pub(crate) db: &'static dyn Data,
pub struct Service {
pub db: Arc<dyn Data>,
}
impl Service {
/// Places one event in the account data of the user and removes the
/// previous entry.
#[tracing::instrument(skip(self, room_id, user_id, event_type, data))]
pub(crate) fn update(
pub fn update(
&self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType,
data: &serde_json::Value,
) -> Result<()> {
@ -28,7 +28,7 @@ impl Service {
/// Searches the account data for a specific kind.
#[tracing::instrument(skip(self, room_id, user_id, event_type))]
pub(crate) fn get(
pub fn get(
&self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType,
) -> Result<Option<Box<serde_json::value::RawValue>>> {
self.db.get(room_id, user_id, event_type)
@ -36,7 +36,7 @@ impl Service {
/// Returns all changes to the account data that happened after `since`.
#[tracing::instrument(skip(self, room_id, user_id, since))]
pub(crate) fn changes_since(
pub fn changes_since(
&self, room_id: Option<&RoomId>, user_id: &UserId, since: u64,
) -> Result<HashMap<RoomAccountDataEventType, Raw<AnyEphemeralRoomEvent>>> {
self.db.changes_since(room_id, user_id, since)

View file

@ -1,11 +1,9 @@
use std::{collections::BTreeMap, sync::Arc};
use std::{collections::BTreeMap, future::Future, pin::Pin, sync::Arc};
use clap::Parser;
use regex::Regex;
use conduit::{Error, Result};
use ruma::{
api::client::error::ErrorKind,
events::{
relation::InReplyTo,
room::{
canonical_alias::RoomCanonicalAliasEventContent,
create::RoomCreateEventContent,
@ -13,133 +11,62 @@ use ruma::{
history_visibility::{HistoryVisibility, RoomHistoryVisibilityEventContent},
join_rules::{JoinRule, RoomJoinRulesEventContent},
member::{MembershipState, RoomMemberEventContent},
message::{Relation::Reply, RoomMessageEventContent},
message::RoomMessageEventContent,
name::RoomNameEventContent,
power_levels::RoomPowerLevelsEventContent,
topic::RoomTopicEventContent,
},
TimelineEventType,
},
EventId, MxcUri, OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId, RoomVersionId, ServerName, UserId,
EventId, OwnedRoomAliasId, OwnedRoomId, OwnedUserId, RoomAliasId, RoomId, RoomVersionId, UserId,
};
use serde_json::value::to_raw_value;
use tokio::sync::{Mutex, MutexGuard};
use tokio::{sync::Mutex, task::JoinHandle};
use tracing::{error, warn};
use self::{fsck::FsckCommand, tester::TesterCommands};
use super::pdu::PduBuilder;
use crate::{
service::admin::{
appservice::AppserviceCommand, debug::DebugCommand, federation::FederationCommand, media::MediaCommand,
query::QueryCommand, room::RoomCommand, server::ServerCommand, user::UserCommand,
},
services, Error, Result,
};
use crate::{pdu::PduBuilder, services};
pub(crate) mod appservice;
pub(crate) mod debug;
pub(crate) mod federation;
pub(crate) mod fsck;
pub(crate) mod media;
pub(crate) mod query;
pub(crate) mod room;
pub(crate) mod server;
pub(crate) mod tester;
pub(crate) mod user;
pub type HandlerResult = Pin<Box<dyn Future<Output = Result<(), Error>> + Send>>;
pub type Handler = fn(AdminRoomEvent, OwnedRoomId, OwnedUserId) -> HandlerResult;
const PAGE_SIZE: usize = 100;
#[cfg_attr(test, derive(Debug))]
#[derive(Parser)]
#[command(name = "@conduit:server.name:", version = env!("CARGO_PKG_VERSION"))]
enum AdminCommand {
#[command(subcommand)]
/// - Commands for managing appservices
Appservices(AppserviceCommand),
#[command(subcommand)]
/// - Commands for managing local users
Users(UserCommand),
#[command(subcommand)]
/// - Commands for managing rooms
Rooms(RoomCommand),
#[command(subcommand)]
/// - Commands for managing federation
Federation(FederationCommand),
#[command(subcommand)]
/// - Commands for managing the server
Server(ServerCommand),
#[command(subcommand)]
/// - Commands for managing media
Media(MediaCommand),
#[command(subcommand)]
/// - Commands for debugging things
Debug(DebugCommand),
#[command(subcommand)]
/// - Query all the database getters and iterators
Query(QueryCommand),
#[command(subcommand)]
/// - Query all the database getters and iterators
Fsck(FsckCommand),
#[command(subcommand)]
Tester(TesterCommands),
pub struct Service {
sender: loole::Sender<AdminRoomEvent>,
receiver: Mutex<loole::Receiver<AdminRoomEvent>>,
handler_join: Mutex<Option<JoinHandle<()>>>,
pub handle: Mutex<Option<Handler>>,
}
#[derive(Debug)]
pub(crate) enum AdminRoomEvent {
pub enum AdminRoomEvent {
ProcessMessage(String, Arc<EventId>),
SendMessage(RoomMessageEventContent),
}
pub(crate) struct Service {
pub(crate) sender: loole::Sender<AdminRoomEvent>,
receiver: Mutex<loole::Receiver<AdminRoomEvent>>,
}
impl Service {
pub(crate) fn build() -> Arc<Self> {
#[must_use]
pub fn build() -> Arc<Self> {
let (sender, receiver) = loole::unbounded();
Arc::new(Self {
sender,
receiver: Mutex::new(receiver),
handler_join: Mutex::new(None),
handle: Mutex::new(None),
})
}
pub(crate) fn start_handler(self: &Arc<Self>) {
let self2 = Arc::clone(self);
tokio::spawn(async move {
self2
pub async fn start_handler(self: &Arc<Self>) {
let self_ = Arc::clone(self);
let handle = services().server.runtime().spawn(async move {
self_
.handler()
.await
.expect("Failed to initialize admin room handler");
});
_ = self.handler_join.lock().await.insert(handle);
}
pub(crate) async fn process_message(&self, room_message: String, event_id: Arc<EventId>) {
self.send(AdminRoomEvent::ProcessMessage(room_message, event_id))
.await;
}
pub(crate) async fn send_message(&self, message_content: RoomMessageEventContent) {
self.send(AdminRoomEvent::SendMessage(message_content))
.await;
}
async fn send(&self, message: AdminRoomEvent) {
debug_assert!(!self.sender.is_full(), "channel full");
debug_assert!(!self.sender.is_closed(), "channel closed");
self.sender.send(message).expect("message sent");
}
async fn handler(&self) -> Result<()> {
async fn handler(self: &Arc<Self>) -> Result<()> {
let receiver = self.receiver.lock().await;
let Ok(Some(admin_room)) = Self::get_admin_room().await else {
return Ok(());
@ -151,246 +78,72 @@ impl Service {
debug_assert!(!receiver.is_closed(), "channel closed");
tokio::select! {
event = receiver.recv_async() => match event {
Ok(event) => self.handle_event(event, &admin_room, &server_user).await?,
Err(e) => error!("Failed to receive admin room event from channel: {e}"),
Ok(event) => self.receive(event, &admin_room, &server_user).await?,
Err(_e) => return Ok(()),
}
}
}
}
async fn handle_event(&self, event: AdminRoomEvent, admin_room: &OwnedRoomId, server_user: &UserId) -> Result<()> {
let (mut message_content, reply) = match event {
AdminRoomEvent::SendMessage(content) => (content, None),
AdminRoomEvent::ProcessMessage(room_message, reply_id) => {
(self.process_admin_message(room_message).await, Some(reply_id))
},
};
let mutex_state = Arc::clone(
services()
.globals
.roomid_mutex_state
.write()
.await
.entry(admin_room.clone())
.or_default(),
);
let state_lock = mutex_state.lock().await;
if let Some(reply) = reply {
message_content.relates_to = Some(Reply {
in_reply_to: InReplyTo {
event_id: reply.into(),
},
});
pub async fn close(&self) {
self.interrupt();
if let Some(handler_join) = self.handler_join.lock().await.take() {
if let Err(e) = handler_join.await {
error!("Failed to shutdown: {e:?}");
}
}
let response_pdu = PduBuilder {
event_type: TimelineEventType::RoomMessage,
content: to_raw_value(&message_content).expect("event is valid, we just created it"),
unsigned: None,
state_key: None,
redacts: None,
};
if let Err(e) = services()
.rooms
.timeline
.build_and_append_pdu(response_pdu, server_user, admin_room, &state_lock)
.await
{
self.handle_response_error(&e, admin_room, server_user, &state_lock)
.await?;
}
Ok(())
}
async fn handle_response_error(
&self, e: &Error, admin_room: &OwnedRoomId, server_user: &UserId, state_lock: &MutexGuard<'_, ()>,
) -> Result<()> {
error!("Failed to build and append admin room response PDU: \"{e}\"");
let error_room_message = RoomMessageEventContent::text_plain(format!(
"Failed to build and append admin room PDU: \"{e}\"\n\nThe original admin command may have finished \
successfully, but we could not return the output."
));
pub fn interrupt(&self) {
if !self.sender.is_closed() {
self.sender.close();
}
}
let response_pdu = PduBuilder {
event_type: TimelineEventType::RoomMessage,
content: to_raw_value(&error_room_message).expect("event is valid, we just created it"),
unsigned: None,
state_key: None,
redacts: None,
};
pub async fn send_message(&self, message_content: RoomMessageEventContent) {
self.send(AdminRoomEvent::SendMessage(message_content))
.await;
}
pub async fn process_message(&self, room_message: String, event_id: Arc<EventId>) {
self.send(AdminRoomEvent::ProcessMessage(room_message, event_id))
.await;
}
async fn receive(&self, event: AdminRoomEvent, room: &OwnedRoomId, user: &UserId) -> Result<(), Error> {
if let Some(handle) = self.handle.lock().await.as_ref() {
handle(event, room.clone(), user.into()).await
} else {
Err(Error::Err("Admin module is not loaded.".into()))
}
}
async fn send(&self, message: AdminRoomEvent) {
debug_assert!(!self.sender.is_full(), "channel full");
debug_assert!(!self.sender.is_closed(), "channel closed");
self.sender.send(message).expect("message sent");
}
/// Gets the room ID of the admin room
///
/// Errors are propagated from the database, and will have None if there is
/// no admin room
pub async fn get_admin_room() -> Result<Option<OwnedRoomId>> {
let admin_room_alias: Box<RoomAliasId> = format!("#admins:{}", services().globals.server_name())
.try_into()
.expect("#admins:server_name is a valid alias name");
services()
.rooms
.timeline
.build_and_append_pdu(response_pdu, server_user, admin_room, state_lock)
.await?;
Ok(())
}
// Parse and process a message from the admin room
async fn process_admin_message(&self, room_message: String) -> RoomMessageEventContent {
let mut lines = room_message.lines().filter(|l| !l.trim().is_empty());
let command_line = lines.next().expect("each string has at least one line");
let body = lines.collect::<Vec<_>>();
let admin_command = match self.parse_admin_command(command_line) {
Ok(command) => command,
Err(error) => {
let server_name = services().globals.server_name();
let message = error.replace("server.name", server_name.as_str());
let html_message = self.usage_to_html(&message, server_name);
return RoomMessageEventContent::text_html(message, html_message);
},
};
match self.process_admin_command(admin_command, body).await {
Ok(reply_message) => reply_message,
Err(error) => {
let markdown_message = format!("Encountered an error while handling the command:\n```\n{error}\n```",);
let html_message = format!("Encountered an error while handling the command:\n<pre>\n{error}\n</pre>",);
RoomMessageEventContent::text_html(markdown_message, html_message)
},
}
}
// Parse chat messages from the admin room into an AdminCommand object
fn parse_admin_command(&self, command_line: &str) -> Result<AdminCommand, String> {
// Note: argv[0] is `@conduit:servername:`, which is treated as the main command
let mut argv = command_line.split_whitespace().collect::<Vec<_>>();
// Replace `help command` with `command --help`
// Clap has a help subcommand, but it omits the long help description.
if argv.len() > 1 && argv[1] == "help" {
argv.remove(1);
argv.push("--help");
}
// Backwards compatibility with `register_appservice`-style commands
let command_with_dashes_argv1;
if argv.len() > 1 && argv[1].contains('_') {
command_with_dashes_argv1 = argv[1].replace('_', "-");
argv[1] = &command_with_dashes_argv1;
}
// Backwards compatibility with `register_appservice`-style commands
let command_with_dashes_argv2;
if argv.len() > 2 && argv[2].contains('_') {
command_with_dashes_argv2 = argv[2].replace('_', "-");
argv[2] = &command_with_dashes_argv2;
}
// if the user is using the `query` command (argv[1]), replace the database
// function/table calls with underscores to match the codebase
let command_with_dashes_argv3;
if argv.len() > 3 && argv[1].eq("query") {
command_with_dashes_argv3 = argv[3].replace('_', "-");
argv[3] = &command_with_dashes_argv3;
}
AdminCommand::try_parse_from(argv).map_err(|error| error.to_string())
}
async fn process_admin_command(&self, command: AdminCommand, body: Vec<&str>) -> Result<RoomMessageEventContent> {
let reply_message_content = match command {
AdminCommand::Appservices(command) => appservice::process(command, body).await?,
AdminCommand::Media(command) => media::process(command, body).await?,
AdminCommand::Users(command) => user::process(command, body).await?,
AdminCommand::Rooms(command) => room::process(command, body).await?,
AdminCommand::Federation(command) => federation::process(command, body).await?,
AdminCommand::Server(command) => server::process(command, body).await?,
AdminCommand::Debug(command) => debug::process(command, body).await?,
AdminCommand::Query(command) => query::process(command, body).await?,
AdminCommand::Fsck(command) => fsck::process(command, body).await?,
AdminCommand::Tester(command) => tester::process(command, body).await?,
};
Ok(reply_message_content)
}
// Utility to turn clap's `--help` text to HTML.
fn usage_to_html(&self, text: &str, server_name: &ServerName) -> String {
// Replace `@conduit:servername:-subcmdname` with `@conduit:servername:
// subcmdname`
let text = text.replace(&format!("@conduit:{server_name}:-"), &format!("@conduit:{server_name}: "));
// For the conduit admin room, subcommands become main commands
let text = text.replace("SUBCOMMAND", "COMMAND");
let text = text.replace("subcommand", "command");
// Escape option names (e.g. `<element-id>`) since they look like HTML tags
let text = escape_html(&text);
// Italicize the first line (command name and version text)
let re = Regex::new("^(.*?)\n").expect("Regex compilation should not fail");
let text = re.replace_all(&text, "<em>$1</em>\n");
// Unmerge wrapped lines
let text = text.replace("\n ", " ");
// Wrap option names in backticks. The lines look like:
// -V, --version Prints version information
// And are converted to:
// <code>-V, --version</code>: Prints version information
// (?m) enables multi-line mode for ^ and $
let re = Regex::new("(?m)^ {4}(([a-zA-Z_&;-]+(, )?)+) +(.*)$").expect("Regex compilation should not fail");
let text = re.replace_all(&text, "<code>$1</code>: $4");
// Look for a `[commandbody]` tag. If it exists, use all lines below it that
// start with a `#` in the USAGE section.
let mut text_lines = text.lines().collect::<Vec<&str>>();
let mut command_body = String::new();
if let Some(line_index) = text_lines.iter().position(|line| *line == "[commandbody]") {
text_lines.remove(line_index);
while text_lines
.get(line_index)
.is_some_and(|line| line.starts_with('#'))
{
command_body += if text_lines[line_index].starts_with("# ") {
&text_lines[line_index][2..]
} else {
&text_lines[line_index][1..]
};
command_body += "[nobr]\n";
text_lines.remove(line_index);
}
}
let text = text_lines.join("\n");
// Improve the usage section
let text = if command_body.is_empty() {
// Wrap the usage line in code tags
let re = Regex::new("(?m)^USAGE:\n {4}(@conduit:.*)$").expect("Regex compilation should not fail");
re.replace_all(&text, "USAGE:\n<code>$1</code>").to_string()
} else {
// Wrap the usage line in a code block, and add a yaml block example
// This makes the usage of e.g. `register-appservice` more accurate
let re = Regex::new("(?m)^USAGE:\n {4}(.*?)\n\n").expect("Regex compilation should not fail");
re.replace_all(&text, "USAGE:\n<pre>$1[nobr]\n[commandbodyblock]</pre>")
.replace("[commandbodyblock]", &command_body)
};
// Add HTML line-breaks
text.replace("\n\n\n", "\n\n")
.replace('\n', "<br>\n")
.replace("[nobr]<br>", "")
.alias
.resolve_local_alias(&admin_room_alias)
}
/// Create the admin room.
///
/// Users in this room are considered admins by conduit, and the room can be
/// used to issue admin commands by talking to the server user inside it.
pub(crate) async fn create_admin_room(&self) -> Result<()> {
pub async fn create_admin_room(&self) -> Result<()> {
let room_id = RoomId::new(services().globals.server_name());
services().rooms.short.get_or_create_shortroomid(&room_id)?;
@ -637,25 +390,10 @@ impl Service {
Ok(())
}
/// Gets the room ID of the admin room
///
/// Errors are propagated from the database, and will have None if there is
/// no admin room
pub(crate) async fn get_admin_room() -> Result<Option<OwnedRoomId>> {
let admin_room_alias: Box<RoomAliasId> = format!("#admins:{}", services().globals.server_name())
.try_into()
.expect("#admins:server_name is a valid alias name");
services()
.rooms
.alias
.resolve_local_alias(&admin_room_alias)
}
/// Invite the user to the conduit admin room.
///
/// In conduit, this is equivalent to granting admin privileges.
pub(crate) async fn make_user_admin(&self, user_id: &UserId, displayname: String) -> Result<()> {
pub async fn make_user_admin(&self, user_id: &UserId, displayname: String) -> Result<()> {
if let Some(room_id) = Self::get_admin_room().await? {
let mutex_state = Arc::clone(
services()
@ -754,7 +492,7 @@ impl Service {
// Send welcome message
services().rooms.timeline.build_and_append_pdu(
PduBuilder {
PduBuilder {
event_type: TimelineEventType::RoomMessage,
content: to_raw_value(&RoomMessageEventContent::text_html(
format!("## Thank you for trying out conduwuit!\n\nconduwuit is a fork of upstream Conduit which is in Beta. This means you can join and participate in most Matrix rooms, but not all features are supported and you might run into bugs from time to time.\n\nHelpful links:\n> Git and Documentation: https://github.com/girlbossceo/conduwuit\n> Report issues: https://github.com/girlbossceo/conduwuit/issues\n\nFor a list of available commands, send the following message in this room: `@conduit:{}: --help`\n\nHere are some rooms you can join (by typing the command):\n\nconduwuit room (Ask questions and get notified on updates):\n`/join #conduwuit:puppygock.gay`", services().globals.server_name()),
@ -776,54 +514,3 @@ impl Service {
}
}
}
fn escape_html(s: &str) -> String {
s.replace('&', "&amp;")
.replace('<', "&lt;")
.replace('>', "&gt;")
}
fn get_room_info(id: &OwnedRoomId) -> (OwnedRoomId, u64, String) {
(
id.clone(),
services()
.rooms
.state_cache
.room_joined_count(id)
.ok()
.flatten()
.unwrap_or(0),
services()
.rooms
.state_accessor
.get_name(id)
.ok()
.flatten()
.unwrap_or_else(|| id.to_string()),
)
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn get_help_short() { get_help_inner("-h"); }
#[test]
fn get_help_long() { get_help_inner("--help"); }
#[test]
fn get_help_subcommand() { get_help_inner("help"); }
fn get_help_inner(input: &str) {
let error = AdminCommand::try_parse_from(["argv[0] doesn't matter", input])
.unwrap_err()
.to_string();
// Search for a handful of keywords that suggest the help printed properly
assert!(error.contains("Usage:"));
assert!(error.contains("Commands:"));
assert!(error.contains("Options:"));
}
}

View file

@ -1,66 +0,0 @@
use ruma::{api::appservice::Registration, events::room::message::RoomMessageEventContent};
use crate::{service::admin::escape_html, services, Result};
pub(crate) async fn register(body: Vec<&str>) -> Result<RoomMessageEventContent> {
if body.len() > 2 && body[0].trim().starts_with("```") && body.last().unwrap().trim() == "```" {
let appservice_config = body[1..body.len().checked_sub(1).unwrap()].join("\n");
let parsed_config = serde_yaml::from_str::<Registration>(&appservice_config);
match parsed_config {
Ok(yaml) => match services().appservice.register_appservice(yaml).await {
Ok(id) => Ok(RoomMessageEventContent::text_plain(format!(
"Appservice registered with ID: {id}."
))),
Err(e) => Ok(RoomMessageEventContent::text_plain(format!(
"Failed to register appservice: {e}"
))),
},
Err(e) => Ok(RoomMessageEventContent::text_plain(format!(
"Could not parse appservice config: {e}"
))),
}
} else {
Ok(RoomMessageEventContent::text_plain(
"Expected code block in command body. Add --help for details.",
))
}
}
pub(crate) async fn unregister(_body: Vec<&str>, appservice_identifier: String) -> Result<RoomMessageEventContent> {
match services()
.appservice
.unregister_appservice(&appservice_identifier)
.await
{
Ok(()) => Ok(RoomMessageEventContent::text_plain("Appservice unregistered.")),
Err(e) => Ok(RoomMessageEventContent::text_plain(format!(
"Failed to unregister appservice: {e}"
))),
}
}
pub(crate) async fn show(_body: Vec<&str>, appservice_identifier: String) -> Result<RoomMessageEventContent> {
match services()
.appservice
.get_registration(&appservice_identifier)
.await
{
Some(config) => {
let config_str = serde_yaml::to_string(&config).expect("config should've been validated on register");
let output = format!("Config for {}:\n\n```yaml\n{}\n```", appservice_identifier, config_str,);
let output_html = format!(
"Config for {}:\n\n<pre><code class=\"language-yaml\">{}</code></pre>",
escape_html(&appservice_identifier),
escape_html(&config_str),
);
Ok(RoomMessageEventContent::text_html(output, output_html))
},
None => Ok(RoomMessageEventContent::text_plain("Appservice does not exist.")),
}
}
pub(crate) async fn list(_body: Vec<&str>) -> Result<RoomMessageEventContent> {
let appservices = services().appservice.iter_ids().await;
let output = format!("Appservices ({}): {}", appservices.len(), appservices.join(", "));
Ok(RoomMessageEventContent::text_plain(output))
}

View file

@ -1,52 +0,0 @@
use clap::Subcommand;
use ruma::events::room::message::RoomMessageEventContent;
use self::appservice_command::{list, register, show, unregister};
use crate::Result;
pub(crate) mod appservice_command;
#[cfg_attr(test, derive(Debug))]
#[derive(Subcommand)]
pub(crate) enum AppserviceCommand {
/// - Register an appservice using its registration YAML
///
/// This command needs a YAML generated by an appservice (such as a bridge),
/// which must be provided in a Markdown code block below the command.
///
/// Registering a new bridge using the ID of an existing bridge will replace
/// the old one.
Register,
/// - Unregister an appservice using its ID
///
/// You can find the ID using the `list-appservices` command.
Unregister {
/// The appservice to unregister
appservice_identifier: String,
},
/// - Show an appservice's config using its ID
///
/// You can find the ID using the `list-appservices` command.
Show {
/// The appservice to show
appservice_identifier: String,
},
/// - List all the currently registered appservices
List,
}
pub(crate) async fn process(command: AppserviceCommand, body: Vec<&str>) -> Result<RoomMessageEventContent> {
Ok(match command {
AppserviceCommand::Register => register(body).await?,
AppserviceCommand::Unregister {
appservice_identifier,
} => unregister(body, appservice_identifier).await?,
AppserviceCommand::Show {
appservice_identifier,
} => show(body, appservice_identifier).await?,
AppserviceCommand::List => list(body).await?,
})
}

View file

@ -1,468 +0,0 @@
use std::{collections::BTreeMap, sync::Arc, time::Instant};
use ruma::{
api::client::error::ErrorKind, events::room::message::RoomMessageEventContent, CanonicalJsonObject, EventId,
RoomId, RoomVersionId, ServerName,
};
use tokio::sync::RwLock;
use tracing::{debug, info, warn};
use tracing_subscriber::EnvFilter;
use crate::{
api::server_server::parse_incoming_pdu, service::sending::send::resolve_actual_dest, services, utils::HtmlEscape,
Error, PduEvent, Result,
};
pub(crate) async fn get_auth_chain(_body: Vec<&str>, event_id: Box<EventId>) -> Result<RoomMessageEventContent> {
let event_id = Arc::<EventId>::from(event_id);
if let Some(event) = services().rooms.timeline.get_pdu_json(&event_id)? {
let room_id_str = event
.get("room_id")
.and_then(|val| val.as_str())
.ok_or_else(|| Error::bad_database("Invalid event in database"))?;
let room_id = <&RoomId>::try_from(room_id_str)
.map_err(|_| Error::bad_database("Invalid room id field in event in database"))?;
let start = Instant::now();
let count = services()
.rooms
.auth_chain
.event_ids_iter(room_id, vec![event_id])
.await?
.count();
let elapsed = start.elapsed();
Ok(RoomMessageEventContent::text_plain(format!(
"Loaded auth chain with length {count} in {elapsed:?}"
)))
} else {
Ok(RoomMessageEventContent::text_plain("Event not found."))
}
}
pub(crate) async fn parse_pdu(body: Vec<&str>) -> Result<RoomMessageEventContent> {
if body.len() > 2 && body[0].trim().starts_with("```") && body.last().unwrap().trim() == "```" {
let string = body[1..body.len() - 1].join("\n");
match serde_json::from_str(&string) {
Ok(value) => match ruma::signatures::reference_hash(&value, &RoomVersionId::V6) {
Ok(hash) => {
let event_id = EventId::parse(format!("${hash}"));
match serde_json::from_value::<PduEvent>(serde_json::to_value(value).expect("value is json")) {
Ok(pdu) => Ok(RoomMessageEventContent::text_plain(format!("EventId: {event_id:?}\n{pdu:#?}"))),
Err(e) => Ok(RoomMessageEventContent::text_plain(format!(
"EventId: {event_id:?}\nCould not parse event: {e}"
))),
}
},
Err(e) => Ok(RoomMessageEventContent::text_plain(format!("Could not parse PDU JSON: {e:?}"))),
},
Err(e) => Ok(RoomMessageEventContent::text_plain(format!(
"Invalid json in command body: {e}"
))),
}
} else {
Ok(RoomMessageEventContent::text_plain("Expected code block in command body."))
}
}
pub(crate) async fn get_pdu(_body: Vec<&str>, event_id: Box<EventId>) -> Result<RoomMessageEventContent> {
let mut outlier = false;
let mut pdu_json = services()
.rooms
.timeline
.get_non_outlier_pdu_json(&event_id)?;
if pdu_json.is_none() {
outlier = true;
pdu_json = services().rooms.timeline.get_pdu_json(&event_id)?;
}
match pdu_json {
Some(json) => {
let json_text = serde_json::to_string_pretty(&json).expect("canonical json is valid json");
Ok(RoomMessageEventContent::text_html(
format!(
"{}\n```json\n{}\n```",
if outlier {
"Outlier PDU found in our database"
} else {
"PDU found in our database"
},
json_text
),
format!(
"<p>{}</p>\n<pre><code class=\"language-json\">{}\n</code></pre>\n",
if outlier {
"Outlier PDU found in our database"
} else {
"PDU found in our database"
},
HtmlEscape(&json_text)
),
))
},
None => Ok(RoomMessageEventContent::text_plain("PDU not found locally.")),
}
}
pub(crate) async fn get_remote_pdu_list(
body: Vec<&str>, server: Box<ServerName>, force: bool,
) -> Result<RoomMessageEventContent> {
if !services().globals.config.allow_federation {
return Ok(RoomMessageEventContent::text_plain(
"Federation is disabled on this homeserver.",
));
}
if server == services().globals.server_name() {
return Ok(RoomMessageEventContent::text_plain(
"Not allowed to send federation requests to ourselves. Please use `get-pdu` for fetching local PDUs.",
));
}
if body.len() > 2 && body[0].trim().starts_with("```") && body.last().unwrap().trim() == "```" {
let list = body
.clone()
.drain(1..body.len().checked_sub(1).unwrap())
.filter_map(|pdu| EventId::parse(pdu).ok())
.collect::<Vec<_>>();
for pdu in list {
if force {
_ = get_remote_pdu(Vec::new(), Box::from(pdu), server.clone()).await;
} else {
get_remote_pdu(Vec::new(), Box::from(pdu), server.clone()).await?;
}
}
return Ok(RoomMessageEventContent::text_plain("Fetched list of remote PDUs."));
}
Ok(RoomMessageEventContent::text_plain(
"Expected code block in command body. Add --help for details.",
))
}
pub(crate) async fn get_remote_pdu(
_body: Vec<&str>, event_id: Box<EventId>, server: Box<ServerName>,
) -> Result<RoomMessageEventContent> {
if !services().globals.config.allow_federation {
return Ok(RoomMessageEventContent::text_plain(
"Federation is disabled on this homeserver.",
));
}
if server == services().globals.server_name() {
return Ok(RoomMessageEventContent::text_plain(
"Not allowed to send federation requests to ourselves. Please use `get-pdu` for fetching local PDUs.",
));
}
match services()
.sending
.send_federation_request(
&server,
ruma::api::federation::event::get_event::v1::Request {
event_id: event_id.clone().into(),
},
)
.await
{
Ok(response) => {
let json: CanonicalJsonObject = serde_json::from_str(response.pdu.get()).map_err(|e| {
warn!(
"Requested event ID {event_id} from server but failed to convert from RawValue to \
CanonicalJsonObject (malformed event/response?): {e}"
);
Error::BadRequest(ErrorKind::Unknown, "Received response from server but failed to parse PDU")
})?;
debug!("Attempting to parse PDU: {:?}", &response.pdu);
let parsed_pdu = {
let parsed_result = parse_incoming_pdu(&response.pdu);
let (event_id, value, room_id) = match parsed_result {
Ok(t) => t,
Err(e) => {
warn!("Failed to parse PDU: {e}");
info!("Full PDU: {:?}", &response.pdu);
return Ok(RoomMessageEventContent::text_plain(format!(
"Failed to parse PDU remote server {server} sent us: {e}"
)));
},
};
vec![(event_id, value, room_id)]
};
let pub_key_map = RwLock::new(BTreeMap::new());
debug!("Attempting to fetch homeserver signing keys for {server}");
services()
.rooms
.event_handler
.fetch_required_signing_keys(parsed_pdu.iter().map(|(_event_id, event, _room_id)| event), &pub_key_map)
.await
.unwrap_or_else(|e| {
warn!("Could not fetch all signatures for PDUs from {server}: {e:?}");
});
info!("Attempting to handle event ID {event_id} as backfilled PDU");
services()
.rooms
.timeline
.backfill_pdu(&server, response.pdu, &pub_key_map)
.await?;
let json_text = serde_json::to_string_pretty(&json).expect("canonical json is valid json");
Ok(RoomMessageEventContent::text_html(
format!(
"{}\n```json\n{}\n```",
"Got PDU from specified server and handled as backfilled PDU successfully. Event body:", json_text
),
format!(
"<p>{}</p>\n<pre><code class=\"language-json\">{}\n</code></pre>\n",
"Got PDU from specified server and handled as backfilled PDU successfully. Event body:",
HtmlEscape(&json_text)
),
))
},
Err(e) => Ok(RoomMessageEventContent::text_plain(format!(
"Remote server did not have PDU or failed sending request to remote server: {e}"
))),
}
}
pub(crate) async fn get_room_state(_body: Vec<&str>, room_id: Box<RoomId>) -> Result<RoomMessageEventContent> {
let room_state = services()
.rooms
.state_accessor
.room_state_full(&room_id)
.await?
.values()
.map(|pdu| pdu.to_state_event())
.collect::<Vec<_>>();
if room_state.is_empty() {
return Ok(RoomMessageEventContent::text_plain(
"Unable to find room state in our database (vector is empty)",
));
}
let json_text = serde_json::to_string_pretty(&room_state).map_err(|e| {
warn!("Failed converting room state vector in our database to pretty JSON: {e}");
Error::bad_database(
"Failed to convert room state events to pretty JSON, possible invalid room state events in our database",
)
})?;
Ok(RoomMessageEventContent::text_html(
format!("{}\n```json\n{}\n```", "Found full room state", json_text),
format!(
"<p>{}</p>\n<pre><code class=\"language-json\">{}\n</code></pre>\n",
"Found full room state",
HtmlEscape(&json_text)
),
))
}
pub(crate) async fn ping(_body: Vec<&str>, server: Box<ServerName>) -> Result<RoomMessageEventContent> {
if server == services().globals.server_name() {
return Ok(RoomMessageEventContent::text_plain(
"Not allowed to send federation requests to ourselves.",
));
}
let timer = tokio::time::Instant::now();
match services()
.sending
.send_federation_request(&server, ruma::api::federation::discovery::get_server_version::v1::Request {})
.await
{
Ok(response) => {
let ping_time = timer.elapsed();
let json_text_res = serde_json::to_string_pretty(&response.server);
if let Ok(json) = json_text_res {
return Ok(RoomMessageEventContent::text_html(
format!("Got response which took {ping_time:?} time:\n```json\n{json}\n```"),
format!(
"<p>Got response which took {ping_time:?} time:</p>\n<pre><code \
class=\"language-json\">{}\n</code></pre>\n",
HtmlEscape(&json)
),
));
}
Ok(RoomMessageEventContent::text_plain(format!(
"Got non-JSON response which took {ping_time:?} time:\n{0:?}",
response
)))
},
Err(e) => {
warn!("Failed sending federation request to specified server from ping debug command: {e}");
Ok(RoomMessageEventContent::text_plain(format!(
"Failed sending federation request to specified server:\n\n{e}",
)))
},
}
}
pub(crate) async fn force_device_list_updates(_body: Vec<&str>) -> Result<RoomMessageEventContent> {
// Force E2EE device list updates for all users
for user_id in services().users.iter().filter_map(Result::ok) {
services().users.mark_device_key_update(&user_id)?;
}
Ok(RoomMessageEventContent::text_plain(
"Marked all devices for all users as having new keys to update",
))
}
pub(crate) async fn change_log_level(
_body: Vec<&str>, filter: Option<String>, reset: bool,
) -> Result<RoomMessageEventContent> {
if reset {
let old_filter_layer = match EnvFilter::try_new(&services().globals.config.log) {
Ok(s) => s,
Err(e) => {
return Ok(RoomMessageEventContent::text_plain(format!(
"Log level from config appears to be invalid now: {e}"
)));
},
};
match services()
.globals
.tracing_reload_handle
.reload(&old_filter_layer)
{
Ok(()) => {
return Ok(RoomMessageEventContent::text_plain(format!(
"Successfully changed log level back to config value {}",
services().globals.config.log
)));
},
Err(e) => {
return Ok(RoomMessageEventContent::text_plain(format!(
"Failed to modify and reload the global tracing log level: {e}"
)));
},
}
}
if let Some(filter) = filter {
let new_filter_layer = match EnvFilter::try_new(filter) {
Ok(s) => s,
Err(e) => {
return Ok(RoomMessageEventContent::text_plain(format!(
"Invalid log level filter specified: {e}"
)));
},
};
match services()
.globals
.tracing_reload_handle
.reload(&new_filter_layer)
{
Ok(()) => {
return Ok(RoomMessageEventContent::text_plain("Successfully changed log level"));
},
Err(e) => {
return Ok(RoomMessageEventContent::text_plain(format!(
"Failed to modify and reload the global tracing log level: {e}"
)));
},
}
}
Ok(RoomMessageEventContent::text_plain("No log level was specified."))
}
pub(crate) async fn sign_json(body: Vec<&str>) -> Result<RoomMessageEventContent> {
if body.len() > 2 && body[0].trim().starts_with("```") && body.last().unwrap().trim() == "```" {
let string = body[1..body.len().checked_sub(1).unwrap()].join("\n");
match serde_json::from_str(&string) {
Ok(mut value) => {
ruma::signatures::sign_json(
services().globals.server_name().as_str(),
services().globals.keypair(),
&mut value,
)
.expect("our request json is what ruma expects");
let json_text = serde_json::to_string_pretty(&value).expect("canonical json is valid json");
Ok(RoomMessageEventContent::text_plain(json_text))
},
Err(e) => Ok(RoomMessageEventContent::text_plain(format!("Invalid json: {e}"))),
}
} else {
Ok(RoomMessageEventContent::text_plain(
"Expected code block in command body. Add --help for details.",
))
}
}
pub(crate) async fn verify_json(body: Vec<&str>) -> Result<RoomMessageEventContent> {
if body.len() > 2 && body[0].trim().starts_with("```") && body.last().unwrap().trim() == "```" {
let string = body[1..body.len().checked_sub(1).unwrap()].join("\n");
match serde_json::from_str(&string) {
Ok(value) => {
let pub_key_map = RwLock::new(BTreeMap::new());
services()
.rooms
.event_handler
.fetch_required_signing_keys([&value], &pub_key_map)
.await?;
let pub_key_map = pub_key_map.read().await;
match ruma::signatures::verify_json(&pub_key_map, &value) {
Ok(()) => Ok(RoomMessageEventContent::text_plain("Signature correct")),
Err(e) => Ok(RoomMessageEventContent::text_plain(format!(
"Signature verification failed: {e}"
))),
}
},
Err(e) => Ok(RoomMessageEventContent::text_plain(format!("Invalid json: {e}"))),
}
} else {
Ok(RoomMessageEventContent::text_plain(
"Expected code block in command body. Add --help for details.",
))
}
}
pub(crate) async fn resolve_true_destination(
_body: Vec<&str>, server_name: Box<ServerName>, no_cache: bool,
) -> Result<RoomMessageEventContent> {
if !services().globals.config.allow_federation {
return Ok(RoomMessageEventContent::text_plain(
"Federation is disabled on this homeserver.",
));
}
if server_name == services().globals.config.server_name {
return Ok(RoomMessageEventContent::text_plain(
"Not allowed to send federation requests to ourselves. Please use `get-pdu` for fetching local PDUs.",
));
}
let (actual_dest, hostname_uri) = resolve_actual_dest(&server_name, no_cache, true).await?;
Ok(RoomMessageEventContent::text_plain(format!(
"Actual destination: {actual_dest:?} | Hostname URI: {hostname_uri}"
)))
}
pub(crate) fn memory_stats() -> RoomMessageEventContent {
let html_body = crate::alloc::memory_stats();
if html_body.is_empty() {
return RoomMessageEventContent::text_plain("malloc stats are not supported on your compiled malloc.");
}
RoomMessageEventContent::text_html(
"This command's output can only be viewed by clients that render HTML.".to_owned(),
html_body,
)
}

View file

@ -1,161 +0,0 @@
use clap::Subcommand;
use ruma::{events::room::message::RoomMessageEventContent, EventId, RoomId, ServerName};
use self::debug_commands::{
change_log_level, force_device_list_updates, get_auth_chain, get_pdu, get_remote_pdu, get_remote_pdu_list,
get_room_state, memory_stats, parse_pdu, ping, resolve_true_destination, sign_json, verify_json,
};
use crate::Result;
pub(crate) mod debug_commands;
#[cfg_attr(test, derive(Debug))]
#[derive(Subcommand)]
pub(crate) enum DebugCommand {
/// - Get the auth_chain of a PDU
GetAuthChain {
/// An event ID (the $ character followed by the base64 reference hash)
event_id: Box<EventId>,
},
/// - Parse and print a PDU from a JSON
///
/// The PDU event is only checked for validity and is not added to the
/// database.
///
/// This command needs a JSON blob provided in a Markdown code block below
/// the command.
ParsePdu,
/// - Retrieve and print a PDU by ID from the conduwuit database
GetPdu {
/// An event ID (a $ followed by the base64 reference hash)
event_id: Box<EventId>,
},
/// - Attempts to retrieve a PDU from a remote server. Inserts it into our
/// database/timeline if found and we do not have this PDU already
/// (following normal event auth rules, handles it as an incoming PDU).
GetRemotePdu {
/// An event ID (a $ followed by the base64 reference hash)
event_id: Box<EventId>,
/// Argument for us to attempt to fetch the event from the
/// specified remote server.
server: Box<ServerName>,
},
/// Same as `get-remote-pdu` but accepts a codeblock newline delimited list
/// of PDUs and a single server to fetch from
GetRemotePduList {
/// Argument for us to attempt to fetch all the events from the
/// specified remote server.
server: Box<ServerName>,
/// If set, ignores errors, else stops at the first error/failure.
#[arg(short, long)]
force: bool,
},
/// - Gets all the room state events for the specified room.
///
/// This is functionally equivalent to `GET
/// /_matrix/client/v3/rooms/{roomid}/state`, except the admin command does
/// *not* check if the sender user is allowed to see state events. This is
/// done because it's implied that server admins here have database access
/// and can see/get room info themselves anyways if they were malicious
/// admins.
///
/// Of course the check is still done on the actual client API.
GetRoomState {
/// Room ID
room_id: Box<RoomId>,
},
/// - Sends a federation request to the remote server's
/// `/_matrix/federation/v1/version` endpoint and measures the latency it
/// took for the server to respond
Ping {
server: Box<ServerName>,
},
/// - Forces device lists for all local and remote users to be updated (as
/// having new keys available)
ForceDeviceListUpdates,
/// - Change tracing log level/filter on the fly
///
/// This accepts the same format as the `log` config option.
ChangeLogLevel {
/// Log level/filter
filter: Option<String>,
/// Resets the log level/filter to the one in your config
#[arg(short, long)]
reset: bool,
},
/// - Verify json signatures
///
/// This command needs a JSON blob provided in a Markdown code block below
/// the command.
SignJson,
/// - Verify json signatures
///
/// This command needs a JSON blob provided in a Markdown code block below
/// the command.
VerifyJson,
/// - Runs a server name through conduwuit's true destination resolution
/// process
///
/// Useful for debugging well-known issues
ResolveTrueDestination {
server_name: Box<ServerName>,
#[arg(short, long)]
no_cache: bool,
},
/// - Print extended memory usage
MemoryStats,
}
pub(crate) async fn process(command: DebugCommand, body: Vec<&str>) -> Result<RoomMessageEventContent> {
Ok(match command {
DebugCommand::GetAuthChain {
event_id,
} => get_auth_chain(body, event_id).await?,
DebugCommand::ParsePdu => parse_pdu(body).await?,
DebugCommand::GetPdu {
event_id,
} => get_pdu(body, event_id).await?,
DebugCommand::GetRemotePdu {
event_id,
server,
} => get_remote_pdu(body, event_id, server).await?,
DebugCommand::GetRoomState {
room_id,
} => get_room_state(body, room_id).await?,
DebugCommand::Ping {
server,
} => ping(body, server).await?,
DebugCommand::ForceDeviceListUpdates => force_device_list_updates(body).await?,
DebugCommand::ChangeLogLevel {
filter,
reset,
} => change_log_level(body, filter, reset).await?,
DebugCommand::SignJson => sign_json(body).await?,
DebugCommand::VerifyJson => verify_json(body).await?,
DebugCommand::GetRemotePduList {
server,
force,
} => get_remote_pdu_list(body, server, force).await?,
DebugCommand::ResolveTrueDestination {
server_name,
no_cache,
} => resolve_true_destination(body, server_name, no_cache).await?,
DebugCommand::MemoryStats => memory_stats(),
})
}

View file

@ -1,134 +0,0 @@
use std::fmt::Write as _;
use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomId, RoomId, ServerName, UserId};
use crate::{
service::admin::{escape_html, get_room_info},
services,
utils::HtmlEscape,
Result,
};
pub(crate) async fn disable_room(_body: Vec<&str>, room_id: Box<RoomId>) -> Result<RoomMessageEventContent> {
services().rooms.metadata.disable_room(&room_id, true)?;
Ok(RoomMessageEventContent::text_plain("Room disabled."))
}
pub(crate) async fn enable_room(_body: Vec<&str>, room_id: Box<RoomId>) -> Result<RoomMessageEventContent> {
services().rooms.metadata.disable_room(&room_id, false)?;
Ok(RoomMessageEventContent::text_plain("Room enabled."))
}
pub(crate) async fn incoming_federeation(_body: Vec<&str>) -> Result<RoomMessageEventContent> {
let map = services().globals.roomid_federationhandletime.read().await;
let mut msg = format!("Handling {} incoming pdus:\n", map.len());
for (r, (e, i)) in map.iter() {
let elapsed = i.elapsed();
let _ = writeln!(msg, "{} {}: {}m{}s", r, e, elapsed.as_secs() / 60, elapsed.as_secs() % 60);
}
Ok(RoomMessageEventContent::text_plain(&msg))
}
pub(crate) async fn fetch_support_well_known(
_body: Vec<&str>, server_name: Box<ServerName>,
) -> Result<RoomMessageEventContent> {
let response = services()
.globals
.client
.default
.get(format!("https://{server_name}/.well-known/matrix/support"))
.send()
.await?;
let text = response.text().await?;
if text.is_empty() {
return Ok(RoomMessageEventContent::text_plain("Response text/body is empty."));
}
if text.len() > 1500 {
return Ok(RoomMessageEventContent::text_plain(
"Response text/body is over 1500 characters, assuming no support well-known.",
));
}
let json: serde_json::Value = match serde_json::from_str(&text) {
Ok(json) => json,
Err(_) => {
return Ok(RoomMessageEventContent::text_plain("Response text/body is not valid JSON."));
},
};
let pretty_json: String = match serde_json::to_string_pretty(&json) {
Ok(json) => json,
Err(_) => {
return Ok(RoomMessageEventContent::text_plain("Response text/body is not valid JSON."));
},
};
Ok(RoomMessageEventContent::text_html(
format!("Got JSON response:\n\n```json\n{pretty_json}\n```"),
format!(
"<p>Got JSON response:</p>\n<pre><code class=\"language-json\">{}\n</code></pre>\n",
HtmlEscape(&pretty_json)
),
))
}
pub(crate) async fn remote_user_in_rooms(_body: Vec<&str>, user_id: Box<UserId>) -> Result<RoomMessageEventContent> {
if user_id.server_name() == services().globals.config.server_name {
return Ok(RoomMessageEventContent::text_plain(
"User belongs to our server, please use `list-joined-rooms` user admin command instead.",
));
}
if !services().users.exists(&user_id)? {
return Ok(RoomMessageEventContent::text_plain(
"Remote user does not exist in our database.",
));
}
let mut rooms: Vec<(OwnedRoomId, u64, String)> = services()
.rooms
.state_cache
.rooms_joined(&user_id)
.filter_map(Result::ok)
.map(|room_id| get_room_info(&room_id))
.collect();
if rooms.is_empty() {
return Ok(RoomMessageEventContent::text_plain("User is not in any rooms."));
}
rooms.sort_by_key(|r| r.1);
rooms.reverse();
let output_plain = format!(
"Rooms {user_id} shares with us:\n{}",
rooms
.iter()
.map(|(id, members, name)| format!("{id}\tMembers: {members}\tName: {name}"))
.collect::<Vec<_>>()
.join("\n")
);
let output_html = format!(
"<table><caption>Rooms {user_id} shares with \
us</caption>\n<tr><th>id</th>\t<th>members</th>\t<th>name</th></tr>\n{}</table>",
rooms
.iter()
.fold(String::new(), |mut output, (id, members, name)| {
writeln!(
output,
"<tr><td>{}</td>\t<td>{}</td>\t<td>{}</td></tr>",
escape_html(id.as_ref()),
members,
escape_html(name)
)
.unwrap();
output
})
);
Ok(RoomMessageEventContent::text_html(output_plain, output_html))
}

View file

@ -1,62 +0,0 @@
use clap::Subcommand;
use ruma::{events::room::message::RoomMessageEventContent, RoomId, ServerName, UserId};
use self::federation_commands::{
disable_room, enable_room, fetch_support_well_known, incoming_federeation, remote_user_in_rooms,
};
use crate::Result;
pub(crate) mod federation_commands;
#[cfg_attr(test, derive(Debug))]
#[derive(Subcommand)]
pub(crate) enum FederationCommand {
/// - List all rooms we are currently handling an incoming pdu from
IncomingFederation,
/// - Disables incoming federation handling for a room.
DisableRoom {
room_id: Box<RoomId>,
},
/// - Enables incoming federation handling for a room again.
EnableRoom {
room_id: Box<RoomId>,
},
/// - Fetch `/.well-known/matrix/support` from the specified server
///
/// Despite the name, this is not a federation endpoint and does not go
/// through the federation / server resolution process as per-spec this is
/// supposed to be served at the server_name.
///
/// Respecting homeservers put this file here for listing administration,
/// moderation, and security inquiries. This command provides a way to
/// easily fetch that information.
FetchSupportWellKnown {
server_name: Box<ServerName>,
},
/// - Lists all the rooms we share/track with the specified *remote* user
RemoteUserInRooms {
user_id: Box<UserId>,
},
}
pub(crate) async fn process(command: FederationCommand, body: Vec<&str>) -> Result<RoomMessageEventContent> {
Ok(match command {
FederationCommand::DisableRoom {
room_id,
} => disable_room(body, room_id).await?,
FederationCommand::EnableRoom {
room_id,
} => enable_room(body, room_id).await?,
FederationCommand::IncomingFederation => incoming_federeation(body).await?,
FederationCommand::FetchSupportWellKnown {
server_name,
} => fetch_support_well_known(body, server_name).await?,
FederationCommand::RemoteUserInRooms {
user_id,
} => remote_user_in_rooms(body, user_id).await?,
})
}

View file

@ -1,26 +0,0 @@
use ruma::events::room::message::RoomMessageEventContent;
use crate::{services, Result};
/// Uses the iterator in `src/database/key_value/users.rs` to iterator over
/// every user in our database (remote and local). Reports total count, any
/// errors if there were any, etc
pub(crate) async fn check_all_users(_body: Vec<&str>) -> Result<RoomMessageEventContent> {
let timer = tokio::time::Instant::now();
let results = services().users.db.iter();
let query_time = timer.elapsed();
let users = results.collect::<Vec<_>>();
let total = users.len();
let err_count = users.iter().filter(|user| user.is_err()).count();
let ok_count = users.iter().filter(|user| user.is_ok()).count();
let message = format!(
"Database query completed in {query_time:?}:\n\n```\nTotal entries: {:?}\nFailure/Invalid user count: \
{:?}\nSuccess/Valid user count: {:?}```",
total, err_count, ok_count
);
Ok(RoomMessageEventContent::notice_html(message, String::new()))
}

View file

@ -1,19 +0,0 @@
use clap::Subcommand;
use ruma::events::room::message::RoomMessageEventContent;
use self::fsck_commands::check_all_users;
use crate::Result;
pub(crate) mod fsck_commands;
#[cfg_attr(test, derive(Debug))]
#[derive(Subcommand)]
pub(crate) enum FsckCommand {
CheckAllUsers,
}
pub(crate) async fn process(command: FsckCommand, body: Vec<&str>) -> Result<RoomMessageEventContent> {
Ok(match command {
FsckCommand::CheckAllUsers => check_all_users(body).await?,
})
}

View file

@ -1,176 +0,0 @@
use ruma::{events::room::message::RoomMessageEventContent, EventId};
use tracing::{debug, info};
use crate::{service::admin::MxcUri, services, Result};
pub(crate) async fn delete(
_body: Vec<&str>, mxc: Option<Box<MxcUri>>, event_id: Option<Box<EventId>>,
) -> Result<RoomMessageEventContent> {
if event_id.is_some() && mxc.is_some() {
return Ok(RoomMessageEventContent::text_plain(
"Please specify either an MXC or an event ID, not both.",
));
}
if let Some(mxc) = mxc {
debug!("Got MXC URL: {mxc}");
services().media.delete(mxc.to_string()).await?;
return Ok(RoomMessageEventContent::text_plain(
"Deleted the MXC from our database and on our filesystem.",
));
} else if let Some(event_id) = event_id {
debug!("Got event ID to delete media from: {event_id}");
let mut mxc_urls = vec![];
let mut mxc_deletion_count = 0;
// parsing the PDU for any MXC URLs begins here
if let Some(event_json) = services().rooms.timeline.get_pdu_json(&event_id)? {
if let Some(content_key) = event_json.get("content") {
debug!("Event ID has \"content\".");
let content_obj = content_key.as_object();
if let Some(content) = content_obj {
// 1. attempts to parse the "url" key
debug!("Attempting to go into \"url\" key for main media file");
if let Some(url) = content.get("url") {
debug!("Got a URL in the event ID {event_id}: {url}");
if url.to_string().starts_with("\"mxc://") {
debug!("Pushing URL {url} to list of MXCs to delete");
let final_url = url.to_string().replace('"', "");
mxc_urls.push(final_url);
} else {
info!("Found a URL in the event ID {event_id} but did not start with mxc://, ignoring");
}
}
// 2. attempts to parse the "info" key
debug!("Attempting to go into \"info\" key for thumbnails");
if let Some(info_key) = content.get("info") {
debug!("Event ID has \"info\".");
let info_obj = info_key.as_object();
if let Some(info) = info_obj {
if let Some(thumbnail_url) = info.get("thumbnail_url") {
debug!("Found a thumbnail_url in info key: {thumbnail_url}");
if thumbnail_url.to_string().starts_with("\"mxc://") {
debug!("Pushing thumbnail URL {thumbnail_url} to list of MXCs to delete");
let final_thumbnail_url = thumbnail_url.to_string().replace('"', "");
mxc_urls.push(final_thumbnail_url);
} else {
info!(
"Found a thumbnail URL in the event ID {event_id} but did not start with \
mxc://, ignoring"
);
}
} else {
info!("No \"thumbnail_url\" key in \"info\" key, assuming no thumbnails.");
}
}
}
// 3. attempts to parse the "file" key
debug!("Attempting to go into \"file\" key");
if let Some(file_key) = content.get("file") {
debug!("Event ID has \"file\".");
let file_obj = file_key.as_object();
if let Some(file) = file_obj {
if let Some(url) = file.get("url") {
debug!("Found url in file key: {url}");
if url.to_string().starts_with("\"mxc://") {
debug!("Pushing URL {url} to list of MXCs to delete");
let final_url = url.to_string().replace('"', "");
mxc_urls.push(final_url);
} else {
info!(
"Found a URL in the event ID {event_id} but did not start with mxc://, \
ignoring"
);
}
} else {
info!("No \"url\" key in \"file\" key.");
}
}
}
} else {
return Ok(RoomMessageEventContent::text_plain(
"Event ID does not have a \"content\" key or failed parsing the event ID JSON.",
));
}
} else {
return Ok(RoomMessageEventContent::text_plain(
"Event ID does not have a \"content\" key, this is not a message or an event type that contains \
media.",
));
}
} else {
return Ok(RoomMessageEventContent::text_plain(
"Event ID does not exist or is not known to us.",
));
}
if mxc_urls.is_empty() {
// we shouldn't get here (should have errored earlier) but just in case for
// whatever reason we do...
info!("Parsed event ID {event_id} but did not contain any MXC URLs.");
return Ok(RoomMessageEventContent::text_plain("Parsed event ID but found no MXC URLs."));
}
for mxc_url in mxc_urls {
services().media.delete(mxc_url).await?;
mxc_deletion_count += 1;
}
return Ok(RoomMessageEventContent::text_plain(format!(
"Deleted {mxc_deletion_count} total MXCs from our database and the filesystem from event ID {event_id}."
)));
}
Ok(RoomMessageEventContent::text_plain(
"Please specify either an MXC using --mxc or an event ID using --event-id of the message containing an image. \
See --help for details.",
))
}
pub(crate) async fn delete_list(body: Vec<&str>) -> Result<RoomMessageEventContent> {
if body.len() > 2 && body[0].trim().starts_with("```") && body.last().unwrap().trim() == "```" {
let mxc_list = body
.clone()
.drain(1..body.len().checked_sub(1).unwrap())
.collect::<Vec<_>>();
let mut mxc_deletion_count: usize = 0;
for mxc in mxc_list {
debug!("Deleting MXC {mxc} in bulk");
services().media.delete(mxc.to_owned()).await?;
mxc_deletion_count = mxc_deletion_count
.checked_add(1)
.expect("mxc_deletion_count should not get this high");
}
return Ok(RoomMessageEventContent::text_plain(format!(
"Finished bulk MXC deletion, deleted {mxc_deletion_count} total MXCs from our database and the filesystem.",
)));
}
Ok(RoomMessageEventContent::text_plain(
"Expected code block in command body. Add --help for details.",
))
}
pub(crate) async fn delete_past_remote_media(_body: Vec<&str>, duration: String) -> Result<RoomMessageEventContent> {
let deleted_count = services()
.media
.delete_all_remote_media_at_after_time(duration)
.await?;
Ok(RoomMessageEventContent::text_plain(format!(
"Deleted {deleted_count} total files.",
)))
}

View file

@ -1,49 +0,0 @@
use clap::Subcommand;
use ruma::{events::room::message::RoomMessageEventContent, EventId};
use self::media_commands::{delete, delete_list, delete_past_remote_media};
use crate::{service::admin::MxcUri, Result};
pub(crate) mod media_commands;
#[cfg_attr(test, derive(Debug))]
#[derive(Subcommand)]
pub(crate) enum MediaCommand {
/// - Deletes a single media file from our database and on the filesystem
/// via a single MXC URL
Delete {
/// The MXC URL to delete
#[arg(long)]
mxc: Option<Box<MxcUri>>,
/// - The message event ID which contains the media and thumbnail MXC
/// URLs
#[arg(long)]
event_id: Option<Box<EventId>>,
},
/// - Deletes a codeblock list of MXC URLs from our database and on the
/// filesystem
DeleteList,
/// - Deletes all remote media in the last X amount of time using filesystem
/// metadata first created at date.
DeletePastRemoteMedia {
/// - The duration (at or after), e.g. "5m" to delete all media in the
/// past 5 minutes
duration: String,
},
}
pub(crate) async fn process(command: MediaCommand, body: Vec<&str>) -> Result<RoomMessageEventContent> {
Ok(match command {
MediaCommand::Delete {
mxc,
event_id,
} => delete(body, mxc, event_id).await?,
MediaCommand::DeleteList => delete_list(body).await?,
MediaCommand::DeletePastRemoteMedia {
duration,
} => delete_past_remote_media(body, duration).await?,
})
}

View file

@ -1,50 +0,0 @@
use ruma::events::room::message::RoomMessageEventContent;
use super::AccountData;
use crate::{services, Result};
/// All the getters and iterators from src/database/key_value/account_data.rs
pub(crate) async fn account_data(subcommand: AccountData) -> Result<RoomMessageEventContent> {
match subcommand {
AccountData::ChangesSince {
user_id,
since,
room_id,
} => {
let timer = tokio::time::Instant::now();
let results = services()
.account_data
.db
.changes_since(room_id.as_deref(), &user_id, since)?;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::text_html(
format!("Query completed in {query_time:?}:\n\n```\n{:?}```", results),
format!(
"<p>Query completed in {query_time:?}:</p>\n<pre><code>{:?}\n</code></pre>",
results
),
))
},
AccountData::Get {
user_id,
kind,
room_id,
} => {
let timer = tokio::time::Instant::now();
let results = services()
.account_data
.db
.get(room_id.as_deref(), &user_id, kind)?;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::text_html(
format!("Query completed in {query_time:?}:\n\n```\n{:?}```", results),
format!(
"<p>Query completed in {query_time:?}:</p>\n<pre><code>{:?}\n</code></pre>",
results
),
))
},
}
}

View file

@ -1,41 +0,0 @@
use ruma::events::room::message::RoomMessageEventContent;
use super::Appservice;
use crate::{services, Result};
/// All the getters and iterators from src/database/key_value/appservice.rs
pub(crate) async fn appservice(subcommand: Appservice) -> Result<RoomMessageEventContent> {
match subcommand {
Appservice::GetRegistration {
appservice_id,
} => {
let timer = tokio::time::Instant::now();
let results = services()
.appservice
.db
.get_registration(appservice_id.as_ref());
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::text_html(
format!("Query completed in {query_time:?}:\n\n```\n{:?}```", results),
format!(
"<p>Query completed in {query_time:?}:</p>\n<pre><code>{:?}\n</code></pre>",
results
),
))
},
Appservice::All => {
let timer = tokio::time::Instant::now();
let results = services().appservice.db.all();
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::text_html(
format!("Query completed in {query_time:?}:\n\n```\n{:?}```", results),
format!(
"<p>Query completed in {query_time:?}:</p>\n<pre><code>{:?}\n</code></pre>",
results
),
))
},
}
}

View file

@ -1,77 +0,0 @@
use ruma::events::room::message::RoomMessageEventContent;
use super::Globals;
use crate::{services, Result};
/// All the getters and iterators from src/database/key_value/globals.rs
pub(crate) async fn globals(subcommand: Globals) -> Result<RoomMessageEventContent> {
match subcommand {
Globals::DatabaseVersion => {
let timer = tokio::time::Instant::now();
let results = services().globals.db.database_version();
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::text_html(
format!("Query completed in {query_time:?}:\n\n```\n{:?}```", results),
format!(
"<p>Query completed in {query_time:?}:</p>\n<pre><code>{:?}\n</code></pre>",
results
),
))
},
Globals::CurrentCount => {
let timer = tokio::time::Instant::now();
let results = services().globals.db.current_count();
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::text_html(
format!("Query completed in {query_time:?}:\n\n```\n{:?}```", results),
format!(
"<p>Query completed in {query_time:?}:</p>\n<pre><code>{:?}\n</code></pre>",
results
),
))
},
Globals::LastCheckForUpdatesId => {
let timer = tokio::time::Instant::now();
let results = services().globals.db.last_check_for_updates_id();
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::text_html(
format!("Query completed in {query_time:?}:\n\n```\n{:?}```", results),
format!(
"<p>Query completed in {query_time:?}:</p>\n<pre><code>{:?}\n</code></pre>",
results
),
))
},
Globals::LoadKeypair => {
let timer = tokio::time::Instant::now();
let results = services().globals.db.load_keypair();
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::text_html(
format!("Query completed in {query_time:?}:\n\n```\n{:?}```", results),
format!(
"<p>Query completed in {query_time:?}:</p>\n<pre><code>{:?}\n</code></pre>",
results
),
))
},
Globals::SigningKeysFor {
origin,
} => {
let timer = tokio::time::Instant::now();
let results = services().globals.db.signing_keys_for(&origin);
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::text_html(
format!("Query completed in {query_time:?}:\n\n```\n{:?}```", results),
format!(
"<p>Query completed in {query_time:?}:</p>\n<pre><code>{:?}\n</code></pre>",
results
),
))
},
}
}

View file

@ -1,223 +0,0 @@
pub(crate) mod account_data;
pub(crate) mod appservice;
pub(crate) mod globals;
pub(crate) mod presence;
pub(crate) mod room_alias;
pub(crate) mod sending;
pub(crate) mod users;
use clap::Subcommand;
use ruma::{
events::{room::message::RoomMessageEventContent, RoomAccountDataEventType},
RoomAliasId, RoomId, ServerName, UserId,
};
use self::{
account_data::account_data, appservice::appservice, globals::globals, presence::presence, room_alias::room_alias,
sending::sending, users::users,
};
use crate::Result;
#[cfg_attr(test, derive(Debug))]
#[derive(Subcommand)]
/// Query tables from database
pub(crate) enum QueryCommand {
/// - account_data.rs iterators and getters
#[command(subcommand)]
AccountData(AccountData),
/// - appservice.rs iterators and getters
#[command(subcommand)]
Appservice(Appservice),
/// - presence.rs iterators and getters
#[command(subcommand)]
Presence(Presence),
/// - rooms/alias.rs iterators and getters
#[command(subcommand)]
RoomAlias(RoomAlias),
/// - globals.rs iterators and getters
#[command(subcommand)]
Globals(Globals),
/// - sending.rs iterators and getters
#[command(subcommand)]
Sending(Sending),
/// - users.rs iterators and getters
#[command(subcommand)]
Users(Users),
}
#[cfg_attr(test, derive(Debug))]
#[derive(Subcommand)]
/// All the getters and iterators from src/database/key_value/account_data.rs
pub(crate) enum AccountData {
/// - Returns all changes to the account data that happened after `since`.
ChangesSince {
/// Full user ID
user_id: Box<UserId>,
/// UNIX timestamp since (u64)
since: u64,
/// Optional room ID of the account data
room_id: Option<Box<RoomId>>,
},
/// - Searches the account data for a specific kind.
Get {
/// Full user ID
user_id: Box<UserId>,
/// Account data event type
kind: RoomAccountDataEventType,
/// Optional room ID of the account data
room_id: Option<Box<RoomId>>,
},
}
#[cfg_attr(test, derive(Debug))]
#[derive(Subcommand)]
/// All the getters and iterators from src/database/key_value/appservice.rs
pub(crate) enum Appservice {
/// - Gets the appservice registration info/details from the ID as a string
GetRegistration {
/// Appservice registration ID
appservice_id: Box<str>,
},
/// - Gets all appservice registrations with their ID and registration info
All,
}
#[cfg_attr(test, derive(Debug))]
#[derive(Subcommand)]
/// All the getters and iterators from src/database/key_value/presence.rs
pub(crate) enum Presence {
/// - Returns the latest presence event for the given user.
GetPresence {
/// Full user ID
user_id: Box<UserId>,
},
/// - Iterator of the most recent presence updates that happened after the
/// event with id `since`.
PresenceSince {
/// UNIX timestamp since (u64)
since: u64,
},
}
#[cfg_attr(test, derive(Debug))]
#[derive(Subcommand)]
/// All the getters and iterators from src/database/key_value/rooms/alias.rs
pub(crate) enum RoomAlias {
ResolveLocalAlias {
/// Full room alias
alias: Box<RoomAliasId>,
},
/// - Iterator of all our local room aliases for the room ID
LocalAliasesForRoom {
/// Full room ID
room_id: Box<RoomId>,
},
/// - Iterator of all our local aliases in our database with their room IDs
AllLocalAliases,
}
#[cfg_attr(test, derive(Debug))]
#[derive(Subcommand)]
/// All the getters and iterators from src/database/key_value/globals.rs
pub(crate) enum Globals {
DatabaseVersion,
CurrentCount,
LastCheckForUpdatesId,
LoadKeypair,
/// - This returns an empty `Ok(BTreeMap<..>)` when there are no keys found
/// for the server.
SigningKeysFor {
origin: Box<ServerName>,
},
}
#[cfg_attr(test, derive(Debug))]
#[derive(Subcommand)]
/// All the getters and iterators from src/database/key_value/sending.rs
pub(crate) enum Sending {
/// - Queries database for all `servercurrentevent_data`
ActiveRequests,
/// - Queries database for `servercurrentevent_data` but for a specific
/// destination
///
/// This command takes only *one* format of these arguments:
///
/// appservice_id
/// server_name
/// user_id AND push_key
///
/// See src/service/sending/mod.rs for the definition of the `Destination`
/// enum
ActiveRequestsFor {
#[arg(short, long)]
appservice_id: Option<String>,
#[arg(short, long)]
server_name: Option<Box<ServerName>>,
#[arg(short, long)]
user_id: Option<Box<UserId>>,
#[arg(short, long)]
push_key: Option<String>,
},
/// - Queries database for `servernameevent_data` which are the queued up
/// requests that will eventually be sent
///
/// This command takes only *one* format of these arguments:
///
/// appservice_id
/// server_name
/// user_id AND push_key
///
/// See src/service/sending/mod.rs for the definition of the `Destination`
/// enum
QueuedRequests {
#[arg(short, long)]
appservice_id: Option<String>,
#[arg(short, long)]
server_name: Option<Box<ServerName>>,
#[arg(short, long)]
user_id: Option<Box<UserId>>,
#[arg(short, long)]
push_key: Option<String>,
},
GetLatestEduCount {
server_name: Box<ServerName>,
},
}
#[cfg_attr(test, derive(Debug))]
#[derive(Subcommand)]
/// All the getters and iterators from src/database/key_value/users.rs
pub(crate) enum Users {
Iter,
}
/// Processes admin query commands
pub(crate) async fn process(command: QueryCommand, _body: Vec<&str>) -> Result<RoomMessageEventContent> {
Ok(match command {
QueryCommand::AccountData(command) => account_data(command).await?,
QueryCommand::Appservice(command) => appservice(command).await?,
QueryCommand::Presence(command) => presence(command).await?,
QueryCommand::RoomAlias(command) => room_alias(command).await?,
QueryCommand::Globals(command) => globals(command).await?,
QueryCommand::Sending(command) => sending(command).await?,
QueryCommand::Users(command) => users(command).await?,
})
}

View file

@ -1,42 +0,0 @@
use ruma::events::room::message::RoomMessageEventContent;
use super::Presence;
use crate::{services, Result};
/// All the getters and iterators in key_value/presence.rs
pub(crate) async fn presence(subcommand: Presence) -> Result<RoomMessageEventContent> {
match subcommand {
Presence::GetPresence {
user_id,
} => {
let timer = tokio::time::Instant::now();
let results = services().presence.db.get_presence(&user_id)?;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::text_html(
format!("Query completed in {query_time:?}:\n\n```\n{:?}```", results),
format!(
"<p>Query completed in {query_time:?}:</p>\n<pre><code>{:?}\n</code></pre>",
results
),
))
},
Presence::PresenceSince {
since,
} => {
let timer = tokio::time::Instant::now();
let results = services().presence.db.presence_since(since);
let query_time = timer.elapsed();
let presence_since: Vec<(_, _, _)> = results.collect();
Ok(RoomMessageEventContent::text_html(
format!("Query completed in {query_time:?}:\n\n```\n{:?}```", presence_since),
format!(
"<p>Query completed in {query_time:?}:</p>\n<pre><code>{:?}\n</code></pre>",
presence_since
),
))
},
}
}

View file

@ -1,57 +0,0 @@
use ruma::events::room::message::RoomMessageEventContent;
use super::RoomAlias;
use crate::{services, Result};
/// All the getters and iterators in src/database/key_value/rooms/alias.rs
pub(crate) async fn room_alias(subcommand: RoomAlias) -> Result<RoomMessageEventContent> {
match subcommand {
RoomAlias::ResolveLocalAlias {
alias,
} => {
let timer = tokio::time::Instant::now();
let results = services().rooms.alias.db.resolve_local_alias(&alias);
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::text_html(
format!("Query completed in {query_time:?}:\n\n```\n{:?}```", results),
format!(
"<p>Query completed in {query_time:?}:</p>\n<pre><code>{:?}\n</code></pre>",
results
),
))
},
RoomAlias::LocalAliasesForRoom {
room_id,
} => {
let timer = tokio::time::Instant::now();
let results = services().rooms.alias.db.local_aliases_for_room(&room_id);
let query_time = timer.elapsed();
let aliases: Vec<_> = results.collect();
Ok(RoomMessageEventContent::text_html(
format!("Query completed in {query_time:?}:\n\n```\n{:?}```", aliases),
format!(
"<p>Query completed in {query_time:?}:</p>\n<pre><code>{:?}\n</code></pre>",
aliases
),
))
},
RoomAlias::AllLocalAliases => {
let timer = tokio::time::Instant::now();
let results = services().rooms.alias.db.all_local_aliases();
let query_time = timer.elapsed();
let aliases: Vec<_> = results.collect();
Ok(RoomMessageEventContent::text_html(
format!("Query completed in {query_time:?}:\n\n```\n{:?}```", aliases),
format!(
"<p>Query completed in {query_time:?}:</p>\n<pre><code>{:?}\n</code></pre>",
aliases
),
))
},
}
}

View file

@ -1,204 +0,0 @@
use ruma::events::room::message::RoomMessageEventContent;
use super::Sending;
use crate::{service::sending::Destination, services, Result};
/// All the getters and iterators in key_value/sending.rs
pub(crate) async fn sending(subcommand: Sending) -> Result<RoomMessageEventContent> {
match subcommand {
Sending::ActiveRequests => {
let timer = tokio::time::Instant::now();
let results = services().sending.db.active_requests();
let query_time = timer.elapsed();
let active_requests: Result<Vec<(_, _, _)>> = results.collect();
Ok(RoomMessageEventContent::text_html(
format!("Query completed in {query_time:?}:\n\n```\n{:?}```", active_requests),
format!(
"<p>Query completed in {query_time:?}:</p>\n<pre><code>{:?}\n</code></pre>",
active_requests
),
))
},
Sending::QueuedRequests {
appservice_id,
server_name,
user_id,
push_key,
} => {
if appservice_id.is_none() && server_name.is_none() && user_id.is_none() && push_key.is_none() {
return Ok(RoomMessageEventContent::text_plain(
"An appservice ID, server name, or a user ID with push key must be specified via arguments. See \
--help for more details.",
));
}
let (results, query_time) = match (appservice_id, server_name, user_id, push_key) {
(Some(appservice_id), None, None, None) => {
if appservice_id.is_empty() {
return Ok(RoomMessageEventContent::text_plain(
"An appservice ID, server name, or a user ID with push key must be specified via \
arguments. See --help for more details.",
));
}
let timer = tokio::time::Instant::now();
let results = services()
.sending
.db
.queued_requests(&Destination::Appservice(appservice_id));
let query_time = timer.elapsed();
(results, query_time)
},
(None, Some(server_name), None, None) => {
let timer = tokio::time::Instant::now();
let results = services()
.sending
.db
.queued_requests(&Destination::Normal(server_name.into()));
let query_time = timer.elapsed();
(results, query_time)
},
(None, None, Some(user_id), Some(push_key)) => {
if push_key.is_empty() {
return Ok(RoomMessageEventContent::text_plain(
"An appservice ID, server name, or a user ID with push key must be specified via \
arguments. See --help for more details.",
));
}
let timer = tokio::time::Instant::now();
let results = services()
.sending
.db
.queued_requests(&Destination::Push(user_id.into(), push_key));
let query_time = timer.elapsed();
(results, query_time)
},
(Some(_), Some(_), Some(_), Some(_)) => {
return Ok(RoomMessageEventContent::text_plain(
"An appservice ID, server name, or a user ID with push key must be specified via arguments. \
Not all of them See --help for more details.",
));
},
_ => {
return Ok(RoomMessageEventContent::text_plain(
"An appservice ID, server name, or a user ID with push key must be specified via arguments. \
See --help for more details.",
));
},
};
let queued_requests = results.collect::<Result<Vec<(_, _)>>>();
Ok(RoomMessageEventContent::text_html(
format!("Query completed in {query_time:?}:\n\n```\n{:?}```", queued_requests),
format!(
"<p>Query completed in {query_time:?}:</p>\n<pre><code>{:?}\n</code></pre>",
queued_requests
),
))
},
Sending::ActiveRequestsFor {
appservice_id,
server_name,
user_id,
push_key,
} => {
if appservice_id.is_none() && server_name.is_none() && user_id.is_none() && push_key.is_none() {
return Ok(RoomMessageEventContent::text_plain(
"An appservice ID, server name, or a user ID with push key must be specified via arguments. See \
--help for more details.",
));
}
let (results, query_time) = match (appservice_id, server_name, user_id, push_key) {
(Some(appservice_id), None, None, None) => {
if appservice_id.is_empty() {
return Ok(RoomMessageEventContent::text_plain(
"An appservice ID, server name, or a user ID with push key must be specified via \
arguments. See --help for more details.",
));
}
let timer = tokio::time::Instant::now();
let results = services()
.sending
.db
.active_requests_for(&Destination::Appservice(appservice_id));
let query_time = timer.elapsed();
(results, query_time)
},
(None, Some(server_name), None, None) => {
let timer = tokio::time::Instant::now();
let results = services()
.sending
.db
.active_requests_for(&Destination::Normal(server_name.into()));
let query_time = timer.elapsed();
(results, query_time)
},
(None, None, Some(user_id), Some(push_key)) => {
if push_key.is_empty() {
return Ok(RoomMessageEventContent::text_plain(
"An appservice ID, server name, or a user ID with push key must be specified via \
arguments. See --help for more details.",
));
}
let timer = tokio::time::Instant::now();
let results = services()
.sending
.db
.active_requests_for(&Destination::Push(user_id.into(), push_key));
let query_time = timer.elapsed();
(results, query_time)
},
(Some(_), Some(_), Some(_), Some(_)) => {
return Ok(RoomMessageEventContent::text_plain(
"An appservice ID, server name, or a user ID with push key must be specified via arguments. \
Not all of them See --help for more details.",
));
},
_ => {
return Ok(RoomMessageEventContent::text_plain(
"An appservice ID, server name, or a user ID with push key must be specified via arguments. \
See --help for more details.",
));
},
};
let active_requests = results.collect::<Result<Vec<(_, _)>>>();
Ok(RoomMessageEventContent::text_html(
format!("Query completed in {query_time:?}:\n\n```\n{:?}```", active_requests),
format!(
"<p>Query completed in {query_time:?}:</p>\n<pre><code>{:?}\n</code></pre>",
active_requests
),
))
},
Sending::GetLatestEduCount {
server_name,
} => {
let timer = tokio::time::Instant::now();
let results = services().sending.db.get_latest_educount(&server_name);
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::text_html(
format!("Query completed in {query_time:?}:\n\n```\n{:?}```", results),
format!(
"<p>Query completed in {query_time:?}:</p>\n<pre><code>{:?}\n</code></pre>",
results
),
))
},
}
}

View file

@ -1,25 +0,0 @@
use ruma::events::room::message::RoomMessageEventContent;
use super::Users;
use crate::{services, Result};
/// All the getters and iterators in key_value/users.rs
pub(crate) async fn users(subcommand: Users) -> Result<RoomMessageEventContent> {
match subcommand {
Users::Iter => {
let timer = tokio::time::Instant::now();
let results = services().users.db.iter();
let query_time = timer.elapsed();
let users = results.collect::<Vec<_>>();
Ok(RoomMessageEventContent::text_html(
format!("Query completed in {query_time:?}:\n\n```\n{:?}```", users),
format!(
"<p>Query completed in {query_time:?}:</p>\n<pre><code>{:?}\n</code></pre>",
users
),
))
},
}
}

View file

@ -1,160 +0,0 @@
use clap::Subcommand;
use ruma::{events::room::message::RoomMessageEventContent, RoomId, RoomOrAliasId};
use self::room_commands::list;
use crate::Result;
pub(crate) mod room_alias_commands;
pub(crate) mod room_commands;
pub(crate) mod room_directory_commands;
pub(crate) mod room_moderation_commands;
#[cfg_attr(test, derive(Debug))]
#[derive(Subcommand)]
pub(crate) enum RoomCommand {
/// - List all rooms the server knows about
List {
page: Option<usize>,
},
#[command(subcommand)]
/// - Manage moderation of remote or local rooms
Moderation(RoomModerationCommand),
#[command(subcommand)]
/// - Manage rooms' aliases
Alias(RoomAliasCommand),
#[command(subcommand)]
/// - Manage the room directory
Directory(RoomDirectoryCommand),
}
#[cfg_attr(test, derive(Debug))]
#[derive(Subcommand)]
pub(crate) enum RoomAliasCommand {
/// - Make an alias point to a room.
Set {
#[arg(short, long)]
/// Set the alias even if a room is already using it
force: bool,
/// The room id to set the alias on
room_id: Box<RoomId>,
/// The alias localpart to use (`alias`, not `#alias:servername.tld`)
room_alias_localpart: String,
},
/// - Remove an alias
Remove {
/// The alias localpart to remove (`alias`, not `#alias:servername.tld`)
room_alias_localpart: String,
},
/// - Show which room is using an alias
Which {
/// The alias localpart to look up (`alias`, not
/// `#alias:servername.tld`)
room_alias_localpart: String,
},
/// - List aliases currently being used
List {
/// If set, only list the aliases for this room
room_id: Option<Box<RoomId>>,
},
}
#[cfg_attr(test, derive(Debug))]
#[derive(Subcommand)]
pub(crate) enum RoomDirectoryCommand {
/// - Publish a room to the room directory
Publish {
/// The room id of the room to publish
room_id: Box<RoomId>,
},
/// - Unpublish a room to the room directory
Unpublish {
/// The room id of the room to unpublish
room_id: Box<RoomId>,
},
/// - List rooms that are published
List {
page: Option<usize>,
},
}
#[cfg_attr(test, derive(Debug))]
#[derive(Subcommand)]
pub(crate) enum RoomModerationCommand {
/// - Bans a room from local users joining and evicts all our local users
/// from the room. Also blocks any invites (local and remote) for the
/// banned room.
///
/// Server admins (users in the conduwuit admin room) will not be evicted
/// and server admins can still join the room. To evict admins too, use
/// --force (also ignores errors) To disable incoming federation of the
/// room, use --disable-federation
BanRoom {
#[arg(short, long)]
/// Evicts admins out of the room and ignores any potential errors when
/// making our local users leave the room
force: bool,
#[arg(long)]
/// Disables incoming federation of the room after banning and evicting
/// users
disable_federation: bool,
/// The room in the format of `!roomid:example.com` or a room alias in
/// the format of `#roomalias:example.com`
room: Box<RoomOrAliasId>,
},
/// - Bans a list of rooms (room IDs and room aliases) from a newline
/// delimited codeblock similar to `user deactivate-all`
BanListOfRooms {
#[arg(short, long)]
/// Evicts admins out of the room and ignores any potential errors when
/// making our local users leave the room
force: bool,
#[arg(long)]
/// Disables incoming federation of the room after banning and evicting
/// users
disable_federation: bool,
},
/// - Unbans a room to allow local users to join again
///
/// To re-enable incoming federation of the room, use --enable-federation
UnbanRoom {
#[arg(long)]
/// Enables incoming federation of the room after unbanning
enable_federation: bool,
/// The room in the format of `!roomid:example.com` or a room alias in
/// the format of `#roomalias:example.com`
room: Box<RoomOrAliasId>,
},
/// - List of all rooms we have banned
ListBannedRooms,
}
pub(crate) async fn process(command: RoomCommand, body: Vec<&str>) -> Result<RoomMessageEventContent> {
Ok(match command {
RoomCommand::Alias(command) => room_alias_commands::process(command, body).await?,
RoomCommand::Directory(command) => room_directory_commands::process(command, body).await?,
RoomCommand::Moderation(command) => room_moderation_commands::process(command, body).await?,
RoomCommand::List {
page,
} => list(body, page).await?,
})
}

View file

@ -1,136 +0,0 @@
use std::fmt::Write as _;
use ruma::{events::room::message::RoomMessageEventContent, RoomAliasId};
use super::RoomAliasCommand;
use crate::{service::admin::escape_html, services, Result};
pub(crate) async fn process(command: RoomAliasCommand, _body: Vec<&str>) -> Result<RoomMessageEventContent> {
match command {
RoomAliasCommand::Set {
ref room_alias_localpart,
..
}
| RoomAliasCommand::Remove {
ref room_alias_localpart,
}
| RoomAliasCommand::Which {
ref room_alias_localpart,
} => {
let room_alias_str = format!("#{}:{}", room_alias_localpart, services().globals.server_name());
let room_alias = match RoomAliasId::parse_box(room_alias_str) {
Ok(alias) => alias,
Err(err) => return Ok(RoomMessageEventContent::text_plain(format!("Failed to parse alias: {}", err))),
};
match command {
RoomAliasCommand::Set {
force,
room_id,
..
} => match (force, services().rooms.alias.resolve_local_alias(&room_alias)) {
(true, Ok(Some(id))) => match services().rooms.alias.set_alias(&room_alias, &room_id) {
Ok(()) => Ok(RoomMessageEventContent::text_plain(format!(
"Successfully overwrote alias (formerly {})",
id
))),
Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Failed to remove alias: {}", err))),
},
(false, Ok(Some(id))) => Ok(RoomMessageEventContent::text_plain(format!(
"Refusing to overwrite in use alias for {}, use -f or --force to overwrite",
id
))),
(_, Ok(None)) => match services().rooms.alias.set_alias(&room_alias, &room_id) {
Ok(()) => Ok(RoomMessageEventContent::text_plain("Successfully set alias")),
Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Failed to remove alias: {err}"))),
},
(_, Err(err)) => Ok(RoomMessageEventContent::text_plain(format!("Unable to lookup alias: {err}"))),
},
RoomAliasCommand::Remove {
..
} => match services().rooms.alias.resolve_local_alias(&room_alias) {
Ok(Some(id)) => match services().rooms.alias.remove_alias(&room_alias) {
Ok(()) => Ok(RoomMessageEventContent::text_plain(format!("Removed alias from {}", id))),
Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Failed to remove alias: {}", err))),
},
Ok(None) => Ok(RoomMessageEventContent::text_plain("Alias isn't in use.")),
Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Unable to lookup alias: {}", err))),
},
RoomAliasCommand::Which {
..
} => match services().rooms.alias.resolve_local_alias(&room_alias) {
Ok(Some(id)) => Ok(RoomMessageEventContent::text_plain(format!("Alias resolves to {}", id))),
Ok(None) => Ok(RoomMessageEventContent::text_plain("Alias isn't in use.")),
Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Unable to lookup alias: {}", err))),
},
RoomAliasCommand::List {
..
} => unreachable!(),
}
},
RoomAliasCommand::List {
room_id,
} => {
if let Some(room_id) = room_id {
let aliases = services()
.rooms
.alias
.local_aliases_for_room(&room_id)
.collect::<Result<Vec<_>, _>>();
match aliases {
Ok(aliases) => {
let plain_list = aliases.iter().fold(String::new(), |mut output, alias| {
writeln!(output, "- {alias}").unwrap();
output
});
let html_list = aliases.iter().fold(String::new(), |mut output, alias| {
writeln!(output, "<li>{}</li>", escape_html(alias.as_ref())).unwrap();
output
});
let plain = format!("Aliases for {room_id}:\n{plain_list}");
let html = format!("Aliases for {room_id}:\n<ul>{html_list}</ul>");
Ok(RoomMessageEventContent::text_html(plain, html))
},
Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Unable to list aliases: {}", err))),
}
} else {
let aliases = services()
.rooms
.alias
.all_local_aliases()
.collect::<Result<Vec<_>, _>>();
match aliases {
Ok(aliases) => {
let server_name = services().globals.server_name();
let plain_list = aliases
.iter()
.fold(String::new(), |mut output, (alias, id)| {
writeln!(output, "- `{alias}` -> #{id}:{server_name}").unwrap();
output
});
let html_list = aliases
.iter()
.fold(String::new(), |mut output, (alias, id)| {
writeln!(
output,
"<li><code>{}</code> -> #{}:{}</li>",
escape_html(alias.as_ref()),
escape_html(id.as_ref()),
server_name
)
.unwrap();
output
});
let plain = format!("Aliases:\n{plain_list}");
let html = format!("Aliases:\n<ul>{html_list}</ul>");
Ok(RoomMessageEventContent::text_html(plain, html))
},
Err(e) => Ok(RoomMessageEventContent::text_plain(format!("Unable to list room aliases: {e}"))),
}
}
},
}
}

View file

@ -1,59 +0,0 @@
use std::fmt::Write as _;
use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomId};
use crate::{
service::admin::{escape_html, get_room_info, PAGE_SIZE},
services, Result,
};
pub(crate) async fn list(_body: Vec<&str>, page: Option<usize>) -> Result<RoomMessageEventContent> {
// TODO: i know there's a way to do this with clap, but i can't seem to find it
let page = page.unwrap_or(1);
let mut rooms = services()
.rooms
.metadata
.iter_ids()
.filter_map(Result::ok)
.map(|id: OwnedRoomId| get_room_info(&id))
.collect::<Vec<_>>();
rooms.sort_by_key(|r| r.1);
rooms.reverse();
let rooms = rooms
.into_iter()
.skip(page.saturating_sub(1).saturating_mul(PAGE_SIZE))
.take(PAGE_SIZE)
.collect::<Vec<_>>();
if rooms.is_empty() {
return Ok(RoomMessageEventContent::text_plain("No more rooms."));
};
let output_plain = format!(
"Rooms:\n{}",
rooms
.iter()
.map(|(id, members, name)| format!("{id}\tMembers: {members}\tName: {name}"))
.collect::<Vec<_>>()
.join("\n")
);
let output_html = format!(
"<table><caption>Room list - page \
{page}</caption>\n<tr><th>id</th>\t<th>members</th>\t<th>name</th></tr>\n{}</table>",
rooms
.iter()
.fold(String::new(), |mut output, (id, members, name)| {
writeln!(
output,
"<tr><td>{}</td>\t<td>{}</td>\t<td>{}</td></tr>",
escape_html(id.as_ref()),
members,
escape_html(name)
)
.unwrap();
output
})
);
Ok(RoomMessageEventContent::text_html(output_plain, output_html))
}

View file

@ -1,78 +0,0 @@
use std::fmt::Write as _;
use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomId};
use super::RoomDirectoryCommand;
use crate::{
service::admin::{escape_html, get_room_info, PAGE_SIZE},
services, Result,
};
pub(crate) async fn process(command: RoomDirectoryCommand, _body: Vec<&str>) -> Result<RoomMessageEventContent> {
match command {
RoomDirectoryCommand::Publish {
room_id,
} => match services().rooms.directory.set_public(&room_id) {
Ok(()) => Ok(RoomMessageEventContent::text_plain("Room published")),
Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Unable to update room: {err}"))),
},
RoomDirectoryCommand::Unpublish {
room_id,
} => match services().rooms.directory.set_not_public(&room_id) {
Ok(()) => Ok(RoomMessageEventContent::text_plain("Room unpublished")),
Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Unable to update room: {err}"))),
},
RoomDirectoryCommand::List {
page,
} => {
// TODO: i know there's a way to do this with clap, but i can't seem to find it
let page = page.unwrap_or(1);
let mut rooms = services()
.rooms
.directory
.public_rooms()
.filter_map(Result::ok)
.map(|id: OwnedRoomId| get_room_info(&id))
.collect::<Vec<_>>();
rooms.sort_by_key(|r| r.1);
rooms.reverse();
let rooms = rooms
.into_iter()
.skip(page.saturating_sub(1).saturating_mul(PAGE_SIZE))
.take(PAGE_SIZE)
.collect::<Vec<_>>();
if rooms.is_empty() {
return Ok(RoomMessageEventContent::text_plain("No more rooms."));
};
let output_plain = format!(
"Rooms:\n{}",
rooms
.iter()
.map(|(id, members, name)| format!("{id}\tMembers: {members}\tName: {name}"))
.collect::<Vec<_>>()
.join("\n")
);
let output_html = format!(
"<table><caption>Room directory - page \
{page}</caption>\n<tr><th>id</th>\t<th>members</th>\t<th>name</th></tr>\n{}</table>",
rooms
.iter()
.fold(String::new(), |mut output, (id, members, name)| {
writeln!(
output,
"<tr><td>{}</td>\t<td>{}</td>\t<td>{}</td></tr>",
escape_html(id.as_ref()),
members,
escape_html(name.as_ref())
)
.unwrap();
output
})
);
Ok(RoomMessageEventContent::text_html(output_plain, output_html))
},
}
}

View file

@ -1,505 +0,0 @@
use std::fmt::Write as _;
use ruma::{
events::room::message::RoomMessageEventContent, OwnedRoomId, OwnedUserId, RoomAliasId, RoomId, RoomOrAliasId,
};
use tracing::{debug, error, info, warn};
use super::RoomModerationCommand;
use crate::{
api::client_server::{get_alias_helper, leave_room},
service::admin::{escape_html, Service},
services,
utils::user_id::user_is_local,
Result,
};
pub(crate) async fn process(command: RoomModerationCommand, body: Vec<&str>) -> Result<RoomMessageEventContent> {
match command {
RoomModerationCommand::BanRoom {
force,
room,
disable_federation,
} => {
debug!("Got room alias or ID: {}", room);
let admin_room_alias: Box<RoomAliasId> = format!("#admins:{}", services().globals.server_name())
.try_into()
.expect("#admins:server_name is a valid alias name");
if let Some(admin_room_id) = Service::get_admin_room().await? {
if room.to_string().eq(&admin_room_id) || room.to_string().eq(&admin_room_alias) {
return Ok(RoomMessageEventContent::text_plain("Not allowed to ban the admin room."));
}
}
let room_id = if room.is_room_id() {
let room_id = match RoomId::parse(&room) {
Ok(room_id) => room_id,
Err(e) => {
return Ok(RoomMessageEventContent::text_plain(format!(
"Failed to parse room ID {room}. Please note that this requires a full room ID \
(`!awIh6gGInaS5wLQJwa:example.com`) or a room alias (`#roomalias:example.com`): {e}"
)))
},
};
debug!("Room specified is a room ID, banning room ID");
services().rooms.metadata.ban_room(&room_id, true)?;
room_id
} else if room.is_room_alias_id() {
let room_alias = match RoomAliasId::parse(&room) {
Ok(room_alias) => room_alias,
Err(e) => {
return Ok(RoomMessageEventContent::text_plain(format!(
"Failed to parse room ID {room}. Please note that this requires a full room ID \
(`!awIh6gGInaS5wLQJwa:example.com`) or a room alias (`#roomalias:example.com`): {e}"
)))
},
};
debug!(
"Room specified is not a room ID, attempting to resolve room alias to a room ID locally, if not \
using get_alias_helper to fetch room ID remotely"
);
let room_id = if let Some(room_id) = services().rooms.alias.resolve_local_alias(&room_alias)? {
room_id
} else {
debug!(
"We don't have this room alias to a room ID locally, attempting to fetch room ID over \
federation"
);
match get_alias_helper(room_alias, None).await {
Ok(response) => {
debug!("Got federation response fetching room ID for room {room}: {:?}", response);
response.room_id
},
Err(e) => {
return Ok(RoomMessageEventContent::text_plain(format!(
"Failed to resolve room alias {room} to a room ID: {e}"
)));
},
}
};
services().rooms.metadata.ban_room(&room_id, true)?;
room_id
} else {
return Ok(RoomMessageEventContent::text_plain(
"Room specified is not a room ID or room alias. Please note that this requires a full room ID \
(`!awIh6gGInaS5wLQJwa:example.com`) or a room alias (`#roomalias:example.com`)",
));
};
debug!("Making all users leave the room {}", &room);
if force {
for local_user in services()
.rooms
.state_cache
.room_members(&room_id)
.filter_map(|user| {
user.ok().filter(|local_user| {
user_is_local(local_user)
// additional wrapped check here is to avoid adding remote users
// who are in the admin room to the list of local users (would fail auth check)
&& (user_is_local(local_user)
&& services()
.users
.is_admin(local_user)
.unwrap_or(true)) // since this is a force
// operation, assume user
// is an admin if somehow
// this fails
})
})
.collect::<Vec<OwnedUserId>>()
{
debug!(
"Attempting leave for user {} in room {} (forced, ignoring all errors, evicting admins too)",
&local_user, &room_id
);
_ = leave_room(&local_user, &room_id, None).await;
}
} else {
for local_user in services()
.rooms
.state_cache
.room_members(&room_id)
.filter_map(|user| {
user.ok().filter(|local_user| {
local_user.server_name() == services().globals.server_name()
// additional wrapped check here is to avoid adding remote users
// who are in the admin room to the list of local users (would fail auth check)
&& (local_user.server_name()
== services().globals.server_name()
&& !services()
.users
.is_admin(local_user)
.unwrap_or(false))
})
})
.collect::<Vec<OwnedUserId>>()
{
debug!("Attempting leave for user {} in room {}", &local_user, &room_id);
if let Err(e) = leave_room(&local_user, &room_id, None).await {
error!(
"Error attempting to make local user {} leave room {} during room banning: {}",
&local_user, &room_id, e
);
return Ok(RoomMessageEventContent::text_plain(format!(
"Error attempting to make local user {} leave room {} during room banning (room is still \
banned but not removing any more users): {}\nIf you would like to ignore errors, use \
--force",
&local_user, &room_id, e
)));
}
}
}
if disable_federation {
services().rooms.metadata.disable_room(&room_id, true)?;
return Ok(RoomMessageEventContent::text_plain(
"Room banned, removed all our local users, and disabled incoming federation with room.",
));
}
Ok(RoomMessageEventContent::text_plain(
"Room banned and removed all our local users, use disable-room to stop receiving new inbound \
federation events as well if needed.",
))
},
RoomModerationCommand::BanListOfRooms {
force,
disable_federation,
} => {
if body.len() > 2 && body[0].trim().starts_with("```") && body.last().unwrap().trim() == "```" {
let rooms_s = body.clone().drain(1..body.len() - 1).collect::<Vec<_>>();
let admin_room_alias: Box<RoomAliasId> = format!("#admins:{}", services().globals.server_name())
.try_into()
.expect("#admins:server_name is a valid alias name");
let mut room_ban_count = 0;
let mut room_ids: Vec<OwnedRoomId> = Vec::new();
for &room in &rooms_s {
match <&RoomOrAliasId>::try_from(room) {
Ok(room_alias_or_id) => {
if let Some(admin_room_id) = Service::get_admin_room().await? {
if room.to_owned().eq(&admin_room_id) || room.to_owned().eq(&admin_room_alias) {
info!("User specified admin room in bulk ban list, ignoring");
continue;
}
}
if room_alias_or_id.is_room_id() {
let room_id = match RoomId::parse(room_alias_or_id) {
Ok(room_id) => room_id,
Err(e) => {
if force {
// ignore rooms we failed to parse if we're force banning
warn!(
"Error parsing room \"{room}\" during bulk room banning, ignoring \
error and logging here: {e}"
);
continue;
}
return Ok(RoomMessageEventContent::text_plain(format!(
"{room} is not a valid room ID or room alias, please fix the list and try \
again: {e}"
)));
},
};
room_ids.push(room_id);
}
if room_alias_or_id.is_room_alias_id() {
match RoomAliasId::parse(room_alias_or_id) {
Ok(room_alias) => {
let room_id = if let Some(room_id) =
services().rooms.alias.resolve_local_alias(&room_alias)?
{
room_id
} else {
debug!(
"We don't have this room alias to a room ID locally, attempting to \
fetch room ID over federation"
);
match get_alias_helper(room_alias, None).await {
Ok(response) => {
debug!(
"Got federation response fetching room ID for room {room}: \
{:?}",
response
);
response.room_id
},
Err(e) => {
// don't fail if force blocking
if force {
warn!("Failed to resolve room alias {room} to a room ID: {e}");
continue;
}
return Ok(RoomMessageEventContent::text_plain(format!(
"Failed to resolve room alias {room} to a room ID: {e}"
)));
},
}
};
room_ids.push(room_id);
},
Err(e) => {
if force {
// ignore rooms we failed to parse if we're force deleting
error!(
"Error parsing room \"{room}\" during bulk room banning, ignoring \
error and logging here: {e}"
);
continue;
}
return Ok(RoomMessageEventContent::text_plain(format!(
"{room} is not a valid room ID or room alias, please fix the list and try \
again: {e}"
)));
},
}
}
},
Err(e) => {
if force {
// ignore rooms we failed to parse if we're force deleting
error!(
"Error parsing room \"{room}\" during bulk room banning, ignoring error and \
logging here: {e}"
);
continue;
}
return Ok(RoomMessageEventContent::text_plain(format!(
"{room} is not a valid room ID or room alias, please fix the list and try again: {e}"
)));
},
}
}
for room_id in room_ids {
if services().rooms.metadata.ban_room(&room_id, true).is_ok() {
debug!("Banned {room_id} successfully");
room_ban_count += 1;
}
debug!("Making all users leave the room {}", &room_id);
if force {
for local_user in services()
.rooms
.state_cache
.room_members(&room_id)
.filter_map(|user| {
user.ok().filter(|local_user| {
local_user.server_name() == services().globals.server_name()
// additional wrapped check here is to avoid adding remote users
// who are in the admin room to the list of local users (would fail auth check)
&& (local_user.server_name()
== services().globals.server_name()
&& services()
.users
.is_admin(local_user)
.unwrap_or(true)) // since this is a
// force operation,
// assume user is
// an admin if
// somehow this
// fails
})
})
.collect::<Vec<OwnedUserId>>()
{
debug!(
"Attempting leave for user {} in room {} (forced, ignoring all errors, evicting \
admins too)",
&local_user, room_id
);
_ = leave_room(&local_user, &room_id, None).await;
}
} else {
for local_user in services()
.rooms
.state_cache
.room_members(&room_id)
.filter_map(|user| {
user.ok().filter(|local_user| {
local_user.server_name() == services().globals.server_name()
// additional wrapped check here is to avoid adding remote users
// who are in the admin room to the list of local users (would fail auth check)
&& (local_user.server_name()
== services().globals.server_name()
&& !services()
.users
.is_admin(local_user)
.unwrap_or(false))
})
})
.collect::<Vec<OwnedUserId>>()
{
debug!("Attempting leave for user {} in room {}", &local_user, &room_id);
if let Err(e) = leave_room(&local_user, &room_id, None).await {
error!(
"Error attempting to make local user {} leave room {} during bulk room banning: {}",
&local_user, &room_id, e
);
return Ok(RoomMessageEventContent::text_plain(format!(
"Error attempting to make local user {} leave room {} during room banning (room \
is still banned but not removing any more users and not banning any more rooms): \
{}\nIf you would like to ignore errors, use --force",
&local_user, &room_id, e
)));
}
}
}
if disable_federation {
services().rooms.metadata.disable_room(&room_id, true)?;
}
}
if disable_federation {
return Ok(RoomMessageEventContent::text_plain(format!(
"Finished bulk room ban, banned {room_ban_count} total rooms, evicted all users, and disabled \
incoming federation with the room."
)));
}
return Ok(RoomMessageEventContent::text_plain(format!(
"Finished bulk room ban, banned {room_ban_count} total rooms and evicted all users."
)));
}
Ok(RoomMessageEventContent::text_plain(
"Expected code block in command body. Add --help for details.",
))
},
RoomModerationCommand::UnbanRoom {
room,
enable_federation,
} => {
let room_id = if room.is_room_id() {
let room_id = match RoomId::parse(&room) {
Ok(room_id) => room_id,
Err(e) => {
return Ok(RoomMessageEventContent::text_plain(format!(
"Failed to parse room ID {room}. Please note that this requires a full room ID \
(`!awIh6gGInaS5wLQJwa:example.com`) or a room alias (`#roomalias:example.com`): {e}"
)))
},
};
debug!("Room specified is a room ID, unbanning room ID");
services().rooms.metadata.ban_room(&room_id, false)?;
room_id
} else if room.is_room_alias_id() {
let room_alias = match RoomAliasId::parse(&room) {
Ok(room_alias) => room_alias,
Err(e) => {
return Ok(RoomMessageEventContent::text_plain(format!(
"Failed to parse room ID {room}. Please note that this requires a full room ID \
(`!awIh6gGInaS5wLQJwa:example.com`) or a room alias (`#roomalias:example.com`): {e}"
)))
},
};
debug!(
"Room specified is not a room ID, attempting to resolve room alias to a room ID locally, if not \
using get_alias_helper to fetch room ID remotely"
);
let room_id = if let Some(room_id) = services().rooms.alias.resolve_local_alias(&room_alias)? {
room_id
} else {
debug!(
"We don't have this room alias to a room ID locally, attempting to fetch room ID over \
federation"
);
match get_alias_helper(room_alias, None).await {
Ok(response) => {
debug!("Got federation response fetching room ID for room {room}: {:?}", response);
response.room_id
},
Err(e) => {
return Ok(RoomMessageEventContent::text_plain(format!(
"Failed to resolve room alias {room} to a room ID: {e}"
)));
},
}
};
services().rooms.metadata.ban_room(&room_id, false)?;
room_id
} else {
return Ok(RoomMessageEventContent::text_plain(
"Room specified is not a room ID or room alias. Please note that this requires a full room ID \
(`!awIh6gGInaS5wLQJwa:example.com`) or a room alias (`#roomalias:example.com`)",
));
};
if enable_federation {
services().rooms.metadata.disable_room(&room_id, false)?;
return Ok(RoomMessageEventContent::text_plain("Room unbanned."));
}
Ok(RoomMessageEventContent::text_plain(
"Room unbanned, you may need to re-enable federation with the room using enable-room if this is a \
remote room to make it fully functional.",
))
},
RoomModerationCommand::ListBannedRooms => {
let rooms = services()
.rooms
.metadata
.list_banned_rooms()
.collect::<Result<Vec<_>, _>>();
match rooms {
Ok(room_ids) => {
// TODO: add room name from our state cache if available, default to the room ID
// as the room name if we dont have it TODO: do same if we have a room alias for
// this
let plain_list = room_ids.iter().fold(String::new(), |mut output, room_id| {
writeln!(output, "- `{}`", room_id).unwrap();
output
});
let html_list = room_ids.iter().fold(String::new(), |mut output, room_id| {
writeln!(output, "<li><code>{}</code></li>", escape_html(room_id.as_ref())).unwrap();
output
});
let plain = format!("Rooms:\n{}", plain_list);
let html = format!("Rooms:\n<ul>{}</ul>", html_list);
Ok(RoomMessageEventContent::text_html(plain, html))
},
Err(e) => {
error!("Failed to list banned rooms: {}", e);
Ok(RoomMessageEventContent::text_plain(format!(
"Unable to list room aliases: {}",
e
)))
},
}
},
}
}

View file

@ -1,62 +0,0 @@
pub(crate) mod server_commands;
use clap::Subcommand;
use ruma::events::room::message::RoomMessageEventContent;
use self::server_commands::{
backup_database, clear_database_caches, clear_service_caches, list_backups, list_database_files, memory_usage,
show_config, uptime,
};
use crate::Result;
#[cfg_attr(test, derive(Debug))]
#[derive(Subcommand)]
pub(crate) enum ServerCommand {
/// - Time elapsed since startup
Uptime,
/// - Show configuration values
ShowConfig,
/// - Print database memory usage statistics
MemoryUsage,
/// - Clears all of Conduit's database caches with index smaller than the
/// amount
ClearDatabaseCaches {
amount: u32,
},
/// - Clears all of Conduit's service caches with index smaller than the
/// amount
ClearServiceCaches {
amount: u32,
},
/// - Performs an online backup of the database (only available for RocksDB
/// at the moment)
BackupDatabase,
/// - List database backups
ListBackups,
/// - List database files
ListDatabaseFiles,
}
pub(crate) async fn process(command: ServerCommand, body: Vec<&str>) -> Result<RoomMessageEventContent> {
Ok(match command {
ServerCommand::Uptime => uptime(body).await?,
ServerCommand::ShowConfig => show_config(body).await?,
ServerCommand::MemoryUsage => memory_usage(body).await?,
ServerCommand::ClearDatabaseCaches {
amount,
} => clear_database_caches(body, amount).await?,
ServerCommand::ClearServiceCaches {
amount,
} => clear_service_caches(body, amount).await?,
ServerCommand::ListBackups => list_backups(body).await?,
ServerCommand::BackupDatabase => backup_database(body).await?,
ServerCommand::ListDatabaseFiles => list_database_files(body).await?,
})
}

View file

@ -1,95 +0,0 @@
use ruma::events::room::message::RoomMessageEventContent;
use crate::{services, Result};
pub(crate) async fn uptime(_body: Vec<&str>) -> Result<RoomMessageEventContent> {
let seconds = services()
.globals
.started
.elapsed()
.expect("standard duration")
.as_secs();
let result = format!(
"up {} days, {} hours, {} minutes, {} seconds.",
seconds / 86400,
(seconds % 86400) / 60 / 60,
(seconds % 3600) / 60,
seconds % 60,
);
Ok(RoomMessageEventContent::notice_html(String::new(), result))
}
pub(crate) async fn show_config(_body: Vec<&str>) -> Result<RoomMessageEventContent> {
// Construct and send the response
Ok(RoomMessageEventContent::text_plain(format!("{}", services().globals.config)))
}
pub(crate) async fn memory_usage(_body: Vec<&str>) -> Result<RoomMessageEventContent> {
let response0 = services().memory_usage().await;
let response1 = services().globals.db.memory_usage();
let response2 = crate::alloc::memory_usage();
Ok(RoomMessageEventContent::text_plain(format!(
"Services:\n{response0}\n\nDatabase:\n{response1}\n{}",
if !response2.is_empty() {
format!("Allocator:\n {response2}")
} else {
String::new()
}
)))
}
pub(crate) async fn clear_database_caches(_body: Vec<&str>, amount: u32) -> Result<RoomMessageEventContent> {
services().globals.db.clear_caches(amount);
Ok(RoomMessageEventContent::text_plain("Done."))
}
pub(crate) async fn clear_service_caches(_body: Vec<&str>, amount: u32) -> Result<RoomMessageEventContent> {
services().clear_caches(amount).await;
Ok(RoomMessageEventContent::text_plain("Done."))
}
pub(crate) async fn list_backups(_body: Vec<&str>) -> Result<RoomMessageEventContent> {
let result = services().globals.db.backup_list()?;
if result.is_empty() {
Ok(RoomMessageEventContent::text_plain("No backups found."))
} else {
Ok(RoomMessageEventContent::text_plain(result))
}
}
pub(crate) async fn backup_database(_body: Vec<&str>) -> Result<RoomMessageEventContent> {
if !cfg!(feature = "rocksdb") {
return Ok(RoomMessageEventContent::text_plain(
"Only RocksDB supports online backups in conduwuit.",
));
}
let mut result = tokio::task::spawn_blocking(move || match services().globals.db.backup() {
Ok(()) => String::new(),
Err(e) => (*e).to_string(),
})
.await
.unwrap();
if result.is_empty() {
result = services().globals.db.backup_list()?;
}
Ok(RoomMessageEventContent::text_plain(&result))
}
pub(crate) async fn list_database_files(_body: Vec<&str>) -> Result<RoomMessageEventContent> {
if !cfg!(feature = "rocksdb") {
return Ok(RoomMessageEventContent::text_plain(
"Only RocksDB supports listing files in conduwuit.",
));
}
let result = services().globals.db.file_list()?;
Ok(RoomMessageEventContent::notice_html(String::new(), result))
}

View file

@ -1,41 +0,0 @@
//! test commands generally used for hot lib reloadable functions.
//! see <https://github.com/rksm/hot-lib-reloader-rs?tab=readme-ov-file#usage> for more details if you are a dev
//#[cfg(not(feature = "hot_reload"))]
//#[allow(unused_imports)]
//#[allow(clippy::wildcard_imports)]
// non hot reloadable functions (?)
//use hot_lib::*;
#[cfg(feature = "hot_reload")]
#[allow(unused_imports)]
#[allow(clippy::wildcard_imports)]
use hot_lib_funcs::*;
use ruma::events::room::message::RoomMessageEventContent;
use crate::{debug_error, Result};
#[cfg(feature = "hot_reload")]
#[hot_lib_reloader::hot_module(dylib = "lib")]
mod hot_lib_funcs {
// these will be functions from lib.rs, so `use hot_lib_funcs::test_command;`
hot_functions_from_file!("hot_lib/src/lib.rs");
}
#[cfg_attr(test, derive(Debug))]
#[derive(clap::Subcommand)]
pub(crate) enum TestCommands {
// !admin test test1
Test1,
}
pub(crate) async fn process(command: TestCommands, _body: Vec<&str>) -> Result<RoomMessageEventContent> {
Ok(match command {
TestCommands::Test1 => {
debug_error!("before calling test_command");
test_command();
debug_error!("after calling test_command");
RoomMessageEventContent::notice_plain(String::from("loaded"))
},
})
}

View file

@ -1,14 +0,0 @@
use ruma::events::room::message::RoomMessageEventContent;
use crate::Result;
#[cfg_attr(test, derive(Debug))]
#[derive(clap::Subcommand)]
pub(crate) enum TesterCommands {
Tester,
}
pub(crate) async fn process(command: TesterCommands, _body: Vec<&str>) -> Result<RoomMessageEventContent> {
Ok(match command {
TesterCommands::Tester => RoomMessageEventContent::notice_plain(String::from("complete")),
})
}

View file

@ -1,89 +0,0 @@
pub(crate) mod user_commands;
use clap::Subcommand;
use ruma::events::room::message::RoomMessageEventContent;
use self::user_commands::{create, deactivate, deactivate_all, list, list_joined_rooms, reset_password};
use crate::Result;
#[cfg_attr(test, derive(Debug))]
#[derive(Subcommand)]
pub(crate) enum UserCommand {
/// - Create a new user
Create {
/// Username of the new user
username: String,
/// Password of the new user, if unspecified one is generated
password: Option<String>,
},
/// - Reset user password
ResetPassword {
/// Username of the user for whom the password should be reset
username: String,
},
/// - Deactivate a user
///
/// User will not be removed from all rooms by default.
/// Use --leave-rooms to force the user to leave all rooms
Deactivate {
#[arg(short, long)]
leave_rooms: bool,
user_id: String,
},
/// - Deactivate a list of users
///
/// Recommended to use in conjunction with list-local-users.
///
/// Users will not be removed from joined rooms by default.
/// Can be overridden with --leave-rooms flag.
/// Removing a mass amount of users from a room may cause a significant
/// amount of leave events. The time to leave rooms may depend significantly
/// on joined rooms and servers.
///
/// This command needs a newline separated list of users provided in a
/// Markdown code block below the command.
DeactivateAll {
#[arg(short, long)]
/// Remove users from their joined rooms
leave_rooms: bool,
#[arg(short, long)]
/// Also deactivate admin accounts
force: bool,
},
/// - List local users in the database
List,
/// - Lists all the rooms (local and remote) that the specified user is
/// joined in
ListJoinedRooms {
user_id: String,
},
}
pub(crate) async fn process(command: UserCommand, body: Vec<&str>) -> Result<RoomMessageEventContent> {
Ok(match command {
UserCommand::List => list(body).await?,
UserCommand::Create {
username,
password,
} => create(body, username, password).await?,
UserCommand::Deactivate {
leave_rooms,
user_id,
} => deactivate(body, leave_rooms, user_id).await?,
UserCommand::ResetPassword {
username,
} => reset_password(body, username).await?,
UserCommand::DeactivateAll {
leave_rooms,
force,
} => deactivate_all(body, leave_rooms, force).await?,
UserCommand::ListJoinedRooms {
user_id,
} => list_joined_rooms(body, user_id).await?,
})
}

View file

@ -1,370 +0,0 @@
use std::{fmt::Write as _, sync::Arc};
use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomId, UserId};
use tracing::{error, info, warn};
use crate::{
api::client_server::{join_room_by_id_helper, leave_all_rooms, AUTO_GEN_PASSWORD_LENGTH},
service::admin::{escape_html, get_room_info},
services,
utils::{self, user_id::user_is_local},
Result,
};
pub(crate) async fn list(_body: Vec<&str>) -> Result<RoomMessageEventContent> {
match services().users.list_local_users() {
Ok(users) => {
let mut msg = format!("Found {} local user account(s):\n", users.len());
msg += &users.join("\n");
Ok(RoomMessageEventContent::text_plain(&msg))
},
Err(e) => Ok(RoomMessageEventContent::text_plain(e.to_string())),
}
}
pub(crate) async fn create(
_body: Vec<&str>, username: String, password: Option<String>,
) -> Result<RoomMessageEventContent> {
let password = password.unwrap_or_else(|| utils::random_string(AUTO_GEN_PASSWORD_LENGTH));
// Validate user id
let user_id =
match UserId::parse_with_server_name(username.as_str().to_lowercase(), services().globals.server_name()) {
Ok(id) => id,
Err(e) => {
return Ok(RoomMessageEventContent::text_plain(format!(
"The supplied username is not a valid username: {e}"
)))
},
};
if !user_is_local(&user_id) {
return Ok(RoomMessageEventContent::text_plain(format!(
"User {user_id} does not belong to our server."
)));
}
if user_id.is_historical() {
return Ok(RoomMessageEventContent::text_plain(format!(
"Userid {user_id} is not allowed due to historical"
)));
}
if services().users.exists(&user_id)? {
return Ok(RoomMessageEventContent::text_plain(format!("Userid {user_id} already exists")));
}
// Create user
services().users.create(&user_id, Some(password.as_str()))?;
// Default to pretty displayname
let mut displayname = user_id.localpart().to_owned();
// If `new_user_displayname_suffix` is set, registration will push whatever
// content is set to the user's display name with a space before it
if !services()
.globals
.config
.new_user_displayname_suffix
.is_empty()
{
_ = write!(displayname, " {}", services().globals.config.new_user_displayname_suffix);
}
services()
.users
.set_displayname(&user_id, Some(displayname))
.await?;
// Initial account data
services().account_data.update(
None,
&user_id,
ruma::events::GlobalAccountDataEventType::PushRules
.to_string()
.into(),
&serde_json::to_value(ruma::events::push_rules::PushRulesEvent {
content: ruma::events::push_rules::PushRulesEventContent {
global: ruma::push::Ruleset::server_default(&user_id),
},
})
.expect("to json value always works"),
)?;
if !services().globals.config.auto_join_rooms.is_empty() {
for room in &services().globals.config.auto_join_rooms {
if !services()
.rooms
.state_cache
.server_in_room(services().globals.server_name(), room)?
{
warn!("Skipping room {room} to automatically join as we have never joined before.");
continue;
}
if let Some(room_id_server_name) = room.server_name() {
match join_room_by_id_helper(
Some(&user_id),
room,
Some("Automatically joining this room upon registration".to_owned()),
&[room_id_server_name.to_owned(), services().globals.server_name().to_owned()],
None,
)
.await
{
Ok(_) => {
info!("Automatically joined room {room} for user {user_id}");
},
Err(e) => {
// don't return this error so we don't fail registrations
error!("Failed to automatically join room {room} for user {user_id}: {e}");
},
};
}
}
}
// we dont add a device since we're not the user, just the creator
// Inhibit login does not work for guests
Ok(RoomMessageEventContent::text_plain(format!(
"Created user with user_id: {user_id} and password: `{password}`"
)))
}
pub(crate) async fn deactivate(
_body: Vec<&str>, leave_rooms: bool, user_id: String,
) -> Result<RoomMessageEventContent> {
// Validate user id
let user_id =
match UserId::parse_with_server_name(user_id.as_str().to_lowercase(), services().globals.server_name()) {
Ok(id) => Arc::<UserId>::from(id),
Err(e) => {
return Ok(RoomMessageEventContent::text_plain(format!(
"The supplied username is not a valid username: {e}"
)))
},
};
// check if user belongs to our server
if user_id.server_name() != services().globals.server_name() {
return Ok(RoomMessageEventContent::text_plain(format!(
"User {user_id} does not belong to our server."
)));
}
// don't deactivate the conduit service account
if user_id
== UserId::parse_with_server_name("conduit", services().globals.server_name()).expect("conduit user exists")
{
return Ok(RoomMessageEventContent::text_plain(
"Not allowed to deactivate the Conduit service account.",
));
}
if services().users.exists(&user_id)? {
RoomMessageEventContent::text_plain(format!("Making {user_id} leave all rooms before deactivation..."));
services().users.deactivate_account(&user_id)?;
if leave_rooms {
leave_all_rooms(&user_id).await;
}
Ok(RoomMessageEventContent::text_plain(format!(
"User {user_id} has been deactivated"
)))
} else {
Ok(RoomMessageEventContent::text_plain(format!(
"User {user_id} doesn't exist on this server"
)))
}
}
pub(crate) async fn reset_password(_body: Vec<&str>, username: String) -> Result<RoomMessageEventContent> {
// Validate user id
let user_id =
match UserId::parse_with_server_name(username.as_str().to_lowercase(), services().globals.server_name()) {
Ok(id) => Arc::<UserId>::from(id),
Err(e) => {
return Ok(RoomMessageEventContent::text_plain(format!(
"The supplied username is not a valid username: {e}"
)))
},
};
// check if user belongs to our server
if user_id.server_name() != services().globals.server_name() {
return Ok(RoomMessageEventContent::text_plain(format!(
"User {user_id} does not belong to our server."
)));
}
// Check if the specified user is valid
if !services().users.exists(&user_id)?
|| user_id
== UserId::parse_with_server_name("conduit", services().globals.server_name()).expect("conduit user exists")
{
return Ok(RoomMessageEventContent::text_plain("The specified user does not exist!"));
}
let new_password = utils::random_string(AUTO_GEN_PASSWORD_LENGTH);
match services()
.users
.set_password(&user_id, Some(new_password.as_str()))
{
Ok(()) => Ok(RoomMessageEventContent::text_plain(format!(
"Successfully reset the password for user {user_id}: `{new_password}`"
))),
Err(e) => Ok(RoomMessageEventContent::text_plain(format!(
"Couldn't reset the password for user {user_id}: {e}"
))),
}
}
pub(crate) async fn deactivate_all(body: Vec<&str>, leave_rooms: bool, force: bool) -> Result<RoomMessageEventContent> {
if body.len() > 2 && body[0].trim().starts_with("```") && body.last().unwrap().trim() == "```" {
let usernames = body.clone().drain(1..body.len() - 1).collect::<Vec<_>>();
let mut user_ids: Vec<&UserId> = Vec::new();
for &username in &usernames {
match <&UserId>::try_from(username) {
Ok(user_id) => user_ids.push(user_id),
Err(e) => {
return Ok(RoomMessageEventContent::text_plain(format!(
"{username} is not a valid username: {e}"
)))
},
}
}
let mut deactivation_count: u16 = 0;
let mut admins = Vec::new();
if !force {
user_ids.retain(|&user_id| match services().users.is_admin(user_id) {
Ok(is_admin) => {
if is_admin {
admins.push(user_id.localpart());
false
} else {
true
}
},
Err(_) => false,
});
}
for &user_id in &user_ids {
// check if user belongs to our server and skips over non-local users
if user_id.server_name() != services().globals.server_name() {
continue;
}
// don't deactivate the conduit service account
if user_id
== UserId::parse_with_server_name("conduit", services().globals.server_name())
.expect("conduit user exists")
{
continue;
}
// user does not exist on our server
if !services().users.exists(user_id)? {
continue;
}
if services().users.deactivate_account(user_id).is_ok() {
deactivation_count = deactivation_count.saturating_add(1);
}
}
if leave_rooms {
for &user_id in &user_ids {
leave_all_rooms(user_id).await;
}
}
if admins.is_empty() {
Ok(RoomMessageEventContent::text_plain(format!(
"Deactivated {deactivation_count} accounts."
)))
} else {
Ok(RoomMessageEventContent::text_plain(format!(
"Deactivated {} accounts.\nSkipped admin accounts: {:?}. Use --force to deactivate admin accounts",
deactivation_count,
admins.join(", ")
)))
}
} else {
Ok(RoomMessageEventContent::text_plain(
"Expected code block in command body. Add --help for details.",
))
}
}
pub(crate) async fn list_joined_rooms(_body: Vec<&str>, user_id: String) -> Result<RoomMessageEventContent> {
// Validate user id
let user_id =
match UserId::parse_with_server_name(user_id.as_str().to_lowercase(), services().globals.server_name()) {
Ok(id) => Arc::<UserId>::from(id),
Err(e) => {
return Ok(RoomMessageEventContent::text_plain(format!(
"The supplied username is not a valid username: {e}"
)))
},
};
if !user_is_local(&user_id) {
return Ok(RoomMessageEventContent::text_plain("User does not belong to our server."));
}
if !services().users.exists(&user_id)? {
return Ok(RoomMessageEventContent::text_plain("User does not exist on this server."));
}
let mut rooms: Vec<(OwnedRoomId, u64, String)> = services()
.rooms
.state_cache
.rooms_joined(&user_id)
.filter_map(Result::ok)
.map(|room_id| get_room_info(&room_id))
.collect();
if rooms.is_empty() {
return Ok(RoomMessageEventContent::text_plain("User is not in any rooms."));
}
rooms.sort_by_key(|r| r.1);
rooms.reverse();
let output_plain = format!(
"Rooms {user_id} Joined ({}):\n{}",
rooms.len(),
rooms
.iter()
.map(|(id, members, name)| format!("{id}\tMembers: {members}\tName: {name}"))
.collect::<Vec<_>>()
.join("\n")
);
let output_html = format!(
"<table><caption>Rooms {user_id} Joined \
({})</caption>\n<tr><th>id</th>\t<th>members</th>\t<th>name</th></tr>\n{}</table>",
rooms.len(),
rooms
.iter()
.fold(String::new(), |mut output, (id, members, name)| {
writeln!(
output,
"<tr><td>{}</td>\t<td>{}</td>\t<td>{}</td></tr>",
escape_html(id.as_ref()),
members,
escape_html(name)
)
.unwrap();
output
})
);
Ok(RoomMessageEventContent::text_html(output_plain, output_html))
}

View file

@ -2,7 +2,7 @@ use ruma::api::appservice::Registration;
use crate::Result;
pub(crate) trait Data: Send + Sync {
pub trait Data: Send + Sync {
/// Registers an appservice and returns the ID to the caller
fn register_appservice(&self, yaml: Registration) -> Result<String>;

View file

@ -1,8 +1,8 @@
mod data;
use std::collections::BTreeMap;
use std::{collections::BTreeMap, sync::Arc};
pub(crate) use data::Data;
pub use data::Data;
use futures_util::Future;
use regex::RegexSet;
use ruma::{
@ -15,14 +15,15 @@ use crate::{services, Result};
/// Compiled regular expressions for a namespace
#[derive(Clone, Debug)]
pub(crate) struct NamespaceRegex {
pub(crate) exclusive: Option<RegexSet>,
pub(crate) non_exclusive: Option<RegexSet>,
pub struct NamespaceRegex {
pub exclusive: Option<RegexSet>,
pub non_exclusive: Option<RegexSet>,
}
impl NamespaceRegex {
/// Checks if this namespace has rights to a namespace
pub(crate) fn is_match(&self, heystack: &str) -> bool {
#[must_use]
pub fn is_match(&self, heystack: &str) -> bool {
if self.is_exclusive_match(heystack) {
return true;
}
@ -36,7 +37,8 @@ impl NamespaceRegex {
}
/// Checks if this namespace has exlusive rights to a namespace
pub(crate) fn is_exclusive_match(&self, heystack: &str) -> bool {
#[must_use]
pub fn is_exclusive_match(&self, heystack: &str) -> bool {
if let Some(exclusive) = &self.exclusive {
if exclusive.is_match(heystack) {
return true;
@ -47,11 +49,13 @@ impl NamespaceRegex {
}
impl RegistrationInfo {
pub(crate) fn is_user_match(&self, user_id: &UserId) -> bool {
#[must_use]
pub fn is_user_match(&self, user_id: &UserId) -> bool {
self.users.is_match(user_id.as_str()) || self.registration.sender_localpart == user_id.localpart()
}
pub(crate) fn is_exclusive_user_match(&self, user_id: &UserId) -> bool {
#[must_use]
pub fn is_exclusive_user_match(&self, user_id: &UserId) -> bool {
self.users.is_exclusive_match(user_id.as_str()) || self.registration.sender_localpart == user_id.localpart()
}
}
@ -88,11 +92,11 @@ impl TryFrom<Vec<Namespace>> for NamespaceRegex {
/// Appservice registration combined with its compiled regular expressions.
#[derive(Clone, Debug)]
pub(crate) struct RegistrationInfo {
pub(crate) registration: Registration,
pub(crate) users: NamespaceRegex,
pub(crate) aliases: NamespaceRegex,
pub(crate) rooms: NamespaceRegex,
pub struct RegistrationInfo {
pub registration: Registration,
pub users: NamespaceRegex,
pub aliases: NamespaceRegex,
pub rooms: NamespaceRegex,
}
impl TryFrom<Registration> for RegistrationInfo {
@ -108,13 +112,13 @@ impl TryFrom<Registration> for RegistrationInfo {
}
}
pub(crate) struct Service {
pub(crate) db: &'static dyn Data,
pub struct Service {
pub db: Arc<dyn Data>,
registration_info: RwLock<BTreeMap<String, RegistrationInfo>>,
}
impl Service {
pub(crate) fn build(db: &'static dyn Data) -> Result<Self> {
pub fn build(db: Arc<dyn Data>) -> Result<Self> {
let mut registration_info = BTreeMap::new();
// Inserting registrations into cache
for appservice in db.all()? {
@ -134,7 +138,7 @@ impl Service {
}
/// Registers an appservice and returns the ID to the caller
pub(crate) async fn register_appservice(&self, yaml: Registration) -> Result<String> {
pub async fn register_appservice(&self, yaml: Registration) -> Result<String> {
//TODO: Check for collisions between exclusive appservice namespaces
services()
.appservice
@ -151,7 +155,7 @@ impl Service {
/// # Arguments
///
/// * `service_name` - the name you send to register the service previously
pub(crate) async fn unregister_appservice(&self, service_name: &str) -> Result<()> {
pub async fn unregister_appservice(&self, service_name: &str) -> Result<()> {
// removes the appservice registration info
services()
.appservice
@ -171,7 +175,7 @@ impl Service {
Ok(())
}
pub(crate) async fn get_registration(&self, id: &str) -> Option<Registration> {
pub async fn get_registration(&self, id: &str) -> Option<Registration> {
self.registration_info
.read()
.await
@ -180,7 +184,7 @@ impl Service {
.map(|info| info.registration)
}
pub(crate) async fn iter_ids(&self) -> Vec<String> {
pub async fn iter_ids(&self) -> Vec<String> {
self.registration_info
.read()
.await
@ -189,7 +193,7 @@ impl Service {
.collect()
}
pub(crate) async fn find_from_token(&self, token: &str) -> Option<RegistrationInfo> {
pub async fn find_from_token(&self, token: &str) -> Option<RegistrationInfo> {
self.read()
.await
.values()
@ -198,7 +202,7 @@ impl Service {
}
/// Checks if a given user id matches any exclusive appservice regex
pub(crate) async fn is_exclusive_user_id(&self, user_id: &UserId) -> bool {
pub async fn is_exclusive_user_id(&self, user_id: &UserId) -> bool {
self.read()
.await
.values()
@ -206,7 +210,7 @@ impl Service {
}
/// Checks if a given room alias matches any exclusive appservice regex
pub(crate) async fn is_exclusive_alias(&self, alias: &RoomAliasId) -> bool {
pub async fn is_exclusive_alias(&self, alias: &RoomAliasId) -> bool {
self.read()
.await
.values()
@ -217,16 +221,14 @@ impl Service {
///
/// TODO: use this?
#[allow(dead_code)]
pub(crate) async fn is_exclusive_room_id(&self, room_id: &RoomId) -> bool {
pub async fn is_exclusive_room_id(&self, room_id: &RoomId) -> bool {
self.read()
.await
.values()
.any(|info| info.rooms.is_exclusive_match(room_id.as_str()))
}
pub(crate) fn read(
&self,
) -> impl Future<Output = tokio::sync::RwLockReadGuard<'_, BTreeMap<String, RegistrationInfo>>> {
pub fn read(&self) -> impl Future<Output = tokio::sync::RwLockReadGuard<'_, BTreeMap<String, RegistrationInfo>>> {
self.registration_info.read()
}
}

View file

@ -4,18 +4,18 @@ use reqwest::redirect;
use crate::{service::globals::resolver, utils::conduwuit_version, Config, Result};
pub(crate) struct Client {
pub(crate) default: reqwest::Client,
pub(crate) url_preview: reqwest::Client,
pub(crate) well_known: reqwest::Client,
pub(crate) federation: reqwest::Client,
pub(crate) sender: reqwest::Client,
pub(crate) appservice: reqwest::Client,
pub(crate) pusher: reqwest::Client,
pub struct Client {
pub default: reqwest::Client,
pub url_preview: reqwest::Client,
pub well_known: reqwest::Client,
pub federation: reqwest::Client,
pub sender: reqwest::Client,
pub appservice: reqwest::Client,
pub pusher: reqwest::Client,
}
impl Client {
pub(crate) fn new(config: &Config, resolver: &Arc<resolver::Resolver>) -> Client {
pub fn new(config: &Config, resolver: &Arc<resolver::Resolver>) -> Client {
Client {
default: Self::base(config)
.unwrap()

View file

@ -10,7 +10,7 @@ use ruma::{
use crate::{database::Cork, Result};
#[async_trait]
pub(crate) trait Data: Send + Sync {
pub trait Data: Send + Sync {
fn next_count(&self) -> Result<u64>;
fn current_count(&self) -> Result<u64>;
fn last_check_for_updates_id(&self) -> Result<u64>;

View file

@ -0,0 +1,66 @@
use conduit::Result;
use ruma::{
events::{
push_rules::PushRulesEventContent, room::message::RoomMessageEventContent, GlobalAccountDataEvent,
GlobalAccountDataEventType,
},
push::Ruleset,
UserId,
};
use tracing::{error, warn};
use crate::services;
pub(crate) async fn init_emergency_access() {
// Set emergency access for the conduit user
match set_emergency_access() {
Ok(pwd_set) => {
if pwd_set {
warn!(
"The Conduit account emergency password is set! Please unset it as soon as you finish admin \
account recovery!"
);
services()
.admin
.send_message(RoomMessageEventContent::text_plain(
"The Conduit account emergency password is set! Please unset it as soon as you finish admin \
account recovery!",
))
.await;
}
},
Err(e) => {
error!("Could not set the configured emergency password for the conduit user: {}", e);
},
};
}
/// Sets the emergency password and push rules for the @conduit account in case
/// emergency password is set
fn set_emergency_access() -> Result<bool> {
let conduit_user = UserId::parse_with_server_name("conduit", services().globals.server_name())
.expect("@conduit:server_name is a valid UserId");
services()
.users
.set_password(&conduit_user, services().globals.emergency_password().as_deref())?;
let (ruleset, res) = match services().globals.emergency_password() {
Some(_) => (Ruleset::server_default(&conduit_user), Ok(true)),
None => (Ruleset::new(), Ok(false)),
};
services().account_data.update(
None,
&conduit_user,
GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(&GlobalAccountDataEvent {
content: PushRulesEventContent {
global: ruleset,
},
})
.expect("to json value always works"),
)?;
res
}

View file

@ -0,0 +1,638 @@
use std::{
collections::{HashMap, HashSet},
fs::{self},
io::Write,
mem::size_of,
sync::Arc,
};
use argon2::{password_hash::SaltString, PasswordHasher, PasswordVerifier};
use database::KeyValueDatabase;
use itertools::Itertools;
use rand::thread_rng;
use ruma::{
events::{push_rules::PushRulesEvent, GlobalAccountDataEventType},
push::Ruleset,
EventId, OwnedRoomId, RoomId, UserId,
};
use tracing::{debug, error, info, warn};
use crate::{services, utils, Config, Error, Result};
pub(crate) async fn migrations(db: &KeyValueDatabase, config: &Config) -> Result<()> {
// Matrix resource ownership is based on the server name; changing it
// requires recreating the database from scratch.
if services().users.count()? > 0 {
let conduit_user =
UserId::parse_with_server_name("conduit", &config.server_name).expect("@conduit:server_name is valid");
if !services().users.exists(&conduit_user)? {
error!("The {} server user does not exist, and the database is not new.", conduit_user);
return Err(Error::bad_database(
"Cannot reuse an existing database after changing the server name, please delete the old one first.",
));
}
}
// If the database has any data, perform data migrations before starting
// do not increment the db version if the user is not using sha256_media
let latest_database_version = if cfg!(feature = "sha256_media") {
14
} else {
13
};
if services().users.count()? > 0 {
// MIGRATIONS
if services().globals.database_version()? < 1 {
for (roomserverid, _) in db.roomserverids.iter() {
let mut parts = roomserverid.split(|&b| b == 0xFF);
let room_id = parts.next().expect("split always returns one element");
let Some(servername) = parts.next() else {
error!("Migration: Invalid roomserverid in db.");
continue;
};
let mut serverroomid = servername.to_vec();
serverroomid.push(0xFF);
serverroomid.extend_from_slice(room_id);
db.serverroomids.insert(&serverroomid, &[])?;
}
services().globals.bump_database_version(1)?;
warn!("Migration: 0 -> 1 finished");
}
if services().globals.database_version()? < 2 {
// We accidentally inserted hashed versions of "" into the db instead of just ""
for (userid, password) in db.userid_password.iter() {
let salt = SaltString::generate(thread_rng());
let empty_pass = services()
.globals
.argon
.hash_password(b"", &salt)
.expect("our own password to be properly hashed");
let empty_hashed_password = services()
.globals
.argon
.verify_password(&password, &empty_pass)
.is_ok();
if empty_hashed_password {
db.userid_password.insert(&userid, b"")?;
}
}
services().globals.bump_database_version(2)?;
warn!("Migration: 1 -> 2 finished");
}
if services().globals.database_version()? < 3 {
// Move media to filesystem
for (key, content) in db.mediaid_file.iter() {
if content.is_empty() {
continue;
}
#[allow(deprecated)]
let path = services().globals.get_media_file(&key);
let mut file = fs::File::create(path)?;
file.write_all(&content)?;
db.mediaid_file.insert(&key, &[])?;
}
services().globals.bump_database_version(3)?;
warn!("Migration: 2 -> 3 finished");
}
if services().globals.database_version()? < 4 {
// Add federated users to services() as deactivated
for our_user in services().users.iter() {
let our_user = our_user?;
if services().users.is_deactivated(&our_user)? {
continue;
}
for room in services().rooms.state_cache.rooms_joined(&our_user) {
for user in services().rooms.state_cache.room_members(&room?) {
let user = user?;
if user.server_name() != config.server_name {
info!(?user, "Migration: creating user");
services().users.create(&user, None)?;
}
}
}
}
services().globals.bump_database_version(4)?;
warn!("Migration: 3 -> 4 finished");
}
if services().globals.database_version()? < 5 {
// Upgrade user data store
for (roomuserdataid, _) in db.roomuserdataid_accountdata.iter() {
let mut parts = roomuserdataid.split(|&b| b == 0xFF);
let room_id = parts.next().unwrap();
let user_id = parts.next().unwrap();
let event_type = roomuserdataid.rsplit(|&b| b == 0xFF).next().unwrap();
let mut key = room_id.to_vec();
key.push(0xFF);
key.extend_from_slice(user_id);
key.push(0xFF);
key.extend_from_slice(event_type);
db.roomusertype_roomuserdataid
.insert(&key, &roomuserdataid)?;
}
services().globals.bump_database_version(5)?;
warn!("Migration: 4 -> 5 finished");
}
if services().globals.database_version()? < 6 {
// Set room member count
for (roomid, _) in db.roomid_shortstatehash.iter() {
let string = utils::string_from_bytes(&roomid).unwrap();
let room_id = <&RoomId>::try_from(string.as_str()).unwrap();
services().rooms.state_cache.update_joined_count(room_id)?;
}
services().globals.bump_database_version(6)?;
warn!("Migration: 5 -> 6 finished");
}
if services().globals.database_version()? < 7 {
// Upgrade state store
let mut last_roomstates: HashMap<OwnedRoomId, u64> = HashMap::new();
let mut current_sstatehash: Option<u64> = None;
let mut current_room = None;
let mut current_state = HashSet::new();
let handle_state = |current_sstatehash: u64,
current_room: &RoomId,
current_state: HashSet<_>,
last_roomstates: &mut HashMap<_, _>| {
let last_roomsstatehash = last_roomstates.get(current_room);
let states_parents = last_roomsstatehash.map_or_else(
|| Ok(Vec::new()),
|&last_roomsstatehash| {
services()
.rooms
.state_compressor
.load_shortstatehash_info(last_roomsstatehash)
},
)?;
let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() {
let statediffnew = current_state
.difference(&parent_stateinfo.1)
.copied()
.collect::<HashSet<_>>();
let statediffremoved = parent_stateinfo
.1
.difference(&current_state)
.copied()
.collect::<HashSet<_>>();
(statediffnew, statediffremoved)
} else {
(current_state, HashSet::new())
};
services().rooms.state_compressor.save_state_from_diff(
current_sstatehash,
Arc::new(statediffnew),
Arc::new(statediffremoved),
2, // every state change is 2 event changes on average
states_parents,
)?;
/*
let mut tmp = services().rooms.load_shortstatehash_info(&current_sstatehash)?;
let state = tmp.pop().unwrap();
println!(
"{}\t{}{:?}: {:?} + {:?} - {:?}",
current_room,
" ".repeat(tmp.len()),
utils::u64_from_bytes(&current_sstatehash).unwrap(),
tmp.last().map(|b| utils::u64_from_bytes(&b.0).unwrap()),
state
.2
.iter()
.map(|b| utils::u64_from_bytes(&b[size_of::<u64>()..]).unwrap())
.collect::<Vec<_>>(),
state
.3
.iter()
.map(|b| utils::u64_from_bytes(&b[size_of::<u64>()..]).unwrap())
.collect::<Vec<_>>()
);
*/
Ok::<_, Error>(())
};
for (k, seventid) in db.db.open_tree("stateid_shorteventid")?.iter() {
let sstatehash = utils::u64_from_bytes(&k[0..size_of::<u64>()]).expect("number of bytes is correct");
let sstatekey = k[size_of::<u64>()..].to_vec();
if Some(sstatehash) != current_sstatehash {
if let Some(current_sstatehash) = current_sstatehash {
handle_state(
current_sstatehash,
current_room.as_deref().unwrap(),
current_state,
&mut last_roomstates,
)?;
last_roomstates.insert(current_room.clone().unwrap(), current_sstatehash);
}
current_state = HashSet::new();
current_sstatehash = Some(sstatehash);
let event_id = db.shorteventid_eventid.get(&seventid).unwrap().unwrap();
let string = utils::string_from_bytes(&event_id).unwrap();
let event_id = <&EventId>::try_from(string.as_str()).unwrap();
let pdu = services()
.rooms
.timeline
.get_pdu(event_id)
.unwrap()
.unwrap();
if Some(&pdu.room_id) != current_room.as_ref() {
current_room = Some(pdu.room_id.clone());
}
}
let mut val = sstatekey;
val.extend_from_slice(&seventid);
current_state.insert(val.try_into().expect("size is correct"));
}
if let Some(current_sstatehash) = current_sstatehash {
handle_state(
current_sstatehash,
current_room.as_deref().unwrap(),
current_state,
&mut last_roomstates,
)?;
}
services().globals.bump_database_version(7)?;
warn!("Migration: 6 -> 7 finished");
}
if services().globals.database_version()? < 8 {
// Generate short room ids for all rooms
for (room_id, _) in db.roomid_shortstatehash.iter() {
let shortroomid = services().globals.next_count()?.to_be_bytes();
db.roomid_shortroomid.insert(&room_id, &shortroomid)?;
info!("Migration: 8");
}
// Update pduids db layout
let mut batch = db.pduid_pdu.iter().filter_map(|(key, v)| {
if !key.starts_with(b"!") {
return None;
}
let mut parts = key.splitn(2, |&b| b == 0xFF);
let room_id = parts.next().unwrap();
let count = parts.next().unwrap();
let short_room_id = db
.roomid_shortroomid
.get(room_id)
.unwrap()
.expect("shortroomid should exist");
let mut new_key = short_room_id;
new_key.extend_from_slice(count);
Some((new_key, v))
});
db.pduid_pdu.insert_batch(&mut batch)?;
let mut batch2 = db.eventid_pduid.iter().filter_map(|(k, value)| {
if !value.starts_with(b"!") {
return None;
}
let mut parts = value.splitn(2, |&b| b == 0xFF);
let room_id = parts.next().unwrap();
let count = parts.next().unwrap();
let short_room_id = db
.roomid_shortroomid
.get(room_id)
.unwrap()
.expect("shortroomid should exist");
let mut new_value = short_room_id;
new_value.extend_from_slice(count);
Some((k, new_value))
});
db.eventid_pduid.insert_batch(&mut batch2)?;
services().globals.bump_database_version(8)?;
warn!("Migration: 7 -> 8 finished");
}
if services().globals.database_version()? < 9 {
// Update tokenids db layout
let mut iter = db
.tokenids
.iter()
.filter_map(|(key, _)| {
if !key.starts_with(b"!") {
return None;
}
let mut parts = key.splitn(4, |&b| b == 0xFF);
let room_id = parts.next().unwrap();
let word = parts.next().unwrap();
let _pdu_id_room = parts.next().unwrap();
let pdu_id_count = parts.next().unwrap();
let short_room_id = db
.roomid_shortroomid
.get(room_id)
.unwrap()
.expect("shortroomid should exist");
let mut new_key = short_room_id;
new_key.extend_from_slice(word);
new_key.push(0xFF);
new_key.extend_from_slice(pdu_id_count);
Some((new_key, Vec::new()))
})
.peekable();
while iter.peek().is_some() {
db.tokenids.insert_batch(&mut iter.by_ref().take(1000))?;
debug!("Inserted smaller batch");
}
info!("Deleting starts");
let batch2: Vec<_> = db
.tokenids
.iter()
.filter_map(|(key, _)| {
if key.starts_with(b"!") {
Some(key)
} else {
None
}
})
.collect();
for key in batch2 {
db.tokenids.remove(&key)?;
}
services().globals.bump_database_version(9)?;
warn!("Migration: 8 -> 9 finished");
}
if services().globals.database_version()? < 10 {
// Add other direction for shortstatekeys
for (statekey, shortstatekey) in db.statekey_shortstatekey.iter() {
db.shortstatekey_statekey
.insert(&shortstatekey, &statekey)?;
}
// Force E2EE device list updates so we can send them over federation
for user_id in services().users.iter().filter_map(Result::ok) {
services().users.mark_device_key_update(&user_id)?;
}
services().globals.bump_database_version(10)?;
warn!("Migration: 9 -> 10 finished");
}
if services().globals.database_version()? < 11 {
db.db
.open_tree("userdevicesessionid_uiaarequest")?
.clear()?;
services().globals.bump_database_version(11)?;
warn!("Migration: 10 -> 11 finished");
}
if services().globals.database_version()? < 12 {
for username in services().users.list_local_users()? {
let user = match UserId::parse_with_server_name(username.clone(), &config.server_name) {
Ok(u) => u,
Err(e) => {
warn!("Invalid username {username}: {e}");
continue;
},
};
let raw_rules_list = services()
.account_data
.get(None, &user, GlobalAccountDataEventType::PushRules.to_string().into())
.unwrap()
.expect("Username is invalid");
let mut account_data = serde_json::from_str::<PushRulesEvent>(raw_rules_list.get()).unwrap();
let rules_list = &mut account_data.content.global;
//content rule
{
let content_rule_transformation = [".m.rules.contains_user_name", ".m.rule.contains_user_name"];
let rule = rules_list.content.get(content_rule_transformation[0]);
if rule.is_some() {
let mut rule = rule.unwrap().clone();
content_rule_transformation[1].clone_into(&mut rule.rule_id);
rules_list
.content
.shift_remove(content_rule_transformation[0]);
rules_list.content.insert(rule);
}
}
//underride rules
{
let underride_rule_transformation = [
[".m.rules.call", ".m.rule.call"],
[".m.rules.room_one_to_one", ".m.rule.room_one_to_one"],
[".m.rules.encrypted_room_one_to_one", ".m.rule.encrypted_room_one_to_one"],
[".m.rules.message", ".m.rule.message"],
[".m.rules.encrypted", ".m.rule.encrypted"],
];
for transformation in underride_rule_transformation {
let rule = rules_list.underride.get(transformation[0]);
if let Some(rule) = rule {
let mut rule = rule.clone();
transformation[1].clone_into(&mut rule.rule_id);
rules_list.underride.shift_remove(transformation[0]);
rules_list.underride.insert(rule);
}
}
}
services().account_data.update(
None,
&user,
GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(account_data).expect("to json value always works"),
)?;
}
services().globals.bump_database_version(12)?;
warn!("Migration: 11 -> 12 finished");
}
// This migration can be reused as-is anytime the server-default rules are
// updated.
if services().globals.database_version()? < 13 {
for username in services().users.list_local_users()? {
let user = match UserId::parse_with_server_name(username.clone(), &config.server_name) {
Ok(u) => u,
Err(e) => {
warn!("Invalid username {username}: {e}");
continue;
},
};
let raw_rules_list = services()
.account_data
.get(None, &user, GlobalAccountDataEventType::PushRules.to_string().into())
.unwrap()
.expect("Username is invalid");
let mut account_data = serde_json::from_str::<PushRulesEvent>(raw_rules_list.get()).unwrap();
let user_default_rules = Ruleset::server_default(&user);
account_data
.content
.global
.update_with_server_default(user_default_rules);
services().account_data.update(
None,
&user,
GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(account_data).expect("to json value always works"),
)?;
}
services().globals.bump_database_version(13)?;
warn!("Migration: 12 -> 13 finished");
}
#[cfg(feature = "sha256_media")]
{
if services().globals.database_version()? < 14 && cfg!(feature = "sha256_media") {
warn!("sha256_media feature flag is enabled, migrating legacy base64 file names to sha256 file names");
// Move old media files to new names
for (key, _) in db.mediaid_file.iter() {
let old_path = services().globals.get_media_file(&key);
debug!("Old file path: {old_path:?}");
let path = services().globals.get_media_file_new(&key);
debug!("New file path: {path:?}");
// move the file to the new location
if old_path.exists() {
tokio::fs::rename(&old_path, &path).await?;
}
}
services().globals.bump_database_version(14)?;
warn!("Migration: 13 -> 14 finished");
}
}
assert_eq!(
services().globals.database_version().unwrap(),
latest_database_version,
"Failed asserting local database version {} is equal to known latest conduwuit database version {}",
services().globals.database_version().unwrap(),
latest_database_version
);
{
let patterns = services().globals.forbidden_usernames();
if !patterns.is_empty() {
for user_id in services()
.users
.iter()
.filter_map(Result::ok)
.filter(|user| !services().users.is_deactivated(user).unwrap_or(true))
.filter(|user| user.server_name() == config.server_name)
{
let matches = patterns.matches(user_id.localpart());
if matches.matched_any() {
warn!(
"User {} matches the following forbidden username patterns: {}",
user_id.to_string(),
matches
.into_iter()
.map(|x| &patterns.patterns()[x])
.join(", ")
);
}
}
}
}
{
let patterns = services().globals.forbidden_alias_names();
if !patterns.is_empty() {
for address in services().rooms.metadata.iter_ids() {
let room_id = address?;
let room_aliases = services().rooms.alias.local_aliases_for_room(&room_id);
for room_alias_result in room_aliases {
let room_alias = room_alias_result?;
let matches = patterns.matches(room_alias.alias());
if matches.matched_any() {
warn!(
"Room with alias {} ({}) matches the following forbidden room name patterns: {}",
room_alias,
&room_id,
matches
.into_iter()
.map(|x| &patterns.patterns()[x])
.join(", ")
);
}
}
}
}
}
info!(
"Loaded {} database with schema version {}",
config.database_backend, latest_database_version
);
} else {
services()
.globals
.bump_database_version(latest_database_version)?;
// Create the admin room and server user on first run
services().admin.create_admin_room().await?;
warn!(
"Created new {} database with version {}",
config.database_backend, latest_database_version
);
}
Ok(())
}

View file

@ -3,16 +3,13 @@ use std::{
fs,
future::Future,
path::PathBuf,
sync::{
atomic::{self, AtomicBool},
Arc,
},
time::{Instant, SystemTime},
sync::Arc,
time::Instant,
};
use argon2::Argon2;
use base64::{engine::general_purpose, Engine as _};
pub(crate) use data::Data;
pub use data::Data;
use hickory_resolver::TokioAsyncResolver;
use ipaddress::IPAddress;
use regex::RegexSet;
@ -25,42 +22,46 @@ use ruma::{
DeviceId, OwnedEventId, OwnedRoomId, OwnedServerName, OwnedServerSigningKeyId, OwnedUserId, RoomVersionId,
ServerName, UserId,
};
use tokio::sync::{broadcast, Mutex, RwLock};
use tracing::{error, info, trace};
use tokio::{
sync::{broadcast, Mutex, RwLock},
task::JoinHandle,
};
use tracing::{error, trace};
use url::Url;
use crate::{services, Config, LogLevelReloadHandles, Result};
use crate::{services, Config, Result};
mod client;
mod data;
pub mod data;
pub(crate) mod emerg_access;
pub(crate) mod migrations;
mod resolver;
pub(crate) mod updates;
type RateLimitState = (Instant, u32); // Time if last failed try, number of failed tries
pub(crate) struct Service<'a> {
pub(crate) db: &'static dyn Data,
pub struct Service {
pub db: Arc<dyn Data>,
pub(crate) tracing_reload_handle: LogLevelReloadHandles,
pub(crate) config: Config,
pub(crate) cidr_range_denylist: Vec<IPAddress>,
pub config: Config,
pub cidr_range_denylist: Vec<IPAddress>,
keypair: Arc<ruma::signatures::Ed25519KeyPair>,
jwt_decoding_key: Option<jsonwebtoken::DecodingKey>,
pub(crate) resolver: Arc<resolver::Resolver>,
pub(crate) client: client::Client,
pub(crate) stable_room_versions: Vec<RoomVersionId>,
pub(crate) unstable_room_versions: Vec<RoomVersionId>,
pub(crate) bad_event_ratelimiter: Arc<RwLock<HashMap<OwnedEventId, RateLimitState>>>,
pub(crate) bad_signature_ratelimiter: Arc<RwLock<HashMap<Vec<String>, RateLimitState>>>,
pub(crate) bad_query_ratelimiter: Arc<RwLock<HashMap<OwnedServerName, RateLimitState>>>,
pub(crate) roomid_mutex_insert: RwLock<HashMap<OwnedRoomId, Arc<Mutex<()>>>>,
pub(crate) roomid_mutex_state: RwLock<HashMap<OwnedRoomId, Arc<Mutex<()>>>>,
pub(crate) roomid_mutex_federation: RwLock<HashMap<OwnedRoomId, Arc<Mutex<()>>>>, // this lock will be held longer
pub(crate) roomid_federationhandletime: RwLock<HashMap<OwnedRoomId, (OwnedEventId, Instant)>>,
pub(crate) stateres_mutex: Arc<Mutex<()>>,
pub(crate) rotate: RotationHandler,
pub(crate) started: SystemTime,
pub(crate) shutdown: AtomicBool,
pub(crate) argon: Argon2<'a>,
pub resolver: Arc<resolver::Resolver>,
pub client: client::Client,
pub stable_room_versions: Vec<RoomVersionId>,
pub unstable_room_versions: Vec<RoomVersionId>,
pub bad_event_ratelimiter: Arc<RwLock<HashMap<OwnedEventId, RateLimitState>>>,
pub bad_signature_ratelimiter: Arc<RwLock<HashMap<Vec<String>, RateLimitState>>>,
pub bad_query_ratelimiter: Arc<RwLock<HashMap<OwnedServerName, RateLimitState>>>,
pub roomid_mutex_insert: RwLock<HashMap<OwnedRoomId, Arc<Mutex<()>>>>,
pub roomid_mutex_state: RwLock<HashMap<OwnedRoomId, Arc<Mutex<()>>>>,
pub roomid_mutex_federation: RwLock<HashMap<OwnedRoomId, Arc<Mutex<()>>>>, // this lock will be held longer
pub roomid_federationhandletime: RwLock<HashMap<OwnedRoomId, (OwnedEventId, Instant)>>,
pub updates_handle: Mutex<Option<JoinHandle<()>>>,
pub stateres_mutex: Arc<Mutex<()>>,
pub rotate: RotationHandler,
pub argon: Argon2<'static>,
}
/// Handles "rotation" of long-polling requests. "Rotation" in this context is
@ -68,7 +69,7 @@ pub(crate) struct Service<'a> {
///
/// This is utilized to have sync workers return early and release read locks on
/// the database.
pub(crate) struct RotationHandler(broadcast::Sender<()>, ());
pub struct RotationHandler(broadcast::Sender<()>, ());
impl RotationHandler {
fn new() -> Self {
@ -76,25 +77,22 @@ impl RotationHandler {
Self(s, ())
}
pub(crate) fn watch(&self) -> impl Future<Output = ()> {
pub fn watch(&self) -> impl Future<Output = ()> {
let mut r = self.0.subscribe();
async move {
_ = r.recv().await;
}
}
fn fire(&self) { _ = self.0.send(()); }
pub fn fire(&self) { _ = self.0.send(()); }
}
impl Default for RotationHandler {
fn default() -> Self { Self::new() }
}
impl Service<'_> {
pub(crate) fn load(
db: &'static dyn Data, config: &Config, tracing_reload_handle: LogLevelReloadHandles,
) -> Result<Self> {
impl Service {
pub fn load(db: Arc<dyn Data>, config: &Config) -> Result<Self> {
let keypair = db.load_keypair();
let keypair = match keypair {
@ -140,7 +138,6 @@ impl Service<'_> {
}
let mut s = Self {
tracing_reload_handle,
db,
config: config.clone(),
cidr_range_denylist,
@ -157,10 +154,9 @@ impl Service<'_> {
roomid_mutex_insert: RwLock::new(HashMap::new()),
roomid_mutex_federation: RwLock::new(HashMap::new()),
roomid_federationhandletime: RwLock::new(HashMap::new()),
updates_handle: Mutex::new(None),
stateres_mutex: Arc::new(Mutex::new(())),
rotate: RotationHandler::new(),
started: SystemTime::now(),
shutdown: AtomicBool::new(false),
argon,
};
@ -178,145 +174,141 @@ impl Service<'_> {
}
/// Returns this server's keypair.
pub(crate) fn keypair(&self) -> &ruma::signatures::Ed25519KeyPair { &self.keypair }
pub fn keypair(&self) -> &ruma::signatures::Ed25519KeyPair { &self.keypair }
#[tracing::instrument(skip(self))]
pub(crate) fn next_count(&self) -> Result<u64> { self.db.next_count() }
pub fn next_count(&self) -> Result<u64> { self.db.next_count() }
#[tracing::instrument(skip(self))]
pub(crate) fn current_count(&self) -> Result<u64> { self.db.current_count() }
pub fn current_count(&self) -> Result<u64> { self.db.current_count() }
#[tracing::instrument(skip(self))]
pub(crate) fn last_check_for_updates_id(&self) -> Result<u64> { self.db.last_check_for_updates_id() }
pub fn last_check_for_updates_id(&self) -> Result<u64> { self.db.last_check_for_updates_id() }
#[tracing::instrument(skip(self))]
pub(crate) fn update_check_for_updates_id(&self, id: u64) -> Result<()> { self.db.update_check_for_updates_id(id) }
pub fn update_check_for_updates_id(&self, id: u64) -> Result<()> { self.db.update_check_for_updates_id(id) }
pub(crate) async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> {
pub async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> {
self.db.watch(user_id, device_id).await
}
pub(crate) fn cleanup(&self) -> Result<()> { self.db.cleanup() }
pub fn cleanup(&self) -> Result<()> { self.db.cleanup() }
/// TODO: use this?
#[allow(dead_code)]
pub(crate) fn flush(&self) -> Result<()> { self.db.flush() }
pub fn flush(&self) -> Result<()> { self.db.flush() }
pub(crate) fn server_name(&self) -> &ServerName { self.config.server_name.as_ref() }
pub fn server_name(&self) -> &ServerName { self.config.server_name.as_ref() }
pub(crate) fn max_request_size(&self) -> u32 { self.config.max_request_size }
pub fn max_request_size(&self) -> u32 { self.config.max_request_size }
pub(crate) fn max_fetch_prev_events(&self) -> u16 { self.config.max_fetch_prev_events }
pub fn max_fetch_prev_events(&self) -> u16 { self.config.max_fetch_prev_events }
pub(crate) fn allow_registration(&self) -> bool { self.config.allow_registration }
pub fn allow_registration(&self) -> bool { self.config.allow_registration }
pub(crate) fn allow_guest_registration(&self) -> bool { self.config.allow_guest_registration }
pub fn allow_guest_registration(&self) -> bool { self.config.allow_guest_registration }
pub(crate) fn allow_guests_auto_join_rooms(&self) -> bool { self.config.allow_guests_auto_join_rooms }
pub fn allow_guests_auto_join_rooms(&self) -> bool { self.config.allow_guests_auto_join_rooms }
pub(crate) fn log_guest_registrations(&self) -> bool { self.config.log_guest_registrations }
pub fn log_guest_registrations(&self) -> bool { self.config.log_guest_registrations }
pub(crate) fn allow_encryption(&self) -> bool { self.config.allow_encryption }
pub fn allow_encryption(&self) -> bool { self.config.allow_encryption }
pub(crate) fn allow_federation(&self) -> bool { self.config.allow_federation }
pub fn allow_federation(&self) -> bool { self.config.allow_federation }
pub(crate) fn allow_public_room_directory_over_federation(&self) -> bool {
pub fn allow_public_room_directory_over_federation(&self) -> bool {
self.config.allow_public_room_directory_over_federation
}
pub(crate) fn allow_device_name_federation(&self) -> bool { self.config.allow_device_name_federation }
pub fn allow_device_name_federation(&self) -> bool { self.config.allow_device_name_federation }
pub(crate) fn allow_room_creation(&self) -> bool { self.config.allow_room_creation }
pub fn allow_room_creation(&self) -> bool { self.config.allow_room_creation }
pub(crate) fn allow_unstable_room_versions(&self) -> bool { self.config.allow_unstable_room_versions }
pub fn allow_unstable_room_versions(&self) -> bool { self.config.allow_unstable_room_versions }
pub(crate) fn default_room_version(&self) -> RoomVersionId { self.config.default_room_version.clone() }
pub fn default_room_version(&self) -> RoomVersionId { self.config.default_room_version.clone() }
pub(crate) fn new_user_displayname_suffix(&self) -> &String { &self.config.new_user_displayname_suffix }
pub fn new_user_displayname_suffix(&self) -> &String { &self.config.new_user_displayname_suffix }
pub(crate) fn allow_check_for_updates(&self) -> bool { self.config.allow_check_for_updates }
pub fn allow_check_for_updates(&self) -> bool { self.config.allow_check_for_updates }
pub(crate) fn trusted_servers(&self) -> &[OwnedServerName] { &self.config.trusted_servers }
pub fn trusted_servers(&self) -> &[OwnedServerName] { &self.config.trusted_servers }
pub(crate) fn query_trusted_key_servers_first(&self) -> bool { self.config.query_trusted_key_servers_first }
pub fn query_trusted_key_servers_first(&self) -> bool { self.config.query_trusted_key_servers_first }
pub(crate) fn dns_resolver(&self) -> &TokioAsyncResolver { &self.resolver.resolver }
pub fn dns_resolver(&self) -> &TokioAsyncResolver { &self.resolver.resolver }
pub(crate) fn actual_destinations(&self) -> &Arc<RwLock<resolver::WellKnownMap>> { &self.resolver.destinations }
pub fn actual_destinations(&self) -> &Arc<RwLock<resolver::WellKnownMap>> { &self.resolver.destinations }
pub(crate) fn jwt_decoding_key(&self) -> Option<&jsonwebtoken::DecodingKey> { self.jwt_decoding_key.as_ref() }
pub fn jwt_decoding_key(&self) -> Option<&jsonwebtoken::DecodingKey> { self.jwt_decoding_key.as_ref() }
pub(crate) fn turn_password(&self) -> &String { &self.config.turn_password }
pub fn turn_password(&self) -> &String { &self.config.turn_password }
pub(crate) fn turn_ttl(&self) -> u64 { self.config.turn_ttl }
pub fn turn_ttl(&self) -> u64 { self.config.turn_ttl }
pub(crate) fn turn_uris(&self) -> &[String] { &self.config.turn_uris }
pub fn turn_uris(&self) -> &[String] { &self.config.turn_uris }
pub(crate) fn turn_username(&self) -> &String { &self.config.turn_username }
pub fn turn_username(&self) -> &String { &self.config.turn_username }
pub(crate) fn turn_secret(&self) -> &String { &self.config.turn_secret }
pub fn turn_secret(&self) -> &String { &self.config.turn_secret }
pub(crate) fn allow_profile_lookup_federation_requests(&self) -> bool {
pub fn allow_profile_lookup_federation_requests(&self) -> bool {
self.config.allow_profile_lookup_federation_requests
}
pub(crate) fn notification_push_path(&self) -> &String { &self.config.notification_push_path }
pub fn notification_push_path(&self) -> &String { &self.config.notification_push_path }
pub(crate) fn emergency_password(&self) -> &Option<String> { &self.config.emergency_password }
pub fn emergency_password(&self) -> &Option<String> { &self.config.emergency_password }
pub(crate) fn url_preview_domain_contains_allowlist(&self) -> &Vec<String> {
pub fn url_preview_domain_contains_allowlist(&self) -> &Vec<String> {
&self.config.url_preview_domain_contains_allowlist
}
pub(crate) fn url_preview_domain_explicit_allowlist(&self) -> &Vec<String> {
pub fn url_preview_domain_explicit_allowlist(&self) -> &Vec<String> {
&self.config.url_preview_domain_explicit_allowlist
}
pub(crate) fn url_preview_domain_explicit_denylist(&self) -> &Vec<String> {
pub fn url_preview_domain_explicit_denylist(&self) -> &Vec<String> {
&self.config.url_preview_domain_explicit_denylist
}
pub(crate) fn url_preview_url_contains_allowlist(&self) -> &Vec<String> {
&self.config.url_preview_url_contains_allowlist
}
pub fn url_preview_url_contains_allowlist(&self) -> &Vec<String> { &self.config.url_preview_url_contains_allowlist }
pub(crate) fn url_preview_max_spider_size(&self) -> usize { self.config.url_preview_max_spider_size }
pub fn url_preview_max_spider_size(&self) -> usize { self.config.url_preview_max_spider_size }
pub(crate) fn url_preview_check_root_domain(&self) -> bool { self.config.url_preview_check_root_domain }
pub fn url_preview_check_root_domain(&self) -> bool { self.config.url_preview_check_root_domain }
pub(crate) fn forbidden_alias_names(&self) -> &RegexSet { &self.config.forbidden_alias_names }
pub fn forbidden_alias_names(&self) -> &RegexSet { &self.config.forbidden_alias_names }
pub(crate) fn forbidden_usernames(&self) -> &RegexSet { &self.config.forbidden_usernames }
pub fn forbidden_usernames(&self) -> &RegexSet { &self.config.forbidden_usernames }
pub(crate) fn allow_local_presence(&self) -> bool { self.config.allow_local_presence }
pub fn allow_local_presence(&self) -> bool { self.config.allow_local_presence }
pub(crate) fn allow_incoming_presence(&self) -> bool { self.config.allow_incoming_presence }
pub fn allow_incoming_presence(&self) -> bool { self.config.allow_incoming_presence }
pub(crate) fn allow_outgoing_presence(&self) -> bool { self.config.allow_outgoing_presence }
pub fn allow_outgoing_presence(&self) -> bool { self.config.allow_outgoing_presence }
pub(crate) fn allow_incoming_read_receipts(&self) -> bool { self.config.allow_incoming_read_receipts }
pub fn allow_incoming_read_receipts(&self) -> bool { self.config.allow_incoming_read_receipts }
pub(crate) fn allow_outgoing_read_receipts(&self) -> bool { self.config.allow_outgoing_read_receipts }
pub fn allow_outgoing_read_receipts(&self) -> bool { self.config.allow_outgoing_read_receipts }
pub(crate) fn prevent_media_downloads_from(&self) -> &[OwnedServerName] {
&self.config.prevent_media_downloads_from
}
pub fn prevent_media_downloads_from(&self) -> &[OwnedServerName] { &self.config.prevent_media_downloads_from }
pub(crate) fn forbidden_remote_room_directory_server_names(&self) -> &[OwnedServerName] {
pub fn forbidden_remote_room_directory_server_names(&self) -> &[OwnedServerName] {
&self.config.forbidden_remote_room_directory_server_names
}
pub(crate) fn well_known_support_page(&self) -> &Option<Url> { &self.config.well_known.support_page }
pub fn well_known_support_page(&self) -> &Option<Url> { &self.config.well_known.support_page }
pub(crate) fn well_known_support_role(&self) -> &Option<ContactRole> { &self.config.well_known.support_role }
pub fn well_known_support_role(&self) -> &Option<ContactRole> { &self.config.well_known.support_role }
pub(crate) fn well_known_support_email(&self) -> &Option<String> { &self.config.well_known.support_email }
pub fn well_known_support_email(&self) -> &Option<String> { &self.config.well_known.support_email }
pub(crate) fn well_known_support_mxid(&self) -> &Option<OwnedUserId> { &self.config.well_known.support_mxid }
pub fn well_known_support_mxid(&self) -> &Option<OwnedUserId> { &self.config.well_known.support_mxid }
pub(crate) fn block_non_admin_invites(&self) -> bool { self.config.block_non_admin_invites }
pub fn block_non_admin_invites(&self) -> bool { self.config.block_non_admin_invites }
pub(crate) fn supported_room_versions(&self) -> Vec<RoomVersionId> {
pub fn supported_room_versions(&self) -> Vec<RoomVersionId> {
let mut room_versions: Vec<RoomVersionId> = vec![];
room_versions.extend(self.stable_room_versions.clone());
if self.allow_unstable_room_versions() {
@ -332,7 +324,7 @@ impl Service<'_> {
///
/// This doesn't actually check that the keys provided are newer than the
/// old set.
pub(crate) fn add_signing_key(
pub fn add_signing_key(
&self, origin: &ServerName, new_keys: ServerSigningKeys,
) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> {
self.db.add_signing_key(origin, new_keys)
@ -340,7 +332,7 @@ impl Service<'_> {
/// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found
/// for the server.
pub(crate) fn signing_keys_for(&self, origin: &ServerName) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> {
pub fn signing_keys_for(&self, origin: &ServerName) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> {
let mut keys = self.db.signing_keys_for(origin)?;
if origin == self.server_name() {
keys.insert(
@ -356,13 +348,11 @@ impl Service<'_> {
Ok(keys)
}
pub(crate) fn database_version(&self) -> Result<u64> { self.db.database_version() }
pub fn database_version(&self) -> Result<u64> { self.db.database_version() }
pub(crate) fn bump_database_version(&self, new_version: u64) -> Result<()> {
self.db.bump_database_version(new_version)
}
pub fn bump_database_version(&self, new_version: u64) -> Result<()> { self.db.bump_database_version(new_version) }
pub(crate) fn get_media_folder(&self) -> PathBuf {
pub fn get_media_folder(&self) -> PathBuf {
let mut r = PathBuf::new();
r.push(self.config.database_path.clone());
r.push("media");
@ -373,7 +363,7 @@ impl Service<'_> {
/// flag enabled and database migrated uses SHA256 hash of the base64 key as
/// the file name
#[cfg(feature = "sha256_media")]
pub(crate) fn get_media_file_new(&self, key: &[u8]) -> PathBuf {
pub fn get_media_file_new(&self, key: &[u8]) -> PathBuf {
let mut r = PathBuf::new();
r.push(self.config.database_path.clone());
r.push("media");
@ -387,7 +377,7 @@ impl Service<'_> {
/// old base64 file name media function
/// This is the old version of `get_media_file` that uses the full base64
/// key as the filename.
pub(crate) fn get_media_file(&self, key: &[u8]) -> PathBuf {
pub fn get_media_file(&self, key: &[u8]) -> PathBuf {
let mut r = PathBuf::new();
r.push(self.config.database_path.clone());
r.push("media");
@ -395,13 +385,13 @@ impl Service<'_> {
r
}
pub(crate) fn well_known_client(&self) -> &Option<Url> { &self.config.well_known.client }
pub fn well_known_client(&self) -> &Option<Url> { &self.config.well_known.client }
pub(crate) fn well_known_server(&self) -> &Option<OwnedServerName> { &self.config.well_known.server }
pub fn well_known_server(&self) -> &Option<OwnedServerName> { &self.config.well_known.server }
pub(crate) fn unix_socket_path(&self) -> &Option<PathBuf> { &self.config.unix_socket_path }
pub fn unix_socket_path(&self) -> &Option<PathBuf> { &self.config.unix_socket_path }
pub(crate) fn valid_cidr_range(&self, ip: &IPAddress) -> bool {
pub fn valid_cidr_range(&self, ip: &IPAddress) -> bool {
for cidr in &self.cidr_range_denylist {
if cidr.includes(ip) {
return false;
@ -410,24 +400,13 @@ impl Service<'_> {
true
}
pub(crate) fn shutdown(&self) {
self.shutdown.store(true, atomic::Ordering::Relaxed);
// On shutdown
if self.unix_socket_path().is_some() {
match &self.unix_socket_path() {
Some(path) => {
fs::remove_file(path).unwrap();
},
None => error!(
"Unable to remove socket file at {:?} during shutdown.",
&self.unix_socket_path()
),
};
};
info!(target: "shutdown-sync", "Received shutdown notification, notifying sync helpers...");
services().globals.rotate.fire();
}
}
#[inline]
#[must_use]
pub fn server_is_ours(server_name: &ServerName) -> bool { server_name == services().globals.config.server_name }
/// checks if `user_id` is local to us via server_name comparison
#[inline]
#[must_use]
pub fn user_is_local(user_id: &UserId) -> bool { server_is_ours(user_id.server_name()) }

View file

@ -17,21 +17,21 @@ use crate::{service::sending::FedDest, Config, Error};
pub(crate) type WellKnownMap = HashMap<OwnedServerName, (FedDest, String)>;
type TlsNameMap = HashMap<String, (Vec<IpAddr>, u16)>;
pub(crate) struct Resolver {
pub(crate) destinations: Arc<RwLock<WellKnownMap>>, // actual_destination, host
pub(crate) overrides: Arc<StdRwLock<TlsNameMap>>,
pub(crate) resolver: Arc<TokioAsyncResolver>,
pub(crate) hooked: Arc<Hooked>,
pub struct Resolver {
pub destinations: Arc<RwLock<WellKnownMap>>, // actual_destination, host
pub overrides: Arc<StdRwLock<TlsNameMap>>,
pub resolver: Arc<TokioAsyncResolver>,
pub hooked: Arc<Hooked>,
}
pub(crate) struct Hooked {
pub(crate) overrides: Arc<StdRwLock<TlsNameMap>>,
pub(crate) resolver: Arc<TokioAsyncResolver>,
pub struct Hooked {
pub overrides: Arc<StdRwLock<TlsNameMap>>,
pub resolver: Arc<TokioAsyncResolver>,
}
impl Resolver {
#[allow(clippy::as_conversions, clippy::cast_sign_loss, clippy::cast_possible_truncation)]
pub(crate) fn new(config: &Config) -> Self {
pub fn new(config: &Config) -> Self {
let (sys_conf, mut opts) = hickory_resolver::system_conf::read_system_conf()
.map_err(|e| {
error!("Failed to set up hickory dns resolver with system config: {}", e);

View file

@ -0,0 +1,76 @@
use std::time::Duration;
use ruma::events::room::message::RoomMessageEventContent;
use serde::Deserialize;
use tokio::{task::JoinHandle, time::interval};
use tracing::{debug, error};
use crate::{
conduit::{Error, Result},
services,
};
#[derive(Deserialize)]
struct CheckForUpdatesResponseEntry {
id: u64,
date: String,
message: String,
}
#[derive(Deserialize)]
struct CheckForUpdatesResponse {
updates: Vec<CheckForUpdatesResponseEntry>,
}
#[tracing::instrument]
pub async fn start_check_for_updates_task() -> Result<JoinHandle<()>> {
let timer_interval = Duration::from_secs(7200); // 2 hours
Ok(services().server.runtime().spawn(async move {
let mut i = interval(timer_interval);
loop {
tokio::select! {
_ = i.tick() => {
debug!(target: "start_check_for_updates_task", "Timer ticked");
},
}
_ = try_handle_updates().await;
}
}))
}
async fn try_handle_updates() -> Result<()> {
let response = services()
.globals
.client
.default
.get("https://pupbrain.dev/check-for-updates/stable")
.send()
.await?;
let response = serde_json::from_str::<CheckForUpdatesResponse>(&response.text().await?).map_err(|e| {
error!("Bad check for updates response: {e}");
Error::BadServerResponse("Bad version check response")
})?;
let mut last_update_id = services().globals.last_check_for_updates_id()?;
for update in response.updates {
last_update_id = last_update_id.max(update.id);
if update.id > services().globals.last_check_for_updates_id()? {
error!("{}", update.message);
services()
.admin
.send_message(RoomMessageEventContent::text_plain(format!(
"@room: the following is a message from the conduwuit puppy. it was sent on '{}':\n\n{}",
update.date, update.message
)))
.await;
}
}
services()
.globals
.update_check_for_updates_id(last_update_id)?;
Ok(())
}

View file

@ -8,7 +8,7 @@ use ruma::{
use crate::Result;
pub(crate) trait Data: Send + Sync {
pub trait Data: Send + Sync {
fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw<BackupAlgorithm>) -> Result<String>;
fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()>;

View file

@ -1,7 +1,7 @@
mod data;
use std::collections::BTreeMap;
use std::{collections::BTreeMap, sync::Arc};
pub(crate) use data::Data;
pub use data::Data;
use ruma::{
api::client::backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup},
serde::Raw,
@ -10,79 +10,73 @@ use ruma::{
use crate::Result;
pub(crate) struct Service {
pub(crate) db: &'static dyn Data,
pub struct Service {
pub db: Arc<dyn Data>,
}
impl Service {
pub(crate) fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw<BackupAlgorithm>) -> Result<String> {
pub fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw<BackupAlgorithm>) -> Result<String> {
self.db.create_backup(user_id, backup_metadata)
}
pub(crate) fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> {
pub fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> {
self.db.delete_backup(user_id, version)
}
pub(crate) fn update_backup(
pub fn update_backup(
&self, user_id: &UserId, version: &str, backup_metadata: &Raw<BackupAlgorithm>,
) -> Result<String> {
self.db.update_backup(user_id, version, backup_metadata)
}
pub(crate) fn get_latest_backup_version(&self, user_id: &UserId) -> Result<Option<String>> {
pub fn get_latest_backup_version(&self, user_id: &UserId) -> Result<Option<String>> {
self.db.get_latest_backup_version(user_id)
}
pub(crate) fn get_latest_backup(&self, user_id: &UserId) -> Result<Option<(String, Raw<BackupAlgorithm>)>> {
pub fn get_latest_backup(&self, user_id: &UserId) -> Result<Option<(String, Raw<BackupAlgorithm>)>> {
self.db.get_latest_backup(user_id)
}
pub(crate) fn get_backup(&self, user_id: &UserId, version: &str) -> Result<Option<Raw<BackupAlgorithm>>> {
pub fn get_backup(&self, user_id: &UserId, version: &str) -> Result<Option<Raw<BackupAlgorithm>>> {
self.db.get_backup(user_id, version)
}
pub(crate) fn add_key(
pub fn add_key(
&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, key_data: &Raw<KeyBackupData>,
) -> Result<()> {
self.db
.add_key(user_id, version, room_id, session_id, key_data)
}
pub(crate) fn count_keys(&self, user_id: &UserId, version: &str) -> Result<usize> {
self.db.count_keys(user_id, version)
}
pub fn count_keys(&self, user_id: &UserId, version: &str) -> Result<usize> { self.db.count_keys(user_id, version) }
pub(crate) fn get_etag(&self, user_id: &UserId, version: &str) -> Result<String> {
self.db.get_etag(user_id, version)
}
pub fn get_etag(&self, user_id: &UserId, version: &str) -> Result<String> { self.db.get_etag(user_id, version) }
pub(crate) fn get_all(&self, user_id: &UserId, version: &str) -> Result<BTreeMap<OwnedRoomId, RoomKeyBackup>> {
pub fn get_all(&self, user_id: &UserId, version: &str) -> Result<BTreeMap<OwnedRoomId, RoomKeyBackup>> {
self.db.get_all(user_id, version)
}
pub(crate) fn get_room(
pub fn get_room(
&self, user_id: &UserId, version: &str, room_id: &RoomId,
) -> Result<BTreeMap<String, Raw<KeyBackupData>>> {
self.db.get_room(user_id, version, room_id)
}
pub(crate) fn get_session(
pub fn get_session(
&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str,
) -> Result<Option<Raw<KeyBackupData>>> {
self.db.get_session(user_id, version, room_id, session_id)
}
pub(crate) fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()> {
pub fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()> {
self.db.delete_all_keys(user_id, version)
}
pub(crate) fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()> {
pub fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()> {
self.db.delete_room_keys(user_id, version, room_id)
}
pub(crate) fn delete_room_key(
&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str,
) -> Result<()> {
pub fn delete_room_key(&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str) -> Result<()> {
self.db
.delete_room_key(user_id, version, room_id, session_id)
}

View file

@ -0,0 +1,137 @@
use std::collections::HashMap;
use ruma::{
api::client::error::ErrorKind,
events::{AnyEphemeralRoomEvent, RoomAccountDataEventType},
serde::Raw,
RoomId, UserId,
};
use tracing::warn;
use crate::{services, utils, Error, KeyValueDatabase, Result};
impl crate::account_data::Data for KeyValueDatabase {
/// Places one event in the account data of the user and removes the
/// previous entry.
#[tracing::instrument(skip(self, room_id, user_id, event_type, data))]
fn update(
&self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType,
data: &serde_json::Value,
) -> Result<()> {
let mut prefix = room_id
.map(ToString::to_string)
.unwrap_or_default()
.as_bytes()
.to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(user_id.as_bytes());
prefix.push(0xFF);
let mut roomuserdataid = prefix.clone();
roomuserdataid.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
roomuserdataid.push(0xFF);
roomuserdataid.extend_from_slice(event_type.to_string().as_bytes());
let mut key = prefix;
key.extend_from_slice(event_type.to_string().as_bytes());
if data.get("type").is_none() || data.get("content").is_none() {
return Err(Error::BadRequest(
ErrorKind::InvalidParam,
"Account data doesn't have all required fields.",
));
}
self.roomuserdataid_accountdata.insert(
&roomuserdataid,
&serde_json::to_vec(&data).expect("to_vec always works on json values"),
)?;
let prev = self.roomusertype_roomuserdataid.get(&key)?;
self.roomusertype_roomuserdataid
.insert(&key, &roomuserdataid)?;
// Remove old entry
if let Some(prev) = prev {
self.roomuserdataid_accountdata.remove(&prev)?;
}
Ok(())
}
/// Searches the account data for a specific kind.
#[tracing::instrument(skip(self, room_id, user_id, kind))]
fn get(
&self, room_id: Option<&RoomId>, user_id: &UserId, kind: RoomAccountDataEventType,
) -> Result<Option<Box<serde_json::value::RawValue>>> {
let mut key = room_id
.map(ToString::to_string)
.unwrap_or_default()
.as_bytes()
.to_vec();
key.push(0xFF);
key.extend_from_slice(user_id.as_bytes());
key.push(0xFF);
key.extend_from_slice(kind.to_string().as_bytes());
self.roomusertype_roomuserdataid
.get(&key)?
.and_then(|roomuserdataid| {
self.roomuserdataid_accountdata
.get(&roomuserdataid)
.transpose()
})
.transpose()?
.map(|data| serde_json::from_slice(&data).map_err(|_| Error::bad_database("could not deserialize")))
.transpose()
}
/// Returns all changes to the account data that happened after `since`.
#[tracing::instrument(skip(self, room_id, user_id, since))]
fn changes_since(
&self, room_id: Option<&RoomId>, user_id: &UserId, since: u64,
) -> Result<HashMap<RoomAccountDataEventType, Raw<AnyEphemeralRoomEvent>>> {
let mut userdata = HashMap::new();
let mut prefix = room_id
.map(ToString::to_string)
.unwrap_or_default()
.as_bytes()
.to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(user_id.as_bytes());
prefix.push(0xFF);
// Skip the data that's exactly at since, because we sent that last time
let mut first_possible = prefix.clone();
first_possible.extend_from_slice(&(since.saturating_add(1)).to_be_bytes());
for r in self
.roomuserdataid_accountdata
.iter_from(&first_possible, false)
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(|(k, v)| {
Ok::<_, Error>((
RoomAccountDataEventType::from(
utils::string_from_bytes(
k.rsplit(|&b| b == 0xFF)
.next()
.ok_or_else(|| Error::bad_database("RoomUserData ID in db is invalid."))?,
)
.map_err(|e| {
warn!("RoomUserData ID in database is invalid: {}", e);
Error::bad_database("RoomUserData ID in db is invalid.")
})?,
),
serde_json::from_slice::<Raw<AnyEphemeralRoomEvent>>(&v)
.map_err(|_| Error::bad_database("Database contains invalid account data."))?,
))
}) {
let (kind, data) = r?;
userdata.insert(kind, data);
}
Ok(userdata)
}
}

View file

@ -0,0 +1,55 @@
use ruma::api::appservice::Registration;
use crate::{utils, Error, KeyValueDatabase, Result};
impl crate::appservice::Data for KeyValueDatabase {
/// Registers an appservice and returns the ID to the caller
fn register_appservice(&self, yaml: Registration) -> Result<String> {
let id = yaml.id.as_str();
self.id_appserviceregistrations
.insert(id.as_bytes(), serde_yaml::to_string(&yaml).unwrap().as_bytes())?;
Ok(id.to_owned())
}
/// Remove an appservice registration
///
/// # Arguments
///
/// * `service_name` - the name you send to register the service previously
fn unregister_appservice(&self, service_name: &str) -> Result<()> {
self.id_appserviceregistrations
.remove(service_name.as_bytes())?;
Ok(())
}
fn get_registration(&self, id: &str) -> Result<Option<Registration>> {
self.id_appserviceregistrations
.get(id.as_bytes())?
.map(|bytes| {
serde_yaml::from_slice(&bytes)
.map_err(|_| Error::bad_database("Invalid registration bytes in id_appserviceregistrations."))
})
.transpose()
}
fn iter_ids<'a>(&'a self) -> Result<Box<dyn Iterator<Item = Result<String>> + 'a>> {
Ok(Box::new(self.id_appserviceregistrations.iter().map(|(id, _)| {
utils::string_from_bytes(&id)
.map_err(|_| Error::bad_database("Invalid id bytes in id_appserviceregistrations."))
})))
}
fn all(&self) -> Result<Vec<(String, Registration)>> {
self.iter_ids()?
.filter_map(Result::ok)
.map(move |id| {
Ok((
id.clone(),
self.get_registration(&id)?
.expect("iter_ids only returns appservices that exist"),
))
})
.collect()
}
}

View file

@ -0,0 +1,299 @@
use std::collections::{BTreeMap, HashMap};
use async_trait::async_trait;
use futures_util::{stream::FuturesUnordered, StreamExt};
use lru_cache::LruCache;
use ruma::{
api::federation::discovery::{ServerSigningKeys, VerifyKey},
signatures::Ed25519KeyPair,
DeviceId, MilliSecondsSinceUnixEpoch, OwnedServerSigningKeyId, ServerName, UserId,
};
use tracing::trace;
use crate::{database::Cork, services, utils, Error, KeyValueDatabase, Result};
const COUNTER: &[u8] = b"c";
const LAST_CHECK_FOR_UPDATES_COUNT: &[u8] = b"u";
#[async_trait]
impl crate::globals::Data for KeyValueDatabase {
fn next_count(&self) -> Result<u64> {
utils::u64_from_bytes(&self.global.increment(COUNTER)?)
.map_err(|_| Error::bad_database("Count has invalid bytes."))
}
fn current_count(&self) -> Result<u64> {
self.global.get(COUNTER)?.map_or(Ok(0_u64), |bytes| {
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Count has invalid bytes."))
})
}
fn last_check_for_updates_id(&self) -> Result<u64> {
self.global
.get(LAST_CHECK_FOR_UPDATES_COUNT)?
.map_or(Ok(0_u64), |bytes| {
utils::u64_from_bytes(&bytes)
.map_err(|_| Error::bad_database("last check for updates count has invalid bytes."))
})
}
fn update_check_for_updates_id(&self, id: u64) -> Result<()> {
self.global
.insert(LAST_CHECK_FOR_UPDATES_COUNT, &id.to_be_bytes())?;
Ok(())
}
#[allow(unused_qualifications)] // async traits
#[tracing::instrument(skip(self))]
async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> {
let userid_bytes = user_id.as_bytes().to_vec();
let mut userid_prefix = userid_bytes.clone();
userid_prefix.push(0xFF);
let mut userdeviceid_prefix = userid_prefix.clone();
userdeviceid_prefix.extend_from_slice(device_id.as_bytes());
userdeviceid_prefix.push(0xFF);
let mut futures = FuturesUnordered::new();
// Return when *any* user changed their key
// TODO: only send for user they share a room with
futures.push(self.todeviceid_events.watch_prefix(&userdeviceid_prefix));
futures.push(self.userroomid_joined.watch_prefix(&userid_prefix));
futures.push(self.userroomid_invitestate.watch_prefix(&userid_prefix));
futures.push(self.userroomid_leftstate.watch_prefix(&userid_prefix));
futures.push(
self.userroomid_notificationcount
.watch_prefix(&userid_prefix),
);
futures.push(self.userroomid_highlightcount.watch_prefix(&userid_prefix));
// Events for rooms we are in
for room_id in services()
.rooms
.state_cache
.rooms_joined(user_id)
.filter_map(Result::ok)
{
let short_roomid = services()
.rooms
.short
.get_shortroomid(&room_id)
.ok()
.flatten()
.expect("room exists")
.to_be_bytes()
.to_vec();
let roomid_bytes = room_id.as_bytes().to_vec();
let mut roomid_prefix = roomid_bytes.clone();
roomid_prefix.push(0xFF);
// PDUs
futures.push(self.pduid_pdu.watch_prefix(&short_roomid));
// EDUs
futures.push(Box::pin(async move {
let _result = services().rooms.typing.wait_for_update(&room_id).await;
}));
futures.push(self.readreceiptid_readreceipt.watch_prefix(&roomid_prefix));
// Key changes
futures.push(self.keychangeid_userid.watch_prefix(&roomid_prefix));
// Room account data
let mut roomuser_prefix = roomid_prefix.clone();
roomuser_prefix.extend_from_slice(&userid_prefix);
futures.push(
self.roomusertype_roomuserdataid
.watch_prefix(&roomuser_prefix),
);
}
let mut globaluserdata_prefix = vec![0xFF];
globaluserdata_prefix.extend_from_slice(&userid_prefix);
futures.push(
self.roomusertype_roomuserdataid
.watch_prefix(&globaluserdata_prefix),
);
// More key changes (used when user is not joined to any rooms)
futures.push(self.keychangeid_userid.watch_prefix(&userid_prefix));
// One time keys
futures.push(self.userid_lastonetimekeyupdate.watch_prefix(&userid_bytes));
futures.push(Box::pin(services().globals.rotate.watch()));
// Wait until one of them finds something
trace!(futures = futures.len(), "watch started");
futures.next().await;
trace!(futures = futures.len(), "watch finished");
Ok(())
}
fn cleanup(&self) -> Result<()> { self.db.cleanup() }
fn flush(&self) -> Result<()> { self.db.flush() }
fn cork(&self) -> Result<Cork> { Ok(Cork::new(&self.db, false, false)) }
fn cork_and_flush(&self) -> Result<Cork> { Ok(Cork::new(&self.db, true, false)) }
fn cork_and_sync(&self) -> Result<Cork> { Ok(Cork::new(&self.db, true, true)) }
fn memory_usage(&self) -> String {
let auth_chain_cache = self.auth_chain_cache.lock().unwrap().len();
let our_real_users_cache = self.our_real_users_cache.read().unwrap().len();
let appservice_in_room_cache = self.appservice_in_room_cache.read().unwrap().len();
let lasttimelinecount_cache = self.lasttimelinecount_cache.lock().unwrap().len();
let max_auth_chain_cache = self.auth_chain_cache.lock().unwrap().capacity();
let max_our_real_users_cache = self.our_real_users_cache.read().unwrap().capacity();
let max_appservice_in_room_cache = self.appservice_in_room_cache.read().unwrap().capacity();
let max_lasttimelinecount_cache = self.lasttimelinecount_cache.lock().unwrap().capacity();
format!(
"\
auth_chain_cache: {auth_chain_cache} / {max_auth_chain_cache}
our_real_users_cache: {our_real_users_cache} / {max_our_real_users_cache}
appservice_in_room_cache: {appservice_in_room_cache} / {max_appservice_in_room_cache}
lasttimelinecount_cache: {lasttimelinecount_cache} / {max_lasttimelinecount_cache}\n\n
{}",
self.db.memory_usage().unwrap_or_default()
)
}
fn clear_caches(&self, amount: u32) {
if amount > 1 {
let c = &mut *self.auth_chain_cache.lock().unwrap();
*c = LruCache::new(c.capacity());
}
if amount > 2 {
let c = &mut *self.our_real_users_cache.write().unwrap();
*c = HashMap::new();
}
if amount > 3 {
let c = &mut *self.appservice_in_room_cache.write().unwrap();
*c = HashMap::new();
}
if amount > 4 {
let c = &mut *self.lasttimelinecount_cache.lock().unwrap();
*c = HashMap::new();
}
}
fn load_keypair(&self) -> Result<Ed25519KeyPair> {
let keypair_bytes = self.global.get(b"keypair")?.map_or_else(
|| {
let keypair = utils::generate_keypair();
self.global.insert(b"keypair", &keypair)?;
Ok::<_, Error>(keypair)
},
Ok,
)?;
let mut parts = keypair_bytes.splitn(2, |&b| b == 0xFF);
utils::string_from_bytes(
// 1. version
parts
.next()
.expect("splitn always returns at least one element"),
)
.map_err(|_| Error::bad_database("Invalid version bytes in keypair."))
.and_then(|version| {
// 2. key
parts
.next()
.ok_or_else(|| Error::bad_database("Invalid keypair format in database."))
.map(|key| (version, key))
})
.and_then(|(version, key)| {
Ed25519KeyPair::from_der(key, version)
.map_err(|_| Error::bad_database("Private or public keys are invalid."))
})
}
fn remove_keypair(&self) -> Result<()> { self.global.remove(b"keypair") }
fn add_signing_key(
&self, origin: &ServerName, new_keys: ServerSigningKeys,
) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> {
// Not atomic, but this is not critical
let signingkeys = self.server_signingkeys.get(origin.as_bytes())?;
let mut keys = signingkeys
.and_then(|keys| serde_json::from_slice(&keys).ok())
.unwrap_or_else(|| {
// Just insert "now", it doesn't matter
ServerSigningKeys::new(origin.to_owned(), MilliSecondsSinceUnixEpoch::now())
});
let ServerSigningKeys {
verify_keys,
old_verify_keys,
..
} = new_keys;
keys.verify_keys.extend(verify_keys);
keys.old_verify_keys.extend(old_verify_keys);
self.server_signingkeys.insert(
origin.as_bytes(),
&serde_json::to_vec(&keys).expect("serversigningkeys can be serialized"),
)?;
let mut tree = keys.verify_keys;
tree.extend(
keys.old_verify_keys
.into_iter()
.map(|old| (old.0, VerifyKey::new(old.1.key))),
);
Ok(tree)
}
/// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found
/// for the server.
fn signing_keys_for(&self, origin: &ServerName) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> {
let signingkeys = self
.server_signingkeys
.get(origin.as_bytes())?
.and_then(|bytes| serde_json::from_slice(&bytes).ok())
.map_or_else(BTreeMap::new, |keys: ServerSigningKeys| {
let mut tree = keys.verify_keys;
tree.extend(
keys.old_verify_keys
.into_iter()
.map(|old| (old.0, VerifyKey::new(old.1.key))),
);
tree
});
Ok(signingkeys)
}
fn database_version(&self) -> Result<u64> {
self.global.get(b"version")?.map_or(Ok(0), |version| {
utils::u64_from_bytes(&version).map_err(|_| Error::bad_database("Database version id is invalid."))
})
}
fn bump_database_version(&self, new_version: u64) -> Result<()> {
self.global.insert(b"version", &new_version.to_be_bytes())?;
Ok(())
}
fn backup(&self) -> Result<(), Box<dyn std::error::Error>> { self.db.backup() }
fn backup_list(&self) -> Result<String> { self.db.backup_list() }
fn file_list(&self) -> Result<String> { self.db.file_list() }
}

View file

@ -0,0 +1,317 @@
use std::collections::BTreeMap;
use ruma::{
api::client::{
backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup},
error::ErrorKind,
},
serde::Raw,
OwnedRoomId, RoomId, UserId,
};
use crate::{services, utils, Error, KeyValueDatabase, Result};
impl crate::key_backups::Data for KeyValueDatabase {
fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw<BackupAlgorithm>) -> Result<String> {
let version = services().globals.next_count()?.to_string();
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
self.backupid_algorithm.insert(
&key,
&serde_json::to_vec(backup_metadata).expect("BackupAlgorithm::to_vec always works"),
)?;
self.backupid_etag
.insert(&key, &services().globals.next_count()?.to_be_bytes())?;
Ok(version)
}
fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
self.backupid_algorithm.remove(&key)?;
self.backupid_etag.remove(&key)?;
key.push(0xFF);
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
self.backupkeyid_backup.remove(&outdated_key)?;
}
Ok(())
}
fn update_backup(&self, user_id: &UserId, version: &str, backup_metadata: &Raw<BackupAlgorithm>) -> Result<String> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
if self.backupid_algorithm.get(&key)?.is_none() {
return Err(Error::BadRequest(ErrorKind::NotFound, "Tried to update nonexistent backup."));
}
self.backupid_algorithm
.insert(&key, backup_metadata.json().get().as_bytes())?;
self.backupid_etag
.insert(&key, &services().globals.next_count()?.to_be_bytes())?;
Ok(version.to_owned())
}
fn get_latest_backup_version(&self, user_id: &UserId) -> Result<Option<String>> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
let mut last_possible_key = prefix.clone();
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
self.backupid_algorithm
.iter_from(&last_possible_key, true)
.take_while(move |(k, _)| k.starts_with(&prefix))
.next()
.map(|(key, _)| {
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("backupid_algorithm key is invalid."))
})
.transpose()
}
fn get_latest_backup(&self, user_id: &UserId) -> Result<Option<(String, Raw<BackupAlgorithm>)>> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
let mut last_possible_key = prefix.clone();
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
self.backupid_algorithm
.iter_from(&last_possible_key, true)
.take_while(move |(k, _)| k.starts_with(&prefix))
.next()
.map(|(key, value)| {
let version = utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("backupid_algorithm key is invalid."))?;
Ok((
version,
serde_json::from_slice(&value)
.map_err(|_| Error::bad_database("Algorithm in backupid_algorithm is invalid."))?,
))
})
.transpose()
}
fn get_backup(&self, user_id: &UserId, version: &str) -> Result<Option<Raw<BackupAlgorithm>>> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
self.backupid_algorithm
.get(&key)?
.map_or(Ok(None), |bytes| {
serde_json::from_slice(&bytes)
.map_err(|_| Error::bad_database("Algorithm in backupid_algorithm is invalid."))
})
}
fn add_key(
&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, key_data: &Raw<KeyBackupData>,
) -> Result<()> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
if self.backupid_algorithm.get(&key)?.is_none() {
return Err(Error::BadRequest(ErrorKind::NotFound, "Tried to update nonexistent backup."));
}
self.backupid_etag
.insert(&key, &services().globals.next_count()?.to_be_bytes())?;
key.push(0xFF);
key.extend_from_slice(room_id.as_bytes());
key.push(0xFF);
key.extend_from_slice(session_id.as_bytes());
self.backupkeyid_backup
.insert(&key, key_data.json().get().as_bytes())?;
Ok(())
}
fn count_keys(&self, user_id: &UserId, version: &str) -> Result<usize> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(version.as_bytes());
Ok(self.backupkeyid_backup.scan_prefix(prefix).count())
}
fn get_etag(&self, user_id: &UserId, version: &str) -> Result<String> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
Ok(utils::u64_from_bytes(
&self
.backupid_etag
.get(&key)?
.ok_or_else(|| Error::bad_database("Backup has no etag."))?,
)
.map_err(|_| Error::bad_database("etag in backupid_etag invalid."))?
.to_string())
}
fn get_all(&self, user_id: &UserId, version: &str) -> Result<BTreeMap<OwnedRoomId, RoomKeyBackup>> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(version.as_bytes());
prefix.push(0xFF);
let mut rooms = BTreeMap::<OwnedRoomId, RoomKeyBackup>::new();
for result in self
.backupkeyid_backup
.scan_prefix(prefix)
.map(|(key, value)| {
let mut parts = key.rsplit(|&b| b == 0xFF);
let session_id = utils::string_from_bytes(
parts
.next()
.ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?,
)
.map_err(|_| Error::bad_database("backupkeyid_backup session_id is invalid."))?;
let room_id = RoomId::parse(
utils::string_from_bytes(
parts
.next()
.ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?,
)
.map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid."))?,
)
.map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid room id."))?;
let key_data = serde_json::from_slice(&value)
.map_err(|_| Error::bad_database("KeyBackupData in backupkeyid_backup is invalid."))?;
Ok::<_, Error>((room_id, session_id, key_data))
}) {
let (room_id, session_id, key_data) = result?;
rooms
.entry(room_id)
.or_insert_with(|| RoomKeyBackup {
sessions: BTreeMap::new(),
})
.sessions
.insert(session_id, key_data);
}
Ok(rooms)
}
fn get_room(
&self, user_id: &UserId, version: &str, room_id: &RoomId,
) -> Result<BTreeMap<String, Raw<KeyBackupData>>> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(version.as_bytes());
prefix.push(0xFF);
prefix.extend_from_slice(room_id.as_bytes());
prefix.push(0xFF);
Ok(self
.backupkeyid_backup
.scan_prefix(prefix)
.map(|(key, value)| {
let mut parts = key.rsplit(|&b| b == 0xFF);
let session_id = utils::string_from_bytes(
parts
.next()
.ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?,
)
.map_err(|_| Error::bad_database("backupkeyid_backup session_id is invalid."))?;
let key_data = serde_json::from_slice(&value)
.map_err(|_| Error::bad_database("KeyBackupData in backupkeyid_backup is invalid."))?;
Ok::<_, Error>((session_id, key_data))
})
.filter_map(Result::ok)
.collect())
}
fn get_session(
&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str,
) -> Result<Option<Raw<KeyBackupData>>> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
key.push(0xFF);
key.extend_from_slice(room_id.as_bytes());
key.push(0xFF);
key.extend_from_slice(session_id.as_bytes());
self.backupkeyid_backup
.get(&key)?
.map(|value| {
serde_json::from_slice(&value)
.map_err(|_| Error::bad_database("KeyBackupData in backupkeyid_backup is invalid."))
})
.transpose()
}
fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
key.push(0xFF);
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
self.backupkeyid_backup.remove(&outdated_key)?;
}
Ok(())
}
fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
key.push(0xFF);
key.extend_from_slice(room_id.as_bytes());
key.push(0xFF);
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
self.backupkeyid_backup.remove(&outdated_key)?;
}
Ok(())
}
fn delete_room_key(&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str) -> Result<()> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
key.push(0xFF);
key.extend_from_slice(room_id.as_bytes());
key.push(0xFF);
key.extend_from_slice(session_id.as_bytes());
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
self.backupkeyid_backup.remove(&outdated_key)?;
}
Ok(())
}
}

View file

@ -0,0 +1,242 @@
use ruma::api::client::error::ErrorKind;
use tracing::debug;
use crate::{media::UrlPreviewData, utils::string_from_bytes, Error, KeyValueDatabase, Result};
impl crate::media::Data for KeyValueDatabase {
fn create_file_metadata(
&self, sender_user: Option<&str>, mxc: String, width: u32, height: u32, content_disposition: Option<&str>,
content_type: Option<&str>,
) -> Result<Vec<u8>> {
let mut key = mxc.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(&width.to_be_bytes());
key.extend_from_slice(&height.to_be_bytes());
key.push(0xFF);
key.extend_from_slice(
content_disposition
.as_ref()
.map(|f| f.as_bytes())
.unwrap_or_default(),
);
key.push(0xFF);
key.extend_from_slice(
content_type
.as_ref()
.map(|c| c.as_bytes())
.unwrap_or_default(),
);
self.mediaid_file.insert(&key, &[])?;
if let Some(user) = sender_user {
let key = mxc.as_bytes().to_vec();
let user = user.as_bytes().to_vec();
self.mediaid_user.insert(&key, &user)?;
}
Ok(key)
}
fn delete_file_mxc(&self, mxc: String) -> Result<()> {
debug!("MXC URI: {:?}", mxc);
let mut prefix = mxc.as_bytes().to_vec();
prefix.push(0xFF);
debug!("MXC db prefix: {prefix:?}");
for (key, _) in self.mediaid_file.scan_prefix(prefix) {
debug!("Deleting key: {:?}", key);
self.mediaid_file.remove(&key)?;
}
for (key, value) in self.mediaid_user.scan_prefix(mxc.as_bytes().to_vec()) {
if key == mxc.as_bytes().to_vec() {
let user = string_from_bytes(&value).unwrap_or_default();
debug!("Deleting key \"{key:?}\" which was uploaded by user {user}");
self.mediaid_user.remove(&key)?;
}
}
Ok(())
}
/// Searches for all files with the given MXC
fn search_mxc_metadata_prefix(&self, mxc: String) -> Result<Vec<Vec<u8>>> {
debug!("MXC URI: {:?}", mxc);
let mut prefix = mxc.as_bytes().to_vec();
prefix.push(0xFF);
let mut keys: Vec<Vec<u8>> = vec![];
for (key, _) in self.mediaid_file.scan_prefix(prefix) {
keys.push(key);
}
if keys.is_empty() {
return Err(Error::bad_database(
"Failed to find any keys in database with the provided MXC.",
));
}
debug!("Got the following keys: {:?}", keys);
Ok(keys)
}
fn search_file_metadata(
&self, mxc: String, width: u32, height: u32,
) -> Result<(Option<String>, Option<String>, Vec<u8>)> {
let mut prefix = mxc.as_bytes().to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(&width.to_be_bytes());
prefix.extend_from_slice(&height.to_be_bytes());
prefix.push(0xFF);
let (key, _) = self
.mediaid_file
.scan_prefix(prefix)
.next()
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Media not found"))?;
let mut parts = key.rsplit(|&b| b == 0xFF);
let content_type = parts
.next()
.map(|bytes| {
string_from_bytes(bytes)
.map_err(|_| Error::bad_database("Content type in mediaid_file is invalid unicode."))
})
.transpose()?;
let content_disposition_bytes = parts
.next()
.ok_or_else(|| Error::bad_database("Media ID in db is invalid."))?;
let content_disposition = if content_disposition_bytes.is_empty() {
None
} else {
Some(
string_from_bytes(content_disposition_bytes)
.map_err(|_| Error::bad_database("Content Disposition in mediaid_file is invalid unicode."))?,
)
};
Ok((content_disposition, content_type, key))
}
/// Gets all the media keys in our database (this includes all the metadata
/// associated with it such as width, height, content-type, etc)
fn get_all_media_keys(&self) -> Result<Vec<Vec<u8>>> {
let mut keys: Vec<Vec<u8>> = vec![];
for (key, _) in self.mediaid_file.iter() {
keys.push(key);
}
Ok(keys)
}
fn remove_url_preview(&self, url: &str) -> Result<()> { self.url_previews.remove(url.as_bytes()) }
fn set_url_preview(&self, url: &str, data: &UrlPreviewData, timestamp: std::time::Duration) -> Result<()> {
let mut value = Vec::<u8>::new();
value.extend_from_slice(&timestamp.as_secs().to_be_bytes());
value.push(0xFF);
value.extend_from_slice(
data.title
.as_ref()
.map(String::as_bytes)
.unwrap_or_default(),
);
value.push(0xFF);
value.extend_from_slice(
data.description
.as_ref()
.map(String::as_bytes)
.unwrap_or_default(),
);
value.push(0xFF);
value.extend_from_slice(
data.image
.as_ref()
.map(String::as_bytes)
.unwrap_or_default(),
);
value.push(0xFF);
value.extend_from_slice(&data.image_size.unwrap_or(0).to_be_bytes());
value.push(0xFF);
value.extend_from_slice(&data.image_width.unwrap_or(0).to_be_bytes());
value.push(0xFF);
value.extend_from_slice(&data.image_height.unwrap_or(0).to_be_bytes());
self.url_previews.insert(url.as_bytes(), &value)
}
fn get_url_preview(&self, url: &str) -> Option<UrlPreviewData> {
let values = self.url_previews.get(url.as_bytes()).ok()??;
let mut values = values.split(|&b| b == 0xFF);
let _ts = values.next();
/* if we ever decide to use timestamp, this is here.
match values.next().map(|b| u64::from_be_bytes(b.try_into().expect("valid BE array"))) {
Some(0) => None,
x => x,
};*/
let title = match values
.next()
.and_then(|b| String::from_utf8(b.to_vec()).ok())
{
Some(s) if s.is_empty() => None,
x => x,
};
let description = match values
.next()
.and_then(|b| String::from_utf8(b.to_vec()).ok())
{
Some(s) if s.is_empty() => None,
x => x,
};
let image = match values
.next()
.and_then(|b| String::from_utf8(b.to_vec()).ok())
{
Some(s) if s.is_empty() => None,
x => x,
};
let image_size = match values
.next()
.map(|b| usize::from_be_bytes(b.try_into().unwrap_or_default()))
{
Some(0) => None,
x => x,
};
let image_width = match values
.next()
.map(|b| u32::from_be_bytes(b.try_into().unwrap_or_default()))
{
Some(0) => None,
x => x,
};
let image_height = match values
.next()
.map(|b| u32::from_be_bytes(b.try_into().unwrap_or_default()))
{
Some(0) => None,
x => x,
};
Some(UrlPreviewData {
title,
description,
image,
image_size,
image_width,
image_height,
})
}
}

View file

@ -0,0 +1,14 @@
mod account_data;
//mod admin;
mod appservice;
mod globals;
mod key_backups;
mod media;
//mod pdu;
mod presence;
mod pusher;
mod rooms;
mod sending;
mod transaction_ids;
mod uiaa;
mod users;

View file

@ -0,0 +1,127 @@
use conduit::debug_info;
use ruma::{events::presence::PresenceEvent, presence::PresenceState, OwnedUserId, UInt, UserId};
use crate::{
presence::Presence,
services,
utils::{self, user_id_from_bytes},
Error, KeyValueDatabase, Result,
};
impl crate::presence::Data for KeyValueDatabase {
fn get_presence(&self, user_id: &UserId) -> Result<Option<(u64, PresenceEvent)>> {
if let Some(count_bytes) = self.userid_presenceid.get(user_id.as_bytes())? {
let count = utils::u64_from_bytes(&count_bytes)
.map_err(|_e| Error::bad_database("No 'count' bytes in presence key"))?;
let key = presenceid_key(count, user_id);
self.presenceid_presence
.get(&key)?
.map(|presence_bytes| -> Result<(u64, PresenceEvent)> {
Ok((count, Presence::from_json_bytes(&presence_bytes)?.to_presence_event(user_id)?))
})
.transpose()
} else {
Ok(None)
}
}
fn set_presence(
&self, user_id: &UserId, presence_state: &PresenceState, currently_active: Option<bool>,
last_active_ago: Option<UInt>, status_msg: Option<String>,
) -> Result<()> {
let last_presence = self.get_presence(user_id)?;
let state_changed = match last_presence {
None => true,
Some(ref presence) => presence.1.content.presence != *presence_state,
};
let now = utils::millis_since_unix_epoch();
let last_last_active_ts = match last_presence {
None => 0,
Some((_, ref presence)) => now.saturating_sub(presence.content.last_active_ago.unwrap_or_default().into()),
};
let last_active_ts = match last_active_ago {
None => now,
Some(last_active_ago) => now.saturating_sub(last_active_ago.into()),
};
// tighten for state flicker?
if !state_changed && last_active_ts <= last_last_active_ts {
debug_info!(
"presence spam {:?} last_active_ts:{:?} <= {:?}",
user_id,
last_active_ts,
last_last_active_ts
);
return Ok(());
}
let status_msg = if status_msg.as_ref().is_some_and(String::is_empty) {
None
} else {
status_msg
};
let presence = Presence::new(
presence_state.to_owned(),
currently_active.unwrap_or(false),
last_active_ts,
status_msg,
);
let count = services().globals.next_count()?;
let key = presenceid_key(count, user_id);
self.presenceid_presence
.insert(&key, &presence.to_json_bytes()?)?;
self.userid_presenceid
.insert(user_id.as_bytes(), &count.to_be_bytes())?;
if let Some((last_count, _)) = last_presence {
let key = presenceid_key(last_count, user_id);
self.presenceid_presence.remove(&key)?;
}
Ok(())
}
fn remove_presence(&self, user_id: &UserId) -> Result<()> {
if let Some(count_bytes) = self.userid_presenceid.get(user_id.as_bytes())? {
let count = utils::u64_from_bytes(&count_bytes)
.map_err(|_e| Error::bad_database("No 'count' bytes in presence key"))?;
let key = presenceid_key(count, user_id);
self.presenceid_presence.remove(&key)?;
self.userid_presenceid.remove(user_id.as_bytes())?;
}
Ok(())
}
fn presence_since<'a>(&'a self, since: u64) -> Box<dyn Iterator<Item = (OwnedUserId, u64, Vec<u8>)> + 'a> {
Box::new(
self.presenceid_presence
.iter()
.flat_map(|(key, presence_bytes)| -> Result<(OwnedUserId, u64, Vec<u8>)> {
let (count, user_id) = presenceid_parse(&key)?;
Ok((user_id, count, presence_bytes))
})
.filter(move |(_, count, _)| *count > since),
)
}
}
#[inline]
fn presenceid_key(count: u64, user_id: &UserId) -> Vec<u8> {
[count.to_be_bytes().to_vec(), user_id.as_bytes().to_vec()].concat()
}
#[inline]
fn presenceid_parse(key: &[u8]) -> Result<(u64, OwnedUserId)> {
let (count, user_id) = key.split_at(8);
let user_id = user_id_from_bytes(user_id)?;
let count = utils::u64_from_bytes(count).unwrap();
Ok((count, user_id))
}

View file

@ -0,0 +1,65 @@
use ruma::{
api::client::push::{set_pusher, Pusher},
UserId,
};
use crate::{utils, Error, KeyValueDatabase, Result};
impl crate::pusher::Data for KeyValueDatabase {
fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::PusherAction) -> Result<()> {
match &pusher {
set_pusher::v3::PusherAction::Post(data) => {
let mut key = sender.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(data.pusher.ids.pushkey.as_bytes());
self.senderkey_pusher
.insert(&key, &serde_json::to_vec(&pusher).expect("Pusher is valid JSON value"))?;
Ok(())
},
set_pusher::v3::PusherAction::Delete(ids) => {
let mut key = sender.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(ids.pushkey.as_bytes());
self.senderkey_pusher.remove(&key).map_err(Into::into)
},
}
}
fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result<Option<Pusher>> {
let mut senderkey = sender.as_bytes().to_vec();
senderkey.push(0xFF);
senderkey.extend_from_slice(pushkey.as_bytes());
self.senderkey_pusher
.get(&senderkey)?
.map(|push| serde_json::from_slice(&push).map_err(|_| Error::bad_database("Invalid Pusher in db.")))
.transpose()
}
fn get_pushers(&self, sender: &UserId) -> Result<Vec<Pusher>> {
let mut prefix = sender.as_bytes().to_vec();
prefix.push(0xFF);
self.senderkey_pusher
.scan_prefix(prefix)
.map(|(_, push)| serde_json::from_slice(&push).map_err(|_| Error::bad_database("Invalid Pusher in db.")))
.collect()
}
fn get_pushkeys<'a>(&'a self, sender: &UserId) -> Box<dyn Iterator<Item = Result<String>> + 'a> {
let mut prefix = sender.as_bytes().to_vec();
prefix.push(0xFF);
Box::new(self.senderkey_pusher.scan_prefix(prefix).map(|(k, _)| {
let mut parts = k.splitn(2, |&b| b == 0xFF);
let _senderkey = parts.next();
let push_key = parts
.next()
.ok_or_else(|| Error::bad_database("Invalid senderkey_pusher in db"))?;
let push_key_string = utils::string_from_bytes(push_key)
.map_err(|_| Error::bad_database("Invalid pusher bytes in senderkey_pusher"))?;
Ok(push_key_string)
}))
}
}

View file

@ -0,0 +1,75 @@
use ruma::{api::client::error::ErrorKind, OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId};
use crate::{services, utils, Error, KeyValueDatabase, Result};
impl crate::rooms::alias::Data for KeyValueDatabase {
fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId) -> Result<()> {
self.alias_roomid
.insert(alias.alias().as_bytes(), room_id.as_bytes())?;
let mut aliasid = room_id.as_bytes().to_vec();
aliasid.push(0xFF);
aliasid.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
self.aliasid_alias.insert(&aliasid, alias.as_bytes())?;
Ok(())
}
fn remove_alias(&self, alias: &RoomAliasId) -> Result<()> {
if let Some(room_id) = self.alias_roomid.get(alias.alias().as_bytes())? {
let mut prefix = room_id;
prefix.push(0xFF);
for (key, _) in self.aliasid_alias.scan_prefix(prefix) {
self.aliasid_alias.remove(&key)?;
}
self.alias_roomid.remove(alias.alias().as_bytes())?;
} else {
return Err(Error::BadRequest(ErrorKind::NotFound, "Alias does not exist."));
}
Ok(())
}
fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result<Option<OwnedRoomId>> {
self.alias_roomid
.get(alias.alias().as_bytes())?
.map(|bytes| {
RoomId::parse(
utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Room ID in alias_roomid is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("Room ID in alias_roomid is invalid."))
})
.transpose()
}
fn local_aliases_for_room<'a>(
&'a self, room_id: &RoomId,
) -> Box<dyn Iterator<Item = Result<OwnedRoomAliasId>> + 'a> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF);
Box::new(self.aliasid_alias.scan_prefix(prefix).map(|(_, bytes)| {
utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Invalid alias bytes in aliasid_alias."))?
.try_into()
.map_err(|_| Error::bad_database("Invalid alias in aliasid_alias."))
}))
}
fn all_local_aliases<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(OwnedRoomId, String)>> + 'a> {
Box::new(
self.alias_roomid
.iter()
.map(|(room_alias_bytes, room_id_bytes)| {
let room_alias_localpart = utils::string_from_bytes(&room_alias_bytes)
.map_err(|_| Error::bad_database("Invalid alias bytes in aliasid_alias."))?;
let room_id = utils::string_from_bytes(&room_id_bytes)
.map_err(|_| Error::bad_database("Invalid room_id bytes in aliasid_alias."))?
.try_into()
.map_err(|_| Error::bad_database("Invalid room_id in aliasid_alias."))?;
Ok((room_id, room_alias_localpart))
}),
)
}
}

View file

@ -0,0 +1,59 @@
use std::{mem::size_of, sync::Arc};
use crate::{utils, KeyValueDatabase, Result};
impl crate::rooms::auth_chain::Data for KeyValueDatabase {
fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Option<Arc<[u64]>>> {
// Check RAM cache
if let Some(result) = self.auth_chain_cache.lock().unwrap().get_mut(key) {
return Ok(Some(Arc::clone(result)));
}
// We only save auth chains for single events in the db
if key.len() == 1 {
// Check DB cache
let chain = self
.shorteventid_authchain
.get(&key[0].to_be_bytes())?
.map(|chain| {
chain
.chunks_exact(size_of::<u64>())
.map(|chunk| utils::u64_from_bytes(chunk).expect("byte length is correct"))
.collect::<Arc<[u64]>>()
});
if let Some(chain) = chain {
// Cache in RAM
self.auth_chain_cache
.lock()
.unwrap()
.insert(vec![key[0]], Arc::clone(&chain));
return Ok(Some(chain));
}
}
Ok(None)
}
fn cache_auth_chain(&self, key: Vec<u64>, auth_chain: Arc<[u64]>) -> Result<()> {
// Only persist single events in db
if key.len() == 1 {
self.shorteventid_authchain.insert(
&key[0].to_be_bytes(),
&auth_chain
.iter()
.flat_map(|s| s.to_be_bytes().to_vec())
.collect::<Vec<u8>>(),
)?;
}
// Cache in RAM
self.auth_chain_cache
.lock()
.unwrap()
.insert(key, auth_chain);
Ok(())
}
}

View file

@ -0,0 +1,23 @@
use ruma::{OwnedRoomId, RoomId};
use crate::{utils, Error, KeyValueDatabase, Result};
impl crate::rooms::directory::Data for KeyValueDatabase {
fn set_public(&self, room_id: &RoomId) -> Result<()> { self.publicroomids.insert(room_id.as_bytes(), &[]) }
fn set_not_public(&self, room_id: &RoomId) -> Result<()> { self.publicroomids.remove(room_id.as_bytes()) }
fn is_public_room(&self, room_id: &RoomId) -> Result<bool> {
Ok(self.publicroomids.get(room_id.as_bytes())?.is_some())
}
fn public_rooms<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
Box::new(self.publicroomids.iter().map(|(bytes, _)| {
RoomId::parse(
utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Room ID in publicroomids is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("Room ID in publicroomids is invalid."))
}))
}
}

View file

@ -0,0 +1,53 @@
use ruma::{DeviceId, RoomId, UserId};
use crate::{KeyValueDatabase, Result};
impl crate::rooms::lazy_loading::Data for KeyValueDatabase {
fn lazy_load_was_sent_before(
&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, ll_user: &UserId,
) -> Result<bool> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(device_id.as_bytes());
key.push(0xFF);
key.extend_from_slice(room_id.as_bytes());
key.push(0xFF);
key.extend_from_slice(ll_user.as_bytes());
Ok(self.lazyloadedids.get(&key)?.is_some())
}
fn lazy_load_confirm_delivery(
&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId,
confirmed_user_ids: &mut dyn Iterator<Item = &UserId>,
) -> Result<()> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(device_id.as_bytes());
prefix.push(0xFF);
prefix.extend_from_slice(room_id.as_bytes());
prefix.push(0xFF);
for ll_id in confirmed_user_ids {
let mut key = prefix.clone();
key.extend_from_slice(ll_id.as_bytes());
self.lazyloadedids.insert(&key, &[])?;
}
Ok(())
}
fn lazy_load_reset(&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId) -> Result<()> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(device_id.as_bytes());
prefix.push(0xFF);
prefix.extend_from_slice(room_id.as_bytes());
prefix.push(0xFF);
for (key, _) in self.lazyloadedids.scan_prefix(prefix) {
self.lazyloadedids.remove(&key)?;
}
Ok(())
}
}

View file

@ -0,0 +1,76 @@
use ruma::{OwnedRoomId, RoomId};
use tracing::error;
use crate::{services, utils, Error, KeyValueDatabase, Result};
impl crate::rooms::metadata::Data for KeyValueDatabase {
fn exists(&self, room_id: &RoomId) -> Result<bool> {
let prefix = match services().rooms.short.get_shortroomid(room_id)? {
Some(b) => b.to_be_bytes().to_vec(),
None => return Ok(false),
};
// Look for PDUs in that room.
Ok(self
.pduid_pdu
.iter_from(&prefix, false)
.next()
.filter(|(k, _)| k.starts_with(&prefix))
.is_some())
}
fn iter_ids<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
Box::new(self.roomid_shortroomid.iter().map(|(bytes, _)| {
RoomId::parse(
utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Room ID in publicroomids is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("Room ID in roomid_shortroomid is invalid."))
}))
}
fn is_disabled(&self, room_id: &RoomId) -> Result<bool> {
Ok(self.disabledroomids.get(room_id.as_bytes())?.is_some())
}
fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()> {
if disabled {
self.disabledroomids.insert(room_id.as_bytes(), &[])?;
} else {
self.disabledroomids.remove(room_id.as_bytes())?;
}
Ok(())
}
fn is_banned(&self, room_id: &RoomId) -> Result<bool> { Ok(self.bannedroomids.get(room_id.as_bytes())?.is_some()) }
fn ban_room(&self, room_id: &RoomId, banned: bool) -> Result<()> {
if banned {
self.bannedroomids.insert(room_id.as_bytes(), &[])?;
} else {
self.bannedroomids.remove(room_id.as_bytes())?;
}
Ok(())
}
fn list_banned_rooms<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
Box::new(self.bannedroomids.iter().map(
|(room_id_bytes, _ /* non-banned rooms should not be in this table */)| {
let room_id = utils::string_from_bytes(&room_id_bytes)
.map_err(|e| {
error!("Invalid room_id bytes in bannedroomids: {e}");
Error::bad_database("Invalid room_id in bannedroomids.")
})?
.try_into()
.map_err(|e| {
error!("Invalid room_id in bannedroomids: {e}");
Error::bad_database("Invalid room_id in bannedroomids")
})?;
Ok(room_id)
},
))
}
}

View file

@ -0,0 +1,21 @@
mod alias;
mod auth_chain;
mod directory;
mod lazy_load;
mod metadata;
mod outlier;
mod pdu_metadata;
mod read_receipt;
mod search;
mod short;
mod state;
mod state_accessor;
mod state_cache;
mod state_compressor;
mod threads;
mod timeline;
mod user;
use crate::KeyValueDatabase;
impl crate::rooms::Data for KeyValueDatabase {}

View file

@ -0,0 +1,28 @@
use ruma::{CanonicalJsonObject, EventId};
use crate::{Error, KeyValueDatabase, PduEvent, Result};
impl crate::rooms::outlier::Data for KeyValueDatabase {
fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
self.eventid_outlierpdu
.get(event_id.as_bytes())?
.map_or(Ok(None), |pdu| {
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))
})
}
fn get_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>> {
self.eventid_outlierpdu
.get(event_id.as_bytes())?
.map_or(Ok(None), |pdu| {
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))
})
}
fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()> {
self.eventid_outlierpdu.insert(
event_id.as_bytes(),
&serde_json::to_vec(&pdu).expect("CanonicalJsonObject is valid"),
)
}
}

View file

@ -0,0 +1,79 @@
use std::{mem, sync::Arc};
use ruma::{EventId, RoomId, UserId};
use crate::{services, utils, Error, KeyValueDatabase, PduCount, PduEvent, Result};
impl crate::rooms::pdu_metadata::Data for KeyValueDatabase {
fn add_relation(&self, from: u64, to: u64) -> Result<()> {
let mut key = to.to_be_bytes().to_vec();
key.extend_from_slice(&from.to_be_bytes());
self.tofrom_relation.insert(&key, &[])?;
Ok(())
}
fn relations_until<'a>(
&'a self, user_id: &'a UserId, shortroomid: u64, target: u64, until: PduCount,
) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>> {
let prefix = target.to_be_bytes().to_vec();
let mut current = prefix.clone();
let count_raw = match until {
PduCount::Normal(x) => x.saturating_sub(1),
PduCount::Backfilled(x) => {
current.extend_from_slice(&0_u64.to_be_bytes());
u64::MAX.saturating_sub(x).saturating_sub(1)
},
};
current.extend_from_slice(&count_raw.to_be_bytes());
Ok(Box::new(
self.tofrom_relation
.iter_from(&current, true)
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(move |(tofrom, _data)| {
let from = utils::u64_from_bytes(&tofrom[(mem::size_of::<u64>())..])
.map_err(|_| Error::bad_database("Invalid count in tofrom_relation."))?;
let mut pduid = shortroomid.to_be_bytes().to_vec();
pduid.extend_from_slice(&from.to_be_bytes());
let mut pdu = services()
.rooms
.timeline
.get_pdu_from_id(&pduid)?
.ok_or_else(|| Error::bad_database("Pdu in tofrom_relation is invalid."))?;
if pdu.sender != user_id {
pdu.remove_transaction_id()?;
}
Ok((PduCount::Normal(from), pdu))
}),
))
}
fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc<EventId>]) -> Result<()> {
for prev in event_ids {
let mut key = room_id.as_bytes().to_vec();
key.extend_from_slice(prev.as_bytes());
self.referencedevents.insert(&key, &[])?;
}
Ok(())
}
fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result<bool> {
let mut key = room_id.as_bytes().to_vec();
key.extend_from_slice(event_id.as_bytes());
Ok(self.referencedevents.get(&key)?.is_some())
}
fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()> {
self.softfailedeventids.insert(event_id.as_bytes(), &[])
}
fn is_event_soft_failed(&self, event_id: &EventId) -> Result<bool> {
self.softfailedeventids
.get(event_id.as_bytes())
.map(|o| o.is_some())
}
}

View file

@ -0,0 +1,120 @@
use std::mem;
use ruma::{events::receipt::ReceiptEvent, serde::Raw, CanonicalJsonObject, OwnedUserId, RoomId, UserId};
use crate::{services, utils, Error, KeyValueDatabase, Result};
impl crate::rooms::read_receipt::Data for KeyValueDatabase {
fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: ReceiptEvent) -> Result<()> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF);
let mut last_possible_key = prefix.clone();
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
// Remove old entry
if let Some((old, _)) = self
.readreceiptid_readreceipt
.iter_from(&last_possible_key, true)
.take_while(|(key, _)| key.starts_with(&prefix))
.find(|(key, _)| {
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element")
== user_id.as_bytes()
}) {
// This is the old room_latest
self.readreceiptid_readreceipt.remove(&old)?;
}
let mut room_latest_id = prefix;
room_latest_id.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
room_latest_id.push(0xFF);
room_latest_id.extend_from_slice(user_id.as_bytes());
self.readreceiptid_readreceipt.insert(
&room_latest_id,
&serde_json::to_vec(&event).expect("EduEvent::to_string always works"),
)?;
Ok(())
}
fn readreceipts_since<'a>(
&'a self, room_id: &RoomId, since: u64,
) -> Box<dyn Iterator<Item = Result<(OwnedUserId, u64, Raw<ruma::events::AnySyncEphemeralRoomEvent>)>> + 'a> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF);
let prefix2 = prefix.clone();
let mut first_possible_edu = prefix.clone();
first_possible_edu.extend_from_slice(&(since.saturating_add(1)).to_be_bytes()); // +1 so we don't send the event at since
Box::new(
self.readreceiptid_readreceipt
.iter_from(&first_possible_edu, false)
.take_while(move |(k, _)| k.starts_with(&prefix2))
.map(move |(k, v)| {
let count = utils::u64_from_bytes(&k[prefix.len()..prefix.len() + mem::size_of::<u64>()])
.map_err(|_| Error::bad_database("Invalid readreceiptid count in db."))?;
let user_id = UserId::parse(
utils::string_from_bytes(&k[prefix.len() + mem::size_of::<u64>() + 1..])
.map_err(|_| Error::bad_database("Invalid readreceiptid userid bytes in db."))?,
)
.map_err(|_| Error::bad_database("Invalid readreceiptid userid in db."))?;
let mut json = serde_json::from_slice::<CanonicalJsonObject>(&v)
.map_err(|_| Error::bad_database("Read receipt in roomlatestid_roomlatest is invalid json."))?;
json.remove("room_id");
Ok((
user_id,
count,
Raw::from_json(serde_json::value::to_raw_value(&json).expect("json is valid raw value")),
))
}),
)
}
fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> {
let mut key = room_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(user_id.as_bytes());
self.roomuserid_privateread
.insert(&key, &count.to_be_bytes())?;
self.roomuserid_lastprivatereadupdate
.insert(&key, &services().globals.next_count()?.to_be_bytes())
}
fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {
let mut key = room_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(user_id.as_bytes());
self.roomuserid_privateread
.get(&key)?
.map_or(Ok(None), |v| {
Ok(Some(
utils::u64_from_bytes(&v).map_err(|_| Error::bad_database("Invalid private read marker bytes"))?,
))
})
}
fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
let mut key = room_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(user_id.as_bytes());
Ok(self
.roomuserid_lastprivatereadupdate
.get(&key)?
.map(|bytes| {
utils::u64_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid."))
})
.transpose()?
.unwrap_or(0))
}
}

View file

@ -0,0 +1,64 @@
use ruma::RoomId;
use crate::{services, utils, KeyValueDatabase, Result};
type SearchPdusResult<'a> = Result<Option<(Box<dyn Iterator<Item = Vec<u8>> + 'a>, Vec<String>)>>;
impl crate::rooms::search::Data for KeyValueDatabase {
fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> {
let mut batch = message_body
.split_terminator(|c: char| !c.is_alphanumeric())
.filter(|s| !s.is_empty())
.filter(|word| word.len() <= 50)
.map(str::to_lowercase)
.map(|word| {
let mut key = shortroomid.to_be_bytes().to_vec();
key.extend_from_slice(word.as_bytes());
key.push(0xFF);
key.extend_from_slice(pdu_id); // TODO: currently we save the room id a second time here
(key, Vec::new())
});
self.tokenids.insert_batch(&mut batch)
}
fn search_pdus<'a>(&'a self, room_id: &RoomId, search_string: &str) -> SearchPdusResult<'a> {
let prefix = services()
.rooms
.short
.get_shortroomid(room_id)?
.expect("room exists")
.to_be_bytes()
.to_vec();
let words: Vec<_> = search_string
.split_terminator(|c: char| !c.is_alphanumeric())
.filter(|s| !s.is_empty())
.map(str::to_lowercase)
.collect();
let iterators = words.clone().into_iter().map(move |word| {
let mut prefix2 = prefix.clone();
prefix2.extend_from_slice(word.as_bytes());
prefix2.push(0xFF);
let prefix3 = prefix2.clone();
let mut last_possible_id = prefix2.clone();
last_possible_id.extend_from_slice(&u64::MAX.to_be_bytes());
self.tokenids
.iter_from(&last_possible_id, true) // Newest pdus first
.take_while(move |(k, _)| k.starts_with(&prefix2))
.map(move |(key, _)| key[prefix3.len()..].to_vec())
});
let Some(common_elements) = utils::common_elements(iterators, |a, b| {
// We compare b with a because we reversed the iterator earlier
b.cmp(a)
}) else {
return Ok(None);
};
Ok(Some((Box::new(common_elements), words)))
}
}

View file

@ -0,0 +1,164 @@
use std::sync::Arc;
use ruma::{events::StateEventType, EventId, RoomId};
use tracing::warn;
use crate::{services, utils, Error, KeyValueDatabase, Result};
impl crate::rooms::short::Data for KeyValueDatabase {
fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result<u64> {
let short = if let Some(shorteventid) = self.eventid_shorteventid.get(event_id.as_bytes())? {
utils::u64_from_bytes(&shorteventid).map_err(|_| Error::bad_database("Invalid shorteventid in db."))?
} else {
let shorteventid = services().globals.next_count()?;
self.eventid_shorteventid
.insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?;
self.shorteventid_eventid
.insert(&shorteventid.to_be_bytes(), event_id.as_bytes())?;
shorteventid
};
Ok(short)
}
fn multi_get_or_create_shorteventid(&self, event_ids: &[&EventId]) -> Result<Vec<u64>> {
let mut ret: Vec<u64> = Vec::with_capacity(event_ids.len());
let keys = event_ids
.iter()
.map(|id| id.as_bytes())
.collect::<Vec<&[u8]>>();
for (i, short) in self
.eventid_shorteventid
.multi_get(&keys)?
.iter()
.enumerate()
{
match short {
Some(short) => ret.push(
utils::u64_from_bytes(short).map_err(|_| Error::bad_database("Invalid shorteventid in db."))?,
),
None => {
let short = services().globals.next_count()?;
self.eventid_shorteventid
.insert(keys[i], &short.to_be_bytes())?;
self.shorteventid_eventid
.insert(&short.to_be_bytes(), keys[i])?;
debug_assert!(ret.len() == i, "position of result must match input");
ret.push(short);
},
}
}
Ok(ret)
}
fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result<Option<u64>> {
let mut statekey_vec = event_type.to_string().as_bytes().to_vec();
statekey_vec.push(0xFF);
statekey_vec.extend_from_slice(state_key.as_bytes());
let short = self
.statekey_shortstatekey
.get(&statekey_vec)?
.map(|shortstatekey| {
utils::u64_from_bytes(&shortstatekey).map_err(|_| Error::bad_database("Invalid shortstatekey in db."))
})
.transpose()?;
Ok(short)
}
fn get_or_create_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result<u64> {
let mut statekey_vec = event_type.to_string().as_bytes().to_vec();
statekey_vec.push(0xFF);
statekey_vec.extend_from_slice(state_key.as_bytes());
let short = if let Some(shortstatekey) = self.statekey_shortstatekey.get(&statekey_vec)? {
utils::u64_from_bytes(&shortstatekey).map_err(|_| Error::bad_database("Invalid shortstatekey in db."))?
} else {
let shortstatekey = services().globals.next_count()?;
self.statekey_shortstatekey
.insert(&statekey_vec, &shortstatekey.to_be_bytes())?;
self.shortstatekey_statekey
.insert(&shortstatekey.to_be_bytes(), &statekey_vec)?;
shortstatekey
};
Ok(short)
}
fn get_eventid_from_short(&self, shorteventid: u64) -> Result<Arc<EventId>> {
let bytes = self
.shorteventid_eventid
.get(&shorteventid.to_be_bytes())?
.ok_or_else(|| Error::bad_database("Shorteventid does not exist"))?;
let event_id = EventId::parse_arc(
utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("EventID in shorteventid_eventid is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("EventId in shorteventid_eventid is invalid."))?;
Ok(event_id)
}
fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> {
let bytes = self
.shortstatekey_statekey
.get(&shortstatekey.to_be_bytes())?
.ok_or_else(|| Error::bad_database("Shortstatekey does not exist"))?;
let mut parts = bytes.splitn(2, |&b| b == 0xFF);
let eventtype_bytes = parts.next().expect("split always returns one entry");
let statekey_bytes = parts
.next()
.ok_or_else(|| Error::bad_database("Invalid statekey in shortstatekey_statekey."))?;
let event_type = StateEventType::from(utils::string_from_bytes(eventtype_bytes).map_err(|e| {
warn!("Event type in shortstatekey_statekey is invalid: {}", e);
Error::bad_database("Event type in shortstatekey_statekey is invalid.")
})?);
let state_key = utils::string_from_bytes(statekey_bytes)
.map_err(|_| Error::bad_database("Statekey in shortstatekey_statekey is invalid unicode."))?;
let result = (event_type, state_key);
Ok(result)
}
/// Returns (shortstatehash, already_existed)
fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)> {
Ok(if let Some(shortstatehash) = self.statehash_shortstatehash.get(state_hash)? {
(
utils::u64_from_bytes(&shortstatehash)
.map_err(|_| Error::bad_database("Invalid shortstatehash in db."))?,
true,
)
} else {
let shortstatehash = services().globals.next_count()?;
self.statehash_shortstatehash
.insert(state_hash, &shortstatehash.to_be_bytes())?;
(shortstatehash, false)
})
}
fn get_shortroomid(&self, room_id: &RoomId) -> Result<Option<u64>> {
self.roomid_shortroomid
.get(room_id.as_bytes())?
.map(|bytes| utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid shortroomid in db.")))
.transpose()
}
fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result<u64> {
Ok(if let Some(short) = self.roomid_shortroomid.get(room_id.as_bytes())? {
utils::u64_from_bytes(&short).map_err(|_| Error::bad_database("Invalid shortroomid in db."))?
} else {
let short = services().globals.next_count()?;
self.roomid_shortroomid
.insert(room_id.as_bytes(), &short.to_be_bytes())?;
short
})
}
}

View file

@ -0,0 +1,73 @@
use std::{collections::HashSet, sync::Arc};
use ruma::{EventId, OwnedEventId, RoomId};
use tokio::sync::MutexGuard;
use crate::{utils, Error, KeyValueDatabase, Result};
impl crate::rooms::state::Data for KeyValueDatabase {
fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result<Option<u64>> {
self.roomid_shortstatehash
.get(room_id.as_bytes())?
.map_or(Ok(None), |bytes| {
Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Invalid shortstatehash in roomid_shortstatehash")
})?))
})
}
fn set_room_state(
&self,
room_id: &RoomId,
new_shortstatehash: u64,
_mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex
) -> Result<()> {
self.roomid_shortstatehash
.insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes())?;
Ok(())
}
fn set_event_state(&self, shorteventid: u64, shortstatehash: u64) -> Result<()> {
self.shorteventid_shortstatehash
.insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?;
Ok(())
}
fn get_forward_extremities(&self, room_id: &RoomId) -> Result<HashSet<Arc<EventId>>> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF);
self.roomid_pduleaves
.scan_prefix(prefix)
.map(|(_, bytes)| {
EventId::parse_arc(
utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("EventID in roomid_pduleaves is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("EventId in roomid_pduleaves is invalid."))
})
.collect()
}
fn set_forward_extremities(
&self,
room_id: &RoomId,
event_ids: Vec<OwnedEventId>,
_mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex
) -> Result<()> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF);
for (key, _) in self.roomid_pduleaves.scan_prefix(prefix.clone()) {
self.roomid_pduleaves.remove(&key)?;
}
for event_id in event_ids {
let mut key = prefix.clone();
key.extend_from_slice(event_id.as_bytes());
self.roomid_pduleaves.insert(&key, event_id.as_bytes())?;
}
Ok(())
}
}

View file

@ -0,0 +1,165 @@
use std::{collections::HashMap, sync::Arc};
use async_trait::async_trait;
use ruma::{events::StateEventType, EventId, RoomId};
use crate::{services, utils, Error, KeyValueDatabase, PduEvent, Result};
#[async_trait]
impl crate::rooms::state_accessor::Data for KeyValueDatabase {
#[allow(unused_qualifications)] // async traits
async fn state_full_ids(&self, shortstatehash: u64) -> Result<HashMap<u64, Arc<EventId>>> {
let full_state = services()
.rooms
.state_compressor
.load_shortstatehash_info(shortstatehash)?
.pop()
.expect("there is always one layer")
.1;
let mut result = HashMap::new();
let mut i: u8 = 0;
for compressed in full_state.iter() {
let parsed = services()
.rooms
.state_compressor
.parse_compressed_state_event(compressed)?;
result.insert(parsed.0, parsed.1);
i = i.wrapping_add(1);
if i % 100 == 0 {
tokio::task::yield_now().await;
}
}
Ok(result)
}
#[allow(unused_qualifications)] // async traits
async fn state_full(&self, shortstatehash: u64) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> {
let full_state = services()
.rooms
.state_compressor
.load_shortstatehash_info(shortstatehash)?
.pop()
.expect("there is always one layer")
.1;
let mut result = HashMap::new();
let mut i: u8 = 0;
for compressed in full_state.iter() {
let (_, eventid) = services()
.rooms
.state_compressor
.parse_compressed_state_event(compressed)?;
if let Some(pdu) = services().rooms.timeline.get_pdu(&eventid)? {
result.insert(
(
pdu.kind.to_string().into(),
pdu.state_key
.as_ref()
.ok_or_else(|| Error::bad_database("State event has no state key."))?
.clone(),
),
pdu,
);
}
i = i.wrapping_add(1);
if i % 100 == 0 {
tokio::task::yield_now().await;
}
}
Ok(result)
}
/// Returns a single PDU from `room_id` with key (`event_type`,
/// `state_key`).
fn state_get_id(
&self, shortstatehash: u64, event_type: &StateEventType, state_key: &str,
) -> Result<Option<Arc<EventId>>> {
let Some(shortstatekey) = services()
.rooms
.short
.get_shortstatekey(event_type, state_key)?
else {
return Ok(None);
};
let full_state = services()
.rooms
.state_compressor
.load_shortstatehash_info(shortstatehash)?
.pop()
.expect("there is always one layer")
.1;
Ok(full_state
.iter()
.find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes()))
.and_then(|compressed| {
services()
.rooms
.state_compressor
.parse_compressed_state_event(compressed)
.ok()
.map(|(_, id)| id)
}))
}
/// Returns a single PDU from `room_id` with key (`event_type`,
/// `state_key`).
fn state_get(
&self, shortstatehash: u64, event_type: &StateEventType, state_key: &str,
) -> Result<Option<Arc<PduEvent>>> {
self.state_get_id(shortstatehash, event_type, state_key)?
.map_or(Ok(None), |event_id| services().rooms.timeline.get_pdu(&event_id))
}
/// Returns the state hash for this pdu.
fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<Option<u64>> {
self.eventid_shorteventid
.get(event_id.as_bytes())?
.map_or(Ok(None), |shorteventid| {
self.shorteventid_shortstatehash
.get(&shorteventid)?
.map(|bytes| {
utils::u64_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Invalid shortstatehash bytes in shorteventid_shortstatehash")
})
})
.transpose()
})
}
/// Returns the full room state.
#[allow(unused_qualifications)] // async traits
async fn room_state_full(&self, room_id: &RoomId) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> {
if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? {
self.state_full(current_shortstatehash).await
} else {
Ok(HashMap::new())
}
}
/// Returns a single PDU from `room_id` with key (`event_type`,
/// `state_key`).
fn room_state_get_id(
&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
) -> Result<Option<Arc<EventId>>> {
if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? {
self.state_get_id(current_shortstatehash, event_type, state_key)
} else {
Ok(None)
}
}
/// Returns a single PDU from `room_id` with key (`event_type`,
/// `state_key`).
fn room_state_get(
&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
) -> Result<Option<Arc<PduEvent>>> {
if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? {
self.state_get(current_shortstatehash, event_type, state_key)
} else {
Ok(None)
}
}
}

View file

@ -0,0 +1,626 @@
use std::{collections::HashSet, sync::Arc};
use itertools::Itertools;
use ruma::{
events::{AnyStrippedStateEvent, AnySyncStateEvent},
serde::Raw,
OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId,
};
use tracing::error;
use crate::{
appservice::RegistrationInfo,
services, user_is_local,
utils::{self},
Error, KeyValueDatabase, Result,
};
type StrippedStateEventIter<'a> = Box<dyn Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnyStrippedStateEvent>>)>> + 'a>;
type AnySyncStateEventIter<'a> = Box<dyn Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnySyncStateEvent>>)>> + 'a>;
impl crate::rooms::state_cache::Data for KeyValueDatabase {
fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
self.roomuseroncejoinedids.insert(&userroom_id, &[])
}
fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
let roomid = room_id.as_bytes().to_vec();
let mut roomid_prefix = room_id.as_bytes().to_vec();
roomid_prefix.push(0xFF);
let mut roomuser_id = roomid_prefix.clone();
roomuser_id.push(0xFF);
roomuser_id.extend_from_slice(user_id.as_bytes());
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
self.userroomid_joined.insert(&userroom_id, &[])?;
self.roomuserid_joined.insert(&roomuser_id, &[])?;
self.userroomid_invitestate.remove(&userroom_id)?;
self.roomuserid_invitecount.remove(&roomuser_id)?;
self.userroomid_leftstate.remove(&userroom_id)?;
self.roomuserid_leftcount.remove(&roomuser_id)?;
if self
.roomuserid_joined
.scan_prefix(roomid_prefix.clone())
.count() == 0
&& self
.roomuserid_invitecount
.scan_prefix(roomid_prefix)
.count() == 0
{
self.roomid_inviteviaservers.remove(&roomid)?;
}
Ok(())
}
fn mark_as_invited(
&self, user_id: &UserId, room_id: &RoomId, last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
invite_via: Option<Vec<OwnedServerName>>,
) -> Result<()> {
let mut roomuser_id = room_id.as_bytes().to_vec();
roomuser_id.push(0xFF);
roomuser_id.extend_from_slice(user_id.as_bytes());
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
self.userroomid_invitestate.insert(
&userroom_id,
&serde_json::to_vec(&last_state.unwrap_or_default()).expect("state to bytes always works"),
)?;
self.roomuserid_invitecount
.insert(&roomuser_id, &services().globals.next_count()?.to_be_bytes())?;
self.userroomid_joined.remove(&userroom_id)?;
self.roomuserid_joined.remove(&roomuser_id)?;
self.userroomid_leftstate.remove(&userroom_id)?;
self.roomuserid_leftcount.remove(&roomuser_id)?;
if let Some(servers) = invite_via {
let mut prev_servers = self.servers_invite_via(room_id)?.unwrap_or(Vec::new());
#[allow(clippy::redundant_clone)] // this is a necessary clone?
prev_servers.append(servers.clone().as_mut());
let servers = prev_servers.iter().rev().unique().rev().collect_vec();
let servers = servers
.iter()
.map(|server| server.as_bytes())
.collect_vec()
.join(&[0xFF][..]);
self.roomid_inviteviaservers
.insert(room_id.as_bytes(), &servers)?;
}
Ok(())
}
fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
let roomid = room_id.as_bytes().to_vec();
let mut roomid_prefix = room_id.as_bytes().to_vec();
roomid_prefix.push(0xFF);
let mut roomuser_id = roomid_prefix.clone();
roomuser_id.extend_from_slice(user_id.as_bytes());
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
self.userroomid_leftstate.insert(
&userroom_id,
&serde_json::to_vec(&Vec::<Raw<AnySyncStateEvent>>::new()).unwrap(),
)?; // TODO
self.roomuserid_leftcount
.insert(&roomuser_id, &services().globals.next_count()?.to_be_bytes())?;
self.userroomid_joined.remove(&userroom_id)?;
self.roomuserid_joined.remove(&roomuser_id)?;
self.userroomid_invitestate.remove(&userroom_id)?;
self.roomuserid_invitecount.remove(&roomuser_id)?;
if self
.roomuserid_joined
.scan_prefix(roomid_prefix.clone())
.count() == 0
&& self
.roomuserid_invitecount
.scan_prefix(roomid_prefix)
.count() == 0
{
self.roomid_inviteviaservers.remove(&roomid)?;
}
Ok(())
}
fn update_joined_count(&self, room_id: &RoomId) -> Result<()> {
let mut joinedcount = 0_u64;
let mut invitedcount = 0_u64;
let mut joined_servers = HashSet::new();
let mut real_users = HashSet::new();
for joined in self.room_members(room_id).filter_map(Result::ok) {
joined_servers.insert(joined.server_name().to_owned());
if user_is_local(&joined) && !services().users.is_deactivated(&joined).unwrap_or(true) {
real_users.insert(joined);
}
joinedcount = joinedcount.saturating_add(1);
}
for _invited in self.room_members_invited(room_id).filter_map(Result::ok) {
invitedcount = invitedcount.saturating_add(1);
}
self.roomid_joinedcount
.insert(room_id.as_bytes(), &joinedcount.to_be_bytes())?;
self.roomid_invitedcount
.insert(room_id.as_bytes(), &invitedcount.to_be_bytes())?;
self.our_real_users_cache
.write()
.unwrap()
.insert(room_id.to_owned(), Arc::new(real_users));
for old_joined_server in self.room_servers(room_id).filter_map(Result::ok) {
if !joined_servers.remove(&old_joined_server) {
// Server not in room anymore
let mut roomserver_id = room_id.as_bytes().to_vec();
roomserver_id.push(0xFF);
roomserver_id.extend_from_slice(old_joined_server.as_bytes());
let mut serverroom_id = old_joined_server.as_bytes().to_vec();
serverroom_id.push(0xFF);
serverroom_id.extend_from_slice(room_id.as_bytes());
self.roomserverids.remove(&roomserver_id)?;
self.serverroomids.remove(&serverroom_id)?;
}
}
// Now only new servers are in joined_servers anymore
for server in joined_servers {
let mut roomserver_id = room_id.as_bytes().to_vec();
roomserver_id.push(0xFF);
roomserver_id.extend_from_slice(server.as_bytes());
let mut serverroom_id = server.as_bytes().to_vec();
serverroom_id.push(0xFF);
serverroom_id.extend_from_slice(room_id.as_bytes());
self.roomserverids.insert(&roomserver_id, &[])?;
self.serverroomids.insert(&serverroom_id, &[])?;
}
self.appservice_in_room_cache
.write()
.unwrap()
.remove(room_id);
Ok(())
}
#[tracing::instrument(skip(self, room_id))]
fn get_our_real_users(&self, room_id: &RoomId) -> Result<Arc<HashSet<OwnedUserId>>> {
let maybe = self
.our_real_users_cache
.read()
.unwrap()
.get(room_id)
.cloned();
if let Some(users) = maybe {
Ok(users)
} else {
self.update_joined_count(room_id)?;
Ok(Arc::clone(
self.our_real_users_cache
.read()
.unwrap()
.get(room_id)
.unwrap(),
))
}
}
#[tracing::instrument(skip(self, room_id, appservice))]
fn appservice_in_room(&self, room_id: &RoomId, appservice: &RegistrationInfo) -> Result<bool> {
let maybe = self
.appservice_in_room_cache
.read()
.unwrap()
.get(room_id)
.and_then(|map| map.get(&appservice.registration.id))
.copied();
if let Some(b) = maybe {
Ok(b)
} else {
let bridge_user_id = UserId::parse_with_server_name(
appservice.registration.sender_localpart.as_str(),
services().globals.server_name(),
)
.ok();
let in_room = bridge_user_id.map_or(false, |id| self.is_joined(&id, room_id).unwrap_or(false))
|| self
.room_members(room_id)
.any(|userid| userid.map_or(false, |userid| appservice.users.is_match(userid.as_str())));
self.appservice_in_room_cache
.write()
.unwrap()
.entry(room_id.to_owned())
.or_default()
.insert(appservice.registration.id.clone(), in_room);
Ok(in_room)
}
}
/// Makes a user forget a room.
#[tracing::instrument(skip(self))]
fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
let mut roomuser_id = room_id.as_bytes().to_vec();
roomuser_id.push(0xFF);
roomuser_id.extend_from_slice(user_id.as_bytes());
self.userroomid_leftstate.remove(&userroom_id)?;
self.roomuserid_leftcount.remove(&roomuser_id)?;
Ok(())
}
/// Returns an iterator of all servers participating in this room.
#[tracing::instrument(skip(self))]
fn room_servers<'a>(&'a self, room_id: &RoomId) -> Box<dyn Iterator<Item = Result<OwnedServerName>> + 'a> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF);
Box::new(self.roomserverids.scan_prefix(prefix).map(|(key, _)| {
ServerName::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("Server name in roomserverids is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("Server name in roomserverids is invalid."))
}))
}
#[tracing::instrument(skip(self))]
fn server_in_room(&self, server: &ServerName, room_id: &RoomId) -> Result<bool> {
let mut key = server.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(room_id.as_bytes());
self.serverroomids.get(&key).map(|o| o.is_some())
}
/// Returns an iterator of all rooms a server participates in (as far as we
/// know).
#[tracing::instrument(skip(self))]
fn server_rooms<'a>(&'a self, server: &ServerName) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
let mut prefix = server.as_bytes().to_vec();
prefix.push(0xFF);
Box::new(self.serverroomids.scan_prefix(prefix).map(|(key, _)| {
RoomId::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("RoomId in serverroomids is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("RoomId in serverroomids is invalid."))
}))
}
/// Returns an iterator over all joined members of a room.
#[tracing::instrument(skip(self))]
fn room_members<'a>(&'a self, room_id: &RoomId) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF);
Box::new(self.roomuserid_joined.scan_prefix(prefix).map(|(key, _)| {
UserId::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("User ID in roomuserid_joined is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("User ID in roomuserid_joined is invalid."))
}))
}
/// Returns the number of users which are currently in a room
#[tracing::instrument(skip(self))]
fn room_joined_count(&self, room_id: &RoomId) -> Result<Option<u64>> {
self.roomid_joinedcount
.get(room_id.as_bytes())?
.map(|b| utils::u64_from_bytes(&b).map_err(|_| Error::bad_database("Invalid joinedcount in db.")))
.transpose()
}
/// Returns the number of users which are currently invited to a room
#[tracing::instrument(skip(self))]
fn room_invited_count(&self, room_id: &RoomId) -> Result<Option<u64>> {
self.roomid_invitedcount
.get(room_id.as_bytes())?
.map(|b| utils::u64_from_bytes(&b).map_err(|_| Error::bad_database("Invalid joinedcount in db.")))
.transpose()
}
/// Returns an iterator over all User IDs who ever joined a room.
#[tracing::instrument(skip(self))]
fn room_useroncejoined<'a>(&'a self, room_id: &RoomId) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF);
Box::new(
self.roomuseroncejoinedids
.scan_prefix(prefix)
.map(|(key, _)| {
UserId::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("User ID in room_useroncejoined is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("User ID in room_useroncejoined is invalid."))
}),
)
}
/// Returns an iterator over all invited members of a room.
#[tracing::instrument(skip(self))]
fn room_members_invited<'a>(&'a self, room_id: &RoomId) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF);
Box::new(
self.roomuserid_invitecount
.scan_prefix(prefix)
.map(|(key, _)| {
UserId::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("User ID in roomuserid_invited is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("User ID in roomuserid_invited is invalid."))
}),
)
}
#[tracing::instrument(skip(self))]
fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {
let mut key = room_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(user_id.as_bytes());
self.roomuserid_invitecount
.get(&key)?
.map_or(Ok(None), |bytes| {
Ok(Some(
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid invitecount in db."))?,
))
})
}
#[tracing::instrument(skip(self))]
fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {
let mut key = room_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(user_id.as_bytes());
self.roomuserid_leftcount
.get(&key)?
.map(|bytes| utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid leftcount in db.")))
.transpose()
}
/// Returns an iterator over all rooms this user joined.
#[tracing::instrument(skip(self))]
fn rooms_joined<'a>(&'a self, user_id: &UserId) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
Box::new(
self.userroomid_joined
.scan_prefix(user_id.as_bytes().to_vec())
.map(|(key, _)| {
RoomId::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("Room ID in userroomid_joined is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("Room ID in userroomid_joined is invalid."))
}),
)
}
/// Returns an iterator over all rooms a user was invited to.
#[tracing::instrument(skip(self))]
fn rooms_invited<'a>(&'a self, user_id: &UserId) -> StrippedStateEventIter<'a> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
Box::new(
self.userroomid_invitestate
.scan_prefix(prefix)
.map(|(key, state)| {
let room_id = RoomId::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid."))?;
let state = serde_json::from_slice(&state)
.map_err(|_| Error::bad_database("Invalid state in userroomid_invitestate."))?;
Ok((room_id, state))
}),
)
}
#[tracing::instrument(skip(self))]
fn invite_state(&self, user_id: &UserId, room_id: &RoomId) -> Result<Option<Vec<Raw<AnyStrippedStateEvent>>>> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(room_id.as_bytes());
self.userroomid_invitestate
.get(&key)?
.map(|state| {
let state = serde_json::from_slice(&state)
.map_err(|_| Error::bad_database("Invalid state in userroomid_invitestate."))?;
Ok(state)
})
.transpose()
}
#[tracing::instrument(skip(self))]
fn left_state(&self, user_id: &UserId, room_id: &RoomId) -> Result<Option<Vec<Raw<AnyStrippedStateEvent>>>> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(room_id.as_bytes());
self.userroomid_leftstate
.get(&key)?
.map(|state| {
let state = serde_json::from_slice(&state)
.map_err(|_| Error::bad_database("Invalid state in userroomid_leftstate."))?;
Ok(state)
})
.transpose()
}
/// Returns an iterator over all rooms a user left.
#[tracing::instrument(skip(self))]
fn rooms_left<'a>(&'a self, user_id: &UserId) -> AnySyncStateEventIter<'a> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
Box::new(
self.userroomid_leftstate
.scan_prefix(prefix)
.map(|(key, state)| {
let room_id = RoomId::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid."))?;
let state = serde_json::from_slice(&state)
.map_err(|_| Error::bad_database("Invalid state in userroomid_leftstate."))?;
Ok((room_id, state))
}),
)
}
#[tracing::instrument(skip(self))]
fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
Ok(self.roomuseroncejoinedids.get(&userroom_id)?.is_some())
}
#[tracing::instrument(skip(self))]
fn is_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
Ok(self.userroomid_joined.get(&userroom_id)?.is_some())
}
#[tracing::instrument(skip(self))]
fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
Ok(self.userroomid_invitestate.get(&userroom_id)?.is_some())
}
#[tracing::instrument(skip(self))]
fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
Ok(self.userroomid_leftstate.get(&userroom_id)?.is_some())
}
#[tracing::instrument(skip(self))]
fn servers_invite_via(&self, room_id: &RoomId) -> Result<Option<Vec<OwnedServerName>>> {
let mut key = room_id.as_bytes().to_vec();
key.push(0xFF);
self.roomid_inviteviaservers
.get(&key)?
.map(|servers| {
let state = serde_json::from_slice(&servers).map_err(|e| {
error!("Invalid state in userroomid_leftstate: {e}");
Error::bad_database("Invalid state in userroomid_leftstate.")
})?;
Ok(state)
})
.transpose()
}
#[tracing::instrument(skip(self))]
fn add_servers_invite_via(&self, room_id: &RoomId, servers: &[OwnedServerName]) -> Result<()> {
let mut prev_servers = self.servers_invite_via(room_id)?.unwrap_or(Vec::new());
prev_servers.append(servers.to_owned().as_mut());
let servers = prev_servers.iter().rev().unique().rev().collect_vec();
let servers = servers
.iter()
.map(|server| server.as_bytes())
.collect_vec()
.join(&[0xFF][..]);
self.roomid_inviteviaservers
.insert(room_id.as_bytes(), &servers)?;
Ok(())
}
}

View file

@ -0,0 +1,60 @@
use std::{collections::HashSet, mem::size_of, sync::Arc};
use crate::{rooms::state_compressor::data::StateDiff, utils, Error, KeyValueDatabase, Result};
impl crate::rooms::state_compressor::Data for KeyValueDatabase {
fn get_statediff(&self, shortstatehash: u64) -> Result<StateDiff> {
let value = self
.shortstatehash_statediff
.get(&shortstatehash.to_be_bytes())?
.ok_or_else(|| Error::bad_database("State hash does not exist"))?;
let parent = utils::u64_from_bytes(&value[0..size_of::<u64>()]).expect("bytes have right length");
let parent = if parent != 0 {
Some(parent)
} else {
None
};
let mut add_mode = true;
let mut added = HashSet::new();
let mut removed = HashSet::new();
let mut i = size_of::<u64>();
while let Some(v) = value.get(i..i + 2 * size_of::<u64>()) {
if add_mode && v.starts_with(&0_u64.to_be_bytes()) {
add_mode = false;
i += size_of::<u64>();
continue;
}
if add_mode {
added.insert(v.try_into().expect("we checked the size above"));
} else {
removed.insert(v.try_into().expect("we checked the size above"));
}
i += 2 * size_of::<u64>();
}
Ok(StateDiff {
parent,
added: Arc::new(added),
removed: Arc::new(removed),
})
}
fn save_statediff(&self, shortstatehash: u64, diff: StateDiff) -> Result<()> {
let mut value = diff.parent.unwrap_or(0).to_be_bytes().to_vec();
for new in diff.added.iter() {
value.extend_from_slice(&new[..]);
}
if !diff.removed.is_empty() {
value.extend_from_slice(&0_u64.to_be_bytes());
for removed in diff.removed.iter() {
value.extend_from_slice(&removed[..]);
}
}
self.shortstatehash_statediff
.insert(&shortstatehash.to_be_bytes(), &value)
}
}

View file

@ -0,0 +1,75 @@
use std::mem;
use ruma::{api::client::threads::get_threads::v1::IncludeThreads, OwnedUserId, RoomId, UserId};
use crate::{services, utils, Error, KeyValueDatabase, PduEvent, Result};
type PduEventIterResult<'a> = Result<Box<dyn Iterator<Item = Result<(u64, PduEvent)>> + 'a>>;
impl crate::rooms::threads::Data for KeyValueDatabase {
fn threads_until<'a>(
&'a self, user_id: &'a UserId, room_id: &'a RoomId, until: u64, _include: &'a IncludeThreads,
) -> PduEventIterResult<'a> {
let prefix = services()
.rooms
.short
.get_shortroomid(room_id)?
.expect("room exists")
.to_be_bytes()
.to_vec();
let mut current = prefix.clone();
current.extend_from_slice(&(until - 1).to_be_bytes());
Ok(Box::new(
self.threadid_userids
.iter_from(&current, true)
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(move |(pduid, _users)| {
let count = utils::u64_from_bytes(&pduid[(mem::size_of::<u64>())..])
.map_err(|_| Error::bad_database("Invalid pduid in threadid_userids."))?;
let mut pdu = services()
.rooms
.timeline
.get_pdu_from_id(&pduid)?
.ok_or_else(|| Error::bad_database("Invalid pduid reference in threadid_userids"))?;
if pdu.sender != user_id {
pdu.remove_transaction_id()?;
}
Ok((count, pdu))
}),
))
}
fn update_participants(&self, root_id: &[u8], participants: &[OwnedUserId]) -> Result<()> {
let users = participants
.iter()
.map(|user| user.as_bytes())
.collect::<Vec<_>>()
.join(&[0xFF][..]);
self.threadid_userids.insert(root_id, &users)?;
Ok(())
}
fn get_participants(&self, root_id: &[u8]) -> Result<Option<Vec<OwnedUserId>>> {
if let Some(users) = self.threadid_userids.get(root_id)? {
Ok(Some(
users
.split(|b| *b == 0xFF)
.map(|bytes| {
UserId::parse(
utils::string_from_bytes(bytes)
.map_err(|_| Error::bad_database("Invalid UserId bytes in threadid_userids."))?,
)
.map_err(|_| Error::bad_database("Invalid UserId in threadid_userids."))
})
.filter_map(Result::ok)
.collect(),
))
} else {
Ok(None)
}
}
}

View file

@ -0,0 +1,295 @@
use std::{collections::hash_map, mem::size_of, sync::Arc};
use ruma::{api::client::error::ErrorKind, CanonicalJsonObject, EventId, OwnedUserId, RoomId, UserId};
use tracing::error;
use crate::{services, utils, Error, KeyValueDatabase, PduCount, PduEvent, Result};
impl crate::rooms::timeline::Data for KeyValueDatabase {
fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result<PduCount> {
match self
.lasttimelinecount_cache
.lock()
.unwrap()
.entry(room_id.to_owned())
{
hash_map::Entry::Vacant(v) => {
if let Some(last_count) = self
.pdus_until(sender_user, room_id, PduCount::max())?
.find_map(|r| {
// Filter out buggy events
if r.is_err() {
error!("Bad pdu in pdus_since: {:?}", r);
}
r.ok()
}) {
Ok(*v.insert(last_count.0))
} else {
Ok(PduCount::Normal(0))
}
},
hash_map::Entry::Occupied(o) => Ok(*o.get()),
}
}
/// Returns the `count` of this pdu's id.
fn get_pdu_count(&self, event_id: &EventId) -> Result<Option<PduCount>> {
self.eventid_pduid
.get(event_id.as_bytes())?
.map(|pdu_id| pdu_count(&pdu_id))
.transpose()
}
/// Returns the json of a pdu.
fn get_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
self.get_non_outlier_pdu_json(event_id)?.map_or_else(
|| {
self.eventid_outlierpdu
.get(event_id.as_bytes())?
.map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")))
.transpose()
},
|x| Ok(Some(x)),
)
}
/// Returns the json of a pdu.
fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
self.eventid_pduid
.get(event_id.as_bytes())?
.map(|pduid| {
self.pduid_pdu
.get(&pduid)?
.ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid."))
})
.transpose()?
.map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")))
.transpose()
}
/// Returns the pdu's id.
fn get_pdu_id(&self, event_id: &EventId) -> Result<Option<Vec<u8>>> { self.eventid_pduid.get(event_id.as_bytes()) }
/// Returns the pdu.
fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>> {
self.eventid_pduid
.get(event_id.as_bytes())?
.map(|pduid| {
self.pduid_pdu
.get(&pduid)?
.ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid."))
})
.transpose()?
.map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")))
.transpose()
}
/// Returns the pdu.
///
/// Checks the `eventid_outlierpdu` Tree if not found in the timeline.
fn get_pdu(&self, event_id: &EventId) -> Result<Option<Arc<PduEvent>>> {
if let Some(pdu) = self
.get_non_outlier_pdu(event_id)?
.map_or_else(
|| {
self.eventid_outlierpdu
.get(event_id.as_bytes())?
.map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")))
.transpose()
},
|x| Ok(Some(x)),
)?
.map(Arc::new)
{
Ok(Some(pdu))
} else {
Ok(None)
}
}
/// Returns the pdu.
///
/// This does __NOT__ check the outliers `Tree`.
fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result<Option<PduEvent>> {
self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| {
Ok(Some(
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))?,
))
})
}
/// Returns the pdu as a `BTreeMap<String, CanonicalJsonValue>`.
fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result<Option<CanonicalJsonObject>> {
self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| {
Ok(Some(
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))?,
))
})
}
fn append_pdu(&self, pdu_id: &[u8], pdu: &PduEvent, json: &CanonicalJsonObject, count: u64) -> Result<()> {
self.pduid_pdu.insert(
pdu_id,
&serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"),
)?;
self.lasttimelinecount_cache
.lock()
.unwrap()
.insert(pdu.room_id.clone(), PduCount::Normal(count));
self.eventid_pduid.insert(pdu.event_id.as_bytes(), pdu_id)?;
self.eventid_outlierpdu.remove(pdu.event_id.as_bytes())?;
Ok(())
}
fn prepend_backfill_pdu(&self, pdu_id: &[u8], event_id: &EventId, json: &CanonicalJsonObject) -> Result<()> {
self.pduid_pdu.insert(
pdu_id,
&serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"),
)?;
self.eventid_pduid.insert(event_id.as_bytes(), pdu_id)?;
self.eventid_outlierpdu.remove(event_id.as_bytes())?;
Ok(())
}
/// Removes a pdu and creates a new one with the same id.
fn replace_pdu(&self, pdu_id: &[u8], pdu_json: &CanonicalJsonObject, _pdu: &PduEvent) -> Result<()> {
if self.pduid_pdu.get(pdu_id)?.is_some() {
self.pduid_pdu.insert(
pdu_id,
&serde_json::to_vec(pdu_json).expect("CanonicalJsonObject is always a valid"),
)?;
} else {
return Err(Error::BadRequest(ErrorKind::NotFound, "PDU does not exist."));
}
Ok(())
}
/// Returns an iterator over all events and their tokens in a room that
/// happened before the event with id `until` in reverse-chronological
/// order.
fn pdus_until<'a>(
&'a self, user_id: &UserId, room_id: &RoomId, until: PduCount,
) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>> {
let (prefix, current) = count_to_id(room_id, until, 1, true)?;
let user_id = user_id.to_owned();
Ok(Box::new(
self.pduid_pdu
.iter_from(&current, true)
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(move |(pdu_id, v)| {
let mut pdu = serde_json::from_slice::<PduEvent>(&v)
.map_err(|_| Error::bad_database("PDU in db is invalid."))?;
if pdu.sender != user_id {
pdu.remove_transaction_id()?;
}
pdu.add_age()?;
let count = pdu_count(&pdu_id)?;
Ok((count, pdu))
}),
))
}
fn pdus_after<'a>(
&'a self, user_id: &UserId, room_id: &RoomId, from: PduCount,
) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>> {
let (prefix, current) = count_to_id(room_id, from, 1, false)?;
let user_id = user_id.to_owned();
Ok(Box::new(
self.pduid_pdu
.iter_from(&current, false)
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(move |(pdu_id, v)| {
let mut pdu = serde_json::from_slice::<PduEvent>(&v)
.map_err(|_| Error::bad_database("PDU in db is invalid."))?;
if pdu.sender != user_id {
pdu.remove_transaction_id()?;
}
pdu.add_age()?;
let count = pdu_count(&pdu_id)?;
Ok((count, pdu))
}),
))
}
fn increment_notification_counts(
&self, room_id: &RoomId, notifies: Vec<OwnedUserId>, highlights: Vec<OwnedUserId>,
) -> Result<()> {
let mut notifies_batch = Vec::new();
let mut highlights_batch = Vec::new();
for user in notifies {
let mut userroom_id = user.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
notifies_batch.push(userroom_id);
}
for user in highlights {
let mut userroom_id = user.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
highlights_batch.push(userroom_id);
}
self.userroomid_notificationcount
.increment_batch(&mut notifies_batch.into_iter())?;
self.userroomid_highlightcount
.increment_batch(&mut highlights_batch.into_iter())?;
Ok(())
}
}
/// Returns the `count` of this pdu's id.
fn pdu_count(pdu_id: &[u8]) -> Result<PduCount> {
let last_u64 = utils::u64_from_bytes(&pdu_id[pdu_id.len() - size_of::<u64>()..])
.map_err(|_| Error::bad_database("PDU has invalid count bytes."))?;
let second_last_u64 =
utils::u64_from_bytes(&pdu_id[pdu_id.len() - 2 * size_of::<u64>()..pdu_id.len() - size_of::<u64>()]);
if matches!(second_last_u64, Ok(0)) {
Ok(PduCount::Backfilled(u64::MAX.saturating_sub(last_u64)))
} else {
Ok(PduCount::Normal(last_u64))
}
}
fn count_to_id(room_id: &RoomId, count: PduCount, offset: u64, subtract: bool) -> Result<(Vec<u8>, Vec<u8>)> {
let prefix = services()
.rooms
.short
.get_shortroomid(room_id)?
.ok_or_else(|| Error::bad_database("Looked for bad shortroomid in timeline"))?
.to_be_bytes()
.to_vec();
let mut pdu_id = prefix.clone();
// +1 so we don't send the base event
let count_raw = match count {
PduCount::Normal(x) => {
if subtract {
x.saturating_sub(offset)
} else {
x.saturating_add(offset)
}
},
PduCount::Backfilled(x) => {
pdu_id.extend_from_slice(&0_u64.to_be_bytes());
let num = u64::MAX.saturating_sub(x);
if subtract {
num.saturating_sub(offset)
} else {
num.saturating_add(offset)
}
},
};
pdu_id.extend_from_slice(&count_raw.to_be_bytes());
Ok((prefix, pdu_id))
}

View file

@ -0,0 +1,137 @@
use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId};
use crate::{services, utils, Error, KeyValueDatabase, Result};
impl crate::rooms::user::Data for KeyValueDatabase {
fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
let mut roomuser_id = room_id.as_bytes().to_vec();
roomuser_id.push(0xFF);
roomuser_id.extend_from_slice(user_id.as_bytes());
self.userroomid_notificationcount
.insert(&userroom_id, &0_u64.to_be_bytes())?;
self.userroomid_highlightcount
.insert(&userroom_id, &0_u64.to_be_bytes())?;
self.roomuserid_lastnotificationread
.insert(&roomuser_id, &services().globals.next_count()?.to_be_bytes())?;
Ok(())
}
fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
self.userroomid_notificationcount
.get(&userroom_id)?
.map_or(Ok(0), |bytes| {
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid notification count in db."))
})
}
fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
self.userroomid_highlightcount
.get(&userroom_id)?
.map_or(Ok(0), |bytes| {
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid highlight count in db."))
})
}
fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
let mut key = room_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(user_id.as_bytes());
Ok(self
.roomuserid_lastnotificationread
.get(&key)?
.map(|bytes| {
utils::u64_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid."))
})
.transpose()?
.unwrap_or(0))
}
fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64, shortstatehash: u64) -> Result<()> {
let shortroomid = services()
.rooms
.short
.get_shortroomid(room_id)?
.expect("room exists");
let mut key = shortroomid.to_be_bytes().to_vec();
key.extend_from_slice(&token.to_be_bytes());
self.roomsynctoken_shortstatehash
.insert(&key, &shortstatehash.to_be_bytes())
}
fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<Option<u64>> {
let shortroomid = services()
.rooms
.short
.get_shortroomid(room_id)?
.expect("room exists");
let mut key = shortroomid.to_be_bytes().to_vec();
key.extend_from_slice(&token.to_be_bytes());
self.roomsynctoken_shortstatehash
.get(&key)?
.map(|bytes| {
utils::u64_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Invalid shortstatehash in roomsynctoken_shortstatehash"))
})
.transpose()
}
fn get_shared_rooms<'a>(
&'a self, users: Vec<OwnedUserId>,
) -> Result<Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a>> {
let iterators = users.into_iter().map(move |user_id| {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
self.userroomid_joined
.scan_prefix(prefix)
.map(|(key, _)| {
let roomid_index = key
.iter()
.enumerate()
.find(|(_, &b)| b == 0xFF)
.ok_or_else(|| Error::bad_database("Invalid userroomid_joined in db."))?
.0
.saturating_add(1); // +1 because the room id starts AFTER the separator
let room_id = key[roomid_index..].to_vec();
Ok::<_, Error>(room_id)
})
.filter_map(Result::ok)
});
// We use the default compare function because keys are sorted correctly (not
// reversed)
Ok(Box::new(
utils::common_elements(iterators, Ord::cmp)
.expect("users is not empty")
.map(|bytes| {
RoomId::parse(
utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Invalid RoomId bytes in userroomid_joined"))?,
)
.map_err(|_| Error::bad_database("Invalid RoomId in userroomid_joined."))
}),
))
}
}

View file

@ -0,0 +1,191 @@
use ruma::{ServerName, UserId};
use crate::{
sending::{Destination, SendingEvent},
services, utils, Error, KeyValueDatabase, Result,
};
impl crate::sending::Data for KeyValueDatabase {
fn active_requests<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(Vec<u8>, Destination, SendingEvent)>> + 'a> {
Box::new(
self.servercurrentevent_data
.iter()
.map(|(key, v)| parse_servercurrentevent(&key, v).map(|(k, e)| (key, k, e))),
)
}
fn active_requests_for<'a>(
&'a self, destination: &Destination,
) -> Box<dyn Iterator<Item = Result<(Vec<u8>, SendingEvent)>> + 'a> {
let prefix = destination.get_prefix();
Box::new(
self.servercurrentevent_data
.scan_prefix(prefix)
.map(|(key, v)| parse_servercurrentevent(&key, v).map(|(_, e)| (key, e))),
)
}
fn delete_active_request(&self, key: Vec<u8>) -> Result<()> { self.servercurrentevent_data.remove(&key) }
fn delete_all_active_requests_for(&self, destination: &Destination) -> Result<()> {
let prefix = destination.get_prefix();
for (key, _) in self.servercurrentevent_data.scan_prefix(prefix) {
self.servercurrentevent_data.remove(&key)?;
}
Ok(())
}
fn delete_all_requests_for(&self, destination: &Destination) -> Result<()> {
let prefix = destination.get_prefix();
for (key, _) in self.servercurrentevent_data.scan_prefix(prefix.clone()) {
self.servercurrentevent_data.remove(&key).unwrap();
}
for (key, _) in self.servernameevent_data.scan_prefix(prefix) {
self.servernameevent_data.remove(&key).unwrap();
}
Ok(())
}
fn queue_requests(&self, requests: &[(&Destination, SendingEvent)]) -> Result<Vec<Vec<u8>>> {
let mut batch = Vec::new();
let mut keys = Vec::new();
for (destination, event) in requests {
let mut key = destination.get_prefix();
if let SendingEvent::Pdu(value) = &event {
key.extend_from_slice(value);
} else {
key.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
}
let value = if let SendingEvent::Edu(value) = &event {
&**value
} else {
&[]
};
batch.push((key.clone(), value.to_owned()));
keys.push(key);
}
self.servernameevent_data
.insert_batch(&mut batch.into_iter())?;
Ok(keys)
}
fn queued_requests<'a>(
&'a self, destination: &Destination,
) -> Box<dyn Iterator<Item = Result<(SendingEvent, Vec<u8>)>> + 'a> {
let prefix = destination.get_prefix();
return Box::new(
self.servernameevent_data
.scan_prefix(prefix)
.map(|(k, v)| parse_servercurrentevent(&k, v).map(|(_, ev)| (ev, k))),
);
}
fn mark_as_active(&self, events: &[(SendingEvent, Vec<u8>)]) -> Result<()> {
for (e, key) in events {
if key.is_empty() {
continue;
}
let value = if let SendingEvent::Edu(value) = &e {
&**value
} else {
&[]
};
self.servercurrentevent_data.insert(key, value)?;
self.servernameevent_data.remove(key)?;
}
Ok(())
}
fn set_latest_educount(&self, server_name: &ServerName, last_count: u64) -> Result<()> {
self.servername_educount
.insert(server_name.as_bytes(), &last_count.to_be_bytes())
}
fn get_latest_educount(&self, server_name: &ServerName) -> Result<u64> {
self.servername_educount
.get(server_name.as_bytes())?
.map_or(Ok(0), |bytes| {
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid u64 in servername_educount."))
})
}
}
#[tracing::instrument(skip(key))]
fn parse_servercurrentevent(key: &[u8], value: Vec<u8>) -> Result<(Destination, SendingEvent)> {
// Appservices start with a plus
Ok::<_, Error>(if key.starts_with(b"+") {
let mut parts = key[1..].splitn(2, |&b| b == 0xFF);
let server = parts.next().expect("splitn always returns one element");
let event = parts
.next()
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
let server = utils::string_from_bytes(server)
.map_err(|_| Error::bad_database("Invalid server bytes in server_currenttransaction"))?;
(
Destination::Appservice(server),
if value.is_empty() {
SendingEvent::Pdu(event.to_vec())
} else {
SendingEvent::Edu(value)
},
)
} else if key.starts_with(b"$") {
let mut parts = key[1..].splitn(3, |&b| b == 0xFF);
let user = parts.next().expect("splitn always returns one element");
let user_string = utils::string_from_bytes(user)
.map_err(|_| Error::bad_database("Invalid user string in servercurrentevent"))?;
let user_id =
UserId::parse(user_string).map_err(|_| Error::bad_database("Invalid user id in servercurrentevent"))?;
let pushkey = parts
.next()
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
let pushkey_string = utils::string_from_bytes(pushkey)
.map_err(|_| Error::bad_database("Invalid pushkey in servercurrentevent"))?;
let event = parts
.next()
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
(
Destination::Push(user_id, pushkey_string),
if value.is_empty() {
SendingEvent::Pdu(event.to_vec())
} else {
// I'm pretty sure this should never be called
SendingEvent::Edu(value)
},
)
} else {
let mut parts = key.splitn(2, |&b| b == 0xFF);
let server = parts.next().expect("splitn always returns one element");
let event = parts
.next()
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
let server = utils::string_from_bytes(server)
.map_err(|_| Error::bad_database("Invalid server bytes in server_currenttransaction"))?;
(
Destination::Normal(
ServerName::parse(server)
.map_err(|_| Error::bad_database("Invalid server string in server_currenttransaction"))?,
),
if value.is_empty() {
SendingEvent::Pdu(event.to_vec())
} else {
SendingEvent::Edu(value)
},
)
})
}

View file

@ -0,0 +1,32 @@
use ruma::{DeviceId, TransactionId, UserId};
use crate::{KeyValueDatabase, Result};
impl crate::transaction_ids::Data for KeyValueDatabase {
fn add_txnid(
&self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, data: &[u8],
) -> Result<()> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default());
key.push(0xFF);
key.extend_from_slice(txn_id.as_bytes());
self.userdevicetxnid_response.insert(&key, data)?;
Ok(())
}
fn existing_txnid(
&self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId,
) -> Result<Option<Vec<u8>>> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default());
key.push(0xFF);
key.extend_from_slice(txn_id.as_bytes());
// If there's no entry, this is a new transaction
self.userdevicetxnid_response.get(&key)
}
}

View file

@ -0,0 +1,68 @@
use ruma::{
api::client::{error::ErrorKind, uiaa::UiaaInfo},
CanonicalJsonValue, DeviceId, UserId,
};
use crate::{Error, KeyValueDatabase, Result};
impl crate::uiaa::Data for KeyValueDatabase {
fn set_uiaa_request(
&self, user_id: &UserId, device_id: &DeviceId, session: &str, request: &CanonicalJsonValue,
) -> Result<()> {
self.userdevicesessionid_uiaarequest
.write()
.unwrap()
.insert(
(user_id.to_owned(), device_id.to_owned(), session.to_owned()),
request.to_owned(),
);
Ok(())
}
fn get_uiaa_request(&self, user_id: &UserId, device_id: &DeviceId, session: &str) -> Option<CanonicalJsonValue> {
self.userdevicesessionid_uiaarequest
.read()
.unwrap()
.get(&(user_id.to_owned(), device_id.to_owned(), session.to_owned()))
.map(ToOwned::to_owned)
}
fn update_uiaa_session(
&self, user_id: &UserId, device_id: &DeviceId, session: &str, uiaainfo: Option<&UiaaInfo>,
) -> Result<()> {
let mut userdevicesessionid = user_id.as_bytes().to_vec();
userdevicesessionid.push(0xFF);
userdevicesessionid.extend_from_slice(device_id.as_bytes());
userdevicesessionid.push(0xFF);
userdevicesessionid.extend_from_slice(session.as_bytes());
if let Some(uiaainfo) = uiaainfo {
self.userdevicesessionid_uiaainfo.insert(
&userdevicesessionid,
&serde_json::to_vec(&uiaainfo).expect("UiaaInfo::to_vec always works"),
)?;
} else {
self.userdevicesessionid_uiaainfo
.remove(&userdevicesessionid)?;
}
Ok(())
}
fn get_uiaa_session(&self, user_id: &UserId, device_id: &DeviceId, session: &str) -> Result<UiaaInfo> {
let mut userdevicesessionid = user_id.as_bytes().to_vec();
userdevicesessionid.push(0xFF);
userdevicesessionid.extend_from_slice(device_id.as_bytes());
userdevicesessionid.push(0xFF);
userdevicesessionid.extend_from_slice(session.as_bytes());
serde_json::from_slice(
&self
.userdevicesessionid_uiaainfo
.get(&userdevicesessionid)?
.ok_or(Error::BadRequest(ErrorKind::forbidden(), "UIAA session does not exist."))?,
)
.map_err(|_| Error::bad_database("UiaaInfo in userdeviceid_uiaainfo is invalid."))
}
}

View file

@ -0,0 +1,898 @@
use std::{collections::BTreeMap, mem::size_of};
use argon2::{password_hash::SaltString, PasswordHasher};
use ruma::{
api::client::{device::Device, error::ErrorKind, filter::FilterDefinition},
encryption::{CrossSigningKey, DeviceKeys, OneTimeKey},
events::{AnyToDeviceEvent, StateEventType},
serde::Raw,
uint, DeviceId, DeviceKeyAlgorithm, DeviceKeyId, MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedDeviceKeyId,
OwnedMxcUri, OwnedUserId, UInt, UserId,
};
use tracing::warn;
use crate::{services, users::clean_signatures, utils, Error, KeyValueDatabase, Result};
impl crate::users::Data for KeyValueDatabase {
/// Check if a user has an account on this homeserver.
fn exists(&self, user_id: &UserId) -> Result<bool> { Ok(self.userid_password.get(user_id.as_bytes())?.is_some()) }
/// Check if account is deactivated
fn is_deactivated(&self, user_id: &UserId) -> Result<bool> {
Ok(self
.userid_password
.get(user_id.as_bytes())?
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "User does not exist."))?
.is_empty())
}
/// Returns the number of users registered on this server.
fn count(&self) -> Result<usize> { Ok(self.userid_password.iter().count()) }
/// Find out which user an access token belongs to.
fn find_from_token(&self, token: &str) -> Result<Option<(OwnedUserId, String)>> {
self.token_userdeviceid
.get(token.as_bytes())?
.map_or(Ok(None), |bytes| {
let mut parts = bytes.split(|&b| b == 0xFF);
let user_bytes = parts
.next()
.ok_or_else(|| Error::bad_database("User ID in token_userdeviceid is invalid."))?;
let device_bytes = parts
.next()
.ok_or_else(|| Error::bad_database("Device ID in token_userdeviceid is invalid."))?;
Ok(Some((
UserId::parse(
utils::string_from_bytes(user_bytes)
.map_err(|_| Error::bad_database("User ID in token_userdeviceid is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("User ID in token_userdeviceid is invalid."))?,
utils::string_from_bytes(device_bytes)
.map_err(|_| Error::bad_database("Device ID in token_userdeviceid is invalid."))?,
)))
})
}
/// Returns an iterator over all users on this homeserver.
fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
Box::new(self.userid_password.iter().map(|(bytes, _)| {
UserId::parse(
utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("User ID in userid_password is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("User ID in userid_password is invalid."))
}))
}
/// Returns a list of local users as list of usernames.
///
/// A user account is considered `local` if the length of it's password is
/// greater then zero.
fn list_local_users(&self) -> Result<Vec<String>> {
let users: Vec<String> = self
.userid_password
.iter()
.filter_map(|(username, pw)| get_username_with_valid_password(&username, &pw))
.collect();
Ok(users)
}
/// Returns the password hash for the given user.
fn password_hash(&self, user_id: &UserId) -> Result<Option<String>> {
self.userid_password
.get(user_id.as_bytes())?
.map_or(Ok(None), |bytes| {
Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Password hash in db is not valid string.")
})?))
})
}
/// Hash and set the user's password to the Argon2 hash
fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> {
if let Some(password) = password {
if let Ok(hash) = calculate_password_hash(password) {
self.userid_password
.insert(user_id.as_bytes(), hash.as_bytes())?;
Ok(())
} else {
Err(Error::BadRequest(
ErrorKind::InvalidParam,
"Password does not meet the requirements.",
))
}
} else {
self.userid_password.insert(user_id.as_bytes(), b"")?;
Ok(())
}
}
/// Returns the displayname of a user on this homeserver.
fn displayname(&self, user_id: &UserId) -> Result<Option<String>> {
self.userid_displayname
.get(user_id.as_bytes())?
.map_or(Ok(None), |bytes| {
Ok(Some(
utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Displayname in db is invalid."))?,
))
})
}
/// Sets a new displayname or removes it if displayname is None. You still
/// need to nofify all rooms of this change.
fn set_displayname(&self, user_id: &UserId, displayname: Option<String>) -> Result<()> {
if let Some(displayname) = displayname {
self.userid_displayname
.insert(user_id.as_bytes(), displayname.as_bytes())?;
} else {
self.userid_displayname.remove(user_id.as_bytes())?;
}
Ok(())
}
/// Get the `avatar_url` of a user.
fn avatar_url(&self, user_id: &UserId) -> Result<Option<OwnedMxcUri>> {
self.userid_avatarurl
.get(user_id.as_bytes())?
.map(|bytes| {
let s_bytes = utils::string_from_bytes(&bytes).map_err(|e| {
warn!("Avatar URL in db is invalid: {}", e);
Error::bad_database("Avatar URL in db is invalid.")
})?;
let mxc_uri: OwnedMxcUri = s_bytes.into();
Ok(mxc_uri)
})
.transpose()
}
/// Sets a new avatar_url or removes it if avatar_url is None.
fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option<OwnedMxcUri>) -> Result<()> {
if let Some(avatar_url) = avatar_url {
self.userid_avatarurl
.insert(user_id.as_bytes(), avatar_url.to_string().as_bytes())?;
} else {
self.userid_avatarurl.remove(user_id.as_bytes())?;
}
Ok(())
}
/// Get the blurhash of a user.
fn blurhash(&self, user_id: &UserId) -> Result<Option<String>> {
self.userid_blurhash
.get(user_id.as_bytes())?
.map(|bytes| {
let s = utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Avatar URL in db is invalid."))?;
Ok(s)
})
.transpose()
}
/// Sets a new avatar_url or removes it if avatar_url is None.
fn set_blurhash(&self, user_id: &UserId, blurhash: Option<String>) -> Result<()> {
if let Some(blurhash) = blurhash {
self.userid_blurhash
.insert(user_id.as_bytes(), blurhash.as_bytes())?;
} else {
self.userid_blurhash.remove(user_id.as_bytes())?;
}
Ok(())
}
/// Adds a new device to a user.
fn create_device(
&self, user_id: &UserId, device_id: &DeviceId, token: &str, initial_device_display_name: Option<String>,
) -> Result<()> {
// This method should never be called for nonexistent users. We shouldn't assert
// though...
if !self.exists(user_id)? {
warn!("Called create_device for non-existent user {} in database", user_id);
return Err(Error::BadRequest(ErrorKind::InvalidParam, "User does not exist."));
}
let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes());
self.userid_devicelistversion
.increment(user_id.as_bytes())?;
self.userdeviceid_metadata.insert(
&userdeviceid,
&serde_json::to_vec(&Device {
device_id: device_id.into(),
display_name: initial_device_display_name,
last_seen_ip: None, // TODO
last_seen_ts: Some(MilliSecondsSinceUnixEpoch::now()),
})
.expect("Device::to_string never fails."),
)?;
self.set_token(user_id, device_id, token)?;
Ok(())
}
/// Removes a device from a user.
fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> {
let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes());
// Remove tokens
if let Some(old_token) = self.userdeviceid_token.get(&userdeviceid)? {
self.userdeviceid_token.remove(&userdeviceid)?;
self.token_userdeviceid.remove(&old_token)?;
}
// Remove todevice events
let mut prefix = userdeviceid.clone();
prefix.push(0xFF);
for (key, _) in self.todeviceid_events.scan_prefix(prefix) {
self.todeviceid_events.remove(&key)?;
}
// TODO: Remove onetimekeys
self.userid_devicelistversion
.increment(user_id.as_bytes())?;
self.userdeviceid_metadata.remove(&userdeviceid)?;
Ok(())
}
/// Returns an iterator over all device ids of this user.
fn all_device_ids<'a>(&'a self, user_id: &UserId) -> Box<dyn Iterator<Item = Result<OwnedDeviceId>> + 'a> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
// All devices have metadata
Box::new(
self.userdeviceid_metadata
.scan_prefix(prefix)
.map(|(bytes, _)| {
Ok(utils::string_from_bytes(
bytes
.rsplit(|&b| b == 0xFF)
.next()
.ok_or_else(|| Error::bad_database("UserDevice ID in db is invalid."))?,
)
.map_err(|_| Error::bad_database("Device ID in userdeviceid_metadata is invalid."))?
.into())
}),
)
}
/// Replaces the access token of one device.
fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> {
let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes());
// should not be None, but we shouldn't assert either lol...
if self.userdeviceid_metadata.get(&userdeviceid)?.is_none() {
warn!(
"Called set_token for a non-existent user \"{}\" and/or device ID \"{}\" with no metadata in database",
user_id, device_id
);
return Err(Error::bad_database(
"User does not exist or device ID has no metadata in database.",
));
}
// Remove old token
if let Some(old_token) = self.userdeviceid_token.get(&userdeviceid)? {
self.token_userdeviceid.remove(&old_token)?;
// It will be removed from userdeviceid_token by the insert later
}
// Assign token to user device combination
self.userdeviceid_token
.insert(&userdeviceid, token.as_bytes())?;
self.token_userdeviceid
.insert(token.as_bytes(), &userdeviceid)?;
Ok(())
}
fn add_one_time_key(
&self, user_id: &UserId, device_id: &DeviceId, one_time_key_key: &DeviceKeyId,
one_time_key_value: &Raw<OneTimeKey>,
) -> Result<()> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(device_id.as_bytes());
// All devices have metadata
// Only existing devices should be able to call this, but we shouldn't assert
// either...
if self.userdeviceid_metadata.get(&key)?.is_none() {
warn!(
"Called add_one_time_key for a non-existent user \"{}\" and/or device ID \"{}\" with no metadata in \
database",
user_id, device_id
);
return Err(Error::bad_database(
"User does not exist or device ID has no metadata in database.",
));
}
key.push(0xFF);
// TODO: Use DeviceKeyId::to_string when it's available (and update everything,
// because there are no wrapping quotation marks anymore)
key.extend_from_slice(
serde_json::to_string(one_time_key_key)
.expect("DeviceKeyId::to_string always works")
.as_bytes(),
);
self.onetimekeyid_onetimekeys.insert(
&key,
&serde_json::to_vec(&one_time_key_value).expect("OneTimeKey::to_vec always works"),
)?;
self.userid_lastonetimekeyupdate
.insert(user_id.as_bytes(), &services().globals.next_count()?.to_be_bytes())?;
Ok(())
}
fn last_one_time_keys_update(&self, user_id: &UserId) -> Result<u64> {
self.userid_lastonetimekeyupdate
.get(user_id.as_bytes())?
.map_or(Ok(0), |bytes| {
utils::u64_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Count in roomid_lastroomactiveupdate is invalid."))
})
}
fn take_one_time_key(
&self, user_id: &UserId, device_id: &DeviceId, key_algorithm: &DeviceKeyAlgorithm,
) -> Result<Option<(OwnedDeviceKeyId, Raw<OneTimeKey>)>> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(device_id.as_bytes());
prefix.push(0xFF);
prefix.push(b'"'); // Annoying quotation mark
prefix.extend_from_slice(key_algorithm.as_ref().as_bytes());
prefix.push(b':');
self.userid_lastonetimekeyupdate
.insert(user_id.as_bytes(), &services().globals.next_count()?.to_be_bytes())?;
self.onetimekeyid_onetimekeys
.scan_prefix(prefix)
.next()
.map(|(key, value)| {
self.onetimekeyid_onetimekeys.remove(&key)?;
Ok((
serde_json::from_slice(
key.rsplit(|&b| b == 0xFF)
.next()
.ok_or_else(|| Error::bad_database("OneTimeKeyId in db is invalid."))?,
)
.map_err(|_| Error::bad_database("OneTimeKeyId in db is invalid."))?,
serde_json::from_slice(&value)
.map_err(|_| Error::bad_database("OneTimeKeys in db are invalid."))?,
))
})
.transpose()
}
fn count_one_time_keys(
&self, user_id: &UserId, device_id: &DeviceId,
) -> Result<BTreeMap<DeviceKeyAlgorithm, UInt>> {
let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes());
let mut counts = BTreeMap::new();
for algorithm in self
.onetimekeyid_onetimekeys
.scan_prefix(userdeviceid)
.map(|(bytes, _)| {
Ok::<_, Error>(
serde_json::from_slice::<OwnedDeviceKeyId>(
bytes
.rsplit(|&b| b == 0xFF)
.next()
.ok_or_else(|| Error::bad_database("OneTimeKey ID in db is invalid."))?,
)
.map_err(|_| Error::bad_database("DeviceKeyId in db is invalid."))?
.algorithm(),
)
}) {
*counts.entry(algorithm?).or_default() += uint!(1);
}
Ok(counts)
}
fn add_device_keys(&self, user_id: &UserId, device_id: &DeviceId, device_keys: &Raw<DeviceKeys>) -> Result<()> {
let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes());
self.keyid_key.insert(
&userdeviceid,
&serde_json::to_vec(&device_keys).expect("DeviceKeys::to_vec always works"),
)?;
self.mark_device_key_update(user_id)?;
Ok(())
}
fn add_cross_signing_keys(
&self, user_id: &UserId, master_key: &Raw<CrossSigningKey>, self_signing_key: &Option<Raw<CrossSigningKey>>,
user_signing_key: &Option<Raw<CrossSigningKey>>, notify: bool,
) -> Result<()> {
// TODO: Check signatures
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
let (master_key_key, _) = self.parse_master_key(user_id, master_key)?;
self.keyid_key
.insert(&master_key_key, master_key.json().get().as_bytes())?;
self.userid_masterkeyid
.insert(user_id.as_bytes(), &master_key_key)?;
// Self-signing key
if let Some(self_signing_key) = self_signing_key {
let mut self_signing_key_ids = self_signing_key
.deserialize()
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid self signing key"))?
.keys
.into_values();
let self_signing_key_id = self_signing_key_ids
.next()
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Self signing key contained no key."))?;
if self_signing_key_ids.next().is_some() {
return Err(Error::BadRequest(
ErrorKind::InvalidParam,
"Self signing key contained more than one key.",
));
}
let mut self_signing_key_key = prefix.clone();
self_signing_key_key.extend_from_slice(self_signing_key_id.as_bytes());
self.keyid_key
.insert(&self_signing_key_key, self_signing_key.json().get().as_bytes())?;
self.userid_selfsigningkeyid
.insert(user_id.as_bytes(), &self_signing_key_key)?;
}
// User-signing key
if let Some(user_signing_key) = user_signing_key {
let mut user_signing_key_ids = user_signing_key
.deserialize()
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid user signing key"))?
.keys
.into_values();
let user_signing_key_id = user_signing_key_ids
.next()
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "User signing key contained no key."))?;
if user_signing_key_ids.next().is_some() {
return Err(Error::BadRequest(
ErrorKind::InvalidParam,
"User signing key contained more than one key.",
));
}
let mut user_signing_key_key = prefix;
user_signing_key_key.extend_from_slice(user_signing_key_id.as_bytes());
self.keyid_key
.insert(&user_signing_key_key, user_signing_key.json().get().as_bytes())?;
self.userid_usersigningkeyid
.insert(user_id.as_bytes(), &user_signing_key_key)?;
}
if notify {
self.mark_device_key_update(user_id)?;
}
Ok(())
}
fn sign_key(
&self, target_id: &UserId, key_id: &str, signature: (String, String), sender_id: &UserId,
) -> Result<()> {
let mut key = target_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(key_id.as_bytes());
let mut cross_signing_key: serde_json::Value = serde_json::from_slice(
&self
.keyid_key
.get(&key)?
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Tried to sign nonexistent key."))?,
)
.map_err(|_| Error::bad_database("key in keyid_key is invalid."))?;
let signatures = cross_signing_key
.get_mut("signatures")
.ok_or_else(|| Error::bad_database("key in keyid_key has no signatures field."))?
.as_object_mut()
.ok_or_else(|| Error::bad_database("key in keyid_key has invalid signatures field."))?
.entry(sender_id.to_string())
.or_insert_with(|| serde_json::Map::new().into());
signatures
.as_object_mut()
.ok_or_else(|| Error::bad_database("signatures in keyid_key for a user is invalid."))?
.insert(signature.0, signature.1.into());
self.keyid_key.insert(
&key,
&serde_json::to_vec(&cross_signing_key).expect("CrossSigningKey::to_vec always works"),
)?;
self.mark_device_key_update(target_id)?;
Ok(())
}
fn keys_changed<'a>(
&'a self, user_or_room_id: &str, from: u64, to: Option<u64>,
) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
let mut prefix = user_or_room_id.as_bytes().to_vec();
prefix.push(0xFF);
let mut start = prefix.clone();
start.extend_from_slice(&(from.saturating_add(1)).to_be_bytes());
let to = to.unwrap_or(u64::MAX);
Box::new(
self.keychangeid_userid
.iter_from(&start, false)
.take_while(move |(k, _)| {
k.starts_with(&prefix)
&& if let Some(current) = k.splitn(2, |&b| b == 0xFF).nth(1) {
if let Ok(c) = utils::u64_from_bytes(current) {
c <= to
} else {
warn!("BadDatabase: Could not parse keychangeid_userid bytes");
false
}
} else {
warn!("BadDatabase: Could not parse keychangeid_userid");
false
}
})
.map(|(_, bytes)| {
UserId::parse(
utils::string_from_bytes(&bytes).map_err(|_| {
Error::bad_database("User ID in devicekeychangeid_userid is invalid unicode.")
})?,
)
.map_err(|_| Error::bad_database("User ID in devicekeychangeid_userid is invalid."))
}),
)
}
fn mark_device_key_update(&self, user_id: &UserId) -> Result<()> {
let count = services().globals.next_count()?.to_be_bytes();
for room_id in services()
.rooms
.state_cache
.rooms_joined(user_id)
.filter_map(Result::ok)
{
// Don't send key updates to unencrypted rooms
if services()
.rooms
.state_accessor
.room_state_get(&room_id, &StateEventType::RoomEncryption, "")?
.is_none()
{
continue;
}
let mut key = room_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(&count);
self.keychangeid_userid.insert(&key, user_id.as_bytes())?;
}
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(&count);
self.keychangeid_userid.insert(&key, user_id.as_bytes())?;
Ok(())
}
fn get_device_keys(&self, user_id: &UserId, device_id: &DeviceId) -> Result<Option<Raw<DeviceKeys>>> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(device_id.as_bytes());
self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| {
Ok(Some(
serde_json::from_slice(&bytes).map_err(|_| Error::bad_database("DeviceKeys in db are invalid."))?,
))
})
}
fn parse_master_key(
&self, user_id: &UserId, master_key: &Raw<CrossSigningKey>,
) -> Result<(Vec<u8>, CrossSigningKey)> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
let master_key = master_key
.deserialize()
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid master key"))?;
let mut master_key_ids = master_key.keys.values();
let master_key_id = master_key_ids
.next()
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Master key contained no key."))?;
if master_key_ids.next().is_some() {
return Err(Error::BadRequest(
ErrorKind::InvalidParam,
"Master key contained more than one key.",
));
}
let mut master_key_key = prefix.clone();
master_key_key.extend_from_slice(master_key_id.as_bytes());
Ok((master_key_key, master_key))
}
fn get_key(
&self, key: &[u8], sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool,
) -> Result<Option<Raw<CrossSigningKey>>> {
self.keyid_key.get(key)?.map_or(Ok(None), |bytes| {
let mut cross_signing_key = serde_json::from_slice::<serde_json::Value>(&bytes)
.map_err(|_| Error::bad_database("CrossSigningKey in db is invalid."))?;
clean_signatures(&mut cross_signing_key, sender_user, user_id, allowed_signatures)?;
Ok(Some(Raw::from_json(
serde_json::value::to_raw_value(&cross_signing_key).expect("Value to RawValue serialization"),
)))
})
}
fn get_master_key(
&self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool,
) -> Result<Option<Raw<CrossSigningKey>>> {
self.userid_masterkeyid
.get(user_id.as_bytes())?
.map_or(Ok(None), |key| self.get_key(&key, sender_user, user_id, allowed_signatures))
}
fn get_self_signing_key(
&self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool,
) -> Result<Option<Raw<CrossSigningKey>>> {
self.userid_selfsigningkeyid
.get(user_id.as_bytes())?
.map_or(Ok(None), |key| self.get_key(&key, sender_user, user_id, allowed_signatures))
}
fn get_user_signing_key(&self, user_id: &UserId) -> Result<Option<Raw<CrossSigningKey>>> {
self.userid_usersigningkeyid
.get(user_id.as_bytes())?
.map_or(Ok(None), |key| {
self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| {
Ok(Some(
serde_json::from_slice(&bytes)
.map_err(|_| Error::bad_database("CrossSigningKey in db is invalid."))?,
))
})
})
}
fn add_to_device_event(
&self, sender: &UserId, target_user_id: &UserId, target_device_id: &DeviceId, event_type: &str,
content: serde_json::Value,
) -> Result<()> {
let mut key = target_user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(target_device_id.as_bytes());
key.push(0xFF);
key.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
let mut json = serde_json::Map::new();
json.insert("type".to_owned(), event_type.to_owned().into());
json.insert("sender".to_owned(), sender.to_string().into());
json.insert("content".to_owned(), content);
let value = serde_json::to_vec(&json).expect("Map::to_vec always works");
self.todeviceid_events.insert(&key, &value)?;
Ok(())
}
fn get_to_device_events(&self, user_id: &UserId, device_id: &DeviceId) -> Result<Vec<Raw<AnyToDeviceEvent>>> {
let mut events = Vec::new();
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(device_id.as_bytes());
prefix.push(0xFF);
for (_, value) in self.todeviceid_events.scan_prefix(prefix) {
events.push(
serde_json::from_slice(&value)
.map_err(|_| Error::bad_database("Event in todeviceid_events is invalid."))?,
);
}
Ok(events)
}
fn remove_to_device_events(&self, user_id: &UserId, device_id: &DeviceId, until: u64) -> Result<()> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(device_id.as_bytes());
prefix.push(0xFF);
let mut last = prefix.clone();
last.extend_from_slice(&until.to_be_bytes());
for (key, _) in self
.todeviceid_events
.iter_from(&last, true) // this includes last
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(|(key, _)| {
Ok::<_, Error>((
key.clone(),
utils::u64_from_bytes(&key[key.len() - size_of::<u64>()..key.len()])
.map_err(|_| Error::bad_database("ToDeviceId has invalid count bytes."))?,
))
})
.filter_map(Result::ok)
.take_while(|&(_, count)| count <= until)
{
self.todeviceid_events.remove(&key)?;
}
Ok(())
}
fn update_device_metadata(&self, user_id: &UserId, device_id: &DeviceId, device: &Device) -> Result<()> {
let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes());
// Only existing devices should be able to call this, but we shouldn't assert
// either...
if self.userdeviceid_metadata.get(&userdeviceid)?.is_none() {
warn!(
"Called update_device_metadata for a non-existent user \"{}\" and/or device ID \"{}\" with no \
metadata in database",
user_id, device_id
);
return Err(Error::bad_database(
"User does not exist or device ID has no metadata in database.",
));
}
self.userid_devicelistversion
.increment(user_id.as_bytes())?;
self.userdeviceid_metadata.insert(
&userdeviceid,
&serde_json::to_vec(device).expect("Device::to_string always works"),
)?;
Ok(())
}
/// Get device metadata.
fn get_device_metadata(&self, user_id: &UserId, device_id: &DeviceId) -> Result<Option<Device>> {
let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes());
self.userdeviceid_metadata
.get(&userdeviceid)?
.map_or(Ok(None), |bytes| {
Ok(Some(serde_json::from_slice(&bytes).map_err(|_| {
Error::bad_database("Metadata in userdeviceid_metadata is invalid.")
})?))
})
}
fn get_devicelist_version(&self, user_id: &UserId) -> Result<Option<u64>> {
self.userid_devicelistversion
.get(user_id.as_bytes())?
.map_or(Ok(None), |bytes| {
utils::u64_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Invalid devicelistversion in db."))
.map(Some)
})
}
fn all_devices_metadata<'a>(&'a self, user_id: &UserId) -> Box<dyn Iterator<Item = Result<Device>> + 'a> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
Box::new(
self.userdeviceid_metadata
.scan_prefix(key)
.map(|(_, bytes)| {
serde_json::from_slice::<Device>(&bytes)
.map_err(|_| Error::bad_database("Device in userdeviceid_metadata is invalid."))
}),
)
}
/// Creates a new sync filter. Returns the filter id.
fn create_filter(&self, user_id: &UserId, filter: &FilterDefinition) -> Result<String> {
let filter_id = utils::random_string(4);
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(filter_id.as_bytes());
self.userfilterid_filter
.insert(&key, &serde_json::to_vec(&filter).expect("filter is valid json"))?;
Ok(filter_id)
}
fn get_filter(&self, user_id: &UserId, filter_id: &str) -> Result<Option<FilterDefinition>> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(filter_id.as_bytes());
let raw = self.userfilterid_filter.get(&key)?;
if let Some(raw) = raw {
serde_json::from_slice(&raw).map_err(|_| Error::bad_database("Invalid filter event in db."))
} else {
Ok(None)
}
}
}
/// Will only return with Some(username) if the password was not empty and the
/// username could be successfully parsed.
/// If `utils::string_from_bytes`(...) returns an error that username will be
/// skipped and the error will be logged.
fn get_username_with_valid_password(username: &[u8], password: &[u8]) -> Option<String> {
// A valid password is not empty
if password.is_empty() {
None
} else {
match utils::string_from_bytes(username) {
Ok(u) => Some(u),
Err(e) => {
warn!("Failed to parse username while calling get_local_users(): {}", e.to_string());
None
},
}
}
}
/// Calculate a new hash for the given password
fn calculate_password_hash(password: &str) -> Result<String, argon2::password_hash::Error> {
let salt = SaltString::generate(rand::thread_rng());
services()
.globals
.argon
.hash_password(password.as_bytes(), &salt)
.map(|it| it.to_string())
}

View file

@ -1,6 +1,6 @@
use crate::Result;
pub(crate) trait Data: Send + Sync {
pub trait Data: Send + Sync {
fn create_file_metadata(
&self, sender_user: Option<&str>, mxc: String, width: u32, height: u32, content_disposition: Option<&str>,
content_type: Option<&str>,

View file

@ -1,7 +1,7 @@
mod data;
use std::{collections::HashMap, io::Cursor, sync::Arc, time::SystemTime};
pub(crate) use data::Data;
pub use data::Data;
use image::imageops::FilterType;
use ruma::{OwnedMxcUri, OwnedUserId};
use serde::Serialize;
@ -15,37 +15,37 @@ use tracing::{debug, error};
use crate::{services, utils, Error, Result};
#[derive(Debug)]
pub(crate) struct FileMeta {
pub struct FileMeta {
#[allow(dead_code)]
pub(crate) content_disposition: Option<String>,
pub(crate) content_type: Option<String>,
pub(crate) file: Vec<u8>,
pub content_disposition: Option<String>,
pub content_type: Option<String>,
pub file: Vec<u8>,
}
#[derive(Serialize, Default)]
pub(crate) struct UrlPreviewData {
pub struct UrlPreviewData {
#[serde(skip_serializing_if = "Option::is_none", rename(serialize = "og:title"))]
pub(crate) title: Option<String>,
pub title: Option<String>,
#[serde(skip_serializing_if = "Option::is_none", rename(serialize = "og:description"))]
pub(crate) description: Option<String>,
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none", rename(serialize = "og:image"))]
pub(crate) image: Option<String>,
pub image: Option<String>,
#[serde(skip_serializing_if = "Option::is_none", rename(serialize = "matrix:image:size"))]
pub(crate) image_size: Option<usize>,
pub image_size: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none", rename(serialize = "og:image:width"))]
pub(crate) image_width: Option<u32>,
pub image_width: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none", rename(serialize = "og:image:height"))]
pub(crate) image_height: Option<u32>,
pub image_height: Option<u32>,
}
pub(crate) struct Service {
pub(crate) db: &'static dyn Data,
pub(crate) url_preview_mutex: RwLock<HashMap<String, Arc<Mutex<()>>>>,
pub struct Service {
pub db: Arc<dyn Data>,
pub url_preview_mutex: RwLock<HashMap<String, Arc<Mutex<()>>>>,
}
impl Service {
/// Uploads a file.
pub(crate) async fn create(
pub async fn create(
&self, sender_user: Option<OwnedUserId>, mxc: String, content_disposition: Option<&str>,
content_type: Option<&str>, file: &[u8],
) -> Result<()> {
@ -79,7 +79,7 @@ impl Service {
}
/// Deletes a file in the database and from the media directory via an MXC
pub(crate) async fn delete(&self, mxc: String) -> Result<()> {
pub async fn delete(&self, mxc: String) -> Result<()> {
if let Ok(keys) = self.db.search_mxc_metadata_prefix(mxc.clone()) {
for key in keys {
let file_path;
@ -116,7 +116,7 @@ impl Service {
/// Uploads or replaces a file thumbnail.
#[allow(clippy::too_many_arguments)]
pub(crate) async fn upload_thumbnail(
pub async fn upload_thumbnail(
&self, sender_user: Option<OwnedUserId>, mxc: String, content_disposition: Option<&str>,
content_type: Option<&str>, width: u32, height: u32, file: &[u8],
) -> Result<()> {
@ -149,7 +149,7 @@ impl Service {
}
/// Downloads a file.
pub(crate) async fn get(&self, mxc: String) -> Result<Option<FileMeta>> {
pub async fn get(&self, mxc: String) -> Result<Option<FileMeta>> {
if let Ok((content_disposition, content_type, key)) = self.db.search_file_metadata(mxc, 0, 0) {
let path;
@ -182,7 +182,7 @@ impl Service {
/// Deletes all remote only media files in the given at or after
/// time/duration. Returns a u32 with the amount of media files deleted.
pub(crate) async fn delete_all_remote_media_at_after_time(&self, time: String) -> Result<u32> {
pub async fn delete_all_remote_media_at_after_time(&self, time: String) -> Result<u32> {
if let Ok(all_keys) = self.db.get_all_media_keys() {
let user_duration: SystemTime = match cyborgtime::parse_duration(&time) {
Ok(duration) => {
@ -296,7 +296,7 @@ impl Service {
/// Returns width, height of the thumbnail and whether it should be cropped.
/// Returns None when the server should send the original file.
pub(crate) fn thumbnail_properties(&self, width: u32, height: u32) -> Option<(u32, u32, bool)> {
pub fn thumbnail_properties(&self, width: u32, height: u32) -> Option<(u32, u32, bool)> {
match (width, height) {
(0..=32, 0..=32) => Some((32, 32, true)),
(0..=96, 0..=96) => Some((96, 96, true)),
@ -320,7 +320,7 @@ impl Service {
///
/// For width,height <= 96 the server uses another thumbnailing algorithm
/// which crops the image afterwards.
pub(crate) async fn get_thumbnail(&self, mxc: String, width: u32, height: u32) -> Result<Option<FileMeta>> {
pub async fn get_thumbnail(&self, mxc: String, width: u32, height: u32) -> Result<Option<FileMeta>> {
let (width, height, crop) = self
.thumbnail_properties(width, height)
.unwrap_or((0, 0, false)); // 0, 0 because that's the original file
@ -467,16 +467,16 @@ impl Service {
}
}
pub(crate) async fn get_url_preview(&self, url: &str) -> Option<UrlPreviewData> { self.db.get_url_preview(url) }
pub async fn get_url_preview(&self, url: &str) -> Option<UrlPreviewData> { self.db.get_url_preview(url) }
/// TODO: use this?
#[allow(dead_code)]
pub(crate) async fn remove_url_preview(&self, url: &str) -> Result<()> {
pub async fn remove_url_preview(&self, url: &str) -> Result<()> {
// TODO: also remove the downloaded image
self.db.remove_url_preview(url)
}
pub(crate) async fn set_url_preview(&self, url: &str, data: &UrlPreviewData) -> Result<()> {
pub async fn set_url_preview(&self, url: &str, data: &UrlPreviewData) -> Result<()> {
let now = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.expect("valid system time");
@ -548,9 +548,9 @@ mod tests {
fn get_url_preview(&self, _url: &str) -> Option<UrlPreviewData> { todo!() }
}
static DB: MockedKVDatabase = MockedKVDatabase;
let db: Arc<MockedKVDatabase> = Arc::new(MockedKVDatabase);
let media = Service {
db: &DB,
db,
url_preview_mutex: RwLock::new(HashMap::new()),
};

View file

@ -1,298 +1,50 @@
use std::{
collections::{BTreeMap, HashMap},
sync::{Arc, Mutex as StdMutex},
};
pub(crate) mod key_value;
pub mod pdu;
pub mod services;
use lru_cache::LruCache;
use tokio::sync::{broadcast, Mutex, RwLock};
pub mod account_data;
pub mod admin;
pub mod appservice;
pub mod globals;
pub mod key_backups;
pub mod media;
pub mod presence;
pub mod pusher;
pub mod rooms;
pub mod sending;
pub mod transaction_ids;
pub mod uiaa;
pub mod users;
use crate::{Config, LogLevelReloadHandles, Result};
extern crate conduit_core as conduit;
extern crate conduit_database as database;
use std::sync::RwLock;
pub(crate) mod account_data;
pub(crate) mod admin;
pub(crate) mod appservice;
pub(crate) mod globals;
pub(crate) mod key_backups;
pub(crate) mod media;
pub(crate) mod pdu;
pub(crate) mod presence;
pub(crate) mod pusher;
pub(crate) mod rooms;
pub(crate) mod sending;
pub(crate) mod transaction_ids;
pub(crate) mod uiaa;
pub(crate) mod users;
pub(crate) use conduit::{config, debug_error, debug_info, debug_warn, utils, Config, Error, PduCount, Result};
pub(crate) use database::KeyValueDatabase;
pub use globals::{server_is_ours, user_is_local};
pub use pdu::PduEvent;
pub use services::Services;
pub(crate) struct Services<'a> {
pub(crate) appservice: appservice::Service,
pub(crate) pusher: pusher::Service,
pub(crate) rooms: rooms::Service,
pub(crate) transaction_ids: transaction_ids::Service,
pub(crate) uiaa: uiaa::Service,
pub(crate) users: users::Service,
pub(crate) account_data: account_data::Service,
pub(crate) presence: Arc<presence::Service>,
pub(crate) admin: Arc<admin::Service>,
pub(crate) globals: globals::Service<'a>,
pub(crate) key_backups: key_backups::Service,
pub(crate) media: media::Service,
pub(crate) sending: Arc<sending::Service>,
pub(crate) use crate as service;
conduit::mod_ctor! {}
conduit::mod_dtor! {}
pub static SERVICES: RwLock<Option<&'static Services>> = RwLock::new(None);
#[must_use]
pub fn services() -> &'static Services {
SERVICES
.read()
.expect("SERVICES locked for reading")
.expect("SERVICES initialized with Services instance")
}
impl Services<'_> {
#[allow(clippy::as_conversions, clippy::cast_sign_loss, clippy::cast_possible_truncation)]
pub(crate) fn build<
D: appservice::Data
+ pusher::Data
+ rooms::Data
+ transaction_ids::Data
+ uiaa::Data
+ users::Data
+ account_data::Data
+ presence::Data
+ globals::Data
+ key_backups::Data
+ media::Data
+ sending::Data
+ 'static,
>(
db: &'static D, config: &Config, tracing_reload_handle: LogLevelReloadHandles,
) -> Result<Self> {
Ok(Self {
appservice: appservice::Service::build(db)?,
pusher: pusher::Service {
db,
},
rooms: rooms::Service {
alias: rooms::alias::Service {
db,
},
auth_chain: rooms::auth_chain::Service {
db,
},
directory: rooms::directory::Service {
db,
},
event_handler: rooms::event_handler::Service,
lazy_loading: rooms::lazy_loading::Service {
db,
lazy_load_waiting: Mutex::new(HashMap::new()),
},
metadata: rooms::metadata::Service {
db,
},
outlier: rooms::outlier::Service {
db,
},
pdu_metadata: rooms::pdu_metadata::Service {
db,
},
read_receipt: rooms::read_receipt::Service {
db,
},
search: rooms::search::Service {
db,
},
short: rooms::short::Service {
db,
},
state: rooms::state::Service {
db,
},
state_accessor: rooms::state_accessor::Service {
db,
server_visibility_cache: StdMutex::new(LruCache::new(
(f64::from(config.server_visibility_cache_capacity) * config.conduit_cache_capacity_modifier)
as usize,
)),
user_visibility_cache: StdMutex::new(LruCache::new(
(f64::from(config.user_visibility_cache_capacity) * config.conduit_cache_capacity_modifier)
as usize,
)),
},
state_cache: rooms::state_cache::Service {
db,
},
state_compressor: rooms::state_compressor::Service {
db,
stateinfo_cache: StdMutex::new(LruCache::new(
(f64::from(config.stateinfo_cache_capacity) * config.conduit_cache_capacity_modifier) as usize,
)),
},
timeline: rooms::timeline::Service {
db,
lasttimelinecount_cache: Mutex::new(HashMap::new()),
},
threads: rooms::threads::Service {
db,
},
typing: rooms::typing::Service {
typing: RwLock::new(BTreeMap::new()),
last_typing_update: RwLock::new(BTreeMap::new()),
typing_update_sender: broadcast::channel(100).0,
},
spaces: rooms::spaces::Service {
roomid_spacehierarchy_cache: Mutex::new(LruCache::new(
(f64::from(config.roomid_spacehierarchy_cache_capacity)
* config.conduit_cache_capacity_modifier) as usize,
)),
},
user: rooms::user::Service {
db,
},
},
transaction_ids: transaction_ids::Service {
db,
},
uiaa: uiaa::Service {
db,
},
users: users::Service {
db,
connections: StdMutex::new(BTreeMap::new()),
},
account_data: account_data::Service {
db,
},
presence: presence::Service::build(db, config),
admin: admin::Service::build(),
key_backups: key_backups::Service {
db,
},
media: media::Service {
db,
url_preview_mutex: RwLock::new(HashMap::new()),
},
sending: sending::Service::build(db, config),
globals: globals::Service::load(db, config, tracing_reload_handle)?,
})
}
async fn memory_usage(&self) -> String {
let lazy_load_waiting = self.rooms.lazy_loading.lazy_load_waiting.lock().await.len();
let server_visibility_cache = self
.rooms
.state_accessor
.server_visibility_cache
.lock()
.unwrap()
.len();
let user_visibility_cache = self
.rooms
.state_accessor
.user_visibility_cache
.lock()
.unwrap()
.len();
let stateinfo_cache = self
.rooms
.state_compressor
.stateinfo_cache
.lock()
.unwrap()
.len();
let lasttimelinecount_cache = self
.rooms
.timeline
.lasttimelinecount_cache
.lock()
.await
.len();
let roomid_spacehierarchy_cache = self
.rooms
.spaces
.roomid_spacehierarchy_cache
.lock()
.await
.len();
let resolver_overrides_cache = self.globals.resolver.overrides.read().unwrap().len();
let resolver_destinations_cache = self.globals.resolver.destinations.read().await.len();
let bad_event_ratelimiter = self.globals.bad_event_ratelimiter.read().await.len();
let bad_query_ratelimiter = self.globals.bad_query_ratelimiter.read().await.len();
let bad_signature_ratelimiter = self.globals.bad_signature_ratelimiter.read().await.len();
format!(
"\
lazy_load_waiting: {lazy_load_waiting}
server_visibility_cache: {server_visibility_cache}
user_visibility_cache: {user_visibility_cache}
stateinfo_cache: {stateinfo_cache}
lasttimelinecount_cache: {lasttimelinecount_cache}
roomid_spacehierarchy_cache: {roomid_spacehierarchy_cache}
resolver_overrides_cache: {resolver_overrides_cache}
resolver_destinations_cache: {resolver_destinations_cache}
bad_event_ratelimiter: {bad_event_ratelimiter}
bad_query_ratelimiter: {bad_query_ratelimiter}
bad_signature_ratelimiter: {bad_signature_ratelimiter}
"
)
}
async fn clear_caches(&self, amount: u32) {
if amount > 0 {
self.rooms
.lazy_loading
.lazy_load_waiting
.lock()
.await
.clear();
}
if amount > 1 {
self.rooms
.state_accessor
.server_visibility_cache
.lock()
.unwrap()
.clear();
}
if amount > 2 {
self.rooms
.state_accessor
.user_visibility_cache
.lock()
.unwrap()
.clear();
}
if amount > 3 {
self.rooms
.state_compressor
.stateinfo_cache
.lock()
.unwrap()
.clear();
}
if amount > 4 {
self.rooms
.timeline
.lasttimelinecount_cache
.lock()
.await
.clear();
}
if amount > 5 {
self.rooms
.spaces
.roomid_spacehierarchy_cache
.lock()
.await
.clear();
}
if amount > 6 {
self.globals.resolver.overrides.write().unwrap().clear();
self.globals.resolver.destinations.write().await.clear();
}
if amount > 7 {
self.globals.resolver.resolver.clear_cache();
}
if amount > 8 {
self.globals.bad_event_ratelimiter.write().await.clear();
}
if amount > 9 {
self.globals.bad_query_ratelimiter.write().await.clear();
}
if amount > 10 {
self.globals.bad_signature_ratelimiter.write().await.clear();
}
}
#[inline]
pub fn available() -> bool {
SERVICES
.read()
.expect("SERVICES locked for reading")
.is_some()
}

View file

@ -23,40 +23,40 @@ use crate::{services, Error};
/// Content hashes of a PDU.
#[derive(Clone, Debug, Deserialize, Serialize)]
pub(crate) struct EventHash {
pub struct EventHash {
/// The SHA-256 hash.
pub(crate) sha256: String,
pub sha256: String,
}
#[derive(Clone, Deserialize, Serialize, Debug)]
pub(crate) struct PduEvent {
pub(crate) event_id: Arc<EventId>,
pub(crate) room_id: OwnedRoomId,
pub(crate) sender: OwnedUserId,
pub struct PduEvent {
pub event_id: Arc<EventId>,
pub room_id: OwnedRoomId,
pub sender: OwnedUserId,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) origin: Option<String>,
pub(crate) origin_server_ts: UInt,
pub origin: Option<String>,
pub origin_server_ts: UInt,
#[serde(rename = "type")]
pub(crate) kind: TimelineEventType,
pub(crate) content: Box<RawJsonValue>,
pub kind: TimelineEventType,
pub content: Box<RawJsonValue>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) state_key: Option<String>,
pub(crate) prev_events: Vec<Arc<EventId>>,
pub(crate) depth: UInt,
pub(crate) auth_events: Vec<Arc<EventId>>,
pub state_key: Option<String>,
pub prev_events: Vec<Arc<EventId>>,
pub depth: UInt,
pub auth_events: Vec<Arc<EventId>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) redacts: Option<Arc<EventId>>,
pub redacts: Option<Arc<EventId>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub(crate) unsigned: Option<Box<RawJsonValue>>,
pub(crate) hashes: EventHash,
pub unsigned: Option<Box<RawJsonValue>>,
pub hashes: EventHash,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub(crate) signatures: Option<Box<RawJsonValue>>, /* BTreeMap<Box<ServerName>, BTreeMap<ServerSigningKeyId,
* String>> */
pub signatures: Option<Box<RawJsonValue>>, /* BTreeMap<Box<ServerName>, BTreeMap<ServerSigningKeyId,
* String>> */
}
impl PduEvent {
#[tracing::instrument(skip(self))]
pub(crate) fn redact(&mut self, room_version_id: RoomVersionId, reason: &PduEvent) -> crate::Result<()> {
pub fn redact(&mut self, room_version_id: RoomVersionId, reason: &PduEvent) -> crate::Result<()> {
self.unsigned = None;
let mut content = serde_json::from_str(self.content.get())
@ -76,7 +76,7 @@ impl PduEvent {
Ok(())
}
pub(crate) fn remove_transaction_id(&mut self) -> crate::Result<()> {
pub fn remove_transaction_id(&mut self) -> crate::Result<()> {
if let Some(unsigned) = &self.unsigned {
let mut unsigned: BTreeMap<String, Box<RawJsonValue>> = serde_json::from_str(unsigned.get())
.map_err(|_| Error::bad_database("Invalid unsigned in pdu event"))?;
@ -87,7 +87,7 @@ impl PduEvent {
Ok(())
}
pub(crate) fn add_age(&mut self) -> crate::Result<()> {
pub fn add_age(&mut self) -> crate::Result<()> {
let mut unsigned: BTreeMap<String, Box<RawJsonValue>> = self
.unsigned
.as_ref()
@ -118,7 +118,7 @@ impl PduEvent {
/// > serving
/// > such events over the Client-Server API.
#[must_use]
pub(crate) fn copy_redacts(&self) -> (Option<Arc<EventId>>, Box<RawJsonValue>) {
pub fn copy_redacts(&self) -> (Option<Arc<EventId>>, Box<RawJsonValue>) {
if self.kind == TimelineEventType::RoomRedaction {
if let Ok(mut content) = serde_json::from_str::<RoomRedactionEventContent>(self.content.get()) {
if let Some(redacts) = content.redacts {
@ -137,7 +137,7 @@ impl PduEvent {
}
#[tracing::instrument(skip(self))]
pub(crate) fn to_sync_room_event(&self) -> Raw<AnySyncTimelineEvent> {
pub fn to_sync_room_event(&self) -> Raw<AnySyncTimelineEvent> {
let (redacts, content) = self.copy_redacts();
let mut json = json!({
"content": content,
@ -162,7 +162,7 @@ impl PduEvent {
/// This only works for events that are also AnyRoomEvents.
#[tracing::instrument(skip(self))]
pub(crate) fn to_any_event(&self) -> Raw<AnyEphemeralRoomEvent> {
pub fn to_any_event(&self) -> Raw<AnyEphemeralRoomEvent> {
let (redacts, content) = self.copy_redacts();
let mut json = json!({
"content": content,
@ -187,7 +187,7 @@ impl PduEvent {
}
#[tracing::instrument(skip(self))]
pub(crate) fn to_room_event(&self) -> Raw<AnyTimelineEvent> {
pub fn to_room_event(&self) -> Raw<AnyTimelineEvent> {
let (redacts, content) = self.copy_redacts();
let mut json = json!({
"content": content,
@ -212,7 +212,7 @@ impl PduEvent {
}
#[tracing::instrument(skip(self))]
pub(crate) fn to_message_like_event(&self) -> Raw<AnyMessageLikeEvent> {
pub fn to_message_like_event(&self) -> Raw<AnyMessageLikeEvent> {
let (redacts, content) = self.copy_redacts();
let mut json = json!({
"content": content,
@ -237,7 +237,7 @@ impl PduEvent {
}
#[tracing::instrument(skip(self))]
pub(crate) fn to_state_event(&self) -> Raw<AnyStateEvent> {
pub fn to_state_event(&self) -> Raw<AnyStateEvent> {
let mut json = json!({
"content": self.content,
"type": self.kind,
@ -256,7 +256,7 @@ impl PduEvent {
}
#[tracing::instrument(skip(self))]
pub(crate) fn to_sync_state_event(&self) -> Raw<AnySyncStateEvent> {
pub fn to_sync_state_event(&self) -> Raw<AnySyncStateEvent> {
let mut json = json!({
"content": self.content,
"type": self.kind,
@ -274,7 +274,7 @@ impl PduEvent {
}
#[tracing::instrument(skip(self))]
pub(crate) fn to_stripped_state_event(&self) -> Raw<AnyStrippedStateEvent> {
pub fn to_stripped_state_event(&self) -> Raw<AnyStrippedStateEvent> {
let json = json!({
"content": self.content,
"type": self.kind,
@ -286,7 +286,7 @@ impl PduEvent {
}
#[tracing::instrument(skip(self))]
pub(crate) fn to_stripped_spacechild_state_event(&self) -> Raw<HierarchySpaceChildEvent> {
pub fn to_stripped_spacechild_state_event(&self) -> Raw<HierarchySpaceChildEvent> {
let json = json!({
"content": self.content,
"type": self.kind,
@ -299,7 +299,7 @@ impl PduEvent {
}
#[tracing::instrument(skip(self))]
pub(crate) fn to_member_event(&self) -> Raw<StateEvent<RoomMemberEventContent>> {
pub fn to_member_event(&self) -> Raw<StateEvent<RoomMemberEventContent>> {
let mut json = json!({
"content": self.content,
"type": self.kind,
@ -320,7 +320,7 @@ impl PduEvent {
/// This does not return a full `Pdu` it is only to satisfy ruma's types.
#[tracing::instrument]
pub(crate) fn convert_to_outgoing_federation_event(mut pdu_json: CanonicalJsonObject) -> Box<RawJsonValue> {
pub fn convert_to_outgoing_federation_event(mut pdu_json: CanonicalJsonObject) -> Box<RawJsonValue> {
if let Some(unsigned) = pdu_json
.get_mut("unsigned")
.and_then(|val| val.as_object_mut())
@ -357,7 +357,7 @@ impl PduEvent {
to_raw_value(&pdu_json).expect("CanonicalJson is valid serde_json::Value")
}
pub(crate) fn from_id_val(event_id: &EventId, mut json: CanonicalJsonObject) -> Result<Self, serde_json::Error> {
pub fn from_id_val(event_id: &EventId, mut json: CanonicalJsonObject) -> Result<Self, serde_json::Error> {
json.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.as_str().to_owned()));
serde_json::from_value(serde_json::to_value(json).expect("valid JSON"))
@ -405,7 +405,7 @@ impl Ord for PduEvent {
///
/// Returns a tuple of the new `EventId` and the PDU as a `BTreeMap<String,
/// CanonicalJsonValue>`.
pub(crate) fn gen_event_id_canonical_json(
pub fn gen_event_id_canonical_json(
pdu: &RawJsonValue, room_version_id: &RoomVersionId,
) -> crate::Result<(OwnedEventId, CanonicalJsonObject)> {
let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| {
@ -426,11 +426,11 @@ pub(crate) fn gen_event_id_canonical_json(
/// Build the start of a PDU in order to add it to the Database.
#[derive(Debug, Deserialize)]
pub(crate) struct PduBuilder {
pub struct PduBuilder {
#[serde(rename = "type")]
pub(crate) event_type: TimelineEventType,
pub(crate) content: Box<RawJsonValue>,
pub(crate) unsigned: Option<BTreeMap<String, serde_json::Value>>,
pub(crate) state_key: Option<String>,
pub(crate) redacts: Option<Arc<EventId>>,
pub event_type: TimelineEventType,
pub content: Box<RawJsonValue>,
pub unsigned: Option<BTreeMap<String, serde_json::Value>>,
pub state_key: Option<String>,
pub redacts: Option<Arc<EventId>>,
}

View file

@ -2,7 +2,7 @@ use ruma::{events::presence::PresenceEvent, presence::PresenceState, OwnedUserId
use crate::Result;
pub(crate) trait Data: Send + Sync {
pub trait Data: Send + Sync {
/// Returns the latest presence event for the given user.
fn get_presence(&self, user_id: &UserId) -> Result<Option<(u64, PresenceEvent)>>;

View file

@ -2,7 +2,7 @@ mod data;
use std::{sync::Arc, time::Duration};
pub(crate) use data::Data;
pub use data::Data;
use futures_util::{stream::FuturesUnordered, StreamExt};
use ruma::{
events::presence::{PresenceEvent, PresenceEventContent},
@ -10,19 +10,19 @@ use ruma::{
OwnedUserId, UInt, UserId,
};
use serde::{Deserialize, Serialize};
use tokio::{sync::Mutex, time::sleep};
use tokio::{sync::Mutex, task::JoinHandle, time::sleep};
use tracing::{debug, error};
use crate::{
services,
utils::{self, user_id::user_is_local},
services, user_is_local,
utils::{self},
Config, Error, Result,
};
/// Represents data required to be kept in order to implement the presence
/// specification.
#[derive(Serialize, Deserialize, Debug, Clone)]
pub(crate) struct Presence {
pub struct Presence {
state: PresenceState,
currently_active: bool,
last_active_ts: u64,
@ -30,9 +30,8 @@ pub(crate) struct Presence {
}
impl Presence {
pub(crate) fn new(
state: PresenceState, currently_active: bool, last_active_ts: u64, status_msg: Option<String>,
) -> Self {
#[must_use]
pub fn new(state: PresenceState, currently_active: bool, last_active_ts: u64, status_msg: Option<String>) -> Self {
Self {
state,
currently_active,
@ -41,21 +40,21 @@ impl Presence {
}
}
pub(crate) fn from_json_bytes_to_event(bytes: &[u8], user_id: &UserId) -> Result<PresenceEvent> {
pub fn from_json_bytes_to_event(bytes: &[u8], user_id: &UserId) -> Result<PresenceEvent> {
let presence = Self::from_json_bytes(bytes)?;
presence.to_presence_event(user_id)
}
pub(crate) fn from_json_bytes(bytes: &[u8]) -> Result<Self> {
pub fn from_json_bytes(bytes: &[u8]) -> Result<Self> {
serde_json::from_slice(bytes).map_err(|_| Error::bad_database("Invalid presence data in database"))
}
pub(crate) fn to_json_bytes(&self) -> Result<Vec<u8>> {
pub fn to_json_bytes(&self) -> Result<Vec<u8>> {
serde_json::to_vec(self).map_err(|_| Error::bad_database("Could not serialize Presence to JSON"))
}
/// Creates a PresenceEvent from available data.
pub(crate) fn to_presence_event(&self, user_id: &UserId) -> Result<PresenceEvent> {
pub fn to_presence_event(&self, user_id: &UserId) -> Result<PresenceEvent> {
let now = utils::millis_since_unix_epoch();
let last_active_ago = if self.currently_active {
None
@ -77,37 +76,55 @@ impl Presence {
}
}
pub(crate) struct Service {
pub(crate) db: &'static dyn Data,
pub(crate) timer_sender: loole::Sender<(OwnedUserId, Duration)>,
pub struct Service {
pub db: Arc<dyn Data>,
pub timer_sender: loole::Sender<(OwnedUserId, Duration)>,
timer_receiver: Mutex<loole::Receiver<(OwnedUserId, Duration)>>,
handler_join: Mutex<Option<JoinHandle<()>>>,
timeout_remote_users: bool,
}
impl Service {
pub(crate) fn build(db: &'static dyn Data, config: &Config) -> Arc<Self> {
pub fn build(db: Arc<dyn Data>, config: &Config) -> Arc<Self> {
let (timer_sender, timer_receiver) = loole::unbounded();
Arc::new(Self {
db,
timer_sender,
timer_receiver: Mutex::new(timer_receiver),
handler_join: Mutex::new(None),
timeout_remote_users: config.presence_timeout_remote_users,
})
}
pub(crate) fn start_handler(self: &Arc<Self>) {
pub async fn start_handler(self: &Arc<Self>) {
let self_ = Arc::clone(self);
tokio::spawn(async move {
let handle = services().server.runtime().spawn(async move {
self_
.handler()
.await
.expect("Failed to start presence handler");
});
_ = self.handler_join.lock().await.insert(handle);
}
pub async fn close(&self) {
self.interrupt();
if let Some(handler_join) = self.handler_join.lock().await.take() {
if let Err(e) = handler_join.await {
error!("Failed to shutdown: {e:?}");
}
}
}
pub fn interrupt(&self) {
if !self.timer_sender.is_closed() {
self.timer_sender.close();
}
}
/// Returns the latest presence event for the given user.
pub(crate) fn get_presence(&self, user_id: &UserId) -> Result<Option<PresenceEvent>> {
pub fn get_presence(&self, user_id: &UserId) -> Result<Option<PresenceEvent>> {
if let Some((_, presence)) = self.db.get_presence(user_id)? {
Ok(Some(presence))
} else {
@ -117,7 +134,7 @@ impl Service {
/// Pings the presence of the given user in the given room, setting the
/// specified state.
pub(crate) fn ping_presence(&self, user_id: &UserId, new_state: &PresenceState) -> Result<()> {
pub fn ping_presence(&self, user_id: &UserId, new_state: &PresenceState) -> Result<()> {
const REFRESH_TIMEOUT: u64 = 60 * 25 * 1000;
let last_presence = self.db.get_presence(user_id)?;
@ -146,7 +163,7 @@ impl Service {
}
/// Adds a presence event which will be saved until a new event replaces it.
pub(crate) fn set_presence(
pub fn set_presence(
&self, user_id: &UserId, state: &PresenceState, currently_active: Option<bool>, last_active_ago: Option<UInt>,
status_msg: Option<String>,
) -> Result<()> {
@ -179,11 +196,11 @@ impl Service {
///
/// TODO: Why is this not used?
#[allow(dead_code)]
pub(crate) fn remove_presence(&self, user_id: &UserId) -> Result<()> { self.db.remove_presence(user_id) }
pub fn remove_presence(&self, user_id: &UserId) -> Result<()> { self.db.remove_presence(user_id) }
/// Returns the most recent presence updates that happened after the event
/// with id `since`.
pub(crate) fn presence_since(&self, since: u64) -> Box<dyn Iterator<Item = (OwnedUserId, u64, Vec<u8>)>> {
pub fn presence_since(&self, since: u64) -> Box<dyn Iterator<Item = (OwnedUserId, u64, Vec<u8>)> + '_> {
self.db.presence_since(since)
}
@ -191,24 +208,16 @@ impl Service {
let mut presence_timers = FuturesUnordered::new();
let receiver = self.timer_receiver.lock().await;
loop {
debug_assert!(!receiver.is_closed(), "channel error");
tokio::select! {
event = receiver.recv_async() => {
match event {
Ok((user_id, timeout)) => {
debug!("Adding timer {}: {user_id} timeout:{timeout:?}", presence_timers.len());
presence_timers.push(presence_timer(user_id, timeout));
}
Err(e) => {
// generally shouldn't happen
error!("Failed to receive presence timer through channel: {e}");
}
}
}
Some(user_id) = presence_timers.next() => {
process_presence_timer(&user_id)?;
}
Some(user_id) = presence_timers.next() => process_presence_timer(&user_id)?,
event = receiver.recv_async() => match event {
Err(_e) => return Ok(()),
Ok((user_id, timeout)) => {
debug!("Adding timer {}: {user_id} timeout:{timeout:?}", presence_timers.len());
presence_timers.push(presence_timer(user_id, timeout));
},
},
}
}
}

View file

@ -5,7 +5,7 @@ use ruma::{
use crate::Result;
pub(crate) trait Data: Send + Sync {
pub trait Data: Send + Sync {
fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::PusherAction) -> Result<()>;
fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result<Option<Pusher>>;

View file

@ -1,8 +1,8 @@
mod data;
use std::{fmt::Debug, mem};
use std::{fmt::Debug, mem, sync::Arc};
use bytes::BytesMut;
pub(crate) use data::Data;
pub use data::Data;
use ipaddress::IPAddress;
use ruma::{
api::{
@ -24,27 +24,28 @@ use tracing::{info, trace, warn};
use crate::{debug_info, services, Error, PduEvent, Result};
pub(crate) struct Service {
pub(crate) db: &'static dyn Data,
pub struct Service {
pub db: Arc<dyn Data>,
}
impl Service {
pub(crate) fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::PusherAction) -> Result<()> {
pub fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::PusherAction) -> Result<()> {
self.db.set_pusher(sender, pusher)
}
pub(crate) fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result<Option<Pusher>> {
pub fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result<Option<Pusher>> {
self.db.get_pusher(sender, pushkey)
}
pub(crate) fn get_pushers(&self, sender: &UserId) -> Result<Vec<Pusher>> { self.db.get_pushers(sender) }
pub fn get_pushers(&self, sender: &UserId) -> Result<Vec<Pusher>> { self.db.get_pushers(sender) }
pub(crate) fn get_pushkeys(&self, sender: &UserId) -> Box<dyn Iterator<Item = Result<String>>> {
#[must_use]
pub fn get_pushkeys(&self, sender: &UserId) -> Box<dyn Iterator<Item = Result<String>> + '_> {
self.db.get_pushkeys(sender)
}
#[tracing::instrument(skip(self, dest, request))]
pub(crate) async fn send_request<T>(&self, dest: &str, request: T) -> Result<T::IncomingResponse>
pub async fn send_request<T>(&self, dest: &str, request: T) -> Result<T::IncomingResponse>
where
T: OutgoingRequest + Debug,
{
@ -131,7 +132,7 @@ impl Service {
}
#[tracing::instrument(skip(self, user, unread, pusher, ruleset, pdu))]
pub(crate) async fn send_push_notice(
pub async fn send_push_notice(
&self, user: &UserId, unread: UInt, pusher: &Pusher, ruleset: Ruleset, pdu: &PduEvent,
) -> Result<()> {
let mut notify = None;
@ -176,7 +177,7 @@ impl Service {
}
#[tracing::instrument(skip(self, user, ruleset, pdu))]
pub(crate) fn get_actions<'a>(
pub fn get_actions<'a>(
&self, user: &UserId, ruleset: &'a Ruleset, power_levels: &RoomPowerLevelsEventContent,
pdu: &Raw<AnySyncTimelineEvent>, room_id: &RoomId,
) -> Result<&'a [Action]> {

View file

@ -2,7 +2,7 @@ use ruma::{OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId};
use crate::Result;
pub(crate) trait Data: Send + Sync {
pub trait Data: Send + Sync {
/// Creates or updates the alias to the given room id.
fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId) -> Result<()>;

View file

@ -1,37 +1,37 @@
mod data;
pub(crate) use data::Data;
use std::sync::Arc;
pub use data::Data;
use ruma::{OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId};
use crate::Result;
pub(crate) struct Service {
pub(crate) db: &'static dyn Data,
pub struct Service {
pub db: Arc<dyn Data>,
}
impl Service {
#[tracing::instrument(skip(self))]
pub(crate) fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId) -> Result<()> {
self.db.set_alias(alias, room_id)
}
pub fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId) -> Result<()> { self.db.set_alias(alias, room_id) }
#[tracing::instrument(skip(self))]
pub(crate) fn remove_alias(&self, alias: &RoomAliasId) -> Result<()> { self.db.remove_alias(alias) }
pub fn remove_alias(&self, alias: &RoomAliasId) -> Result<()> { self.db.remove_alias(alias) }
#[tracing::instrument(skip(self))]
pub(crate) fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result<Option<OwnedRoomId>> {
pub fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result<Option<OwnedRoomId>> {
self.db.resolve_local_alias(alias)
}
#[tracing::instrument(skip(self))]
pub(crate) fn local_aliases_for_room<'a>(
pub fn local_aliases_for_room<'a>(
&'a self, room_id: &RoomId,
) -> Box<dyn Iterator<Item = Result<OwnedRoomAliasId>> + 'a> {
self.db.local_aliases_for_room(room_id)
}
#[tracing::instrument(skip(self))]
pub(crate) fn all_local_aliases<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(OwnedRoomId, String)>> + 'a> {
pub fn all_local_aliases<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(OwnedRoomId, String)>> + 'a> {
self.db.all_local_aliases()
}
}

View file

@ -2,7 +2,7 @@ use std::sync::Arc;
use crate::Result;
pub(crate) trait Data: Send + Sync {
pub trait Data: Send + Sync {
fn get_cached_eventid_authchain(&self, shorteventid: &[u64]) -> Result<Option<Arc<[u64]>>>;
fn cache_auth_chain(&self, shorteventid: Vec<u64>, auth_chain: Arc<[u64]>) -> Result<()>;
}

View file

@ -4,18 +4,18 @@ use std::{
sync::Arc,
};
pub(crate) use data::Data;
pub use data::Data;
use ruma::{api::client::error::ErrorKind, EventId, RoomId};
use tracing::{debug, error, trace, warn};
use crate::{services, utils::debug_slice_truncated, Error, Result};
pub(crate) struct Service {
pub(crate) db: &'static dyn Data,
pub struct Service {
pub db: Arc<dyn Data>,
}
impl Service {
pub(crate) async fn event_ids_iter<'a>(
pub async fn event_ids_iter<'a>(
&self, room_id: &RoomId, starting_events_: Vec<Arc<EventId>>,
) -> Result<impl Iterator<Item = Arc<EventId>> + 'a> {
let mut starting_events: Vec<&EventId> = Vec::with_capacity(starting_events_.len());
@ -31,7 +31,7 @@ impl Service {
}
#[tracing::instrument(skip(self), fields(starting_events = debug_slice_truncated(starting_events, 5)))]
pub(crate) async fn get_auth_chain(&self, room_id: &RoomId, starting_events: &[&EventId]) -> Result<Vec<u64>> {
pub async fn get_auth_chain(&self, room_id: &RoomId, starting_events: &[&EventId]) -> Result<Vec<u64>> {
const NUM_BUCKETS: usize = 50; //TODO: change possible w/o disrupting db?
const BUCKET: BTreeSet<(u64, &EventId)> = BTreeSet::new();
@ -172,18 +172,18 @@ impl Service {
Ok(found)
}
pub(crate) fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Option<Arc<[u64]>>> {
pub fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Option<Arc<[u64]>>> {
self.db.get_cached_eventid_authchain(key)
}
#[tracing::instrument(skip(self))]
pub(crate) fn cache_auth_chain(&self, key: Vec<u64>, auth_chain: &HashSet<u64>) -> Result<()> {
pub fn cache_auth_chain(&self, key: Vec<u64>, auth_chain: &HashSet<u64>) -> Result<()> {
self.db
.cache_auth_chain(key, auth_chain.iter().copied().collect::<Arc<[u64]>>())
}
#[tracing::instrument(skip(self))]
pub(crate) fn cache_auth_chain_vec(&self, key: Vec<u64>, auth_chain: &Vec<u64>) -> Result<()> {
pub fn cache_auth_chain_vec(&self, key: Vec<u64>, auth_chain: &Vec<u64>) -> Result<()> {
self.db
.cache_auth_chain(key, auth_chain.iter().copied().collect::<Arc<[u64]>>())
}

View file

@ -2,7 +2,7 @@ use ruma::{OwnedRoomId, RoomId};
use crate::Result;
pub(crate) trait Data: Send + Sync {
pub trait Data: Send + Sync {
/// Adds the room to the public room directory
fn set_public(&self, room_id: &RoomId) -> Result<()>;

View file

@ -1,24 +1,26 @@
mod data;
pub(crate) use data::Data;
use std::sync::Arc;
pub use data::Data;
use ruma::{OwnedRoomId, RoomId};
use crate::Result;
pub(crate) struct Service {
pub(crate) db: &'static dyn Data,
pub struct Service {
pub db: Arc<dyn Data>,
}
impl Service {
#[tracing::instrument(skip(self))]
pub(crate) fn set_public(&self, room_id: &RoomId) -> Result<()> { self.db.set_public(room_id) }
pub fn set_public(&self, room_id: &RoomId) -> Result<()> { self.db.set_public(room_id) }
#[tracing::instrument(skip(self))]
pub(crate) fn set_not_public(&self, room_id: &RoomId) -> Result<()> { self.db.set_not_public(room_id) }
pub fn set_not_public(&self, room_id: &RoomId) -> Result<()> { self.db.set_not_public(room_id) }
#[tracing::instrument(skip(self))]
pub(crate) fn is_public_room(&self, room_id: &RoomId) -> Result<bool> { self.db.is_public_room(room_id) }
pub fn is_public_room(&self, room_id: &RoomId) -> Result<bool> { self.db.is_public_room(room_id) }
#[tracing::instrument(skip(self))]
pub(crate) fn public_rooms(&self) -> impl Iterator<Item = Result<OwnedRoomId>> + '_ { self.db.public_rooms() }
pub fn public_rooms(&self) -> impl Iterator<Item = Result<OwnedRoomId>> + '_ { self.db.public_rooms() }
}

View file

@ -1,11 +1,17 @@
mod parse_incoming_pdu;
mod signing_keys;
pub struct Service;
use std::{
cmp,
collections::{hash_map, HashSet},
collections::{hash_map, BTreeMap, HashMap, HashSet},
pin::Pin,
sync::Arc,
time::{Duration, Instant},
};
use futures_util::Future;
pub use parse_incoming_pdu::parse_incoming_pdu;
use ruma::{
api::{
client::error::ErrorKind,
@ -24,14 +30,7 @@ use tokio::sync::RwLock;
use tracing::{debug, error, info, trace, warn};
use super::state_compressor::CompressedStateEvent;
use crate::{
debug_error, debug_info,
service::{pdu, Arc, BTreeMap, HashMap, Result},
services, Error, PduEvent,
};
mod signing_keys;
pub(crate) struct Service;
use crate::{debug_error, debug_info, pdu, services, Error, PduEvent, Result};
// We use some AsyncRecursiveType hacks here so we can call async funtion
// recursively.
@ -70,7 +69,7 @@ impl Service {
/// 14. Check if the event passes auth based on the "current state" of the
/// room, if not soft fail it
#[tracing::instrument(skip(self, origin, value, is_timeline_event, pub_key_map), name = "pdu")]
pub(crate) async fn handle_incoming_pdu<'a>(
pub async fn handle_incoming_pdu<'a>(
&self, origin: &'a ServerName, room_id: &'a RoomId, event_id: &'a EventId,
value: BTreeMap<String, CanonicalJsonValue>, is_timeline_event: bool,
pub_key_map: &'a RwLock<BTreeMap<String, BTreeMap<String, Base64>>>,
@ -207,7 +206,7 @@ impl Service {
skip(self, origin, event_id, room_id, pub_key_map, eventid_info, create_event, first_pdu_in_room),
name = "prev"
)]
pub(crate) async fn handle_prev_pdu<'a>(
pub async fn handle_prev_pdu<'a>(
&self, origin: &'a ServerName, event_id: &'a EventId, room_id: &'a RoomId,
pub_key_map: &'a RwLock<BTreeMap<String, BTreeMap<String, Base64>>>,
eventid_info: &mut HashMap<Arc<EventId>, (Arc<PduEvent>, BTreeMap<String, CanonicalJsonValue>)>,
@ -427,7 +426,7 @@ impl Service {
})
}
pub(crate) async fn upgrade_outlier_to_timeline_pdu(
pub async fn upgrade_outlier_to_timeline_pdu(
&self, incoming_pdu: Arc<PduEvent>, val: BTreeMap<String, CanonicalJsonValue>, create_event: &PduEvent,
origin: &ServerName, room_id: &RoomId, pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>,
) -> Result<Option<Vec<u8>>> {
@ -748,7 +747,7 @@ impl Service {
// TODO: if we know the prev_events of the incoming event we can avoid the
// request and build the state from a known point and resolve if > 1 prev_event
#[tracing::instrument(skip_all, name = "state")]
pub(crate) async fn state_at_incoming_degree_one(
pub async fn state_at_incoming_degree_one(
&self, incoming_pdu: &Arc<PduEvent>,
) -> Result<Option<HashMap<u64, Arc<EventId>>>> {
let prev_event = &*incoming_pdu.prev_events[0];
@ -796,7 +795,7 @@ impl Service {
}
#[tracing::instrument(skip_all, name = "state")]
pub(crate) async fn state_at_incoming_resolved(
pub async fn state_at_incoming_resolved(
&self, incoming_pdu: &Arc<PduEvent>, room_id: &RoomId, room_version_id: &RoomVersionId,
) -> Result<Option<HashMap<u64, Arc<EventId>>>> {
debug!("Calculating state at event using state res");
@ -988,7 +987,7 @@ impl Service {
/// b. Look at outlier pdu tree
/// c. Ask origin server over federation
/// d. TODO: Ask other servers over federation?
pub(crate) fn fetch_and_handle_outliers<'a>(
pub fn fetch_and_handle_outliers<'a>(
&'a self, origin: &'a ServerName, events: &'a [Arc<EventId>], create_event: &'a PduEvent, room_id: &'a RoomId,
room_version_id: &'a RoomVersionId, pub_key_map: &'a RwLock<BTreeMap<String, BTreeMap<String, Base64>>>,
) -> AsyncRecursiveCanonicalJsonVec<'a> {
@ -1275,7 +1274,7 @@ impl Service {
/// Returns Ok if the acl allows the server
#[tracing::instrument(skip_all)]
pub(crate) fn acl_check(&self, server_name: &ServerName, room_id: &RoomId) -> Result<()> {
pub fn acl_check(&self, server_name: &ServerName, room_id: &RoomId) -> Result<()> {
let acl_event = if let Some(acl) =
services()
.rooms

View file

@ -0,0 +1,31 @@
use ruma::{api::client::error::ErrorKind, CanonicalJsonObject, OwnedEventId, OwnedRoomId, RoomId};
use serde_json::value::RawValue as RawJsonValue;
use tracing::warn;
use crate::{service::pdu::gen_event_id_canonical_json, services, Error, Result};
pub fn parse_incoming_pdu(pdu: &RawJsonValue) -> Result<(OwnedEventId, CanonicalJsonObject, OwnedRoomId)> {
let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| {
warn!("Error parsing incoming event {:?}: {:?}", pdu, e);
Error::BadServerResponse("Invalid PDU in server response")
})?;
let room_id: OwnedRoomId = value
.get("room_id")
.and_then(|id| RoomId::parse(id.as_str()?).ok())
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Invalid room id in pdu"))?;
let Ok(room_version_id) = services().rooms.state.get_room_version(&room_id) else {
return Err(Error::Err(format!("Server is not in room {room_id}")));
};
let Ok((event_id, value)) = gen_event_id_canonical_json(pdu, &room_version_id) else {
// Event could not be converted to canonical json
return Err(Error::BadRequest(
ErrorKind::InvalidParam,
"Could not convert event to canonical json.",
));
};
Ok((event_id, value, room_id))
}

View file

@ -1,5 +1,5 @@
use std::{
collections::HashSet,
collections::{BTreeMap, HashMap, HashSet},
time::{Duration, SystemTime},
};
@ -21,13 +21,10 @@ use serde_json::value::RawValue as RawJsonValue;
use tokio::sync::{RwLock, RwLockWriteGuard};
use tracing::{debug, error, info, trace, warn};
use crate::{
service::{BTreeMap, HashMap, Result},
services, Error,
};
use crate::{services, Error, Result};
impl super::Service {
pub(crate) async fn fetch_required_signing_keys<'a, E>(
pub async fn fetch_required_signing_keys<'a, E>(
&'a self, events: E, pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>,
) -> Result<()>
where
@ -265,7 +262,7 @@ impl super::Service {
Ok(())
}
pub(crate) async fn fetch_join_signing_keys(
pub async fn fetch_join_signing_keys(
&self, event: &create_join_event::v2::Response, room_version: &RoomVersionId,
pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>,
) -> Result<()> {
@ -342,7 +339,7 @@ impl super::Service {
/// Search the DB for the signing keys of the given server, if we don't have
/// them fetch them from the server and save to our DB.
#[tracing::instrument(skip_all)]
pub(crate) async fn fetch_signing_keys_for_server(
pub async fn fetch_signing_keys_for_server(
&self, origin: &ServerName, signature_ids: Vec<String>,
) -> Result<BTreeMap<String, Base64>> {
let contains_all_ids = |keys: &BTreeMap<String, Base64>| signature_ids.iter().all(|id| keys.contains_key(id));

View file

@ -2,7 +2,7 @@ use ruma::{DeviceId, RoomId, UserId};
use crate::Result;
pub(crate) trait Data: Send + Sync {
pub trait Data: Send + Sync {
fn lazy_load_was_sent_before(
&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, ll_user: &UserId,
) -> Result<bool>;

View file

@ -1,24 +1,25 @@
mod data;
use std::collections::{HashMap, HashSet};
use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
pub(crate) use data::Data;
pub use data::Data;
use ruma::{DeviceId, OwnedDeviceId, OwnedRoomId, OwnedUserId, RoomId, UserId};
use tokio::sync::Mutex;
use super::timeline::PduCount;
use crate::Result;
use crate::{PduCount, Result};
pub(crate) struct Service {
pub(crate) db: &'static dyn Data,
pub struct Service {
pub db: Arc<dyn Data>,
#[allow(clippy::type_complexity)]
pub(crate) lazy_load_waiting:
Mutex<HashMap<(OwnedUserId, OwnedDeviceId, OwnedRoomId, PduCount), HashSet<OwnedUserId>>>,
pub lazy_load_waiting: Mutex<HashMap<(OwnedUserId, OwnedDeviceId, OwnedRoomId, PduCount), HashSet<OwnedUserId>>>,
}
impl Service {
#[tracing::instrument(skip(self))]
pub(crate) fn lazy_load_was_sent_before(
pub fn lazy_load_was_sent_before(
&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, ll_user: &UserId,
) -> Result<bool> {
self.db
@ -26,7 +27,7 @@ impl Service {
}
#[tracing::instrument(skip(self))]
pub(crate) async fn lazy_load_mark_sent(
pub async fn lazy_load_mark_sent(
&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, lazy_load: HashSet<OwnedUserId>,
count: PduCount,
) {
@ -37,7 +38,7 @@ impl Service {
}
#[tracing::instrument(skip(self))]
pub(crate) async fn lazy_load_confirm_delivery(
pub async fn lazy_load_confirm_delivery(
&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, since: PduCount,
) -> Result<()> {
if let Some(user_ids) = self.lazy_load_waiting.lock().await.remove(&(
@ -56,7 +57,7 @@ impl Service {
}
#[tracing::instrument(skip(self))]
pub(crate) fn lazy_load_reset(&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId) -> Result<()> {
pub fn lazy_load_reset(&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId) -> Result<()> {
self.db.lazy_load_reset(user_id, device_id, room_id)
}
}

View file

@ -2,7 +2,7 @@ use ruma::{OwnedRoomId, RoomId};
use crate::Result;
pub(crate) trait Data: Send + Sync {
pub trait Data: Send + Sync {
fn exists(&self, room_id: &RoomId) -> Result<bool>;
fn iter_ids<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a>;
fn is_disabled(&self, room_id: &RoomId) -> Result<bool>;

View file

@ -1,32 +1,36 @@
mod data;
pub(crate) use data::Data;
use std::sync::Arc;
pub use data::Data;
use ruma::{OwnedRoomId, RoomId};
use crate::Result;
pub(crate) struct Service {
pub(crate) db: &'static dyn Data,
pub struct Service {
pub db: Arc<dyn Data>,
}
impl Service {
/// Checks if a room exists.
#[tracing::instrument(skip(self))]
pub(crate) fn exists(&self, room_id: &RoomId) -> Result<bool> { self.db.exists(room_id) }
pub fn exists(&self, room_id: &RoomId) -> Result<bool> { self.db.exists(room_id) }
pub(crate) fn iter_ids<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> { self.db.iter_ids() }
#[must_use]
pub fn iter_ids<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> { self.db.iter_ids() }
pub(crate) fn is_disabled(&self, room_id: &RoomId) -> Result<bool> { self.db.is_disabled(room_id) }
pub fn is_disabled(&self, room_id: &RoomId) -> Result<bool> { self.db.is_disabled(room_id) }
pub(crate) fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()> {
pub fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()> {
self.db.disable_room(room_id, disabled)
}
pub(crate) fn is_banned(&self, room_id: &RoomId) -> Result<bool> { self.db.is_banned(room_id) }
pub fn is_banned(&self, room_id: &RoomId) -> Result<bool> { self.db.is_banned(room_id) }
pub(crate) fn ban_room(&self, room_id: &RoomId, banned: bool) -> Result<()> { self.db.ban_room(room_id, banned) }
pub fn ban_room(&self, room_id: &RoomId, banned: bool) -> Result<()> { self.db.ban_room(room_id, banned) }
pub(crate) fn list_banned_rooms<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
#[must_use]
pub fn list_banned_rooms<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
self.db.list_banned_rooms()
}
}

View file

@ -1,25 +1,25 @@
pub(crate) mod alias;
pub(crate) mod auth_chain;
pub(crate) mod directory;
pub(crate) mod event_handler;
pub(crate) mod lazy_loading;
pub(crate) mod metadata;
pub(crate) mod outlier;
pub(crate) mod pdu_metadata;
pub(crate) mod read_receipt;
pub(crate) mod search;
pub(crate) mod short;
pub(crate) mod spaces;
pub(crate) mod state;
pub(crate) mod state_accessor;
pub(crate) mod state_cache;
pub(crate) mod state_compressor;
pub(crate) mod threads;
pub(crate) mod timeline;
pub(crate) mod typing;
pub(crate) mod user;
pub mod alias;
pub mod auth_chain;
pub mod directory;
pub mod event_handler;
pub mod lazy_loading;
pub mod metadata;
pub mod outlier;
pub mod pdu_metadata;
pub mod read_receipt;
pub mod search;
pub mod short;
pub mod spaces;
pub mod state;
pub mod state_accessor;
pub mod state_cache;
pub mod state_compressor;
pub mod threads;
pub mod timeline;
pub mod typing;
pub mod user;
pub(crate) trait Data:
pub trait Data:
alias::Data
+ auth_chain::Data
+ directory::Data
@ -40,25 +40,25 @@ pub(crate) trait Data:
{
}
pub(crate) struct Service {
pub(crate) alias: alias::Service,
pub(crate) auth_chain: auth_chain::Service,
pub(crate) directory: directory::Service,
pub(crate) event_handler: event_handler::Service,
pub(crate) lazy_loading: lazy_loading::Service,
pub(crate) metadata: metadata::Service,
pub(crate) outlier: outlier::Service,
pub(crate) pdu_metadata: pdu_metadata::Service,
pub(crate) read_receipt: read_receipt::Service,
pub(crate) search: search::Service,
pub(crate) short: short::Service,
pub(crate) state: state::Service,
pub(crate) state_accessor: state_accessor::Service,
pub(crate) state_cache: state_cache::Service,
pub(crate) state_compressor: state_compressor::Service,
pub(crate) timeline: timeline::Service,
pub(crate) threads: threads::Service,
pub(crate) typing: typing::Service,
pub(crate) spaces: spaces::Service,
pub(crate) user: user::Service,
pub struct Service {
pub alias: alias::Service,
pub auth_chain: auth_chain::Service,
pub directory: directory::Service,
pub event_handler: event_handler::Service,
pub lazy_loading: lazy_loading::Service,
pub metadata: metadata::Service,
pub outlier: outlier::Service,
pub pdu_metadata: pdu_metadata::Service,
pub read_receipt: read_receipt::Service,
pub search: search::Service,
pub short: short::Service,
pub state: state::Service,
pub state_accessor: state_accessor::Service,
pub state_cache: state_cache::Service,
pub state_compressor: state_compressor::Service,
pub timeline: timeline::Service,
pub threads: threads::Service,
pub typing: typing::Service,
pub spaces: spaces::Service,
pub user: user::Service,
}

View file

@ -2,7 +2,7 @@ use ruma::{CanonicalJsonObject, EventId};
use crate::{PduEvent, Result};
pub(crate) trait Data: Send + Sync {
pub trait Data: Send + Sync {
fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>>;
fn get_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>>;
fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()>;

View file

@ -1,17 +1,19 @@
mod data;
pub(crate) use data::Data;
use std::sync::Arc;
pub use data::Data;
use ruma::{CanonicalJsonObject, EventId};
use crate::{PduEvent, Result};
pub(crate) struct Service {
pub(crate) db: &'static dyn Data,
pub struct Service {
pub db: Arc<dyn Data>,
}
impl Service {
/// Returns the pdu from the outlier tree.
pub(crate) fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
pub fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
self.db.get_outlier_pdu_json(event_id)
}
@ -19,13 +21,11 @@ impl Service {
///
/// TODO: use this?
#[allow(dead_code)]
pub(crate) fn get_pdu_outlier(&self, event_id: &EventId) -> Result<Option<PduEvent>> {
self.db.get_outlier_pdu(event_id)
}
pub fn get_pdu_outlier(&self, event_id: &EventId) -> Result<Option<PduEvent>> { self.db.get_outlier_pdu(event_id) }
/// Append the PDU as an outlier.
#[tracing::instrument(skip(self, pdu))]
pub(crate) fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()> {
pub fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()> {
self.db.add_pdu_outlier(event_id, pdu)
}
}

View file

@ -2,9 +2,9 @@ use std::sync::Arc;
use ruma::{EventId, RoomId, UserId};
use crate::{service::rooms::timeline::PduCount, PduEvent, Result};
use crate::{PduCount, PduEvent, Result};
pub(crate) trait Data: Send + Sync {
pub trait Data: Send + Sync {
fn add_relation(&self, from: u64, to: u64) -> Result<()>;
#[allow(clippy::type_complexity)]
fn relations_until<'a>(

View file

@ -1,7 +1,8 @@
mod data;
use std::sync::Arc;
pub(crate) use data::Data;
pub use data::Data;
use ruma::{
api::{client::relations::get_relating_events, Direction},
events::{relation::RelationType, TimelineEventType},
@ -9,11 +10,10 @@ use ruma::{
};
use serde::Deserialize;
use super::timeline::PduCount;
use crate::{services, PduEvent, Result};
use crate::{services, PduCount, PduEvent, Result};
pub(crate) struct Service {
pub(crate) db: &'static dyn Data,
pub struct Service {
pub db: Arc<dyn Data>,
}
#[derive(Clone, Debug, Deserialize)]
@ -28,7 +28,7 @@ struct ExtractRelatesToEventId {
impl Service {
#[tracing::instrument(skip(self, from, to))]
pub(crate) fn add_relation(&self, from: PduCount, to: PduCount) -> Result<()> {
pub fn add_relation(&self, from: PduCount, to: PduCount) -> Result<()> {
match (from, to) {
(PduCount::Normal(f), PduCount::Normal(t)) => self.db.add_relation(f, t),
_ => {
@ -40,7 +40,7 @@ impl Service {
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn paginate_relations_with_filter(
pub fn paginate_relations_with_filter(
&self, sender_user: &UserId, room_id: &RoomId, target: &EventId, filter_event_type: &Option<TimelineEventType>,
filter_rel_type: &Option<RelationType>, from: &Option<String>, to: &Option<String>, limit: &Option<UInt>,
recurse: bool, dir: Direction,
@ -174,7 +174,7 @@ impl Service {
}
}
pub(crate) fn relations_until<'a>(
pub fn relations_until<'a>(
&'a self, user_id: &'a UserId, room_id: &'a RoomId, target: &'a EventId, until: PduCount, max_depth: u8,
) -> Result<Vec<(PduCount, PduEvent)>> {
let room_id = services().rooms.short.get_or_create_shortroomid(room_id)?;
@ -216,22 +216,18 @@ impl Service {
}
#[tracing::instrument(skip(self, room_id, event_ids))]
pub(crate) fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc<EventId>]) -> Result<()> {
pub fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc<EventId>]) -> Result<()> {
self.db.mark_as_referenced(room_id, event_ids)
}
#[tracing::instrument(skip(self))]
pub(crate) fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result<bool> {
pub fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result<bool> {
self.db.is_event_referenced(room_id, event_id)
}
#[tracing::instrument(skip(self))]
pub(crate) fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()> {
self.db.mark_event_soft_failed(event_id)
}
pub fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()> { self.db.mark_event_soft_failed(event_id) }
#[tracing::instrument(skip(self))]
pub(crate) fn is_event_soft_failed(&self, event_id: &EventId) -> Result<bool> {
self.db.is_event_soft_failed(event_id)
}
pub fn is_event_soft_failed(&self, event_id: &EventId) -> Result<bool> { self.db.is_event_soft_failed(event_id) }
}

Some files were not shown because too many files have changed in this diff Show more