Compare commits

...

10 commits

Author SHA1 Message Date
Jade Ellis
95610499c7
chore: Disable direnv's nix flake interfering with cargo cache
Some checks failed
Release Docker Image / build-image (linux/amd64, release, linux-amd64, base) (push) Has been skipped
Release Docker Image / build-image (linux/arm64, release, linux-arm64, base) (push) Has been skipped
Release Docker Image / merge (push) Has been skipped
Documentation / Build and Deploy Documentation (push) Has been skipped
Checks / Prefligit / prefligit (push) Failing after 5s
Release Docker Image / define-variables (push) Failing after 3s
Checks / Rust / Format (push) Failing after 5s
Checks / Rust / Clippy (push) Failing after 30s
Checks / Rust / Cargo Test (push) Failing after 28s
2025-07-20 16:36:01 +01:00
Jade Ellis
f593cac58a
feat: Enable hardware-lock-elision and deadlock_detection 2025-07-20 16:35:59 +01:00
Jade Ellis
1c985c59f5
refactor: Allow with_lock to return data and take an async closure 2025-07-20 16:34:48 +01:00
Jade Ellis
b635e825d2
refactor: Implement with_lock for lock_api 2025-07-20 16:34:36 +01:00
Jade Ellis
6d29098d1a
refactor: Replace remaining std RwLocks 2025-07-20 16:33:36 +01:00
Jade Ellis
374fb2745c
refactor: Replace remaining std Mutexes 2025-07-20 16:32:48 +01:00
Jade Ellis
a1d616e3e3
refactor: Replace std RwLock with parking_lot 2025-07-20 16:31:55 +01:00
Jade Ellis
30a8c06fd9
refactor: Replace std Mutex with parking_lot 2025-07-20 16:31:02 +01:00
rooot
0631094350
docs(config): warn about federation key query timeout caveat
Signed-off-by: rooot <hey@rooot.gay>
2025-07-20 16:24:56 +01:00
rooot
9051ce63f7
feat(config): introduce federation connection timeout setting
fixes #906

Signed-off-by: rooot <hey@rooot.gay>
2025-07-20 16:24:26 +01:00
46 changed files with 438 additions and 359 deletions

2
.envrc
View file

@ -2,6 +2,6 @@
dotenv_if_exists dotenv_if_exists
use flake ".#${DIRENV_DEVSHELL:-default}" # use flake ".#${DIRENV_DEVSHELL:-default}"
PATH_add bin PATH_add bin

31
Cargo.lock generated
View file

@ -963,10 +963,12 @@ dependencies = [
"itertools 0.14.0", "itertools 0.14.0",
"libc", "libc",
"libloading", "libloading",
"lock_api",
"log", "log",
"maplit", "maplit",
"nix", "nix",
"num-traits", "num-traits",
"parking_lot",
"rand 0.8.5", "rand 0.8.5",
"regex", "regex",
"reqwest", "reqwest",
@ -1657,6 +1659,12 @@ dependencies = [
"winapi", "winapi",
] ]
[[package]]
name = "fixedbitset"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80"
[[package]] [[package]]
name = "flate2" name = "flate2"
version = "1.1.2" version = "1.1.2"
@ -3218,10 +3226,13 @@ version = "0.9.11"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bc838d2a56b5b1a6c25f55575dfc605fabb63bb2365f6c2353ef9159aa69e4a5" checksum = "bc838d2a56b5b1a6c25f55575dfc605fabb63bb2365f6c2353ef9159aa69e4a5"
dependencies = [ dependencies = [
"backtrace",
"cfg-if", "cfg-if",
"libc", "libc",
"petgraph",
"redox_syscall", "redox_syscall",
"smallvec", "smallvec",
"thread-id",
"windows-targets 0.52.6", "windows-targets 0.52.6",
] ]
@ -3271,6 +3282,16 @@ version = "2.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e"
[[package]]
name = "petgraph"
version = "0.6.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db"
dependencies = [
"fixedbitset",
"indexmap 2.9.0",
]
[[package]] [[package]]
name = "phf" name = "phf"
version = "0.11.3" version = "0.11.3"
@ -4892,6 +4913,16 @@ dependencies = [
"syn", "syn",
] ]
[[package]]
name = "thread-id"
version = "4.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cfe8f25bbdd100db7e1d34acf7fd2dc59c4bf8f7483f505eaa7d4f12f76cc0ea"
dependencies = [
"libc",
"winapi",
]
[[package]] [[package]]
name = "thread_local" name = "thread_local"
version = "1.1.9" version = "1.1.9"

View file

@ -515,6 +515,14 @@ version = "1.0"
[workspace.dependencies.proc-macro2] [workspace.dependencies.proc-macro2]
version = "1.0" version = "1.0"
[workspace.dependencies.parking_lot]
version = "0.12.4"
features = ["hardware-lock-elision", "deadlock_detection"] # TODO: Check if deadlock_detection has a perf impact, if it does only enable with debug_assertions
# Use this when extending with_lock::WithLock to parking_lot
[workspace.dependencies.lock_api]
version = "0.4.13"
[workspace.dependencies.bytesize] [workspace.dependencies.bytesize]
version = "2.0" version = "2.0"

View file

@ -325,6 +325,15 @@
# #
#well_known_timeout = 10 #well_known_timeout = 10
# Federation client connection timeout (seconds). You should not set this
# to high values, as dead homeservers can significantly slow down
# federation, specifically key retrieval, which will take roughly the
# amount of time you configure here given that a homeserver doesn't
# respond. This will cause most clients to time out /keys/query, causing
# E2EE and device verification to fail.
#
#federation_conn_timeout = 10
# Federation client request timeout (seconds). You most definitely want # Federation client request timeout (seconds). You most definitely want
# this to be high to account for extremely large room joins, slow # this to be high to account for extremely large room joins, slow
# homeservers, your own resources etc. # homeservers, your own resources etc.

View file

@ -26,8 +26,7 @@ pub(super) async fn incoming_federation(&self) -> Result {
.rooms .rooms
.event_handler .event_handler
.federation_handletime .federation_handletime
.read() .read();
.expect("locked");
let mut msg = format!("Handling {} incoming pdus:\n", map.len()); let mut msg = format!("Handling {} incoming pdus:\n", map.len());
for (r, (e, i)) in map.iter() { for (r, (e, i)) in map.iter() {

View file

@ -37,11 +37,7 @@ pub use crate::admin::AdminCommand;
/// Install the admin command processor /// Install the admin command processor
pub async fn init(admin_service: &service::admin::Service) { pub async fn init(admin_service: &service::admin::Service) {
_ = admin_service _ = admin_service.complete.write().insert(processor::complete);
.complete
.write()
.expect("locked for writing")
.insert(processor::complete);
_ = admin_service _ = admin_service
.handle .handle
.write() .write()
@ -52,9 +48,5 @@ pub async fn init(admin_service: &service::admin::Service) {
/// Uninstall the admin command handler /// Uninstall the admin command handler
pub async fn fini(admin_service: &service::admin::Service) { pub async fn fini(admin_service: &service::admin::Service) {
_ = admin_service.handle.write().await.take(); _ = admin_service.handle.write().await.take();
_ = admin_service _ = admin_service.complete.write().take();
.complete
.write()
.expect("locked for writing")
.take();
} }

View file

@ -1,14 +1,8 @@
use std::{ use std::{fmt::Write, mem::take, panic::AssertUnwindSafe, sync::Arc, time::SystemTime};
fmt::Write,
mem::take,
panic::AssertUnwindSafe,
sync::{Arc, Mutex},
time::SystemTime,
};
use clap::{CommandFactory, Parser}; use clap::{CommandFactory, Parser};
use conduwuit::{ use conduwuit::{
Error, Result, debug, error, Error, Result, SyncMutex, debug, error,
log::{ log::{
capture, capture,
capture::Capture, capture::Capture,
@ -123,7 +117,7 @@ async fn process(
let mut output = String::new(); let mut output = String::new();
// Prepend the logs only if any were captured // Prepend the logs only if any were captured
let logs = logs.lock().expect("locked"); let logs = logs.lock();
if logs.lines().count() > 2 { if logs.lines().count() > 2 {
writeln!(&mut output, "{logs}").expect("failed to format logs to command output"); writeln!(&mut output, "{logs}").expect("failed to format logs to command output");
} }
@ -132,7 +126,7 @@ async fn process(
(result, output) (result, output)
} }
fn capture_create(context: &Context<'_>) -> (Arc<Capture>, Arc<Mutex<String>>) { fn capture_create(context: &Context<'_>) -> (Arc<Capture>, Arc<SyncMutex<String>>) {
let env_config = &context.services.server.config.admin_log_capture; let env_config = &context.services.server.config.admin_log_capture;
let env_filter = EnvFilter::try_new(env_config).unwrap_or_else(|e| { let env_filter = EnvFilter::try_new(env_config).unwrap_or_else(|e| {
warn!("admin_log_capture filter invalid: {e:?}"); warn!("admin_log_capture filter invalid: {e:?}");
@ -152,7 +146,7 @@ fn capture_create(context: &Context<'_>) -> (Arc<Capture>, Arc<Mutex<String>>) {
data.level() <= log_level && data.our_modules() && data.scope.contains(&"admin") data.level() <= log_level && data.our_modules() && data.scope.contains(&"admin")
}; };
let logs = Arc::new(Mutex::new( let logs = Arc::new(SyncMutex::new(
collect_stream(|s| markdown_table_head(s)).expect("markdown table header"), collect_stream(|s| markdown_table_head(s)).expect("markdown table header"),
)); ));

View file

@ -110,6 +110,8 @@ tracing-core.workspace = true
tracing-subscriber.workspace = true tracing-subscriber.workspace = true
tracing.workspace = true tracing.workspace = true
url.workspace = true url.workspace = true
parking_lot.workspace = true
lock_api.workspace = true
[target.'cfg(unix)'.dependencies] [target.'cfg(unix)'.dependencies]
nix.workspace = true nix.workspace = true

View file

@ -4,7 +4,6 @@ use std::{
cell::OnceCell, cell::OnceCell,
ffi::{CStr, c_char, c_void}, ffi::{CStr, c_char, c_void},
fmt::Debug, fmt::Debug,
sync::RwLock,
}; };
use arrayvec::ArrayVec; use arrayvec::ArrayVec;
@ -13,7 +12,7 @@ use tikv_jemalloc_sys as ffi;
use tikv_jemallocator as jemalloc; use tikv_jemallocator as jemalloc;
use crate::{ use crate::{
Result, err, is_equal_to, is_nonzero, Result, SyncRwLock, err, is_equal_to, is_nonzero,
utils::{math, math::Tried}, utils::{math, math::Tried},
}; };
@ -40,7 +39,7 @@ const MALLOC_CONF_PROF: &str = "";
#[global_allocator] #[global_allocator]
static JEMALLOC: jemalloc::Jemalloc = jemalloc::Jemalloc; static JEMALLOC: jemalloc::Jemalloc = jemalloc::Jemalloc;
static CONTROL: RwLock<()> = RwLock::new(()); static CONTROL: SyncRwLock<()> = SyncRwLock::new(());
type Name = ArrayVec<u8, NAME_MAX>; type Name = ArrayVec<u8, NAME_MAX>;
type Key = ArrayVec<usize, KEY_SEGS>; type Key = ArrayVec<usize, KEY_SEGS>;
@ -332,7 +331,7 @@ fn set<T>(key: &Key, val: T) -> Result<T>
where where
T: Copy + Debug, T: Copy + Debug,
{ {
let _lock = CONTROL.write()?; let _lock = CONTROL.write();
let res = xchg(key, val)?; let res = xchg(key, val)?;
inc_epoch()?; inc_epoch()?;

View file

@ -412,6 +412,17 @@ pub struct Config {
#[serde(default = "default_well_known_timeout")] #[serde(default = "default_well_known_timeout")]
pub well_known_timeout: u64, pub well_known_timeout: u64,
/// Federation client connection timeout (seconds). You should not set this
/// to high values, as dead homeservers can significantly slow down
/// federation, specifically key retrieval, which will take roughly the
/// amount of time you configure here given that a homeserver doesn't
/// respond. This will cause most clients to time out /keys/query, causing
/// E2EE and device verification to fail.
///
/// default: 10
#[serde(default = "default_federation_conn_timeout")]
pub federation_conn_timeout: u64,
/// Federation client request timeout (seconds). You most definitely want /// Federation client request timeout (seconds). You most definitely want
/// this to be high to account for extremely large room joins, slow /// this to be high to account for extremely large room joins, slow
/// homeservers, your own resources etc. /// homeservers, your own resources etc.
@ -2193,6 +2204,8 @@ fn default_well_known_conn_timeout() -> u64 { 6 }
fn default_well_known_timeout() -> u64 { 10 } fn default_well_known_timeout() -> u64 { 10 }
fn default_federation_conn_timeout() -> u64 { 10 }
fn default_federation_timeout() -> u64 { 25 } fn default_federation_timeout() -> u64 { 25 }
fn default_federation_idle_timeout() -> u64 { 25 } fn default_federation_idle_timeout() -> u64 { 25 }

View file

@ -3,18 +3,15 @@
//! several crates, lower-level information is supplied from each crate during //! several crates, lower-level information is supplied from each crate during
//! static initialization. //! static initialization.
use std::{ use std::{collections::BTreeMap, sync::OnceLock};
collections::BTreeMap,
sync::{Mutex, OnceLock},
};
use crate::utils::exchange; use crate::{SyncMutex, utils::exchange};
/// Raw capture of rustc flags used to build each crate in the project. Informed /// Raw capture of rustc flags used to build each crate in the project. Informed
/// by rustc_flags_capture macro (one in each crate's mod.rs). This is /// by rustc_flags_capture macro (one in each crate's mod.rs). This is
/// done during static initialization which is why it's mutex-protected and pub. /// done during static initialization which is why it's mutex-protected and pub.
/// Should not be written to by anything other than our macro. /// Should not be written to by anything other than our macro.
pub static FLAGS: Mutex<BTreeMap<&str, &[&str]>> = Mutex::new(BTreeMap::new()); pub static FLAGS: SyncMutex<BTreeMap<&str, &[&str]>> = SyncMutex::new(BTreeMap::new());
/// Processed list of enabled features across all project crates. This is /// Processed list of enabled features across all project crates. This is
/// generated from the data in FLAGS. /// generated from the data in FLAGS.
@ -27,7 +24,6 @@ fn init_features() -> Vec<&'static str> {
let mut features = Vec::new(); let mut features = Vec::new();
FLAGS FLAGS
.lock() .lock()
.expect("locked")
.iter() .iter()
.for_each(|(_, flags)| append_features(&mut features, flags)); .for_each(|(_, flags)| append_features(&mut features, flags));

View file

@ -40,7 +40,6 @@ where
self.state self.state
.active .active
.read() .read()
.expect("shared lock")
.iter() .iter()
.filter(|capture| filter(self, capture, event, &ctx)) .filter(|capture| filter(self, capture, event, &ctx))
.for_each(|capture| handle(self, capture, event, &ctx)); .for_each(|capture| handle(self, capture, event, &ctx));
@ -55,7 +54,7 @@ where
let mut visitor = Visitor { values: Values::new() }; let mut visitor = Visitor { values: Values::new() };
event.record(&mut visitor); event.record(&mut visitor);
let mut closure = capture.closure.lock().expect("exclusive lock"); let mut closure = capture.closure.lock();
closure(Data { closure(Data {
layer, layer,
event, event,

View file

@ -4,7 +4,7 @@ pub mod layer;
pub mod state; pub mod state;
pub mod util; pub mod util;
use std::sync::{Arc, Mutex}; use std::sync::Arc;
pub use data::Data; pub use data::Data;
use guard::Guard; use guard::Guard;
@ -12,6 +12,8 @@ pub use layer::{Layer, Value};
pub use state::State; pub use state::State;
pub use util::*; pub use util::*;
use crate::SyncMutex;
pub type Filter = dyn Fn(Data<'_>) -> bool + Send + Sync + 'static; pub type Filter = dyn Fn(Data<'_>) -> bool + Send + Sync + 'static;
pub type Closure = dyn FnMut(Data<'_>) + Send + Sync + 'static; pub type Closure = dyn FnMut(Data<'_>) + Send + Sync + 'static;
@ -19,7 +21,7 @@ pub type Closure = dyn FnMut(Data<'_>) + Send + Sync + 'static;
pub struct Capture { pub struct Capture {
state: Arc<State>, state: Arc<State>,
filter: Option<Box<Filter>>, filter: Option<Box<Filter>>,
closure: Mutex<Box<Closure>>, closure: SyncMutex<Box<Closure>>,
} }
impl Capture { impl Capture {
@ -34,7 +36,7 @@ impl Capture {
Arc::new(Self { Arc::new(Self {
state: state.clone(), state: state.clone(),
filter: filter.map(|p| -> Box<Filter> { Box::new(p) }), filter: filter.map(|p| -> Box<Filter> { Box::new(p) }),
closure: Mutex::new(Box::new(closure)), closure: SyncMutex::new(Box::new(closure)),
}) })
} }

View file

@ -1,10 +1,11 @@
use std::sync::{Arc, RwLock}; use std::sync::Arc;
use super::Capture; use super::Capture;
use crate::SyncRwLock;
/// Capture layer state. /// Capture layer state.
pub struct State { pub struct State {
pub(super) active: RwLock<Vec<Arc<Capture>>>, pub(super) active: SyncRwLock<Vec<Arc<Capture>>>,
} }
impl Default for State { impl Default for State {
@ -13,17 +14,14 @@ impl Default for State {
impl State { impl State {
#[must_use] #[must_use]
pub fn new() -> Self { Self { active: RwLock::new(Vec::new()) } } pub fn new() -> Self { Self { active: SyncRwLock::new(Vec::new()) } }
pub(super) fn add(&self, capture: &Arc<Capture>) { pub(super) fn add(&self, capture: &Arc<Capture>) {
self.active self.active.write().push(capture.clone());
.write()
.expect("locked for writing")
.push(capture.clone());
} }
pub(super) fn del(&self, capture: &Arc<Capture>) { pub(super) fn del(&self, capture: &Arc<Capture>) {
let mut vec = self.active.write().expect("locked for writing"); let mut vec = self.active.write();
if let Some(pos) = vec.iter().position(|v| Arc::ptr_eq(v, capture)) { if let Some(pos) = vec.iter().position(|v| Arc::ptr_eq(v, capture)) {
vec.swap_remove(pos); vec.swap_remove(pos);
} }

View file

@ -1,31 +1,31 @@
use std::sync::{Arc, Mutex}; use std::sync::Arc;
use super::{ use super::{
super::{Level, fmt}, super::{Level, fmt},
Closure, Data, Closure, Data,
}; };
use crate::Result; use crate::{Result, SyncMutex};
pub fn fmt_html<S>(out: Arc<Mutex<S>>) -> Box<Closure> pub fn fmt_html<S>(out: Arc<SyncMutex<S>>) -> Box<Closure>
where where
S: std::fmt::Write + Send + 'static, S: std::fmt::Write + Send + 'static,
{ {
fmt(fmt::html, out) fmt(fmt::html, out)
} }
pub fn fmt_markdown<S>(out: Arc<Mutex<S>>) -> Box<Closure> pub fn fmt_markdown<S>(out: Arc<SyncMutex<S>>) -> Box<Closure>
where where
S: std::fmt::Write + Send + 'static, S: std::fmt::Write + Send + 'static,
{ {
fmt(fmt::markdown, out) fmt(fmt::markdown, out)
} }
pub fn fmt<F, S>(fun: F, out: Arc<Mutex<S>>) -> Box<Closure> pub fn fmt<F, S>(fun: F, out: Arc<SyncMutex<S>>) -> Box<Closure>
where where
F: Fn(&mut S, &Level, &str, &str) -> Result<()> + Send + Sync + Copy + 'static, F: Fn(&mut S, &Level, &str, &str) -> Result<()> + Send + Sync + Copy + 'static,
S: std::fmt::Write + Send + 'static, S: std::fmt::Write + Send + 'static,
{ {
Box::new(move |data| call(fun, &mut *out.lock().expect("locked"), &data)) Box::new(move |data| call(fun, &mut *out.lock(), &data))
} }
fn call<F, S>(fun: F, out: &mut S, data: &Data<'_>) fn call<F, S>(fun: F, out: &mut S, data: &Data<'_>)

View file

@ -1,11 +1,8 @@
use std::{ use std::{collections::HashMap, sync::Arc};
collections::HashMap,
sync::{Arc, Mutex},
};
use tracing_subscriber::{EnvFilter, reload}; use tracing_subscriber::{EnvFilter, reload};
use crate::{Result, error}; use crate::{Result, SyncMutex, error};
/// We need to store a reload::Handle value, but can't name it's type explicitly /// We need to store a reload::Handle value, but can't name it's type explicitly
/// because the S type parameter depends on the subscriber's previous layers. In /// because the S type parameter depends on the subscriber's previous layers. In
@ -35,7 +32,7 @@ impl<L: Clone, S> ReloadHandle<L> for reload::Handle<L, S> {
#[derive(Clone)] #[derive(Clone)]
pub struct LogLevelReloadHandles { pub struct LogLevelReloadHandles {
handles: Arc<Mutex<HandleMap>>, handles: Arc<SyncMutex<HandleMap>>,
} }
type HandleMap = HashMap<String, Handle>; type HandleMap = HashMap<String, Handle>;
@ -43,16 +40,12 @@ type Handle = Box<dyn ReloadHandle<EnvFilter> + Send + Sync>;
impl LogLevelReloadHandles { impl LogLevelReloadHandles {
pub fn add(&self, name: &str, handle: Handle) { pub fn add(&self, name: &str, handle: Handle) {
self.handles self.handles.lock().insert(name.into(), handle);
.lock()
.expect("locked")
.insert(name.into(), handle);
} }
pub fn reload(&self, new_value: &EnvFilter, names: Option<&[&str]>) -> Result<()> { pub fn reload(&self, new_value: &EnvFilter, names: Option<&[&str]>) -> Result<()> {
self.handles self.handles
.lock() .lock()
.expect("locked")
.iter() .iter()
.filter(|(name, _)| names.is_some_and(|names| names.contains(&name.as_str()))) .filter(|(name, _)| names.is_some_and(|names| names.contains(&name.as_str())))
.for_each(|(_, handle)| { .for_each(|(_, handle)| {
@ -66,7 +59,6 @@ impl LogLevelReloadHandles {
pub fn current(&self, name: &str) -> Option<EnvFilter> { pub fn current(&self, name: &str) -> Option<EnvFilter> {
self.handles self.handles
.lock() .lock()
.expect("locked")
.get(name) .get(name)
.map(|handle| handle.current())? .map(|handle| handle.current())?
} }

View file

@ -28,6 +28,7 @@ pub use info::{
pub use matrix::{ pub use matrix::{
Event, EventTypeExt, Pdu, PduCount, PduEvent, PduId, RoomVersion, pdu, state_res, Event, EventTypeExt, Pdu, PduCount, PduEvent, PduId, RoomVersion, pdu, state_res,
}; };
pub use parking_lot::{Mutex as SyncMutex, RwLock as SyncRwLock};
pub use server::Server; pub use server::Server;
pub use utils::{ctor, dtor, implement, result, result::Result}; pub use utils::{ctor, dtor, implement, result, result::Result};

View file

@ -1,12 +1,8 @@
use std::{ use std::{fmt::Debug, hash::Hash, sync::Arc};
fmt::Debug,
hash::Hash,
sync::{Arc, TryLockError::WouldBlock},
};
use tokio::sync::OwnedMutexGuard as Omg; use tokio::sync::OwnedMutexGuard as Omg;
use crate::{Result, err}; use crate::{Result, SyncMutex, err};
/// Map of Mutexes /// Map of Mutexes
pub struct MutexMap<Key, Val> { pub struct MutexMap<Key, Val> {
@ -19,7 +15,7 @@ pub struct Guard<Key, Val> {
} }
type Map<Key, Val> = Arc<MapMutex<Key, Val>>; type Map<Key, Val> = Arc<MapMutex<Key, Val>>;
type MapMutex<Key, Val> = std::sync::Mutex<HashMap<Key, Val>>; type MapMutex<Key, Val> = SyncMutex<HashMap<Key, Val>>;
type HashMap<Key, Val> = std::collections::HashMap<Key, Value<Val>>; type HashMap<Key, Val> = std::collections::HashMap<Key, Value<Val>>;
type Value<Val> = Arc<tokio::sync::Mutex<Val>>; type Value<Val> = Arc<tokio::sync::Mutex<Val>>;
@ -45,7 +41,6 @@ where
let val = self let val = self
.map .map
.lock() .lock()
.expect("locked")
.entry(k.try_into().expect("failed to construct key")) .entry(k.try_into().expect("failed to construct key"))
.or_default() .or_default()
.clone(); .clone();
@ -66,7 +61,6 @@ where
let val = self let val = self
.map .map
.lock() .lock()
.expect("locked")
.entry(k.try_into().expect("failed to construct key")) .entry(k.try_into().expect("failed to construct key"))
.or_default() .or_default()
.clone(); .clone();
@ -87,10 +81,7 @@ where
let val = self let val = self
.map .map
.try_lock() .try_lock()
.map_err(|e| match e { .ok_or_else(|| err!("would block"))?
| WouldBlock => err!("would block"),
| _ => panic!("{e:?}"),
})?
.entry(k.try_into().expect("failed to construct key")) .entry(k.try_into().expect("failed to construct key"))
.or_default() .or_default()
.clone(); .clone();
@ -102,13 +93,13 @@ where
} }
#[must_use] #[must_use]
pub fn contains(&self, k: &Key) -> bool { self.map.lock().expect("locked").contains_key(k) } pub fn contains(&self, k: &Key) -> bool { self.map.lock().contains_key(k) }
#[must_use] #[must_use]
pub fn is_empty(&self) -> bool { self.map.lock().expect("locked").is_empty() } pub fn is_empty(&self) -> bool { self.map.lock().is_empty() }
#[must_use] #[must_use]
pub fn len(&self) -> usize { self.map.lock().expect("locked").len() } pub fn len(&self) -> usize { self.map.lock().len() }
} }
impl<Key, Val> Default for MutexMap<Key, Val> impl<Key, Val> Default for MutexMap<Key, Val>
@ -123,7 +114,7 @@ impl<Key, Val> Drop for Guard<Key, Val> {
#[tracing::instrument(name = "unlock", level = "trace", skip_all)] #[tracing::instrument(name = "unlock", level = "trace", skip_all)]
fn drop(&mut self) { fn drop(&mut self) {
if Arc::strong_count(Omg::mutex(&self.val)) <= 2 { if Arc::strong_count(Omg::mutex(&self.val)) <= 2 {
self.map.lock().expect("locked").retain(|_, val| { self.map.lock().retain(|_, val| {
!Arc::ptr_eq(val, Omg::mutex(&self.val)) || Arc::strong_count(val) > 2 !Arc::ptr_eq(val, Omg::mutex(&self.val)) || Arc::strong_count(val) > 2
}); });
} }

View file

@ -1,65 +1,212 @@
//! Traits for explicitly scoping the lifetime of locks. //! Traits for explicitly scoping the lifetime of locks.
use std::sync::{Arc, Mutex}; use std::{
future::Future,
sync::{Arc, Mutex},
};
pub trait WithLock<T> { pub trait WithLock<T: ?Sized> {
/// Acquires a lock and executes the given closure with the locked data. /// Acquires a lock and executes the given closure with the locked data,
fn with_lock<F>(&self, f: F) /// returning the result.
fn with_lock<R, F>(&self, f: F) -> R
where where
F: FnMut(&mut T); F: FnMut(&mut T) -> R;
} }
impl<T> WithLock<T> for Mutex<T> { impl<T> WithLock<T> for Mutex<T> {
fn with_lock<F>(&self, mut f: F) fn with_lock<R, F>(&self, mut f: F) -> R
where where
F: FnMut(&mut T), F: FnMut(&mut T) -> R,
{ {
// The locking and unlocking logic is hidden inside this function. // The locking and unlocking logic is hidden inside this function.
let mut data_guard = self.lock().unwrap(); let mut data_guard = self.lock().unwrap();
f(&mut data_guard); f(&mut data_guard)
// Lock is released here when `data_guard` goes out of scope. // Lock is released here when `data_guard` goes out of scope.
} }
} }
impl<T> WithLock<T> for Arc<Mutex<T>> { impl<T> WithLock<T> for Arc<Mutex<T>> {
fn with_lock<F>(&self, mut f: F) fn with_lock<R, F>(&self, mut f: F) -> R
where where
F: FnMut(&mut T), F: FnMut(&mut T) -> R,
{ {
// The locking and unlocking logic is hidden inside this function. // The locking and unlocking logic is hidden inside this function.
let mut data_guard = self.lock().unwrap(); let mut data_guard = self.lock().unwrap();
f(&mut data_guard); f(&mut data_guard)
// Lock is released here when `data_guard` goes out of scope.
}
}
impl<R: lock_api::RawMutex, T: ?Sized> WithLock<T> for lock_api::Mutex<R, T> {
fn with_lock<Ret, F>(&self, mut f: F) -> Ret
where
F: FnMut(&mut T) -> Ret,
{
// The locking and unlocking logic is hidden inside this function.
let mut data_guard = self.lock();
f(&mut data_guard)
// Lock is released here when `data_guard` goes out of scope.
}
}
impl<R: lock_api::RawMutex, T: ?Sized> WithLock<T> for Arc<lock_api::Mutex<R, T>> {
fn with_lock<Ret, F>(&self, mut f: F) -> Ret
where
F: FnMut(&mut T) -> Ret,
{
// The locking and unlocking logic is hidden inside this function.
let mut data_guard = self.lock();
f(&mut data_guard)
// Lock is released here when `data_guard` goes out of scope. // Lock is released here when `data_guard` goes out of scope.
} }
} }
pub trait WithLockAsync<T> { pub trait WithLockAsync<T> {
/// Acquires a lock and executes the given closure with the locked data. /// Acquires a lock and executes the given closure with the locked data,
fn with_lock<F>(&self, f: F) -> impl Future<Output = ()> /// returning the result.
fn with_lock<R, F>(&self, f: F) -> impl Future<Output = R>
where where
F: FnMut(&mut T); F: FnMut(&mut T) -> R;
/// Acquires a lock and executes the given async closure with the locked
/// data.
fn with_lock_async<R, F>(&self, f: F) -> impl std::future::Future<Output = R>
where
F: AsyncFnMut(&mut T) -> R;
} }
impl<T> WithLockAsync<T> for futures::lock::Mutex<T> { impl<T> WithLockAsync<T> for futures::lock::Mutex<T> {
async fn with_lock<F>(&self, mut f: F) async fn with_lock<R, F>(&self, mut f: F) -> R
where where
F: FnMut(&mut T), F: FnMut(&mut T) -> R,
{ {
// The locking and unlocking logic is hidden inside this function. // The locking and unlocking logic is hidden inside this function.
let mut data_guard = self.lock().await; let mut data_guard = self.lock().await;
f(&mut data_guard); f(&mut data_guard)
// Lock is released here when `data_guard` goes out of scope.
}
async fn with_lock_async<R, F>(&self, mut f: F) -> R
where
F: AsyncFnMut(&mut T) -> R,
{
// The locking and unlocking logic is hidden inside this function.
let mut data_guard = self.lock().await;
f(&mut data_guard).await
// Lock is released here when `data_guard` goes out of scope. // Lock is released here when `data_guard` goes out of scope.
} }
} }
impl<T> WithLockAsync<T> for Arc<futures::lock::Mutex<T>> { impl<T> WithLockAsync<T> for Arc<futures::lock::Mutex<T>> {
async fn with_lock<F>(&self, mut f: F) async fn with_lock<R, F>(&self, mut f: F) -> R
where where
F: FnMut(&mut T), F: FnMut(&mut T) -> R,
{ {
// The locking and unlocking logic is hidden inside this function. // The locking and unlocking logic is hidden inside this function.
let mut data_guard = self.lock().await; let mut data_guard = self.lock().await;
f(&mut data_guard); f(&mut data_guard)
// Lock is released here when `data_guard` goes out of scope.
}
async fn with_lock_async<R, F>(&self, mut f: F) -> R
where
F: AsyncFnMut(&mut T) -> R,
{
// The locking and unlocking logic is hidden inside this function.
let mut data_guard = self.lock().await;
f(&mut data_guard).await
// Lock is released here when `data_guard` goes out of scope. // Lock is released here when `data_guard` goes out of scope.
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_with_lock_return_value() {
let mutex = Mutex::new(5);
let result = mutex.with_lock(|v| {
*v += 1;
*v * 2
});
assert_eq!(result, 12);
let value = mutex.lock().unwrap();
assert_eq!(*value, 6);
}
#[test]
fn test_with_lock_unit_return() {
let mutex = Mutex::new(10);
mutex.with_lock(|v| {
*v += 2;
});
let value = mutex.lock().unwrap();
assert_eq!(*value, 12);
}
#[test]
fn test_with_lock_arc_mutex() {
let mutex = Arc::new(Mutex::new(1));
let result = mutex.with_lock(|v| {
*v *= 10;
*v
});
assert_eq!(result, 10);
assert_eq!(*mutex.lock().unwrap(), 10);
}
#[tokio::test]
async fn test_with_lock_async_return_value() {
use futures::lock::Mutex as AsyncMutex;
let mutex = AsyncMutex::new(7);
let result = mutex
.with_lock(|v| {
*v += 3;
*v * 2
})
.await;
assert_eq!(result, 20);
let value = mutex.lock().await;
assert_eq!(*value, 10);
}
#[tokio::test]
async fn test_with_lock_async_unit_return() {
use futures::lock::Mutex as AsyncMutex;
let mutex = AsyncMutex::new(100);
mutex
.with_lock(|v| {
*v -= 50;
})
.await;
let value = mutex.lock().await;
assert_eq!(*value, 50);
}
#[tokio::test]
async fn test_with_lock_async_closure() {
use futures::lock::Mutex as AsyncMutex;
let mutex = AsyncMutex::new(1);
mutex
.with_lock_async(async |v| {
*v += 9;
})
.await;
let value = mutex.lock().await;
assert_eq!(*value, 10);
}
#[tokio::test]
async fn test_with_lock_async_arc_mutex() {
use futures::lock::Mutex as AsyncMutex;
let mutex = Arc::new(AsyncMutex::new(2));
mutex
.with_lock_async(async |v: &mut i32| {
*v *= 5;
})
.await;
let value = mutex.lock().await;
assert_eq!(*value, 10);
}
}

View file

@ -71,7 +71,7 @@ pub fn backup_count(&self) -> Result<usize> {
fn backup_engine(&self) -> Result<BackupEngine> { fn backup_engine(&self) -> Result<BackupEngine> {
let path = self.backup_path()?; let path = self.backup_path()?;
let options = BackupEngineOptions::new(path).map_err(map_err)?; let options = BackupEngineOptions::new(path).map_err(map_err)?;
BackupEngine::open(&options, &*self.ctx.env.lock()?).map_err(map_err) BackupEngine::open(&options, &self.ctx.env.lock()).map_err(map_err)
} }
#[implement(Engine)] #[implement(Engine)]

View file

@ -232,7 +232,7 @@ fn get_cache(ctx: &Context, desc: &Descriptor) -> Option<Cache> {
cache_opts.set_num_shard_bits(shard_bits); cache_opts.set_num_shard_bits(shard_bits);
cache_opts.set_capacity(size); cache_opts.set_capacity(size);
let mut caches = ctx.col_cache.lock().expect("locked"); let mut caches = ctx.col_cache.lock();
match desc.cache_disp { match desc.cache_disp {
| CacheDisp::Unique if desc.cache_size == 0 => None, | CacheDisp::Unique if desc.cache_size == 0 => None,
| CacheDisp::Unique => { | CacheDisp::Unique => {

View file

@ -1,9 +1,6 @@
use std::{ use std::{collections::BTreeMap, sync::Arc};
collections::BTreeMap,
sync::{Arc, Mutex},
};
use conduwuit::{Result, Server, debug, utils::math::usize_from_f64}; use conduwuit::{Result, Server, SyncMutex, debug, utils::math::usize_from_f64};
use rocksdb::{Cache, Env, LruCacheOptions}; use rocksdb::{Cache, Env, LruCacheOptions};
use crate::{or_else, pool::Pool}; use crate::{or_else, pool::Pool};
@ -14,9 +11,9 @@ use crate::{or_else, pool::Pool};
/// These assets are housed in the shared Context. /// These assets are housed in the shared Context.
pub(crate) struct Context { pub(crate) struct Context {
pub(crate) pool: Arc<Pool>, pub(crate) pool: Arc<Pool>,
pub(crate) col_cache: Mutex<BTreeMap<String, Cache>>, pub(crate) col_cache: SyncMutex<BTreeMap<String, Cache>>,
pub(crate) row_cache: Mutex<Cache>, pub(crate) row_cache: SyncMutex<Cache>,
pub(crate) env: Mutex<Env>, pub(crate) env: SyncMutex<Env>,
pub(crate) server: Arc<Server>, pub(crate) server: Arc<Server>,
} }
@ -68,7 +65,7 @@ impl Drop for Context {
debug!("Closing frontend pool"); debug!("Closing frontend pool");
self.pool.close(); self.pool.close();
let mut env = self.env.lock().expect("locked"); let mut env = self.env.lock();
debug!("Shutting down background threads"); debug!("Shutting down background threads");
env.set_high_priority_background_threads(0); env.set_high_priority_background_threads(0);

View file

@ -9,7 +9,7 @@ use crate::or_else;
#[implement(Engine)] #[implement(Engine)]
pub fn memory_usage(&self) -> Result<String> { pub fn memory_usage(&self) -> Result<String> {
let mut res = String::new(); let mut res = String::new();
let stats = get_memory_usage_stats(Some(&[&self.db]), Some(&[&*self.ctx.row_cache.lock()?])) let stats = get_memory_usage_stats(Some(&[&self.db]), Some(&[&*self.ctx.row_cache.lock()]))
.or_else(or_else)?; .or_else(or_else)?;
let mibs = |input| f64::from(u32::try_from(input / 1024).unwrap_or(0)) / 1024.0; let mibs = |input| f64::from(u32::try_from(input / 1024).unwrap_or(0)) / 1024.0;
writeln!( writeln!(
@ -19,10 +19,10 @@ pub fn memory_usage(&self) -> Result<String> {
mibs(stats.mem_table_total), mibs(stats.mem_table_total),
mibs(stats.mem_table_unflushed), mibs(stats.mem_table_unflushed),
mibs(stats.mem_table_readers_total), mibs(stats.mem_table_readers_total),
mibs(u64::try_from(self.ctx.row_cache.lock()?.get_usage())?), mibs(u64::try_from(self.ctx.row_cache.lock().get_usage())?),
)?; )?;
for (name, cache) in &*self.ctx.col_cache.lock()? { for (name, cache) in &*self.ctx.col_cache.lock() {
writeln!(res, "{name} cache: {:.2} MiB", mibs(u64::try_from(cache.get_usage())?))?; writeln!(res, "{name} cache: {:.2} MiB", mibs(u64::try_from(cache.get_usage())?))?;
} }

View file

@ -23,11 +23,7 @@ pub(crate) async fn open(ctx: Arc<Context>, desc: &[Descriptor]) -> Result<Arc<S
let config = &server.config; let config = &server.config;
let path = &config.database_path; let path = &config.database_path;
let db_opts = db_options( let db_opts = db_options(config, &ctx.env.lock(), &ctx.row_cache.lock())?;
config,
&ctx.env.lock().expect("environment locked"),
&ctx.row_cache.lock().expect("row cache locked"),
)?;
let cfds = Self::configure_cfds(&ctx, &db_opts, desc)?; let cfds = Self::configure_cfds(&ctx, &db_opts, desc)?;
let num_cfds = cfds.len(); let num_cfds = cfds.len();

View file

@ -3,7 +3,7 @@ mod configure;
use std::{ use std::{
mem::take, mem::take,
sync::{ sync::{
Arc, Mutex, Arc,
atomic::{AtomicUsize, Ordering}, atomic::{AtomicUsize, Ordering},
}, },
thread, thread,
@ -12,7 +12,7 @@ use std::{
use async_channel::{QueueStrategy, Receiver, RecvError, Sender}; use async_channel::{QueueStrategy, Receiver, RecvError, Sender};
use conduwuit::{ use conduwuit::{
Error, Result, Server, debug, err, error, implement, Error, Result, Server, SyncMutex, debug, err, error, implement,
result::DebugInspect, result::DebugInspect,
smallvec::SmallVec, smallvec::SmallVec,
trace, trace,
@ -31,7 +31,7 @@ use crate::{Handle, Map, keyval::KeyBuf, stream};
pub(crate) struct Pool { pub(crate) struct Pool {
server: Arc<Server>, server: Arc<Server>,
queues: Vec<Sender<Cmd>>, queues: Vec<Sender<Cmd>>,
workers: Mutex<Vec<JoinHandle<()>>>, workers: SyncMutex<Vec<JoinHandle<()>>>,
topology: Vec<usize>, topology: Vec<usize>,
busy: AtomicUsize, busy: AtomicUsize,
queued_max: AtomicUsize, queued_max: AtomicUsize,
@ -115,7 +115,7 @@ impl Drop for Pool {
#[implement(Pool)] #[implement(Pool)]
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all)]
pub(crate) fn close(&self) { pub(crate) fn close(&self) {
let workers = take(&mut *self.workers.lock().expect("locked")); let workers = take(&mut *self.workers.lock());
let senders = self.queues.iter().map(Sender::sender_count).sum::<usize>(); let senders = self.queues.iter().map(Sender::sender_count).sum::<usize>();
@ -154,7 +154,7 @@ pub(crate) fn close(&self) {
#[implement(Pool)] #[implement(Pool)]
fn spawn_until(self: &Arc<Self>, recv: &[Receiver<Cmd>], count: usize) -> Result { fn spawn_until(self: &Arc<Self>, recv: &[Receiver<Cmd>], count: usize) -> Result {
let mut workers = self.workers.lock().expect("locked"); let mut workers = self.workers.lock();
while workers.len() < count { while workers.len() < count {
self.clone().spawn_one(&mut workers, recv)?; self.clone().spawn_one(&mut workers, recv)?;
} }

View file

@ -2,12 +2,12 @@ use std::{
collections::{HashMap, hash_map}, collections::{HashMap, hash_map},
future::Future, future::Future,
pin::Pin, pin::Pin,
sync::RwLock,
}; };
use conduwuit::SyncRwLock;
use tokio::sync::watch; use tokio::sync::watch;
type Watcher = RwLock<HashMap<Vec<u8>, (watch::Sender<()>, watch::Receiver<()>)>>; type Watcher = SyncRwLock<HashMap<Vec<u8>, (watch::Sender<()>, watch::Receiver<()>)>>;
#[derive(Default)] #[derive(Default)]
pub(crate) struct Watchers { pub(crate) struct Watchers {
@ -19,7 +19,7 @@ impl Watchers {
&'a self, &'a self,
prefix: &[u8], prefix: &[u8],
) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> { ) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
let mut rx = match self.watchers.write().unwrap().entry(prefix.to_vec()) { let mut rx = match self.watchers.write().entry(prefix.to_vec()) {
| hash_map::Entry::Occupied(o) => o.get().1.clone(), | hash_map::Entry::Occupied(o) => o.get().1.clone(),
| hash_map::Entry::Vacant(v) => { | hash_map::Entry::Vacant(v) => {
let (tx, rx) = watch::channel(()); let (tx, rx) = watch::channel(());
@ -35,7 +35,7 @@ impl Watchers {
} }
pub(crate) fn wake(&self, key: &[u8]) { pub(crate) fn wake(&self, key: &[u8]) {
let watchers = self.watchers.read().unwrap(); let watchers = self.watchers.read();
let mut triggered = Vec::new(); let mut triggered = Vec::new();
for length in 0..=key.len() { for length in 0..=key.len() {
if watchers.contains_key(&key[..length]) { if watchers.contains_key(&key[..length]) {
@ -46,7 +46,7 @@ impl Watchers {
drop(watchers); drop(watchers);
if !triggered.is_empty() { if !triggered.is_empty() {
let mut watchers = self.watchers.write().unwrap(); let mut watchers = self.watchers.write();
for prefix in triggered { for prefix in triggered {
if let Some(tx) = watchers.remove(prefix) { if let Some(tx) = watchers.remove(prefix) {
tx.0.send(()).expect("channel should still be open"); tx.0.send(()).expect("channel should still be open");

View file

@ -15,13 +15,13 @@ pub(super) fn flags_capture(args: TokenStream) -> TokenStream {
#[conduwuit_core::ctor] #[conduwuit_core::ctor]
fn _set_rustc_flags() { fn _set_rustc_flags() {
conduwuit_core::info::rustc::FLAGS.lock().expect("locked").insert(#crate_name, &RUSTC_FLAGS); conduwuit_core::info::rustc::FLAGS.lock().insert(#crate_name, &RUSTC_FLAGS);
} }
// static strings have to be yanked on module unload // static strings have to be yanked on module unload
#[conduwuit_core::dtor] #[conduwuit_core::dtor]
fn _unset_rustc_flags() { fn _unset_rustc_flags() {
conduwuit_core::info::rustc::FLAGS.lock().expect("locked").remove(#crate_name); conduwuit_core::info::rustc::FLAGS.lock().remove(#crate_name);
} }
}; };

View file

@ -1,11 +1,8 @@
#![cfg(feature = "console")] #![cfg(feature = "console")]
use std::{ use std::{collections::VecDeque, sync::Arc};
collections::VecDeque,
sync::{Arc, Mutex},
};
use conduwuit::{Server, debug, defer, error, log, log::is_systemd_mode}; use conduwuit::{Server, SyncMutex, debug, defer, error, log, log::is_systemd_mode};
use futures::future::{AbortHandle, Abortable}; use futures::future::{AbortHandle, Abortable};
use ruma::events::room::message::RoomMessageEventContent; use ruma::events::room::message::RoomMessageEventContent;
use rustyline_async::{Readline, ReadlineError, ReadlineEvent}; use rustyline_async::{Readline, ReadlineError, ReadlineEvent};
@ -17,10 +14,10 @@ use crate::{Dep, admin};
pub struct Console { pub struct Console {
server: Arc<Server>, server: Arc<Server>,
admin: Dep<admin::Service>, admin: Dep<admin::Service>,
worker_join: Mutex<Option<JoinHandle<()>>>, worker_join: SyncMutex<Option<JoinHandle<()>>>,
input_abort: Mutex<Option<AbortHandle>>, input_abort: SyncMutex<Option<AbortHandle>>,
command_abort: Mutex<Option<AbortHandle>>, command_abort: SyncMutex<Option<AbortHandle>>,
history: Mutex<VecDeque<String>>, history: SyncMutex<VecDeque<String>>,
output: MadSkin, output: MadSkin,
} }
@ -50,7 +47,7 @@ impl Console {
} }
pub async fn start(self: &Arc<Self>) { pub async fn start(self: &Arc<Self>) {
let mut worker_join = self.worker_join.lock().expect("locked"); let mut worker_join = self.worker_join.lock();
if worker_join.is_none() { if worker_join.is_none() {
let self_ = Arc::clone(self); let self_ = Arc::clone(self);
_ = worker_join.insert(self.server.runtime().spawn(self_.worker())); _ = worker_join.insert(self.server.runtime().spawn(self_.worker()));
@ -60,7 +57,7 @@ impl Console {
pub async fn close(self: &Arc<Self>) { pub async fn close(self: &Arc<Self>) {
self.interrupt(); self.interrupt();
let Some(worker_join) = self.worker_join.lock().expect("locked").take() else { let Some(worker_join) = self.worker_join.lock().take() else {
return; return;
}; };
@ -70,22 +67,18 @@ impl Console {
pub fn interrupt(self: &Arc<Self>) { pub fn interrupt(self: &Arc<Self>) {
self.interrupt_command(); self.interrupt_command();
self.interrupt_readline(); self.interrupt_readline();
self.worker_join self.worker_join.lock().as_ref().map(JoinHandle::abort);
.lock()
.expect("locked")
.as_ref()
.map(JoinHandle::abort);
} }
pub fn interrupt_readline(self: &Arc<Self>) { pub fn interrupt_readline(self: &Arc<Self>) {
if let Some(input_abort) = self.input_abort.lock().expect("locked").take() { if let Some(input_abort) = self.input_abort.lock().take() {
debug!("Interrupting console readline..."); debug!("Interrupting console readline...");
input_abort.abort(); input_abort.abort();
} }
} }
pub fn interrupt_command(self: &Arc<Self>) { pub fn interrupt_command(self: &Arc<Self>) {
if let Some(command_abort) = self.command_abort.lock().expect("locked").take() { if let Some(command_abort) = self.command_abort.lock().take() {
debug!("Interrupting console command..."); debug!("Interrupting console command...");
command_abort.abort(); command_abort.abort();
} }
@ -120,7 +113,7 @@ impl Console {
} }
debug!("session ending"); debug!("session ending");
self.worker_join.lock().expect("locked").take(); self.worker_join.lock().take();
} }
async fn readline(self: &Arc<Self>) -> Result<ReadlineEvent, ReadlineError> { async fn readline(self: &Arc<Self>) -> Result<ReadlineEvent, ReadlineError> {
@ -135,9 +128,9 @@ impl Console {
let (abort, abort_reg) = AbortHandle::new_pair(); let (abort, abort_reg) = AbortHandle::new_pair();
let future = Abortable::new(future, abort_reg); let future = Abortable::new(future, abort_reg);
_ = self.input_abort.lock().expect("locked").insert(abort); _ = self.input_abort.lock().insert(abort);
defer! {{ defer! {{
_ = self.input_abort.lock().expect("locked").take(); _ = self.input_abort.lock().take();
}} }}
let Ok(result) = future.await else { let Ok(result) = future.await else {
@ -158,9 +151,9 @@ impl Console {
let (abort, abort_reg) = AbortHandle::new_pair(); let (abort, abort_reg) = AbortHandle::new_pair();
let future = Abortable::new(future, abort_reg); let future = Abortable::new(future, abort_reg);
_ = self.command_abort.lock().expect("locked").insert(abort); _ = self.command_abort.lock().insert(abort);
defer! {{ defer! {{
_ = self.command_abort.lock().expect("locked").take(); _ = self.command_abort.lock().take();
}} }}
_ = future.await; _ = future.await;
@ -184,20 +177,15 @@ impl Console {
} }
fn set_history(&self, readline: &mut Readline) { fn set_history(&self, readline: &mut Readline) {
self.history self.history.lock().iter().rev().for_each(|entry| {
.lock() readline
.expect("locked") .add_history_entry(entry.clone())
.iter() .expect("added history entry");
.rev() });
.for_each(|entry| {
readline
.add_history_entry(entry.clone())
.expect("added history entry");
});
} }
fn add_history(&self, line: String) { fn add_history(&self, line: String) {
let mut history = self.history.lock().expect("locked"); let mut history = self.history.lock();
history.push_front(line); history.push_front(line);
history.truncate(HISTORY_LIMIT); history.truncate(HISTORY_LIMIT);
} }

View file

@ -5,11 +5,11 @@ mod grant;
use std::{ use std::{
pin::Pin, pin::Pin,
sync::{Arc, RwLock as StdRwLock, Weak}, sync::{Arc, Weak},
}; };
use async_trait::async_trait; use async_trait::async_trait;
use conduwuit::{Err, utils}; use conduwuit::{Err, SyncRwLock, utils};
use conduwuit_core::{ use conduwuit_core::{
Error, Event, Result, Server, debug, err, error, error::default_log, pdu::PduBuilder, Error, Event, Result, Server, debug, err, error, error::default_log, pdu::PduBuilder,
}; };
@ -36,7 +36,7 @@ pub struct Service {
services: Services, services: Services,
channel: (Sender<CommandInput>, Receiver<CommandInput>), channel: (Sender<CommandInput>, Receiver<CommandInput>),
pub handle: RwLock<Option<Processor>>, pub handle: RwLock<Option<Processor>>,
pub complete: StdRwLock<Option<Completer>>, pub complete: SyncRwLock<Option<Completer>>,
#[cfg(feature = "console")] #[cfg(feature = "console")]
pub console: Arc<console::Console>, pub console: Arc<console::Console>,
} }
@ -50,7 +50,7 @@ struct Services {
state_cache: Dep<rooms::state_cache::Service>, state_cache: Dep<rooms::state_cache::Service>,
state_accessor: Dep<rooms::state_accessor::Service>, state_accessor: Dep<rooms::state_accessor::Service>,
account_data: Dep<account_data::Service>, account_data: Dep<account_data::Service>,
services: StdRwLock<Option<Weak<crate::Services>>>, services: SyncRwLock<Option<Weak<crate::Services>>>,
media: Dep<crate::media::Service>, media: Dep<crate::media::Service>,
} }
@ -105,7 +105,7 @@ impl crate::Service for Service {
}, },
channel: loole::bounded(COMMAND_QUEUE_LIMIT), channel: loole::bounded(COMMAND_QUEUE_LIMIT),
handle: RwLock::new(None), handle: RwLock::new(None),
complete: StdRwLock::new(None), complete: SyncRwLock::new(None),
#[cfg(feature = "console")] #[cfg(feature = "console")]
console: console::Console::new(&args), console: console::Console::new(&args),
})) }))
@ -312,10 +312,7 @@ impl Service {
/// Invokes the tab-completer to complete the command. When unavailable, /// Invokes the tab-completer to complete the command. When unavailable,
/// None is returned. /// None is returned.
pub fn complete_command(&self, command: &str) -> Option<String> { pub fn complete_command(&self, command: &str) -> Option<String> {
self.complete self.complete.read().map(|complete| complete(command))
.read()
.expect("locked for reading")
.map(|complete| complete(command))
} }
async fn handle_signal(&self, sig: &'static str) { async fn handle_signal(&self, sig: &'static str) {
@ -338,17 +335,13 @@ impl Service {
} }
async fn process_command(&self, command: CommandInput) -> ProcessorResult { async fn process_command(&self, command: CommandInput) -> ProcessorResult {
let handle = &self let handle_guard = self.handle.read().await;
.handle let handle = handle_guard.as_ref().expect("Admin module is not loaded");
.read()
.await
.expect("Admin module is not loaded");
let services = self let services = self
.services .services
.services .services
.read() .read()
.expect("locked")
.as_ref() .as_ref()
.and_then(Weak::upgrade) .and_then(Weak::upgrade)
.expect("Services self-reference not initialized."); .expect("Services self-reference not initialized.");
@ -523,7 +516,7 @@ impl Service {
/// Sets the self-reference to crate::Services which will provide context to /// Sets the self-reference to crate::Services which will provide context to
/// the admin commands. /// the admin commands.
pub(super) fn set_services(&self, services: Option<&Arc<crate::Services>>) { pub(super) fn set_services(&self, services: Option<&Arc<crate::Services>>) {
let receiver = &mut *self.services.services.write().expect("locked for writing"); let receiver = &mut *self.services.services.write();
let weak = services.map(Arc::downgrade); let weak = services.map(Arc::downgrade);
*receiver = weak; *receiver = weak;
} }

View file

@ -66,6 +66,7 @@ impl crate::Service for Service {
federation: base(config)? federation: base(config)?
.dns_resolver(resolver.resolver.hooked.clone()) .dns_resolver(resolver.resolver.hooked.clone())
.connect_timeout(Duration::from_secs(config.federation_conn_timeout))
.read_timeout(Duration::from_secs(config.federation_timeout)) .read_timeout(Duration::from_secs(config.federation_timeout))
.pool_max_idle_per_host(config.federation_idle_per_host.into()) .pool_max_idle_per_host(config.federation_idle_per_host.into())
.pool_idle_timeout(Duration::from_secs(config.federation_idle_timeout)) .pool_idle_timeout(Duration::from_secs(config.federation_idle_timeout))
@ -74,6 +75,7 @@ impl crate::Service for Service {
synapse: base(config)? synapse: base(config)?
.dns_resolver(resolver.resolver.hooked.clone()) .dns_resolver(resolver.resolver.hooked.clone())
.connect_timeout(Duration::from_secs(config.federation_conn_timeout))
.read_timeout(Duration::from_secs(305)) .read_timeout(Duration::from_secs(305))
.pool_max_idle_per_host(0) .pool_max_idle_per_host(0)
.redirect(redirect::Policy::limited(3)) .redirect(redirect::Policy::limited(3))
@ -81,6 +83,7 @@ impl crate::Service for Service {
sender: base(config)? sender: base(config)?
.dns_resolver(resolver.resolver.hooked.clone()) .dns_resolver(resolver.resolver.hooked.clone())
.connect_timeout(Duration::from_secs(config.federation_conn_timeout))
.read_timeout(Duration::from_secs(config.sender_timeout)) .read_timeout(Duration::from_secs(config.sender_timeout))
.timeout(Duration::from_secs(config.sender_timeout)) .timeout(Duration::from_secs(config.sender_timeout))
.pool_max_idle_per_host(1) .pool_max_idle_per_host(1)

View file

@ -1,11 +1,11 @@
use std::sync::{Arc, RwLock}; use std::sync::Arc;
use conduwuit::{Result, utils}; use conduwuit::{Result, SyncRwLock, utils};
use database::{Database, Deserialized, Map}; use database::{Database, Deserialized, Map};
pub struct Data { pub struct Data {
global: Arc<Map>, global: Arc<Map>,
counter: RwLock<u64>, counter: SyncRwLock<u64>,
pub(super) db: Arc<Database>, pub(super) db: Arc<Database>,
} }
@ -16,25 +16,21 @@ impl Data {
let db = &args.db; let db = &args.db;
Self { Self {
global: db["global"].clone(), global: db["global"].clone(),
counter: RwLock::new( counter: SyncRwLock::new(Self::stored_count(&db["global"]).unwrap_or_default()),
Self::stored_count(&db["global"]).expect("initialized global counter"),
),
db: args.db.clone(), db: args.db.clone(),
} }
} }
pub fn next_count(&self) -> Result<u64> { pub fn next_count(&self) -> Result<u64> {
let _cork = self.db.cork(); let _cork = self.db.cork();
let mut lock = self.counter.write().expect("locked"); let mut lock = self.counter.write();
let counter: &mut u64 = &mut lock; let counter: &mut u64 = &mut lock;
debug_assert!( debug_assert!(
*counter == Self::stored_count(&self.global).expect("database failure"), *counter == Self::stored_count(&self.global).unwrap_or_default(),
"counter mismatch" "counter mismatch"
); );
*counter = counter *counter = counter.checked_add(1).unwrap_or(*counter);
.checked_add(1)
.expect("counter must not overflow u64");
self.global.insert(COUNTER, counter.to_be_bytes()); self.global.insert(COUNTER, counter.to_be_bytes());
@ -43,10 +39,10 @@ impl Data {
#[inline] #[inline]
pub fn current_count(&self) -> u64 { pub fn current_count(&self) -> u64 {
let lock = self.counter.read().expect("locked"); let lock = self.counter.read();
let counter: &u64 = &lock; let counter: &u64 = &lock;
debug_assert!( debug_assert!(
*counter == Self::stored_count(&self.global).expect("database failure"), *counter == Self::stored_count(&self.global).unwrap_or_default(),
"counter mismatch" "counter mismatch"
); );

View file

@ -1,14 +1,9 @@
mod data; mod data;
use std::{ use std::{collections::HashMap, fmt::Write, sync::Arc, time::Instant};
collections::HashMap,
fmt::Write,
sync::{Arc, RwLock},
time::Instant,
};
use async_trait::async_trait; use async_trait::async_trait;
use conduwuit::{Result, Server, error, utils::bytes::pretty}; use conduwuit::{Result, Server, SyncRwLock, error, utils::bytes::pretty};
use data::Data; use data::Data;
use regex::RegexSet; use regex::RegexSet;
use ruma::{OwnedEventId, OwnedRoomAliasId, OwnedServerName, OwnedUserId, ServerName, UserId}; use ruma::{OwnedEventId, OwnedRoomAliasId, OwnedServerName, OwnedUserId, ServerName, UserId};
@ -19,7 +14,7 @@ pub struct Service {
pub db: Data, pub db: Data,
server: Arc<Server>, server: Arc<Server>,
pub bad_event_ratelimiter: Arc<RwLock<HashMap<OwnedEventId, RateLimitState>>>, pub bad_event_ratelimiter: Arc<SyncRwLock<HashMap<OwnedEventId, RateLimitState>>>,
pub server_user: OwnedUserId, pub server_user: OwnedUserId,
pub admin_alias: OwnedRoomAliasId, pub admin_alias: OwnedRoomAliasId,
pub turn_secret: String, pub turn_secret: String,
@ -62,7 +57,7 @@ impl crate::Service for Service {
Ok(Arc::new(Self { Ok(Arc::new(Self {
db, db,
server: args.server.clone(), server: args.server.clone(),
bad_event_ratelimiter: Arc::new(RwLock::new(HashMap::new())), bad_event_ratelimiter: Arc::new(SyncRwLock::new(HashMap::new())),
admin_alias: OwnedRoomAliasId::try_from(format!("#admins:{}", &args.server.name)) admin_alias: OwnedRoomAliasId::try_from(format!("#admins:{}", &args.server.name))
.expect("#admins:server_name is valid alias name"), .expect("#admins:server_name is valid alias name"),
server_user: UserId::parse_with_server_name( server_user: UserId::parse_with_server_name(
@ -76,7 +71,7 @@ impl crate::Service for Service {
} }
async fn memory_usage(&self, out: &mut (dyn Write + Send)) -> Result { async fn memory_usage(&self, out: &mut (dyn Write + Send)) -> Result {
let (ber_count, ber_bytes) = self.bad_event_ratelimiter.read()?.iter().fold( let (ber_count, ber_bytes) = self.bad_event_ratelimiter.read().iter().fold(
(0_usize, 0_usize), (0_usize, 0_usize),
|(mut count, mut bytes), (event_id, _)| { |(mut count, mut bytes), (event_id, _)| {
bytes = bytes.saturating_add(event_id.capacity()); bytes = bytes.saturating_add(event_id.capacity());
@ -91,12 +86,7 @@ impl crate::Service for Service {
Ok(()) Ok(())
} }
async fn clear_cache(&self) { async fn clear_cache(&self) { self.bad_event_ratelimiter.write().clear(); }
self.bad_event_ratelimiter
.write()
.expect("locked for writing")
.clear();
}
fn name(&self) -> &str { service::make_name(std::module_path!()) } fn name(&self) -> &str { service::make_name(std::module_path!()) }
} }

View file

@ -58,7 +58,6 @@ impl Manager {
let services: Vec<Arc<dyn Service>> = self let services: Vec<Arc<dyn Service>> = self
.service .service
.read() .read()
.expect("locked for reading")
.values() .values()
.map(|val| val.0.upgrade()) .map(|val| val.0.upgrade())
.map(|arc| arc.expect("services available for manager startup")) .map(|arc| arc.expect("services available for manager startup"))

View file

@ -1,9 +1,6 @@
use std::{ use std::{mem::size_of, sync::Arc};
mem::size_of,
sync::{Arc, Mutex},
};
use conduwuit::{Err, Result, err, utils, utils::math::usize_from_f64}; use conduwuit::{Err, Result, SyncMutex, err, utils, utils::math::usize_from_f64};
use database::Map; use database::Map;
use lru_cache::LruCache; use lru_cache::LruCache;
@ -11,7 +8,7 @@ use crate::rooms::short::ShortEventId;
pub(super) struct Data { pub(super) struct Data {
shorteventid_authchain: Arc<Map>, shorteventid_authchain: Arc<Map>,
pub(super) auth_chain_cache: Mutex<LruCache<Vec<u64>, Arc<[ShortEventId]>>>, pub(super) auth_chain_cache: SyncMutex<LruCache<Vec<u64>, Arc<[ShortEventId]>>>,
} }
impl Data { impl Data {
@ -23,7 +20,7 @@ impl Data {
.expect("valid cache size"); .expect("valid cache size");
Self { Self {
shorteventid_authchain: db["shorteventid_authchain"].clone(), shorteventid_authchain: db["shorteventid_authchain"].clone(),
auth_chain_cache: Mutex::new(LruCache::new(cache_size)), auth_chain_cache: SyncMutex::new(LruCache::new(cache_size)),
} }
} }
@ -34,12 +31,7 @@ impl Data {
debug_assert!(!key.is_empty(), "auth_chain key must not be empty"); debug_assert!(!key.is_empty(), "auth_chain key must not be empty");
// Check RAM cache // Check RAM cache
if let Some(result) = self if let Some(result) = self.auth_chain_cache.lock().get_mut(key) {
.auth_chain_cache
.lock()
.expect("cache locked")
.get_mut(key)
{
return Ok(Arc::clone(result)); return Ok(Arc::clone(result));
} }
@ -63,7 +55,6 @@ impl Data {
// Cache in RAM // Cache in RAM
self.auth_chain_cache self.auth_chain_cache
.lock() .lock()
.expect("cache locked")
.insert(vec![key[0]], Arc::clone(&chain)); .insert(vec![key[0]], Arc::clone(&chain));
Ok(chain) Ok(chain)
@ -84,9 +75,6 @@ impl Data {
} }
// Cache in RAM // Cache in RAM
self.auth_chain_cache self.auth_chain_cache.lock().insert(key, auth_chain);
.lock()
.expect("cache locked")
.insert(key, auth_chain);
} }
} }

View file

@ -248,10 +248,10 @@ pub fn cache_auth_chain_vec(&self, key: Vec<u64>, auth_chain: &[ShortEventId]) {
#[implement(Service)] #[implement(Service)]
pub fn get_cache_usage(&self) -> (usize, usize) { pub fn get_cache_usage(&self) -> (usize, usize) {
let cache = self.db.auth_chain_cache.lock().expect("locked"); let cache = self.db.auth_chain_cache.lock();
(cache.len(), cache.capacity()) (cache.len(), cache.capacity())
} }
#[implement(Service)] #[implement(Service)]
pub fn clear_cache(&self) { self.db.auth_chain_cache.lock().expect("locked").clear(); } pub fn clear_cache(&self) { self.db.auth_chain_cache.lock().clear(); }

View file

@ -41,7 +41,6 @@ where
.globals .globals
.bad_event_ratelimiter .bad_event_ratelimiter
.write() .write()
.expect("locked")
.entry(id) .entry(id)
{ {
| hash_map::Entry::Vacant(e) => { | hash_map::Entry::Vacant(e) => {
@ -76,7 +75,6 @@ where
.globals .globals
.bad_event_ratelimiter .bad_event_ratelimiter
.read() .read()
.expect("locked")
.get(&*next_id) .get(&*next_id)
{ {
// Exponential backoff // Exponential backoff
@ -187,7 +185,6 @@ where
.globals .globals
.bad_event_ratelimiter .bad_event_ratelimiter
.read() .read()
.expect("locked")
.get(&*next_id) .get(&*next_id)
{ {
// Exponential backoff // Exponential backoff

View file

@ -160,7 +160,6 @@ pub async fn handle_incoming_pdu<'a>(
.globals .globals
.bad_event_ratelimiter .bad_event_ratelimiter
.write() .write()
.expect("locked")
.entry(prev_id.into()) .entry(prev_id.into())
{ {
| hash_map::Entry::Vacant(e) => { | hash_map::Entry::Vacant(e) => {
@ -181,13 +180,11 @@ pub async fn handle_incoming_pdu<'a>(
let start_time = Instant::now(); let start_time = Instant::now();
self.federation_handletime self.federation_handletime
.write() .write()
.expect("locked")
.insert(room_id.into(), (event_id.to_owned(), start_time)); .insert(room_id.into(), (event_id.to_owned(), start_time));
defer! {{ defer! {{
self.federation_handletime self.federation_handletime
.write() .write()
.expect("locked")
.remove(room_id); .remove(room_id);
}}; }};

View file

@ -42,7 +42,6 @@ where
.globals .globals
.bad_event_ratelimiter .bad_event_ratelimiter
.read() .read()
.expect("locked")
.get(prev_id) .get(prev_id)
{ {
// Exponential backoff // Exponential backoff
@ -70,13 +69,11 @@ where
let start_time = Instant::now(); let start_time = Instant::now();
self.federation_handletime self.federation_handletime
.write() .write()
.expect("locked")
.insert(room_id.into(), ((*prev_id).to_owned(), start_time)); .insert(room_id.into(), ((*prev_id).to_owned(), start_time));
defer! {{ defer! {{
self.federation_handletime self.federation_handletime
.write() .write()
.expect("locked")
.remove(room_id); .remove(room_id);
}}; }};

View file

@ -10,15 +10,10 @@ mod resolve_state;
mod state_at_incoming; mod state_at_incoming;
mod upgrade_outlier_pdu; mod upgrade_outlier_pdu;
use std::{ use std::{collections::HashMap, fmt::Write, sync::Arc, time::Instant};
collections::HashMap,
fmt::Write,
sync::{Arc, RwLock as StdRwLock},
time::Instant,
};
use async_trait::async_trait; use async_trait::async_trait;
use conduwuit::{Err, Event, PduEvent, Result, RoomVersion, Server, utils::MutexMap}; use conduwuit::{Err, Event, PduEvent, Result, RoomVersion, Server, SyncRwLock, utils::MutexMap};
use ruma::{ use ruma::{
OwnedEventId, OwnedRoomId, RoomId, RoomVersionId, OwnedEventId, OwnedRoomId, RoomId, RoomVersionId,
events::room::create::RoomCreateEventContent, events::room::create::RoomCreateEventContent,
@ -28,7 +23,7 @@ use crate::{Dep, globals, rooms, sending, server_keys};
pub struct Service { pub struct Service {
pub mutex_federation: RoomMutexMap, pub mutex_federation: RoomMutexMap,
pub federation_handletime: StdRwLock<HandleTimeMap>, pub federation_handletime: SyncRwLock<HandleTimeMap>,
services: Services, services: Services,
} }
@ -81,11 +76,7 @@ impl crate::Service for Service {
let mutex_federation = self.mutex_federation.len(); let mutex_federation = self.mutex_federation.len();
writeln!(out, "federation_mutex: {mutex_federation}")?; writeln!(out, "federation_mutex: {mutex_federation}")?;
let federation_handletime = self let federation_handletime = self.federation_handletime.read().len();
.federation_handletime
.read()
.expect("locked for reading")
.len();
writeln!(out, "federation_handletime: {federation_handletime}")?; writeln!(out, "federation_handletime: {federation_handletime}")?;
Ok(()) Ok(())

View file

@ -1,13 +1,10 @@
mod update; mod update;
mod via; mod via;
use std::{ use std::{collections::HashMap, sync::Arc};
collections::HashMap,
sync::{Arc, RwLock},
};
use conduwuit::{ use conduwuit::{
Result, implement, Result, SyncRwLock, implement,
result::LogErr, result::LogErr,
utils::{ReadyExt, stream::TryIgnore}, utils::{ReadyExt, stream::TryIgnore},
warn, warn,
@ -54,14 +51,14 @@ struct Data {
userroomid_knockedstate: Arc<Map>, userroomid_knockedstate: Arc<Map>,
} }
type AppServiceInRoomCache = RwLock<HashMap<OwnedRoomId, HashMap<String, bool>>>; type AppServiceInRoomCache = SyncRwLock<HashMap<OwnedRoomId, HashMap<String, bool>>>;
type StrippedStateEventItem = (OwnedRoomId, Vec<Raw<AnyStrippedStateEvent>>); type StrippedStateEventItem = (OwnedRoomId, Vec<Raw<AnyStrippedStateEvent>>);
type SyncStateEventItem = (OwnedRoomId, Vec<Raw<AnySyncStateEvent>>); type SyncStateEventItem = (OwnedRoomId, Vec<Raw<AnySyncStateEvent>>);
impl crate::Service for Service { impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self { Ok(Arc::new(Self {
appservice_in_room_cache: RwLock::new(HashMap::new()), appservice_in_room_cache: SyncRwLock::new(HashMap::new()),
services: Services { services: Services {
account_data: args.depend::<account_data::Service>("account_data"), account_data: args.depend::<account_data::Service>("account_data"),
config: args.depend::<config::Service>("config"), config: args.depend::<config::Service>("config"),
@ -99,7 +96,6 @@ pub async fn appservice_in_room(&self, room_id: &RoomId, appservice: &Registrati
if let Some(cached) = self if let Some(cached) = self
.appservice_in_room_cache .appservice_in_room_cache
.read() .read()
.expect("locked")
.get(room_id) .get(room_id)
.and_then(|map| map.get(&appservice.registration.id)) .and_then(|map| map.get(&appservice.registration.id))
.copied() .copied()
@ -124,7 +120,6 @@ pub async fn appservice_in_room(&self, room_id: &RoomId, appservice: &Registrati
self.appservice_in_room_cache self.appservice_in_room_cache
.write() .write()
.expect("locked")
.entry(room_id.into()) .entry(room_id.into())
.or_default() .or_default()
.insert(appservice.registration.id.clone(), in_room); .insert(appservice.registration.id.clone(), in_room);
@ -134,19 +129,14 @@ pub async fn appservice_in_room(&self, room_id: &RoomId, appservice: &Registrati
#[implement(Service)] #[implement(Service)]
pub fn get_appservice_in_room_cache_usage(&self) -> (usize, usize) { pub fn get_appservice_in_room_cache_usage(&self) -> (usize, usize) {
let cache = self.appservice_in_room_cache.read().expect("locked"); let cache = self.appservice_in_room_cache.read();
(cache.len(), cache.capacity()) (cache.len(), cache.capacity())
} }
#[implement(Service)] #[implement(Service)]
#[tracing::instrument(level = "debug", skip_all)] #[tracing::instrument(level = "debug", skip_all)]
pub fn clear_appservice_in_room_cache(&self) { pub fn clear_appservice_in_room_cache(&self) { self.appservice_in_room_cache.write().clear(); }
self.appservice_in_room_cache
.write()
.expect("locked")
.clear();
}
/// Returns an iterator of all servers participating in this room. /// Returns an iterator of all servers participating in this room.
#[implement(Service)] #[implement(Service)]

View file

@ -211,10 +211,7 @@ pub async fn update_joined_count(&self, room_id: &RoomId) {
self.db.serverroomids.put_raw(serverroom_id, []); self.db.serverroomids.put_raw(serverroom_id, []);
} }
self.appservice_in_room_cache self.appservice_in_room_cache.write().remove(room_id);
.write()
.expect("locked")
.remove(room_id);
} }
/// Direct DB function to directly mark a user as joined. It is not /// Direct DB function to directly mark a user as joined. It is not

View file

@ -2,12 +2,12 @@ use std::{
collections::{BTreeSet, HashMap}, collections::{BTreeSet, HashMap},
fmt::{Debug, Write}, fmt::{Debug, Write},
mem::size_of, mem::size_of,
sync::{Arc, Mutex}, sync::Arc,
}; };
use async_trait::async_trait; use async_trait::async_trait;
use conduwuit::{ use conduwuit::{
Result, Result, SyncMutex,
arrayvec::ArrayVec, arrayvec::ArrayVec,
at, checked, err, expected, implement, utils, at, checked, err, expected, implement, utils,
utils::{bytes, math::usize_from_f64, stream::IterStream}, utils::{bytes, math::usize_from_f64, stream::IterStream},
@ -23,7 +23,7 @@ use crate::{
}; };
pub struct Service { pub struct Service {
pub stateinfo_cache: Mutex<StateInfoLruCache>, pub stateinfo_cache: SyncMutex<StateInfoLruCache>,
db: Data, db: Data,
services: Services, services: Services,
} }
@ -86,7 +86,7 @@ impl crate::Service for Service {
async fn memory_usage(&self, out: &mut (dyn Write + Send)) -> Result { async fn memory_usage(&self, out: &mut (dyn Write + Send)) -> Result {
let (cache_len, ents) = { let (cache_len, ents) = {
let cache = self.stateinfo_cache.lock().expect("locked"); let cache = self.stateinfo_cache.lock();
let ents = cache.iter().map(at!(1)).flat_map(|vec| vec.iter()).fold( let ents = cache.iter().map(at!(1)).flat_map(|vec| vec.iter()).fold(
HashMap::new(), HashMap::new(),
|mut ents, ssi| { |mut ents, ssi| {
@ -110,7 +110,7 @@ impl crate::Service for Service {
Ok(()) Ok(())
} }
async fn clear_cache(&self) { self.stateinfo_cache.lock().expect("locked").clear(); } async fn clear_cache(&self) { self.stateinfo_cache.lock().clear(); }
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
} }
@ -123,7 +123,7 @@ pub async fn load_shortstatehash_info(
&self, &self,
shortstatehash: ShortStateHash, shortstatehash: ShortStateHash,
) -> Result<ShortStateInfoVec> { ) -> Result<ShortStateInfoVec> {
if let Some(r) = self.stateinfo_cache.lock()?.get_mut(&shortstatehash) { if let Some(r) = self.stateinfo_cache.lock().get_mut(&shortstatehash) {
return Ok(r.clone()); return Ok(r.clone());
} }
@ -152,7 +152,7 @@ async fn cache_shortstatehash_info(
shortstatehash: ShortStateHash, shortstatehash: ShortStateHash,
stack: ShortStateInfoVec, stack: ShortStateInfoVec,
) -> Result { ) -> Result {
self.stateinfo_cache.lock()?.insert(shortstatehash, stack); self.stateinfo_cache.lock().insert(shortstatehash, stack);
Ok(()) Ok(())
} }

View file

@ -3,11 +3,13 @@ use std::{
collections::BTreeMap, collections::BTreeMap,
fmt::Write, fmt::Write,
ops::Deref, ops::Deref,
sync::{Arc, OnceLock, RwLock, Weak}, sync::{Arc, OnceLock, Weak},
}; };
use async_trait::async_trait; use async_trait::async_trait;
use conduwuit::{Err, Result, Server, err, error::inspect_log, utils::string::SplitInfallible}; use conduwuit::{
Err, Result, Server, SyncRwLock, err, error::inspect_log, utils::string::SplitInfallible,
};
use database::Database; use database::Database;
/// Abstract interface for a Service /// Abstract interface for a Service
@ -62,7 +64,7 @@ pub(crate) struct Dep<T: Service + Send + Sync> {
name: &'static str, name: &'static str,
} }
pub(crate) type Map = RwLock<MapType>; pub(crate) type Map = SyncRwLock<MapType>;
pub(crate) type MapType = BTreeMap<MapKey, MapVal>; pub(crate) type MapType = BTreeMap<MapKey, MapVal>;
pub(crate) type MapVal = (Weak<dyn Service>, Weak<dyn Any + Send + Sync>); pub(crate) type MapVal = (Weak<dyn Service>, Weak<dyn Any + Send + Sync>);
pub(crate) type MapKey = String; pub(crate) type MapKey = String;
@ -143,15 +145,12 @@ pub(crate) fn get<T>(map: &Map, name: &str) -> Option<Arc<T>>
where where
T: Any + Send + Sync + Sized, T: Any + Send + Sync + Sized,
{ {
map.read() map.read().get(name).map(|(_, s)| {
.expect("locked for reading") s.upgrade().map(|s| {
.get(name) s.downcast::<T>()
.map(|(_, s)| { .expect("Service must be correctly downcast.")
s.upgrade().map(|s| { })
s.downcast::<T>() })?
.expect("Service must be correctly downcast.")
})
})?
} }
/// Reference a Service by name. Returns Err if the Service does not exist or /// Reference a Service by name. Returns Err if the Service does not exist or
@ -160,21 +159,18 @@ pub(crate) fn try_get<T>(map: &Map, name: &str) -> Result<Arc<T>>
where where
T: Any + Send + Sync + Sized, T: Any + Send + Sync + Sized,
{ {
map.read() map.read().get(name).map_or_else(
.expect("locked for reading") || Err!("Service {name:?} does not exist or has not been built yet."),
.get(name) |(_, s)| {
.map_or_else( s.upgrade().map_or_else(
|| Err!("Service {name:?} does not exist or has not been built yet."), || Err!("Service {name:?} no longer exists."),
|(_, s)| { |s| {
s.upgrade().map_or_else( s.downcast::<T>()
|| Err!("Service {name:?} no longer exists."), .map_err(|_| err!("Service {name:?} must be correctly downcast."))
|s| { },
s.downcast::<T>() )
.map_err(|_| err!("Service {name:?} must be correctly downcast.")) },
}, )
)
},
)
} }
/// Utility for service implementations; see Service::name() in the trait. /// Utility for service implementations; see Service::name() in the trait.

View file

@ -1,10 +1,8 @@
use std::{ use std::{any::Any, collections::BTreeMap, sync::Arc};
any::Any,
collections::BTreeMap,
sync::{Arc, RwLock},
};
use conduwuit::{Result, Server, debug, debug_info, info, trace, utils::stream::IterStream}; use conduwuit::{
Result, Server, SyncRwLock, debug, debug_info, info, trace, utils::stream::IterStream,
};
use database::Database; use database::Database;
use futures::{Stream, StreamExt, TryStreamExt}; use futures::{Stream, StreamExt, TryStreamExt};
use tokio::sync::Mutex; use tokio::sync::Mutex;
@ -52,7 +50,7 @@ impl Services {
#[allow(clippy::cognitive_complexity)] #[allow(clippy::cognitive_complexity)]
pub async fn build(server: Arc<Server>) -> Result<Arc<Self>> { pub async fn build(server: Arc<Server>) -> Result<Arc<Self>> {
let db = Database::open(&server).await?; let db = Database::open(&server).await?;
let service: Arc<Map> = Arc::new(RwLock::new(BTreeMap::new())); let service: Arc<Map> = Arc::new(SyncRwLock::new(BTreeMap::new()));
macro_rules! build { macro_rules! build {
($tyname:ty) => {{ ($tyname:ty) => {{
let built = <$tyname>::build(Args { let built = <$tyname>::build(Args {
@ -193,7 +191,7 @@ impl Services {
fn interrupt(&self) { fn interrupt(&self) {
debug!("Interrupting services..."); debug!("Interrupting services...");
for (name, (service, ..)) in self.service.read().expect("locked for reading").iter() { for (name, (service, ..)) in self.service.read().iter() {
if let Some(service) = service.upgrade() { if let Some(service) = service.upgrade() {
trace!("Interrupting {name}"); trace!("Interrupting {name}");
service.interrupt(); service.interrupt();
@ -205,7 +203,6 @@ impl Services {
fn services(&self) -> impl Stream<Item = Arc<dyn Service>> + Send { fn services(&self) -> impl Stream<Item = Arc<dyn Service>> + Send {
self.service self.service
.read() .read()
.expect("locked for reading")
.values() .values()
.filter_map(|val| val.0.upgrade()) .filter_map(|val| val.0.upgrade())
.collect::<Vec<_>>() .collect::<Vec<_>>()
@ -233,10 +230,9 @@ impl Services {
#[allow(clippy::needless_pass_by_value)] #[allow(clippy::needless_pass_by_value)]
fn add_service(map: &Arc<Map>, s: Arc<dyn Service>, a: Arc<dyn Any + Send + Sync>) { fn add_service(map: &Arc<Map>, s: Arc<dyn Service>, a: Arc<dyn Any + Send + Sync>) {
let name = s.name(); let name = s.name();
let len = map.read().expect("locked for reading").len(); let len = map.read().len();
trace!("built service #{len}: {name:?}"); trace!("built service #{len}: {name:?}");
map.write() map.write()
.expect("locked for writing")
.insert(name.to_owned(), (Arc::downgrade(&s), Arc::downgrade(&a))); .insert(name.to_owned(), (Arc::downgrade(&s), Arc::downgrade(&a)));
} }

View file

@ -2,10 +2,10 @@ mod watch;
use std::{ use std::{
collections::{BTreeMap, BTreeSet}, collections::{BTreeMap, BTreeSet},
sync::{Arc, Mutex, Mutex as StdMutex}, sync::Arc,
}; };
use conduwuit::{Result, Server}; use conduwuit::{Result, Server, SyncMutex};
use database::Map; use database::Map;
use ruma::{ use ruma::{
OwnedDeviceId, OwnedRoomId, OwnedUserId, OwnedDeviceId, OwnedRoomId, OwnedUserId,
@ -62,11 +62,11 @@ struct SnakeSyncCache {
extensions: v5::request::Extensions, extensions: v5::request::Extensions,
} }
type DbConnections<K, V> = Mutex<BTreeMap<K, V>>; type DbConnections<K, V> = SyncMutex<BTreeMap<K, V>>;
type DbConnectionsKey = (OwnedUserId, OwnedDeviceId, String); type DbConnectionsKey = (OwnedUserId, OwnedDeviceId, String);
type DbConnectionsVal = Arc<Mutex<SlidingSyncCache>>; type DbConnectionsVal = Arc<SyncMutex<SlidingSyncCache>>;
type SnakeConnectionsKey = (OwnedUserId, OwnedDeviceId, Option<String>); type SnakeConnectionsKey = (OwnedUserId, OwnedDeviceId, Option<String>);
type SnakeConnectionsVal = Arc<Mutex<SnakeSyncCache>>; type SnakeConnectionsVal = Arc<SyncMutex<SnakeSyncCache>>;
impl crate::Service for Service { impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
@ -90,8 +90,8 @@ impl crate::Service for Service {
state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"), state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"),
typing: args.depend::<rooms::typing::Service>("rooms::typing"), typing: args.depend::<rooms::typing::Service>("rooms::typing"),
}, },
connections: StdMutex::new(BTreeMap::new()), connections: SyncMutex::new(BTreeMap::new()),
snake_connections: StdMutex::new(BTreeMap::new()), snake_connections: SyncMutex::new(BTreeMap::new()),
})) }))
} }
@ -100,22 +100,19 @@ impl crate::Service for Service {
impl Service { impl Service {
pub fn snake_connection_cached(&self, key: &SnakeConnectionsKey) -> bool { pub fn snake_connection_cached(&self, key: &SnakeConnectionsKey) -> bool {
self.snake_connections self.snake_connections.lock().contains_key(key)
.lock()
.expect("locked")
.contains_key(key)
} }
pub fn forget_snake_sync_connection(&self, key: &SnakeConnectionsKey) { pub fn forget_snake_sync_connection(&self, key: &SnakeConnectionsKey) {
self.snake_connections.lock().expect("locked").remove(key); self.snake_connections.lock().remove(key);
} }
pub fn remembered(&self, key: &DbConnectionsKey) -> bool { pub fn remembered(&self, key: &DbConnectionsKey) -> bool {
self.connections.lock().expect("locked").contains_key(key) self.connections.lock().contains_key(key)
} }
pub fn forget_sync_request_connection(&self, key: &DbConnectionsKey) { pub fn forget_sync_request_connection(&self, key: &DbConnectionsKey) {
self.connections.lock().expect("locked").remove(key); self.connections.lock().remove(key);
} }
pub fn update_snake_sync_request_with_cache( pub fn update_snake_sync_request_with_cache(
@ -123,13 +120,13 @@ impl Service {
snake_key: &SnakeConnectionsKey, snake_key: &SnakeConnectionsKey,
request: &mut v5::Request, request: &mut v5::Request,
) -> BTreeMap<String, BTreeMap<OwnedRoomId, u64>> { ) -> BTreeMap<String, BTreeMap<OwnedRoomId, u64>> {
let mut cache = self.snake_connections.lock().expect("locked"); let mut cache = self.snake_connections.lock();
let cached = Arc::clone( let cached = Arc::clone(
cache cache
.entry(snake_key.clone()) .entry(snake_key.clone())
.or_insert_with(|| Arc::new(Mutex::new(SnakeSyncCache::default()))), .or_insert_with(|| Arc::new(SyncMutex::new(SnakeSyncCache::default()))),
); );
let cached = &mut cached.lock().expect("locked"); let cached = &mut cached.lock();
drop(cache); drop(cache);
//v5::Request::try_from_http_request(req, path_args); //v5::Request::try_from_http_request(req, path_args);
@ -232,16 +229,16 @@ impl Service {
}; };
let key = into_db_key(key.0.clone(), key.1.clone(), conn_id); let key = into_db_key(key.0.clone(), key.1.clone(), conn_id);
let mut cache = self.connections.lock().expect("locked"); let mut cache = self.connections.lock();
let cached = Arc::clone(cache.entry(key).or_insert_with(|| { let cached = Arc::clone(cache.entry(key).or_insert_with(|| {
Arc::new(Mutex::new(SlidingSyncCache { Arc::new(SyncMutex::new(SlidingSyncCache {
lists: BTreeMap::new(), lists: BTreeMap::new(),
subscriptions: BTreeMap::new(), subscriptions: BTreeMap::new(),
known_rooms: BTreeMap::new(), known_rooms: BTreeMap::new(),
extensions: ExtensionsConfig::default(), extensions: ExtensionsConfig::default(),
})) }))
})); }));
let cached = &mut cached.lock().expect("locked"); let cached = &mut cached.lock();
drop(cache); drop(cache);
for (list_id, list) in &mut request.lists { for (list_id, list) in &mut request.lists {
@ -328,16 +325,16 @@ impl Service {
key: &DbConnectionsKey, key: &DbConnectionsKey,
subscriptions: BTreeMap<OwnedRoomId, sync_events::v4::RoomSubscription>, subscriptions: BTreeMap<OwnedRoomId, sync_events::v4::RoomSubscription>,
) { ) {
let mut cache = self.connections.lock().expect("locked"); let mut cache = self.connections.lock();
let cached = Arc::clone(cache.entry(key.clone()).or_insert_with(|| { let cached = Arc::clone(cache.entry(key.clone()).or_insert_with(|| {
Arc::new(Mutex::new(SlidingSyncCache { Arc::new(SyncMutex::new(SlidingSyncCache {
lists: BTreeMap::new(), lists: BTreeMap::new(),
subscriptions: BTreeMap::new(), subscriptions: BTreeMap::new(),
known_rooms: BTreeMap::new(), known_rooms: BTreeMap::new(),
extensions: ExtensionsConfig::default(), extensions: ExtensionsConfig::default(),
})) }))
})); }));
let cached = &mut cached.lock().expect("locked"); let cached = &mut cached.lock();
drop(cache); drop(cache);
cached.subscriptions = subscriptions; cached.subscriptions = subscriptions;
@ -350,16 +347,16 @@ impl Service {
new_cached_rooms: BTreeSet<OwnedRoomId>, new_cached_rooms: BTreeSet<OwnedRoomId>,
globalsince: u64, globalsince: u64,
) { ) {
let mut cache = self.connections.lock().expect("locked"); let mut cache = self.connections.lock();
let cached = Arc::clone(cache.entry(key.clone()).or_insert_with(|| { let cached = Arc::clone(cache.entry(key.clone()).or_insert_with(|| {
Arc::new(Mutex::new(SlidingSyncCache { Arc::new(SyncMutex::new(SlidingSyncCache {
lists: BTreeMap::new(), lists: BTreeMap::new(),
subscriptions: BTreeMap::new(), subscriptions: BTreeMap::new(),
known_rooms: BTreeMap::new(), known_rooms: BTreeMap::new(),
extensions: ExtensionsConfig::default(), extensions: ExtensionsConfig::default(),
})) }))
})); }));
let cached = &mut cached.lock().expect("locked"); let cached = &mut cached.lock();
drop(cache); drop(cache);
for (room_id, lastsince) in cached for (room_id, lastsince) in cached
@ -386,13 +383,13 @@ impl Service {
globalsince: u64, globalsince: u64,
) { ) {
assert!(key.2.is_some(), "Some(conn_id) required for this call"); assert!(key.2.is_some(), "Some(conn_id) required for this call");
let mut cache = self.snake_connections.lock().expect("locked"); let mut cache = self.snake_connections.lock();
let cached = Arc::clone( let cached = Arc::clone(
cache cache
.entry(key.clone()) .entry(key.clone())
.or_insert_with(|| Arc::new(Mutex::new(SnakeSyncCache::default()))), .or_insert_with(|| Arc::new(SyncMutex::new(SnakeSyncCache::default()))),
); );
let cached = &mut cached.lock().expect("locked"); let cached = &mut cached.lock();
drop(cache); drop(cache);
for (room_id, lastsince) in cached for (room_id, lastsince) in cached
@ -416,13 +413,13 @@ impl Service {
key: &SnakeConnectionsKey, key: &SnakeConnectionsKey,
subscriptions: BTreeMap<OwnedRoomId, v5::request::RoomSubscription>, subscriptions: BTreeMap<OwnedRoomId, v5::request::RoomSubscription>,
) { ) {
let mut cache = self.snake_connections.lock().expect("locked"); let mut cache = self.snake_connections.lock();
let cached = Arc::clone( let cached = Arc::clone(
cache cache
.entry(key.clone()) .entry(key.clone())
.or_insert_with(|| Arc::new(Mutex::new(SnakeSyncCache::default()))), .or_insert_with(|| Arc::new(SyncMutex::new(SnakeSyncCache::default()))),
); );
let cached = &mut cached.lock().expect("locked"); let cached = &mut cached.lock();
drop(cache); drop(cache);
cached.subscriptions = subscriptions; cached.subscriptions = subscriptions;

View file

@ -1,10 +1,10 @@
use std::{ use std::{
collections::{BTreeMap, HashSet}, collections::{BTreeMap, HashSet},
sync::{Arc, RwLock}, sync::Arc,
}; };
use conduwuit::{ use conduwuit::{
Err, Error, Result, err, error, implement, utils, Err, Error, Result, SyncRwLock, err, error, implement, utils,
utils::{hash, string::EMPTY}, utils::{hash, string::EMPTY},
}; };
use database::{Deserialized, Json, Map}; use database::{Deserialized, Json, Map};
@ -19,7 +19,7 @@ use ruma::{
use crate::{Dep, config, globals, users}; use crate::{Dep, config, globals, users};
pub struct Service { pub struct Service {
userdevicesessionid_uiaarequest: RwLock<RequestMap>, userdevicesessionid_uiaarequest: SyncRwLock<RequestMap>,
db: Data, db: Data,
services: Services, services: Services,
} }
@ -42,7 +42,7 @@ pub const SESSION_ID_LENGTH: usize = 32;
impl crate::Service for Service { impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self { Ok(Arc::new(Self {
userdevicesessionid_uiaarequest: RwLock::new(RequestMap::new()), userdevicesessionid_uiaarequest: SyncRwLock::new(RequestMap::new()),
db: Data { db: Data {
userdevicesessionid_uiaainfo: args.db["userdevicesessionid_uiaainfo"].clone(), userdevicesessionid_uiaainfo: args.db["userdevicesessionid_uiaainfo"].clone(),
}, },
@ -268,7 +268,6 @@ fn set_uiaa_request(
let key = (user_id.to_owned(), device_id.to_owned(), session.to_owned()); let key = (user_id.to_owned(), device_id.to_owned(), session.to_owned());
self.userdevicesessionid_uiaarequest self.userdevicesessionid_uiaarequest
.write() .write()
.expect("locked for writing")
.insert(key, request.to_owned()); .insert(key, request.to_owned());
} }
@ -287,7 +286,6 @@ pub fn get_uiaa_request(
self.userdevicesessionid_uiaarequest self.userdevicesessionid_uiaarequest
.read() .read()
.expect("locked for reading")
.get(&key) .get(&key)
.cloned() .cloned()
} }