continuwuity/src/service/users/mod.rs
Jason Volk a0b28aa602 Database Refactor
combine service/users data w/ mod unit

split sliding sync related out of service/users

instrument database entry points

remove increment crap from database interface

de-wrap all database get() calls

de-wrap all database insert() calls

de-wrap all database remove() calls

refactor database interface for async streaming

add query key serializer for database

implement Debug for result handle

add query deserializer for database

add deserialization trait for option handle

start a stream utils suite

de-wrap/asyncify/type-query count_one_time_keys()

de-wrap/asyncify users count

add admin query users command suite

de-wrap/asyncify users exists

de-wrap/partially asyncify user filter related

asyncify/de-wrap users device/keys related

asyncify/de-wrap user auth/misc related

asyncify/de-wrap users blurhash

asyncify/de-wrap account_data get; merge Data into Service

partial asyncify/de-wrap uiaa; merge Data into Service

partially asyncify/de-wrap transaction_ids get; merge Data into Service

partially asyncify/de-wrap key_backups; merge Data into Service

asyncify/de-wrap pusher service getters; merge Data into Service

asyncify/de-wrap rooms alias getters/some iterators

asyncify/de-wrap rooms directory getters/iterator

partially asyncify/de-wrap rooms lazy-loading

partially asyncify/de-wrap rooms metadata

asyncify/dewrap rooms outlier

asyncify/dewrap rooms pdu_metadata

dewrap/partially asyncify rooms read receipt

de-wrap rooms search service

de-wrap/partially asyncify rooms user service

partial de-wrap rooms state_compressor

de-wrap rooms state_cache

de-wrap room state et al

de-wrap rooms timeline service

additional users device/keys related

de-wrap/asyncify sender

asyncify services

refactor database to TryFuture/TryStream

refactor services for TryFuture/TryStream

asyncify api handlers

additional asyncification for admin module

abstract stream related; support reverse streams

additional stream conversions

asyncify state-res related

Signed-off-by: Jason Volk <jason@zemos.net>
2024-10-21 22:07:37 +00:00

1006 lines
32 KiB
Rust

use std::{collections::BTreeMap, mem, mem::size_of, sync::Arc};
use conduit::{
debug_warn, err, utils,
utils::{stream::TryIgnore, string::Unquoted, ReadyExt, TryReadyExt},
warn, Err, Error, Result, Server,
};
use database::{Deserialized, Ignore, Interfix, Map};
use futures::{pin_mut, FutureExt, Stream, StreamExt, TryFutureExt};
use ruma::{
api::client::{device::Device, error::ErrorKind, filter::FilterDefinition},
encryption::{CrossSigningKey, DeviceKeys, OneTimeKey},
events::{AnyToDeviceEvent, StateEventType},
serde::Raw,
DeviceId, DeviceKeyAlgorithm, DeviceKeyId, MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedDeviceKeyId,
OwnedMxcUri, OwnedUserId, UInt, UserId,
};
use crate::{admin, globals, rooms, Dep};
pub struct Service {
services: Services,
db: Data,
}
struct Services {
server: Arc<Server>,
admin: Dep<admin::Service>,
globals: Dep<globals::Service>,
state_accessor: Dep<rooms::state_accessor::Service>,
state_cache: Dep<rooms::state_cache::Service>,
}
struct Data {
keychangeid_userid: Arc<Map>,
keyid_key: Arc<Map>,
onetimekeyid_onetimekeys: Arc<Map>,
openidtoken_expiresatuserid: Arc<Map>,
todeviceid_events: Arc<Map>,
token_userdeviceid: Arc<Map>,
userdeviceid_metadata: Arc<Map>,
userdeviceid_token: Arc<Map>,
userfilterid_filter: Arc<Map>,
userid_avatarurl: Arc<Map>,
userid_blurhash: Arc<Map>,
userid_devicelistversion: Arc<Map>,
userid_displayname: Arc<Map>,
userid_lastonetimekeyupdate: Arc<Map>,
userid_masterkeyid: Arc<Map>,
userid_password: Arc<Map>,
userid_selfsigningkeyid: Arc<Map>,
userid_usersigningkeyid: Arc<Map>,
useridprofilekey_value: Arc<Map>,
}
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
services: Services {
server: args.server.clone(),
admin: args.depend::<admin::Service>("admin"),
globals: args.depend::<globals::Service>("globals"),
state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"),
},
db: Data {
keychangeid_userid: args.db["keychangeid_userid"].clone(),
keyid_key: args.db["keyid_key"].clone(),
onetimekeyid_onetimekeys: args.db["onetimekeyid_onetimekeys"].clone(),
openidtoken_expiresatuserid: args.db["openidtoken_expiresatuserid"].clone(),
todeviceid_events: args.db["todeviceid_events"].clone(),
token_userdeviceid: args.db["token_userdeviceid"].clone(),
userdeviceid_metadata: args.db["userdeviceid_metadata"].clone(),
userdeviceid_token: args.db["userdeviceid_token"].clone(),
userfilterid_filter: args.db["userfilterid_filter"].clone(),
userid_avatarurl: args.db["userid_avatarurl"].clone(),
userid_blurhash: args.db["userid_blurhash"].clone(),
userid_devicelistversion: args.db["userid_devicelistversion"].clone(),
userid_displayname: args.db["userid_displayname"].clone(),
userid_lastonetimekeyupdate: args.db["userid_lastonetimekeyupdate"].clone(),
userid_masterkeyid: args.db["userid_masterkeyid"].clone(),
userid_password: args.db["userid_password"].clone(),
userid_selfsigningkeyid: args.db["userid_selfsigningkeyid"].clone(),
userid_usersigningkeyid: args.db["userid_usersigningkeyid"].clone(),
useridprofilekey_value: args.db["useridprofilekey_value"].clone(),
},
}))
}
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
}
impl Service {
/// Check if a user is an admin
#[inline]
pub async fn is_admin(&self, user_id: &UserId) -> bool { self.services.admin.user_is_admin(user_id).await }
/// Create a new user account on this homeserver.
#[inline]
pub fn create(&self, user_id: &UserId, password: Option<&str>) -> Result<()> {
self.set_password(user_id, password)
}
/// Deactivate account
pub async fn deactivate_account(&self, user_id: &UserId) -> Result<()> {
// Remove all associated devices
self.all_device_ids(user_id)
.for_each(|device_id| self.remove_device(user_id, device_id))
.await;
// Set the password to "" to indicate a deactivated account. Hashes will never
// result in an empty string, so the user will not be able to log in again.
// Systems like changing the password without logging in should check if the
// account is deactivated.
self.set_password(user_id, None)?;
// TODO: Unhook 3PID
Ok(())
}
/// Check if a user has an account on this homeserver.
#[inline]
pub async fn exists(&self, user_id: &UserId) -> bool { self.db.userid_password.qry(user_id).await.is_ok() }
/// Check if account is deactivated
pub async fn is_deactivated(&self, user_id: &UserId) -> Result<bool> {
self.db
.userid_password
.qry(user_id)
.map_ok(|val| val.is_empty())
.map_err(|_| err!(Request(NotFound("User does not exist."))))
.await
}
/// Check if account is active, infallible
pub async fn is_active(&self, user_id: &UserId) -> bool { !self.is_deactivated(user_id).await.unwrap_or(true) }
/// Check if account is active, infallible
pub async fn is_active_local(&self, user_id: &UserId) -> bool {
self.services.globals.user_is_local(user_id) && self.is_active(user_id).await
}
/// Returns the number of users registered on this server.
#[inline]
pub async fn count(&self) -> usize { self.db.userid_password.count().await }
/// Find out which user an access token belongs to.
pub async fn find_from_token(&self, token: &str) -> Result<(OwnedUserId, OwnedDeviceId)> {
self.db.token_userdeviceid.qry(token).await.deserialized()
}
/// Returns an iterator over all users on this homeserver (offered for
/// compatibility)
#[allow(clippy::iter_without_into_iter, clippy::iter_not_returning_iterator)]
pub fn iter(&self) -> impl Stream<Item = OwnedUserId> + Send + '_ { self.stream().map(ToOwned::to_owned) }
/// Returns an iterator over all users on this homeserver.
pub fn stream(&self) -> impl Stream<Item = &UserId> + Send { self.db.userid_password.keys().ignore_err() }
/// Returns a list of local users as list of usernames.
///
/// A user account is considered `local` if the length of it's password is
/// greater then zero.
pub fn list_local_users(&self) -> impl Stream<Item = &UserId> + Send + '_ {
self.db
.userid_password
.stream()
.ignore_err()
.ready_filter_map(|(u, p): (&UserId, &[u8])| (!p.is_empty()).then_some(u))
}
/// Returns the password hash for the given user.
pub async fn password_hash(&self, user_id: &UserId) -> Result<String> {
self.db.userid_password.qry(user_id).await.deserialized()
}
/// Hash and set the user's password to the Argon2 hash
pub fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> {
if let Some(password) = password {
if let Ok(hash) = utils::hash::password(password) {
self.db
.userid_password
.insert(user_id.as_bytes(), hash.as_bytes());
Ok(())
} else {
Err(Error::BadRequest(
ErrorKind::InvalidParam,
"Password does not meet the requirements.",
))
}
} else {
self.db.userid_password.insert(user_id.as_bytes(), b"");
Ok(())
}
}
/// Returns the displayname of a user on this homeserver.
pub async fn displayname(&self, user_id: &UserId) -> Result<String> {
self.db.userid_displayname.qry(user_id).await.deserialized()
}
/// Sets a new displayname or removes it if displayname is None. You still
/// need to nofify all rooms of this change.
pub fn set_displayname(&self, user_id: &UserId, displayname: Option<String>) {
if let Some(displayname) = displayname {
self.db
.userid_displayname
.insert(user_id.as_bytes(), displayname.as_bytes());
} else {
self.db.userid_displayname.remove(user_id.as_bytes());
}
}
/// Get the `avatar_url` of a user.
pub async fn avatar_url(&self, user_id: &UserId) -> Result<OwnedMxcUri> {
self.db.userid_avatarurl.qry(user_id).await.deserialized()
}
/// Sets a new avatar_url or removes it if avatar_url is None.
pub fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option<OwnedMxcUri>) {
if let Some(avatar_url) = avatar_url {
self.db
.userid_avatarurl
.insert(user_id.as_bytes(), avatar_url.to_string().as_bytes());
} else {
self.db.userid_avatarurl.remove(user_id.as_bytes());
}
}
/// Get the blurhash of a user.
pub async fn blurhash(&self, user_id: &UserId) -> Result<String> {
self.db.userid_blurhash.qry(user_id).await.deserialized()
}
/// Sets a new avatar_url or removes it if avatar_url is None.
pub fn set_blurhash(&self, user_id: &UserId, blurhash: Option<String>) {
if let Some(blurhash) = blurhash {
self.db
.userid_blurhash
.insert(user_id.as_bytes(), blurhash.as_bytes());
} else {
self.db.userid_blurhash.remove(user_id.as_bytes());
}
}
/// Adds a new device to a user.
pub async fn create_device(
&self, user_id: &UserId, device_id: &DeviceId, token: &str, initial_device_display_name: Option<String>,
client_ip: Option<String>,
) -> Result<()> {
// This method should never be called for nonexistent users. We shouldn't assert
// though...
if !self.exists(user_id).await {
warn!("Called create_device for non-existent user {} in database", user_id);
return Err(Error::BadRequest(ErrorKind::InvalidParam, "User does not exist."));
}
let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes());
increment(&self.db.userid_devicelistversion, user_id.as_bytes());
self.db.userdeviceid_metadata.insert(
&userdeviceid,
&serde_json::to_vec(&Device {
device_id: device_id.into(),
display_name: initial_device_display_name,
last_seen_ip: client_ip,
last_seen_ts: Some(MilliSecondsSinceUnixEpoch::now()),
})
.expect("Device::to_string never fails."),
);
self.set_token(user_id, device_id, token).await?;
Ok(())
}
/// Removes a device from a user.
pub async fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) {
let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes());
// Remove tokens
if let Ok(old_token) = self.db.userdeviceid_token.qry(&userdeviceid).await {
self.db.userdeviceid_token.remove(&userdeviceid);
self.db.token_userdeviceid.remove(&old_token);
}
// Remove todevice events
let prefix = (user_id, device_id, Interfix);
self.db
.todeviceid_events
.keys_raw_prefix(&prefix)
.ignore_err()
.ready_for_each(|key| self.db.todeviceid_events.remove(key))
.await;
// TODO: Remove onetimekeys
increment(&self.db.userid_devicelistversion, user_id.as_bytes());
self.db.userdeviceid_metadata.remove(&userdeviceid);
}
/// Returns an iterator over all device ids of this user.
pub fn all_device_ids<'a>(&'a self, user_id: &'a UserId) -> impl Stream<Item = &DeviceId> + Send + 'a {
let prefix = (user_id, Interfix);
self.db
.userdeviceid_metadata
.keys_prefix(&prefix)
.ignore_err()
.map(|(_, device_id): (Ignore, &DeviceId)| device_id)
}
/// Replaces the access token of one device.
pub async fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> {
let key = (user_id, device_id);
// should not be None, but we shouldn't assert either lol...
if self.db.userdeviceid_metadata.qry(&key).await.is_err() {
return Err!(Database(error!(
?user_id,
?device_id,
"User does not exist or device has no metadata."
)));
}
// Remove old token
if let Ok(old_token) = self.db.userdeviceid_token.qry(&key).await {
self.db.token_userdeviceid.remove(&old_token);
// It will be removed from userdeviceid_token by the insert later
}
// Assign token to user device combination
let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes());
self.db
.userdeviceid_token
.insert(&userdeviceid, token.as_bytes());
self.db
.token_userdeviceid
.insert(token.as_bytes(), &userdeviceid);
Ok(())
}
pub async fn add_one_time_key(
&self, user_id: &UserId, device_id: &DeviceId, one_time_key_key: &DeviceKeyId,
one_time_key_value: &Raw<OneTimeKey>,
) -> Result<()> {
// All devices have metadata
// Only existing devices should be able to call this, but we shouldn't assert
// either...
let key = (user_id, device_id);
if self.db.userdeviceid_metadata.qry(&key).await.is_err() {
return Err!(Database(error!(
?user_id,
?device_id,
"User does not exist or device has no metadata."
)));
}
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(device_id.as_bytes());
key.push(0xFF);
// TODO: Use DeviceKeyId::to_string when it's available (and update everything,
// because there are no wrapping quotation marks anymore)
key.extend_from_slice(
serde_json::to_string(one_time_key_key)
.expect("DeviceKeyId::to_string always works")
.as_bytes(),
);
self.db.onetimekeyid_onetimekeys.insert(
&key,
&serde_json::to_vec(&one_time_key_value).expect("OneTimeKey::to_vec always works"),
);
self.db
.userid_lastonetimekeyupdate
.insert(user_id.as_bytes(), &self.services.globals.next_count()?.to_be_bytes());
Ok(())
}
pub async fn last_one_time_keys_update(&self, user_id: &UserId) -> u64 {
self.db
.userid_lastonetimekeyupdate
.qry(user_id)
.await
.deserialized()
.unwrap_or(0)
}
pub async fn take_one_time_key(
&self, user_id: &UserId, device_id: &DeviceId, key_algorithm: &DeviceKeyAlgorithm,
) -> Result<(OwnedDeviceKeyId, Raw<OneTimeKey>)> {
self.db
.userid_lastonetimekeyupdate
.insert(user_id.as_bytes(), &self.services.globals.next_count()?.to_be_bytes());
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(device_id.as_bytes());
prefix.push(0xFF);
prefix.push(b'"'); // Annoying quotation mark
prefix.extend_from_slice(key_algorithm.as_ref().as_bytes());
prefix.push(b':');
let one_time_key = self
.db
.onetimekeyid_onetimekeys
.raw_stream_prefix(&prefix)
.ignore_err()
.map(|(key, val)| {
self.db.onetimekeyid_onetimekeys.remove(key);
let key = key
.rsplit(|&b| b == 0xFF)
.next()
.ok_or_else(|| err!(Database("OneTimeKeyId in db is invalid.")))
.unwrap();
let key = serde_json::from_slice(key)
.map_err(|e| err!(Database("OneTimeKeyId in db is invalid. {e}")))
.unwrap();
let val = serde_json::from_slice(val)
.map_err(|e| err!(Database("OneTimeKeys in db are invalid. {e}")))
.unwrap();
(key, val)
})
.next()
.await;
one_time_key.ok_or_else(|| err!(Request(NotFound("No one-time-key found"))))
}
pub async fn count_one_time_keys(
&self, user_id: &UserId, device_id: &DeviceId,
) -> BTreeMap<DeviceKeyAlgorithm, UInt> {
type KeyVal<'a> = ((Ignore, Ignore, &'a Unquoted), Ignore);
let mut algorithm_counts = BTreeMap::<DeviceKeyAlgorithm, UInt>::new();
let query = (user_id, device_id);
self.db
.onetimekeyid_onetimekeys
.stream_prefix(&query)
.ignore_err()
.ready_for_each(|((Ignore, Ignore, device_key_id), Ignore): KeyVal<'_>| {
let device_key_id: &DeviceKeyId = device_key_id
.as_str()
.try_into()
.expect("Invalid DeviceKeyID in database");
let count: &mut UInt = algorithm_counts
.entry(device_key_id.algorithm())
.or_default();
*count = count.saturating_add(1_u32.into());
})
.await;
algorithm_counts
}
pub async fn add_device_keys(&self, user_id: &UserId, device_id: &DeviceId, device_keys: &Raw<DeviceKeys>) {
let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes());
self.db.keyid_key.insert(
&userdeviceid,
&serde_json::to_vec(&device_keys).expect("DeviceKeys::to_vec always works"),
);
self.mark_device_key_update(user_id).await;
}
pub async fn add_cross_signing_keys(
&self, user_id: &UserId, master_key: &Raw<CrossSigningKey>, self_signing_key: &Option<Raw<CrossSigningKey>>,
user_signing_key: &Option<Raw<CrossSigningKey>>, notify: bool,
) -> Result<()> {
// TODO: Check signatures
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
let (master_key_key, _) = parse_master_key(user_id, master_key)?;
self.db
.keyid_key
.insert(&master_key_key, master_key.json().get().as_bytes());
self.db
.userid_masterkeyid
.insert(user_id.as_bytes(), &master_key_key);
// Self-signing key
if let Some(self_signing_key) = self_signing_key {
let mut self_signing_key_ids = self_signing_key
.deserialize()
.map_err(|e| err!(Request(InvalidParam("Invalid self signing key: {e:?}"))))?
.keys
.into_values();
let self_signing_key_id = self_signing_key_ids
.next()
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Self signing key contained no key."))?;
if self_signing_key_ids.next().is_some() {
return Err(Error::BadRequest(
ErrorKind::InvalidParam,
"Self signing key contained more than one key.",
));
}
let mut self_signing_key_key = prefix.clone();
self_signing_key_key.extend_from_slice(self_signing_key_id.as_bytes());
self.db
.keyid_key
.insert(&self_signing_key_key, self_signing_key.json().get().as_bytes());
self.db
.userid_selfsigningkeyid
.insert(user_id.as_bytes(), &self_signing_key_key);
}
// User-signing key
if let Some(user_signing_key) = user_signing_key {
let mut user_signing_key_ids = user_signing_key
.deserialize()
.map_err(|_| err!(Request(InvalidParam("Invalid user signing key"))))?
.keys
.into_values();
let user_signing_key_id = user_signing_key_ids
.next()
.ok_or(err!(Request(InvalidParam("User signing key contained no key."))))?;
if user_signing_key_ids.next().is_some() {
return Err!(Request(InvalidParam("User signing key contained more than one key.")));
}
let mut user_signing_key_key = prefix;
user_signing_key_key.extend_from_slice(user_signing_key_id.as_bytes());
self.db
.keyid_key
.insert(&user_signing_key_key, user_signing_key.json().get().as_bytes());
self.db
.userid_usersigningkeyid
.insert(user_id.as_bytes(), &user_signing_key_key);
}
if notify {
self.mark_device_key_update(user_id).await;
}
Ok(())
}
pub async fn sign_key(
&self, target_id: &UserId, key_id: &str, signature: (String, String), sender_id: &UserId,
) -> Result<()> {
let key = (target_id, key_id);
let mut cross_signing_key: serde_json::Value = self
.db
.keyid_key
.qry(&key)
.await
.map_err(|_| err!(Request(InvalidParam("Tried to sign nonexistent key."))))?
.deserialized_json()
.map_err(|e| err!(Database("key in keyid_key is invalid. {e:?}")))?;
let signatures = cross_signing_key
.get_mut("signatures")
.ok_or_else(|| err!(Database("key in keyid_key has no signatures field.")))?
.as_object_mut()
.ok_or_else(|| err!(Database("key in keyid_key has invalid signatures field.")))?
.entry(sender_id.to_string())
.or_insert_with(|| serde_json::Map::new().into());
signatures
.as_object_mut()
.ok_or_else(|| err!(Database("signatures in keyid_key for a user is invalid.")))?
.insert(signature.0, signature.1.into());
let mut key = target_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(key_id.as_bytes());
self.db.keyid_key.insert(
&key,
&serde_json::to_vec(&cross_signing_key).expect("CrossSigningKey::to_vec always works"),
);
self.mark_device_key_update(target_id).await;
Ok(())
}
pub fn keys_changed<'a>(
&'a self, user_or_room_id: &'a str, from: u64, to: Option<u64>,
) -> impl Stream<Item = &UserId> + Send + 'a {
type KeyVal<'a> = ((&'a str, u64), &'a UserId);
let to = to.unwrap_or(u64::MAX);
let start = (user_or_room_id, from.saturating_add(1));
self.db
.keychangeid_userid
.stream_from(&start)
.ignore_err()
.ready_take_while(move |((prefix, count), _): &KeyVal<'_>| *prefix == user_or_room_id && *count <= to)
.map(|((..), user_id): KeyVal<'_>| user_id)
}
pub async fn mark_device_key_update(&self, user_id: &UserId) {
let count = self.services.globals.next_count().unwrap().to_be_bytes();
let rooms_joined = self.services.state_cache.rooms_joined(user_id);
pin_mut!(rooms_joined);
while let Some(room_id) = rooms_joined.next().await {
// Don't send key updates to unencrypted rooms
if self
.services
.state_accessor
.room_state_get(room_id, &StateEventType::RoomEncryption, "")
.await
.is_err()
{
continue;
}
let mut key = room_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(&count);
self.db.keychangeid_userid.insert(&key, user_id.as_bytes());
}
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(&count);
self.db.keychangeid_userid.insert(&key, user_id.as_bytes());
}
pub async fn get_device_keys<'a>(&'a self, user_id: &'a UserId, device_id: &DeviceId) -> Result<Raw<DeviceKeys>> {
let key_id = (user_id, device_id);
self.db.keyid_key.qry(&key_id).await.deserialized_json()
}
pub async fn get_key<F>(
&self, key_id: &[u8], sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &F,
) -> Result<Raw<CrossSigningKey>>
where
F: Fn(&UserId) -> bool + Send + Sync,
{
let key = self
.db
.keyid_key
.qry(key_id)
.await
.deserialized_json::<serde_json::Value>()?;
let cleaned = clean_signatures(key, sender_user, user_id, allowed_signatures)?;
let raw_value = serde_json::value::to_raw_value(&cleaned)?;
Ok(Raw::from_json(raw_value))
}
pub async fn get_master_key<F>(
&self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &F,
) -> Result<Raw<CrossSigningKey>>
where
F: Fn(&UserId) -> bool + Send + Sync,
{
let key_id = self.db.userid_masterkeyid.qry(user_id).await?;
self.get_key(&key_id, sender_user, user_id, allowed_signatures)
.await
}
pub async fn get_self_signing_key<F>(
&self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &F,
) -> Result<Raw<CrossSigningKey>>
where
F: Fn(&UserId) -> bool + Send + Sync,
{
let key_id = self.db.userid_selfsigningkeyid.qry(user_id).await?;
self.get_key(&key_id, sender_user, user_id, allowed_signatures)
.await
}
pub async fn get_user_signing_key(&self, user_id: &UserId) -> Result<Raw<CrossSigningKey>> {
let key_id = self.db.userid_usersigningkeyid.qry(user_id).await?;
self.db.keyid_key.qry(&*key_id).await.deserialized_json()
}
pub async fn add_to_device_event(
&self, sender: &UserId, target_user_id: &UserId, target_device_id: &DeviceId, event_type: &str,
content: serde_json::Value,
) {
let mut key = target_user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(target_device_id.as_bytes());
key.push(0xFF);
key.extend_from_slice(&self.services.globals.next_count().unwrap().to_be_bytes());
let mut json = serde_json::Map::new();
json.insert("type".to_owned(), event_type.to_owned().into());
json.insert("sender".to_owned(), sender.to_string().into());
json.insert("content".to_owned(), content);
let value = serde_json::to_vec(&json).expect("Map::to_vec always works");
self.db.todeviceid_events.insert(&key, &value);
}
pub fn get_to_device_events<'a>(
&'a self, user_id: &'a UserId, device_id: &'a DeviceId,
) -> impl Stream<Item = Raw<AnyToDeviceEvent>> + Send + 'a {
let prefix = (user_id, device_id, Interfix);
self.db
.todeviceid_events
.stream_raw_prefix(&prefix)
.ready_and_then(|(_, val)| serde_json::from_slice(val).map_err(Into::into))
.ignore_err()
}
pub async fn remove_to_device_events(&self, user_id: &UserId, device_id: &DeviceId, until: u64) {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(device_id.as_bytes());
prefix.push(0xFF);
let mut last = prefix.clone();
last.extend_from_slice(&until.to_be_bytes());
self.db
.todeviceid_events
.rev_raw_keys_from(&last) // this includes last
.ignore_err()
.ready_take_while(move |key| key.starts_with(&prefix))
.map(|key| {
let len = key.len();
let start = len.saturating_sub(size_of::<u64>());
let count = utils::u64_from_u8(&key[start..len]);
(key, count)
})
.ready_take_while(move |(_, count)| *count <= until)
.ready_for_each(|(key, _)| self.db.todeviceid_events.remove(&key))
.boxed()
.await;
}
pub async fn update_device_metadata(&self, user_id: &UserId, device_id: &DeviceId, device: &Device) -> Result<()> {
// Only existing devices should be able to call this, but we shouldn't assert
// either...
let key = (user_id, device_id);
if self.db.userdeviceid_metadata.qry(&key).await.is_err() {
return Err!(Database(error!(
?user_id,
?device_id,
"Called update_device_metadata for a non-existent user and/or device"
)));
}
increment(&self.db.userid_devicelistversion, user_id.as_bytes());
let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes());
self.db.userdeviceid_metadata.insert(
&userdeviceid,
&serde_json::to_vec(device).expect("Device::to_string always works"),
);
Ok(())
}
/// Get device metadata.
pub async fn get_device_metadata(&self, user_id: &UserId, device_id: &DeviceId) -> Result<Device> {
self.db
.userdeviceid_metadata
.qry(&(user_id, device_id))
.await
.deserialized_json()
}
pub async fn get_devicelist_version(&self, user_id: &UserId) -> Result<u64> {
self.db
.userid_devicelistversion
.qry(user_id)
.await
.deserialized()
}
pub fn all_devices_metadata<'a>(&'a self, user_id: &'a UserId) -> impl Stream<Item = Device> + Send + 'a {
self.db
.userdeviceid_metadata
.stream_raw_prefix(&(user_id, Interfix))
.ready_and_then(|(_, val)| serde_json::from_slice::<Device>(val).map_err(Into::into))
.ignore_err()
}
/// Creates a new sync filter. Returns the filter id.
pub fn create_filter(&self, user_id: &UserId, filter: &FilterDefinition) -> String {
let filter_id = utils::random_string(4);
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(filter_id.as_bytes());
self.db
.userfilterid_filter
.insert(&key, &serde_json::to_vec(&filter).expect("filter is valid json"));
filter_id
}
pub async fn get_filter(&self, user_id: &UserId, filter_id: &str) -> Result<FilterDefinition> {
self.db
.userfilterid_filter
.qry(&(user_id, filter_id))
.await
.deserialized_json()
}
/// Creates an OpenID token, which can be used to prove that a user has
/// access to an account (primarily for integrations)
pub fn create_openid_token(&self, user_id: &UserId, token: &str) -> Result<u64> {
use std::num::Saturating as Sat;
let expires_in = self.services.server.config.openid_token_ttl;
let expires_at = Sat(utils::millis_since_unix_epoch()) + Sat(expires_in) * Sat(1000);
let mut value = expires_at.0.to_be_bytes().to_vec();
value.extend_from_slice(user_id.as_bytes());
self.db
.openidtoken_expiresatuserid
.insert(token.as_bytes(), value.as_slice());
Ok(expires_in)
}
/// Find out which user an OpenID access token belongs to.
pub async fn find_from_openid_token(&self, token: &str) -> Result<OwnedUserId> {
let Ok(value) = self.db.openidtoken_expiresatuserid.qry(token).await else {
return Err!(Request(Unauthorized("OpenID token is unrecognised")));
};
let (expires_at_bytes, user_bytes) = value.split_at(0_u64.to_be_bytes().len());
let expires_at = u64::from_be_bytes(
expires_at_bytes
.try_into()
.map_err(|e| err!(Database("expires_at in openid_userid is invalid u64. {e}")))?,
);
if expires_at < utils::millis_since_unix_epoch() {
debug_warn!("OpenID token is expired, removing");
self.db.openidtoken_expiresatuserid.remove(token.as_bytes());
return Err!(Request(Unauthorized("OpenID token is expired")));
}
let user_string = utils::string_from_bytes(user_bytes)
.map_err(|e| err!(Database("User ID in openid_userid is invalid unicode. {e}")))?;
UserId::parse(user_string).map_err(|e| err!(Database("User ID in openid_userid is invalid. {e}")))
}
/// Gets a specific user profile key
pub async fn profile_key(&self, user_id: &UserId, profile_key: &str) -> Result<serde_json::Value> {
let key = (user_id, profile_key);
self.db
.useridprofilekey_value
.qry(&key)
.await
.deserialized()
}
/// Gets all the user's profile keys and values in an iterator
pub fn all_profile_keys<'a>(
&'a self, user_id: &'a UserId,
) -> impl Stream<Item = (String, serde_json::Value)> + 'a + Send {
type KeyVal = ((Ignore, String), serde_json::Value);
let prefix = (user_id, Interfix);
self.db
.useridprofilekey_value
.stream_prefix(&prefix)
.ignore_err()
.map(|((_, key), val): KeyVal| (key, val))
}
/// Sets a new profile key value, removes the key if value is None
pub fn set_profile_key(&self, user_id: &UserId, profile_key: &str, profile_key_value: Option<serde_json::Value>) {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(profile_key.as_bytes());
// TODO: insert to the stable MSC4175 key when it's stable
if let Some(value) = profile_key_value {
let value = serde_json::to_vec(&value).unwrap();
self.db.useridprofilekey_value.insert(&key, &value);
} else {
self.db.useridprofilekey_value.remove(&key);
}
}
/// Get the timezone of a user.
pub async fn timezone(&self, user_id: &UserId) -> Result<String> {
// TODO: transparently migrate unstable key usage to the stable key once MSC4133
// and MSC4175 are stable, likely a remove/insert in this block.
// first check the unstable prefix then check the stable prefix
let unstable_key = (user_id, "us.cloke.msc4175.tz");
let stable_key = (user_id, "m.tz");
self.db
.useridprofilekey_value
.qry(&unstable_key)
.or_else(|_| self.db.useridprofilekey_value.qry(&stable_key))
.await
.deserialized()
}
/// Sets a new timezone or removes it if timezone is None.
pub fn set_timezone(&self, user_id: &UserId, timezone: Option<String>) {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(b"us.cloke.msc4175.tz");
// TODO: insert to the stable MSC4175 key when it's stable
if let Some(timezone) = timezone {
self.db
.useridprofilekey_value
.insert(&key, timezone.as_bytes());
} else {
self.db.useridprofilekey_value.remove(&key);
}
}
}
pub fn parse_master_key(user_id: &UserId, master_key: &Raw<CrossSigningKey>) -> Result<(Vec<u8>, CrossSigningKey)> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
let master_key = master_key
.deserialize()
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid master key"))?;
let mut master_key_ids = master_key.keys.values();
let master_key_id = master_key_ids
.next()
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Master key contained no key."))?;
if master_key_ids.next().is_some() {
return Err(Error::BadRequest(
ErrorKind::InvalidParam,
"Master key contained more than one key.",
));
}
let mut master_key_key = prefix.clone();
master_key_key.extend_from_slice(master_key_id.as_bytes());
Ok((master_key_key, master_key))
}
/// Ensure that a user only sees signatures from themselves and the target user
fn clean_signatures<F>(
mut cross_signing_key: serde_json::Value, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &F,
) -> Result<serde_json::Value>
where
F: Fn(&UserId) -> bool + Send + Sync,
{
if let Some(signatures) = cross_signing_key
.get_mut("signatures")
.and_then(|v| v.as_object_mut())
{
// Don't allocate for the full size of the current signatures, but require
// at most one resize if nothing is dropped
let new_capacity = signatures.len() / 2;
for (user, signature) in mem::replace(signatures, serde_json::Map::with_capacity(new_capacity)) {
let sid =
<&UserId>::try_from(user.as_str()).map_err(|_| Error::bad_database("Invalid user ID in database."))?;
if sender_user == Some(user_id) || sid == user_id || allowed_signatures(sid) {
signatures.insert(user, signature);
}
}
}
Ok(cross_signing_key)
}
//TODO: this is an ABA
fn increment(db: &Arc<Map>, key: &[u8]) {
let old = db.get(key);
let new = utils::increment(old.ok().as_deref());
db.insert(key, &new);
}