diff --git a/.envrc b/.envrc index bad73b75..952ec2f8 100644 --- a/.envrc +++ b/.envrc @@ -2,6 +2,6 @@ dotenv_if_exists -# use flake ".#${DIRENV_DEVSHELL:-default}" +use flake ".#${DIRENV_DEVSHELL:-default}" PATH_add bin diff --git a/.forgejo/workflows/release-image.yml b/.forgejo/workflows/release-image.yml index 04fc9de9..5ac5ddfa 100644 --- a/.forgejo/workflows/release-image.yml +++ b/.forgejo/workflows/release-image.yml @@ -262,7 +262,7 @@ jobs: type=ref,event=branch,prefix=${{ format('refs/heads/{0}', github.event.repository.default_branch) != github.ref && 'branch-' || '' }} type=ref,event=pr type=sha,format=long - type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/v') }} + type=raw,value=latest,enable=${{ !startsWith(github.ref, 'refs/tags/v') }} images: ${{needs.define-variables.outputs.images}} # default labels & annotations: https://github.com/docker/metadata-action/blob/master/src/meta.ts#L509 env: diff --git a/Cargo.lock b/Cargo.lock index 22c90e17..700c04f5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -963,12 +963,10 @@ dependencies = [ "itertools 0.14.0", "libc", "libloading", - "lock_api", "log", "maplit", "nix", "num-traits", - "parking_lot", "rand 0.8.5", "regex", "reqwest", @@ -1659,12 +1657,6 @@ dependencies = [ "winapi", ] -[[package]] -name = "fixedbitset" -version = "0.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" - [[package]] name = "flate2" version = "1.1.2" @@ -3226,13 +3218,10 @@ version = "0.9.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bc838d2a56b5b1a6c25f55575dfc605fabb63bb2365f6c2353ef9159aa69e4a5" dependencies = [ - "backtrace", "cfg-if", "libc", - "petgraph", "redox_syscall", "smallvec", - "thread-id", "windows-targets 0.52.6", ] @@ -3282,16 +3271,6 @@ version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" -[[package]] -name = "petgraph" -version = "0.6.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db" -dependencies = [ - "fixedbitset", - "indexmap 2.9.0", -] - [[package]] name = "phf" version = "0.11.3" @@ -4913,16 +4892,6 @@ dependencies = [ "syn", ] -[[package]] -name = "thread-id" -version = "4.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cfe8f25bbdd100db7e1d34acf7fd2dc59c4bf8f7483f505eaa7d4f12f76cc0ea" -dependencies = [ - "libc", - "winapi", -] - [[package]] name = "thread_local" version = "1.1.9" diff --git a/Cargo.toml b/Cargo.toml index 9cb5ff84..fb00d6d3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -515,14 +515,6 @@ version = "1.0" [workspace.dependencies.proc-macro2] version = "1.0" -[workspace.dependencies.parking_lot] -version = "0.12.4" -features = ["hardware-lock-elision", "deadlock_detection"] # TODO: Check if deadlock_detection has a perf impact, if it does only enable with debug_assertions - -# Use this when extending with_lock::WithLock to parking_lot -[workspace.dependencies.lock_api] -version = "0.4.13" - [workspace.dependencies.bytesize] version = "2.0" diff --git a/conduwuit-example.toml b/conduwuit-example.toml index 2fab9cdf..bdc2f570 100644 --- a/conduwuit-example.toml +++ b/conduwuit-example.toml @@ -325,15 +325,6 @@ # #well_known_timeout = 10 -# Federation client connection timeout (seconds). You should not set this -# to high values, as dead homeservers can significantly slow down -# federation, specifically key retrieval, which will take roughly the -# amount of time you configure here given that a homeserver doesn't -# respond. This will cause most clients to time out /keys/query, causing -# E2EE and device verification to fail. -# -#federation_conn_timeout = 10 - # Federation client request timeout (seconds). You most definitely want # this to be high to account for extremely large room joins, slow # homeservers, your own resources etc. diff --git a/src/admin/federation/commands.rs b/src/admin/federation/commands.rs index f77dadab..545dcbca 100644 --- a/src/admin/federation/commands.rs +++ b/src/admin/federation/commands.rs @@ -26,7 +26,8 @@ pub(super) async fn incoming_federation(&self) -> Result { .rooms .event_handler .federation_handletime - .read(); + .read() + .expect("locked"); let mut msg = format!("Handling {} incoming pdus:\n", map.len()); for (r, (e, i)) in map.iter() { diff --git a/src/admin/mod.rs b/src/admin/mod.rs index 1d46590b..732b8ce0 100644 --- a/src/admin/mod.rs +++ b/src/admin/mod.rs @@ -37,7 +37,11 @@ pub use crate::admin::AdminCommand; /// Install the admin command processor pub async fn init(admin_service: &service::admin::Service) { - _ = admin_service.complete.write().insert(processor::complete); + _ = admin_service + .complete + .write() + .expect("locked for writing") + .insert(processor::complete); _ = admin_service .handle .write() @@ -48,5 +52,9 @@ pub async fn init(admin_service: &service::admin::Service) { /// Uninstall the admin command handler pub async fn fini(admin_service: &service::admin::Service) { _ = admin_service.handle.write().await.take(); - _ = admin_service.complete.write().take(); + _ = admin_service + .complete + .write() + .expect("locked for writing") + .take(); } diff --git a/src/admin/processor.rs b/src/admin/processor.rs index 2c91efe1..e80000c1 100644 --- a/src/admin/processor.rs +++ b/src/admin/processor.rs @@ -1,8 +1,14 @@ -use std::{fmt::Write, mem::take, panic::AssertUnwindSafe, sync::Arc, time::SystemTime}; +use std::{ + fmt::Write, + mem::take, + panic::AssertUnwindSafe, + sync::{Arc, Mutex}, + time::SystemTime, +}; use clap::{CommandFactory, Parser}; use conduwuit::{ - Error, Result, SyncMutex, debug, error, + Error, Result, debug, error, log::{ capture, capture::Capture, @@ -117,7 +123,7 @@ async fn process( let mut output = String::new(); // Prepend the logs only if any were captured - let logs = logs.lock(); + let logs = logs.lock().expect("locked"); if logs.lines().count() > 2 { writeln!(&mut output, "{logs}").expect("failed to format logs to command output"); } @@ -126,7 +132,7 @@ async fn process( (result, output) } -fn capture_create(context: &Context<'_>) -> (Arc, Arc>) { +fn capture_create(context: &Context<'_>) -> (Arc, Arc>) { let env_config = &context.services.server.config.admin_log_capture; let env_filter = EnvFilter::try_new(env_config).unwrap_or_else(|e| { warn!("admin_log_capture filter invalid: {e:?}"); @@ -146,7 +152,7 @@ fn capture_create(context: &Context<'_>) -> (Arc, Arc data.level() <= log_level && data.our_modules() && data.scope.contains(&"admin") }; - let logs = Arc::new(SyncMutex::new( + let logs = Arc::new(Mutex::new( collect_stream(|s| markdown_table_head(s)).expect("markdown table header"), )); diff --git a/src/core/Cargo.toml b/src/core/Cargo.toml index 462b8e54..0c33c590 100644 --- a/src/core/Cargo.toml +++ b/src/core/Cargo.toml @@ -110,8 +110,6 @@ tracing-core.workspace = true tracing-subscriber.workspace = true tracing.workspace = true url.workspace = true -parking_lot.workspace = true -lock_api.workspace = true [target.'cfg(unix)'.dependencies] nix.workspace = true diff --git a/src/core/alloc/je.rs b/src/core/alloc/je.rs index 77deebc5..e138233e 100644 --- a/src/core/alloc/je.rs +++ b/src/core/alloc/je.rs @@ -4,6 +4,7 @@ use std::{ cell::OnceCell, ffi::{CStr, c_char, c_void}, fmt::Debug, + sync::RwLock, }; use arrayvec::ArrayVec; @@ -12,7 +13,7 @@ use tikv_jemalloc_sys as ffi; use tikv_jemallocator as jemalloc; use crate::{ - Result, SyncRwLock, err, is_equal_to, is_nonzero, + Result, err, is_equal_to, is_nonzero, utils::{math, math::Tried}, }; @@ -39,7 +40,7 @@ const MALLOC_CONF_PROF: &str = ""; #[global_allocator] static JEMALLOC: jemalloc::Jemalloc = jemalloc::Jemalloc; -static CONTROL: SyncRwLock<()> = SyncRwLock::new(()); +static CONTROL: RwLock<()> = RwLock::new(()); type Name = ArrayVec; type Key = ArrayVec; @@ -331,7 +332,7 @@ fn set(key: &Key, val: T) -> Result where T: Copy + Debug, { - let _lock = CONTROL.write(); + let _lock = CONTROL.write()?; let res = xchg(key, val)?; inc_epoch()?; diff --git a/src/core/config/mod.rs b/src/core/config/mod.rs index 909462db..d93acd9b 100644 --- a/src/core/config/mod.rs +++ b/src/core/config/mod.rs @@ -412,17 +412,6 @@ pub struct Config { #[serde(default = "default_well_known_timeout")] pub well_known_timeout: u64, - /// Federation client connection timeout (seconds). You should not set this - /// to high values, as dead homeservers can significantly slow down - /// federation, specifically key retrieval, which will take roughly the - /// amount of time you configure here given that a homeserver doesn't - /// respond. This will cause most clients to time out /keys/query, causing - /// E2EE and device verification to fail. - /// - /// default: 10 - #[serde(default = "default_federation_conn_timeout")] - pub federation_conn_timeout: u64, - /// Federation client request timeout (seconds). You most definitely want /// this to be high to account for extremely large room joins, slow /// homeservers, your own resources etc. @@ -2204,8 +2193,6 @@ fn default_well_known_conn_timeout() -> u64 { 6 } fn default_well_known_timeout() -> u64 { 10 } -fn default_federation_conn_timeout() -> u64 { 10 } - fn default_federation_timeout() -> u64 { 25 } fn default_federation_idle_timeout() -> u64 { 25 } diff --git a/src/core/info/rustc.rs b/src/core/info/rustc.rs index 60156301..048c0cd5 100644 --- a/src/core/info/rustc.rs +++ b/src/core/info/rustc.rs @@ -3,15 +3,18 @@ //! several crates, lower-level information is supplied from each crate during //! static initialization. -use std::{collections::BTreeMap, sync::OnceLock}; +use std::{ + collections::BTreeMap, + sync::{Mutex, OnceLock}, +}; -use crate::{SyncMutex, utils::exchange}; +use crate::utils::exchange; /// Raw capture of rustc flags used to build each crate in the project. Informed /// by rustc_flags_capture macro (one in each crate's mod.rs). This is /// done during static initialization which is why it's mutex-protected and pub. /// Should not be written to by anything other than our macro. -pub static FLAGS: SyncMutex> = SyncMutex::new(BTreeMap::new()); +pub static FLAGS: Mutex> = Mutex::new(BTreeMap::new()); /// Processed list of enabled features across all project crates. This is /// generated from the data in FLAGS. @@ -24,6 +27,7 @@ fn init_features() -> Vec<&'static str> { let mut features = Vec::new(); FLAGS .lock() + .expect("locked") .iter() .for_each(|(_, flags)| append_features(&mut features, flags)); diff --git a/src/core/log/capture/layer.rs b/src/core/log/capture/layer.rs index e3fe66df..381a652f 100644 --- a/src/core/log/capture/layer.rs +++ b/src/core/log/capture/layer.rs @@ -40,6 +40,7 @@ where self.state .active .read() + .expect("shared lock") .iter() .filter(|capture| filter(self, capture, event, &ctx)) .for_each(|capture| handle(self, capture, event, &ctx)); @@ -54,7 +55,7 @@ where let mut visitor = Visitor { values: Values::new() }; event.record(&mut visitor); - let mut closure = capture.closure.lock(); + let mut closure = capture.closure.lock().expect("exclusive lock"); closure(Data { layer, event, diff --git a/src/core/log/capture/mod.rs b/src/core/log/capture/mod.rs index b7e5d2b5..20f70091 100644 --- a/src/core/log/capture/mod.rs +++ b/src/core/log/capture/mod.rs @@ -4,7 +4,7 @@ pub mod layer; pub mod state; pub mod util; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; pub use data::Data; use guard::Guard; @@ -12,8 +12,6 @@ pub use layer::{Layer, Value}; pub use state::State; pub use util::*; -use crate::SyncMutex; - pub type Filter = dyn Fn(Data<'_>) -> bool + Send + Sync + 'static; pub type Closure = dyn FnMut(Data<'_>) + Send + Sync + 'static; @@ -21,7 +19,7 @@ pub type Closure = dyn FnMut(Data<'_>) + Send + Sync + 'static; pub struct Capture { state: Arc, filter: Option>, - closure: SyncMutex>, + closure: Mutex>, } impl Capture { @@ -36,7 +34,7 @@ impl Capture { Arc::new(Self { state: state.clone(), filter: filter.map(|p| -> Box { Box::new(p) }), - closure: SyncMutex::new(Box::new(closure)), + closure: Mutex::new(Box::new(closure)), }) } diff --git a/src/core/log/capture/state.rs b/src/core/log/capture/state.rs index 92a1608f..dad6c8d8 100644 --- a/src/core/log/capture/state.rs +++ b/src/core/log/capture/state.rs @@ -1,11 +1,10 @@ -use std::sync::Arc; +use std::sync::{Arc, RwLock}; use super::Capture; -use crate::SyncRwLock; /// Capture layer state. pub struct State { - pub(super) active: SyncRwLock>>, + pub(super) active: RwLock>>, } impl Default for State { @@ -14,14 +13,17 @@ impl Default for State { impl State { #[must_use] - pub fn new() -> Self { Self { active: SyncRwLock::new(Vec::new()) } } + pub fn new() -> Self { Self { active: RwLock::new(Vec::new()) } } pub(super) fn add(&self, capture: &Arc) { - self.active.write().push(capture.clone()); + self.active + .write() + .expect("locked for writing") + .push(capture.clone()); } pub(super) fn del(&self, capture: &Arc) { - let mut vec = self.active.write(); + let mut vec = self.active.write().expect("locked for writing"); if let Some(pos) = vec.iter().position(|v| Arc::ptr_eq(v, capture)) { vec.swap_remove(pos); } diff --git a/src/core/log/capture/util.rs b/src/core/log/capture/util.rs index 21a416a9..65524be5 100644 --- a/src/core/log/capture/util.rs +++ b/src/core/log/capture/util.rs @@ -1,31 +1,31 @@ -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use super::{ super::{Level, fmt}, Closure, Data, }; -use crate::{Result, SyncMutex}; +use crate::Result; -pub fn fmt_html(out: Arc>) -> Box +pub fn fmt_html(out: Arc>) -> Box where S: std::fmt::Write + Send + 'static, { fmt(fmt::html, out) } -pub fn fmt_markdown(out: Arc>) -> Box +pub fn fmt_markdown(out: Arc>) -> Box where S: std::fmt::Write + Send + 'static, { fmt(fmt::markdown, out) } -pub fn fmt(fun: F, out: Arc>) -> Box +pub fn fmt(fun: F, out: Arc>) -> Box where F: Fn(&mut S, &Level, &str, &str) -> Result<()> + Send + Sync + Copy + 'static, S: std::fmt::Write + Send + 'static, { - Box::new(move |data| call(fun, &mut *out.lock(), &data)) + Box::new(move |data| call(fun, &mut *out.lock().expect("locked"), &data)) } fn call(fun: F, out: &mut S, data: &Data<'_>) diff --git a/src/core/log/reload.rs b/src/core/log/reload.rs index 356ee9f2..f72fde47 100644 --- a/src/core/log/reload.rs +++ b/src/core/log/reload.rs @@ -1,8 +1,11 @@ -use std::{collections::HashMap, sync::Arc}; +use std::{ + collections::HashMap, + sync::{Arc, Mutex}, +}; use tracing_subscriber::{EnvFilter, reload}; -use crate::{Result, SyncMutex, error}; +use crate::{Result, error}; /// We need to store a reload::Handle value, but can't name it's type explicitly /// because the S type parameter depends on the subscriber's previous layers. In @@ -32,7 +35,7 @@ impl ReloadHandle for reload::Handle { #[derive(Clone)] pub struct LogLevelReloadHandles { - handles: Arc>, + handles: Arc>, } type HandleMap = HashMap; @@ -40,12 +43,16 @@ type Handle = Box + Send + Sync>; impl LogLevelReloadHandles { pub fn add(&self, name: &str, handle: Handle) { - self.handles.lock().insert(name.into(), handle); + self.handles + .lock() + .expect("locked") + .insert(name.into(), handle); } pub fn reload(&self, new_value: &EnvFilter, names: Option<&[&str]>) -> Result<()> { self.handles .lock() + .expect("locked") .iter() .filter(|(name, _)| names.is_some_and(|names| names.contains(&name.as_str()))) .for_each(|(_, handle)| { @@ -59,6 +66,7 @@ impl LogLevelReloadHandles { pub fn current(&self, name: &str) -> Option { self.handles .lock() + .expect("locked") .get(name) .map(|handle| handle.current())? } diff --git a/src/core/mod.rs b/src/core/mod.rs index 363fece8..d99139be 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -28,7 +28,6 @@ pub use info::{ pub use matrix::{ Event, EventTypeExt, Pdu, PduCount, PduEvent, PduId, RoomVersion, pdu, state_res, }; -pub use parking_lot::{Mutex as SyncMutex, RwLock as SyncRwLock}; pub use server::Server; pub use utils::{ctor, dtor, implement, result, result::Result}; diff --git a/src/core/utils/mutex_map.rs b/src/core/utils/mutex_map.rs index ddb361a4..01504ce6 100644 --- a/src/core/utils/mutex_map.rs +++ b/src/core/utils/mutex_map.rs @@ -1,8 +1,12 @@ -use std::{fmt::Debug, hash::Hash, sync::Arc}; +use std::{ + fmt::Debug, + hash::Hash, + sync::{Arc, TryLockError::WouldBlock}, +}; use tokio::sync::OwnedMutexGuard as Omg; -use crate::{Result, SyncMutex, err}; +use crate::{Result, err}; /// Map of Mutexes pub struct MutexMap { @@ -15,7 +19,7 @@ pub struct Guard { } type Map = Arc>; -type MapMutex = SyncMutex>; +type MapMutex = std::sync::Mutex>; type HashMap = std::collections::HashMap>; type Value = Arc>; @@ -41,6 +45,7 @@ where let val = self .map .lock() + .expect("locked") .entry(k.try_into().expect("failed to construct key")) .or_default() .clone(); @@ -61,6 +66,7 @@ where let val = self .map .lock() + .expect("locked") .entry(k.try_into().expect("failed to construct key")) .or_default() .clone(); @@ -81,7 +87,10 @@ where let val = self .map .try_lock() - .ok_or_else(|| err!("would block"))? + .map_err(|e| match e { + | WouldBlock => err!("would block"), + | _ => panic!("{e:?}"), + })? .entry(k.try_into().expect("failed to construct key")) .or_default() .clone(); @@ -93,13 +102,13 @@ where } #[must_use] - pub fn contains(&self, k: &Key) -> bool { self.map.lock().contains_key(k) } + pub fn contains(&self, k: &Key) -> bool { self.map.lock().expect("locked").contains_key(k) } #[must_use] - pub fn is_empty(&self) -> bool { self.map.lock().is_empty() } + pub fn is_empty(&self) -> bool { self.map.lock().expect("locked").is_empty() } #[must_use] - pub fn len(&self) -> usize { self.map.lock().len() } + pub fn len(&self) -> usize { self.map.lock().expect("locked").len() } } impl Default for MutexMap @@ -114,7 +123,7 @@ impl Drop for Guard { #[tracing::instrument(name = "unlock", level = "trace", skip_all)] fn drop(&mut self) { if Arc::strong_count(Omg::mutex(&self.val)) <= 2 { - self.map.lock().retain(|_, val| { + self.map.lock().expect("locked").retain(|_, val| { !Arc::ptr_eq(val, Omg::mutex(&self.val)) || Arc::strong_count(val) > 2 }); } diff --git a/src/core/utils/with_lock.rs b/src/core/utils/with_lock.rs index 91e8e8d1..76f014d1 100644 --- a/src/core/utils/with_lock.rs +++ b/src/core/utils/with_lock.rs @@ -1,212 +1,65 @@ //! Traits for explicitly scoping the lifetime of locks. -use std::{ - future::Future, - sync::{Arc, Mutex}, -}; +use std::sync::{Arc, Mutex}; -pub trait WithLock { - /// Acquires a lock and executes the given closure with the locked data, - /// returning the result. - fn with_lock(&self, f: F) -> R +pub trait WithLock { + /// Acquires a lock and executes the given closure with the locked data. + fn with_lock(&self, f: F) where - F: FnMut(&mut T) -> R; + F: FnMut(&mut T); } impl WithLock for Mutex { - fn with_lock(&self, mut f: F) -> R + fn with_lock(&self, mut f: F) where - F: FnMut(&mut T) -> R, + F: FnMut(&mut T), { // The locking and unlocking logic is hidden inside this function. let mut data_guard = self.lock().unwrap(); - f(&mut data_guard) + f(&mut data_guard); // Lock is released here when `data_guard` goes out of scope. } } impl WithLock for Arc> { - fn with_lock(&self, mut f: F) -> R + fn with_lock(&self, mut f: F) where - F: FnMut(&mut T) -> R, + F: FnMut(&mut T), { // The locking and unlocking logic is hidden inside this function. let mut data_guard = self.lock().unwrap(); - f(&mut data_guard) - // Lock is released here when `data_guard` goes out of scope. - } -} - -impl WithLock for lock_api::Mutex { - fn with_lock(&self, mut f: F) -> Ret - where - F: FnMut(&mut T) -> Ret, - { - // The locking and unlocking logic is hidden inside this function. - let mut data_guard = self.lock(); - f(&mut data_guard) - // Lock is released here when `data_guard` goes out of scope. - } -} - -impl WithLock for Arc> { - fn with_lock(&self, mut f: F) -> Ret - where - F: FnMut(&mut T) -> Ret, - { - // The locking and unlocking logic is hidden inside this function. - let mut data_guard = self.lock(); - f(&mut data_guard) + f(&mut data_guard); // Lock is released here when `data_guard` goes out of scope. } } pub trait WithLockAsync { - /// Acquires a lock and executes the given closure with the locked data, - /// returning the result. - fn with_lock(&self, f: F) -> impl Future + /// Acquires a lock and executes the given closure with the locked data. + fn with_lock(&self, f: F) -> impl Future where - F: FnMut(&mut T) -> R; - - /// Acquires a lock and executes the given async closure with the locked - /// data. - fn with_lock_async(&self, f: F) -> impl std::future::Future - where - F: AsyncFnMut(&mut T) -> R; + F: FnMut(&mut T); } impl WithLockAsync for futures::lock::Mutex { - async fn with_lock(&self, mut f: F) -> R + async fn with_lock(&self, mut f: F) where - F: FnMut(&mut T) -> R, + F: FnMut(&mut T), { // The locking and unlocking logic is hidden inside this function. let mut data_guard = self.lock().await; - f(&mut data_guard) - // Lock is released here when `data_guard` goes out of scope. - } - - async fn with_lock_async(&self, mut f: F) -> R - where - F: AsyncFnMut(&mut T) -> R, - { - // The locking and unlocking logic is hidden inside this function. - let mut data_guard = self.lock().await; - f(&mut data_guard).await + f(&mut data_guard); // Lock is released here when `data_guard` goes out of scope. } } impl WithLockAsync for Arc> { - async fn with_lock(&self, mut f: F) -> R + async fn with_lock(&self, mut f: F) where - F: FnMut(&mut T) -> R, + F: FnMut(&mut T), { // The locking and unlocking logic is hidden inside this function. let mut data_guard = self.lock().await; - f(&mut data_guard) - // Lock is released here when `data_guard` goes out of scope. - } - - async fn with_lock_async(&self, mut f: F) -> R - where - F: AsyncFnMut(&mut T) -> R, - { - // The locking and unlocking logic is hidden inside this function. - let mut data_guard = self.lock().await; - f(&mut data_guard).await + f(&mut data_guard); // Lock is released here when `data_guard` goes out of scope. } } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_with_lock_return_value() { - let mutex = Mutex::new(5); - let result = mutex.with_lock(|v| { - *v += 1; - *v * 2 - }); - assert_eq!(result, 12); - let value = mutex.lock().unwrap(); - assert_eq!(*value, 6); - } - - #[test] - fn test_with_lock_unit_return() { - let mutex = Mutex::new(10); - mutex.with_lock(|v| { - *v += 2; - }); - let value = mutex.lock().unwrap(); - assert_eq!(*value, 12); - } - - #[test] - fn test_with_lock_arc_mutex() { - let mutex = Arc::new(Mutex::new(1)); - let result = mutex.with_lock(|v| { - *v *= 10; - *v - }); - assert_eq!(result, 10); - assert_eq!(*mutex.lock().unwrap(), 10); - } - - #[tokio::test] - async fn test_with_lock_async_return_value() { - use futures::lock::Mutex as AsyncMutex; - let mutex = AsyncMutex::new(7); - let result = mutex - .with_lock(|v| { - *v += 3; - *v * 2 - }) - .await; - assert_eq!(result, 20); - let value = mutex.lock().await; - assert_eq!(*value, 10); - } - - #[tokio::test] - async fn test_with_lock_async_unit_return() { - use futures::lock::Mutex as AsyncMutex; - let mutex = AsyncMutex::new(100); - mutex - .with_lock(|v| { - *v -= 50; - }) - .await; - let value = mutex.lock().await; - assert_eq!(*value, 50); - } - - #[tokio::test] - async fn test_with_lock_async_closure() { - use futures::lock::Mutex as AsyncMutex; - let mutex = AsyncMutex::new(1); - mutex - .with_lock_async(async |v| { - *v += 9; - }) - .await; - let value = mutex.lock().await; - assert_eq!(*value, 10); - } - - #[tokio::test] - async fn test_with_lock_async_arc_mutex() { - use futures::lock::Mutex as AsyncMutex; - let mutex = Arc::new(AsyncMutex::new(2)); - mutex - .with_lock_async(async |v: &mut i32| { - *v *= 5; - }) - .await; - let value = mutex.lock().await; - assert_eq!(*value, 10); - } -} diff --git a/src/database/engine/backup.rs b/src/database/engine/backup.rs index 4cdb6172..ac72e6d4 100644 --- a/src/database/engine/backup.rs +++ b/src/database/engine/backup.rs @@ -71,7 +71,7 @@ pub fn backup_count(&self) -> Result { fn backup_engine(&self) -> Result { let path = self.backup_path()?; let options = BackupEngineOptions::new(path).map_err(map_err)?; - BackupEngine::open(&options, &self.ctx.env.lock()).map_err(map_err) + BackupEngine::open(&options, &*self.ctx.env.lock()?).map_err(map_err) } #[implement(Engine)] diff --git a/src/database/engine/cf_opts.rs b/src/database/engine/cf_opts.rs index 58358f02..cbbd1012 100644 --- a/src/database/engine/cf_opts.rs +++ b/src/database/engine/cf_opts.rs @@ -232,7 +232,7 @@ fn get_cache(ctx: &Context, desc: &Descriptor) -> Option { cache_opts.set_num_shard_bits(shard_bits); cache_opts.set_capacity(size); - let mut caches = ctx.col_cache.lock(); + let mut caches = ctx.col_cache.lock().expect("locked"); match desc.cache_disp { | CacheDisp::Unique if desc.cache_size == 0 => None, | CacheDisp::Unique => { diff --git a/src/database/engine/context.rs b/src/database/engine/context.rs index 3b9238bd..380e37af 100644 --- a/src/database/engine/context.rs +++ b/src/database/engine/context.rs @@ -1,6 +1,9 @@ -use std::{collections::BTreeMap, sync::Arc}; +use std::{ + collections::BTreeMap, + sync::{Arc, Mutex}, +}; -use conduwuit::{Result, Server, SyncMutex, debug, utils::math::usize_from_f64}; +use conduwuit::{Result, Server, debug, utils::math::usize_from_f64}; use rocksdb::{Cache, Env, LruCacheOptions}; use crate::{or_else, pool::Pool}; @@ -11,9 +14,9 @@ use crate::{or_else, pool::Pool}; /// These assets are housed in the shared Context. pub(crate) struct Context { pub(crate) pool: Arc, - pub(crate) col_cache: SyncMutex>, - pub(crate) row_cache: SyncMutex, - pub(crate) env: SyncMutex, + pub(crate) col_cache: Mutex>, + pub(crate) row_cache: Mutex, + pub(crate) env: Mutex, pub(crate) server: Arc, } @@ -65,7 +68,7 @@ impl Drop for Context { debug!("Closing frontend pool"); self.pool.close(); - let mut env = self.env.lock(); + let mut env = self.env.lock().expect("locked"); debug!("Shutting down background threads"); env.set_high_priority_background_threads(0); diff --git a/src/database/engine/memory_usage.rs b/src/database/engine/memory_usage.rs index 21af35c8..9bb5c535 100644 --- a/src/database/engine/memory_usage.rs +++ b/src/database/engine/memory_usage.rs @@ -9,7 +9,7 @@ use crate::or_else; #[implement(Engine)] pub fn memory_usage(&self) -> Result { let mut res = String::new(); - let stats = get_memory_usage_stats(Some(&[&self.db]), Some(&[&*self.ctx.row_cache.lock()])) + let stats = get_memory_usage_stats(Some(&[&self.db]), Some(&[&*self.ctx.row_cache.lock()?])) .or_else(or_else)?; let mibs = |input| f64::from(u32::try_from(input / 1024).unwrap_or(0)) / 1024.0; writeln!( @@ -19,10 +19,10 @@ pub fn memory_usage(&self) -> Result { mibs(stats.mem_table_total), mibs(stats.mem_table_unflushed), mibs(stats.mem_table_readers_total), - mibs(u64::try_from(self.ctx.row_cache.lock().get_usage())?), + mibs(u64::try_from(self.ctx.row_cache.lock()?.get_usage())?), )?; - for (name, cache) in &*self.ctx.col_cache.lock() { + for (name, cache) in &*self.ctx.col_cache.lock()? { writeln!(res, "{name} cache: {:.2} MiB", mibs(u64::try_from(cache.get_usage())?))?; } diff --git a/src/database/engine/open.rs b/src/database/engine/open.rs index 7b9d93c2..84e59a6a 100644 --- a/src/database/engine/open.rs +++ b/src/database/engine/open.rs @@ -23,7 +23,11 @@ pub(crate) async fn open(ctx: Arc, desc: &[Descriptor]) -> Result, queues: Vec>, - workers: SyncMutex>>, + workers: Mutex>>, topology: Vec, busy: AtomicUsize, queued_max: AtomicUsize, @@ -115,7 +115,7 @@ impl Drop for Pool { #[implement(Pool)] #[tracing::instrument(skip_all)] pub(crate) fn close(&self) { - let workers = take(&mut *self.workers.lock()); + let workers = take(&mut *self.workers.lock().expect("locked")); let senders = self.queues.iter().map(Sender::sender_count).sum::(); @@ -154,7 +154,7 @@ pub(crate) fn close(&self) { #[implement(Pool)] fn spawn_until(self: &Arc, recv: &[Receiver], count: usize) -> Result { - let mut workers = self.workers.lock(); + let mut workers = self.workers.lock().expect("locked"); while workers.len() < count { self.clone().spawn_one(&mut workers, recv)?; } diff --git a/src/database/watchers.rs b/src/database/watchers.rs index 0e911c82..efb939d7 100644 --- a/src/database/watchers.rs +++ b/src/database/watchers.rs @@ -2,12 +2,12 @@ use std::{ collections::{HashMap, hash_map}, future::Future, pin::Pin, + sync::RwLock, }; -use conduwuit::SyncRwLock; use tokio::sync::watch; -type Watcher = SyncRwLock, (watch::Sender<()>, watch::Receiver<()>)>>; +type Watcher = RwLock, (watch::Sender<()>, watch::Receiver<()>)>>; #[derive(Default)] pub(crate) struct Watchers { @@ -19,7 +19,7 @@ impl Watchers { &'a self, prefix: &[u8], ) -> Pin + Send + 'a>> { - let mut rx = match self.watchers.write().entry(prefix.to_vec()) { + let mut rx = match self.watchers.write().unwrap().entry(prefix.to_vec()) { | hash_map::Entry::Occupied(o) => o.get().1.clone(), | hash_map::Entry::Vacant(v) => { let (tx, rx) = watch::channel(()); @@ -35,7 +35,7 @@ impl Watchers { } pub(crate) fn wake(&self, key: &[u8]) { - let watchers = self.watchers.read(); + let watchers = self.watchers.read().unwrap(); let mut triggered = Vec::new(); for length in 0..=key.len() { if watchers.contains_key(&key[..length]) { @@ -46,7 +46,7 @@ impl Watchers { drop(watchers); if !triggered.is_empty() { - let mut watchers = self.watchers.write(); + let mut watchers = self.watchers.write().unwrap(); for prefix in triggered { if let Some(tx) = watchers.remove(prefix) { tx.0.send(()).expect("channel should still be open"); diff --git a/src/macros/rustc.rs b/src/macros/rustc.rs index cf935fe5..1220c8d4 100644 --- a/src/macros/rustc.rs +++ b/src/macros/rustc.rs @@ -15,13 +15,13 @@ pub(super) fn flags_capture(args: TokenStream) -> TokenStream { #[conduwuit_core::ctor] fn _set_rustc_flags() { - conduwuit_core::info::rustc::FLAGS.lock().insert(#crate_name, &RUSTC_FLAGS); + conduwuit_core::info::rustc::FLAGS.lock().expect("locked").insert(#crate_name, &RUSTC_FLAGS); } // static strings have to be yanked on module unload #[conduwuit_core::dtor] fn _unset_rustc_flags() { - conduwuit_core::info::rustc::FLAGS.lock().remove(#crate_name); + conduwuit_core::info::rustc::FLAGS.lock().expect("locked").remove(#crate_name); } }; diff --git a/src/service/admin/console.rs b/src/service/admin/console.rs index 931bb719..02f41303 100644 --- a/src/service/admin/console.rs +++ b/src/service/admin/console.rs @@ -1,8 +1,11 @@ #![cfg(feature = "console")] -use std::{collections::VecDeque, sync::Arc}; +use std::{ + collections::VecDeque, + sync::{Arc, Mutex}, +}; -use conduwuit::{Server, SyncMutex, debug, defer, error, log, log::is_systemd_mode}; +use conduwuit::{Server, debug, defer, error, log, log::is_systemd_mode}; use futures::future::{AbortHandle, Abortable}; use ruma::events::room::message::RoomMessageEventContent; use rustyline_async::{Readline, ReadlineError, ReadlineEvent}; @@ -14,10 +17,10 @@ use crate::{Dep, admin}; pub struct Console { server: Arc, admin: Dep, - worker_join: SyncMutex>>, - input_abort: SyncMutex>, - command_abort: SyncMutex>, - history: SyncMutex>, + worker_join: Mutex>>, + input_abort: Mutex>, + command_abort: Mutex>, + history: Mutex>, output: MadSkin, } @@ -47,7 +50,7 @@ impl Console { } pub async fn start(self: &Arc) { - let mut worker_join = self.worker_join.lock(); + let mut worker_join = self.worker_join.lock().expect("locked"); if worker_join.is_none() { let self_ = Arc::clone(self); _ = worker_join.insert(self.server.runtime().spawn(self_.worker())); @@ -57,7 +60,7 @@ impl Console { pub async fn close(self: &Arc) { self.interrupt(); - let Some(worker_join) = self.worker_join.lock().take() else { + let Some(worker_join) = self.worker_join.lock().expect("locked").take() else { return; }; @@ -67,18 +70,22 @@ impl Console { pub fn interrupt(self: &Arc) { self.interrupt_command(); self.interrupt_readline(); - self.worker_join.lock().as_ref().map(JoinHandle::abort); + self.worker_join + .lock() + .expect("locked") + .as_ref() + .map(JoinHandle::abort); } pub fn interrupt_readline(self: &Arc) { - if let Some(input_abort) = self.input_abort.lock().take() { + if let Some(input_abort) = self.input_abort.lock().expect("locked").take() { debug!("Interrupting console readline..."); input_abort.abort(); } } pub fn interrupt_command(self: &Arc) { - if let Some(command_abort) = self.command_abort.lock().take() { + if let Some(command_abort) = self.command_abort.lock().expect("locked").take() { debug!("Interrupting console command..."); command_abort.abort(); } @@ -113,7 +120,7 @@ impl Console { } debug!("session ending"); - self.worker_join.lock().take(); + self.worker_join.lock().expect("locked").take(); } async fn readline(self: &Arc) -> Result { @@ -128,9 +135,9 @@ impl Console { let (abort, abort_reg) = AbortHandle::new_pair(); let future = Abortable::new(future, abort_reg); - _ = self.input_abort.lock().insert(abort); + _ = self.input_abort.lock().expect("locked").insert(abort); defer! {{ - _ = self.input_abort.lock().take(); + _ = self.input_abort.lock().expect("locked").take(); }} let Ok(result) = future.await else { @@ -151,9 +158,9 @@ impl Console { let (abort, abort_reg) = AbortHandle::new_pair(); let future = Abortable::new(future, abort_reg); - _ = self.command_abort.lock().insert(abort); + _ = self.command_abort.lock().expect("locked").insert(abort); defer! {{ - _ = self.command_abort.lock().take(); + _ = self.command_abort.lock().expect("locked").take(); }} _ = future.await; @@ -177,15 +184,20 @@ impl Console { } fn set_history(&self, readline: &mut Readline) { - self.history.lock().iter().rev().for_each(|entry| { - readline - .add_history_entry(entry.clone()) - .expect("added history entry"); - }); + self.history + .lock() + .expect("locked") + .iter() + .rev() + .for_each(|entry| { + readline + .add_history_entry(entry.clone()) + .expect("added history entry"); + }); } fn add_history(&self, line: String) { - let mut history = self.history.lock(); + let mut history = self.history.lock().expect("locked"); history.push_front(line); history.truncate(HISTORY_LIMIT); } diff --git a/src/service/admin/mod.rs b/src/service/admin/mod.rs index c052198c..f496c414 100644 --- a/src/service/admin/mod.rs +++ b/src/service/admin/mod.rs @@ -5,11 +5,11 @@ mod grant; use std::{ pin::Pin, - sync::{Arc, Weak}, + sync::{Arc, RwLock as StdRwLock, Weak}, }; use async_trait::async_trait; -use conduwuit::{Err, SyncRwLock, utils}; +use conduwuit::{Err, utils}; use conduwuit_core::{ Error, Event, Result, Server, debug, err, error, error::default_log, pdu::PduBuilder, }; @@ -36,7 +36,7 @@ pub struct Service { services: Services, channel: (Sender, Receiver), pub handle: RwLock>, - pub complete: SyncRwLock>, + pub complete: StdRwLock>, #[cfg(feature = "console")] pub console: Arc, } @@ -50,7 +50,7 @@ struct Services { state_cache: Dep, state_accessor: Dep, account_data: Dep, - services: SyncRwLock>>, + services: StdRwLock>>, media: Dep, } @@ -105,7 +105,7 @@ impl crate::Service for Service { }, channel: loole::bounded(COMMAND_QUEUE_LIMIT), handle: RwLock::new(None), - complete: SyncRwLock::new(None), + complete: StdRwLock::new(None), #[cfg(feature = "console")] console: console::Console::new(&args), })) @@ -312,7 +312,10 @@ impl Service { /// Invokes the tab-completer to complete the command. When unavailable, /// None is returned. pub fn complete_command(&self, command: &str) -> Option { - self.complete.read().map(|complete| complete(command)) + self.complete + .read() + .expect("locked for reading") + .map(|complete| complete(command)) } async fn handle_signal(&self, sig: &'static str) { @@ -335,13 +338,17 @@ impl Service { } async fn process_command(&self, command: CommandInput) -> ProcessorResult { - let handle_guard = self.handle.read().await; - let handle = handle_guard.as_ref().expect("Admin module is not loaded"); + let handle = &self + .handle + .read() + .await + .expect("Admin module is not loaded"); let services = self .services .services .read() + .expect("locked") .as_ref() .and_then(Weak::upgrade) .expect("Services self-reference not initialized."); @@ -516,7 +523,7 @@ impl Service { /// Sets the self-reference to crate::Services which will provide context to /// the admin commands. pub(super) fn set_services(&self, services: Option<&Arc>) { - let receiver = &mut *self.services.services.write(); + let receiver = &mut *self.services.services.write().expect("locked for writing"); let weak = services.map(Arc::downgrade); *receiver = weak; } diff --git a/src/service/client/mod.rs b/src/service/client/mod.rs index 239340ba..1aeeb492 100644 --- a/src/service/client/mod.rs +++ b/src/service/client/mod.rs @@ -66,7 +66,6 @@ impl crate::Service for Service { federation: base(config)? .dns_resolver(resolver.resolver.hooked.clone()) - .connect_timeout(Duration::from_secs(config.federation_conn_timeout)) .read_timeout(Duration::from_secs(config.federation_timeout)) .pool_max_idle_per_host(config.federation_idle_per_host.into()) .pool_idle_timeout(Duration::from_secs(config.federation_idle_timeout)) @@ -75,7 +74,6 @@ impl crate::Service for Service { synapse: base(config)? .dns_resolver(resolver.resolver.hooked.clone()) - .connect_timeout(Duration::from_secs(config.federation_conn_timeout)) .read_timeout(Duration::from_secs(305)) .pool_max_idle_per_host(0) .redirect(redirect::Policy::limited(3)) @@ -83,7 +81,6 @@ impl crate::Service for Service { sender: base(config)? .dns_resolver(resolver.resolver.hooked.clone()) - .connect_timeout(Duration::from_secs(config.federation_conn_timeout)) .read_timeout(Duration::from_secs(config.sender_timeout)) .timeout(Duration::from_secs(config.sender_timeout)) .pool_max_idle_per_host(1) diff --git a/src/service/globals/data.rs b/src/service/globals/data.rs index 07f1de5c..21c09252 100644 --- a/src/service/globals/data.rs +++ b/src/service/globals/data.rs @@ -1,11 +1,11 @@ -use std::sync::Arc; +use std::sync::{Arc, RwLock}; -use conduwuit::{Result, SyncRwLock, utils}; +use conduwuit::{Result, utils}; use database::{Database, Deserialized, Map}; pub struct Data { global: Arc, - counter: SyncRwLock, + counter: RwLock, pub(super) db: Arc, } @@ -16,21 +16,25 @@ impl Data { let db = &args.db; Self { global: db["global"].clone(), - counter: SyncRwLock::new(Self::stored_count(&db["global"]).unwrap_or_default()), + counter: RwLock::new( + Self::stored_count(&db["global"]).expect("initialized global counter"), + ), db: args.db.clone(), } } pub fn next_count(&self) -> Result { let _cork = self.db.cork(); - let mut lock = self.counter.write(); + let mut lock = self.counter.write().expect("locked"); let counter: &mut u64 = &mut lock; debug_assert!( - *counter == Self::stored_count(&self.global).unwrap_or_default(), + *counter == Self::stored_count(&self.global).expect("database failure"), "counter mismatch" ); - *counter = counter.checked_add(1).unwrap_or(*counter); + *counter = counter + .checked_add(1) + .expect("counter must not overflow u64"); self.global.insert(COUNTER, counter.to_be_bytes()); @@ -39,10 +43,10 @@ impl Data { #[inline] pub fn current_count(&self) -> u64 { - let lock = self.counter.read(); + let lock = self.counter.read().expect("locked"); let counter: &u64 = &lock; debug_assert!( - *counter == Self::stored_count(&self.global).unwrap_or_default(), + *counter == Self::stored_count(&self.global).expect("database failure"), "counter mismatch" ); diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index 12f2ec78..a23a4c21 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -1,9 +1,14 @@ mod data; -use std::{collections::HashMap, fmt::Write, sync::Arc, time::Instant}; +use std::{ + collections::HashMap, + fmt::Write, + sync::{Arc, RwLock}, + time::Instant, +}; use async_trait::async_trait; -use conduwuit::{Result, Server, SyncRwLock, error, utils::bytes::pretty}; +use conduwuit::{Result, Server, error, utils::bytes::pretty}; use data::Data; use regex::RegexSet; use ruma::{OwnedEventId, OwnedRoomAliasId, OwnedServerName, OwnedUserId, ServerName, UserId}; @@ -14,7 +19,7 @@ pub struct Service { pub db: Data, server: Arc, - pub bad_event_ratelimiter: Arc>>, + pub bad_event_ratelimiter: Arc>>, pub server_user: OwnedUserId, pub admin_alias: OwnedRoomAliasId, pub turn_secret: String, @@ -57,7 +62,7 @@ impl crate::Service for Service { Ok(Arc::new(Self { db, server: args.server.clone(), - bad_event_ratelimiter: Arc::new(SyncRwLock::new(HashMap::new())), + bad_event_ratelimiter: Arc::new(RwLock::new(HashMap::new())), admin_alias: OwnedRoomAliasId::try_from(format!("#admins:{}", &args.server.name)) .expect("#admins:server_name is valid alias name"), server_user: UserId::parse_with_server_name( @@ -71,7 +76,7 @@ impl crate::Service for Service { } async fn memory_usage(&self, out: &mut (dyn Write + Send)) -> Result { - let (ber_count, ber_bytes) = self.bad_event_ratelimiter.read().iter().fold( + let (ber_count, ber_bytes) = self.bad_event_ratelimiter.read()?.iter().fold( (0_usize, 0_usize), |(mut count, mut bytes), (event_id, _)| { bytes = bytes.saturating_add(event_id.capacity()); @@ -86,7 +91,12 @@ impl crate::Service for Service { Ok(()) } - async fn clear_cache(&self) { self.bad_event_ratelimiter.write().clear(); } + async fn clear_cache(&self) { + self.bad_event_ratelimiter + .write() + .expect("locked for writing") + .clear(); + } fn name(&self) -> &str { service::make_name(std::module_path!()) } } diff --git a/src/service/manager.rs b/src/service/manager.rs index 7a2e50d5..3cdf5945 100644 --- a/src/service/manager.rs +++ b/src/service/manager.rs @@ -58,6 +58,7 @@ impl Manager { let services: Vec> = self .service .read() + .expect("locked for reading") .values() .map(|val| val.0.upgrade()) .map(|arc| arc.expect("services available for manager startup")) diff --git a/src/service/rooms/auth_chain/data.rs b/src/service/rooms/auth_chain/data.rs index e9e40979..8c3588cc 100644 --- a/src/service/rooms/auth_chain/data.rs +++ b/src/service/rooms/auth_chain/data.rs @@ -1,6 +1,9 @@ -use std::{mem::size_of, sync::Arc}; +use std::{ + mem::size_of, + sync::{Arc, Mutex}, +}; -use conduwuit::{Err, Result, SyncMutex, err, utils, utils::math::usize_from_f64}; +use conduwuit::{Err, Result, err, utils, utils::math::usize_from_f64}; use database::Map; use lru_cache::LruCache; @@ -8,7 +11,7 @@ use crate::rooms::short::ShortEventId; pub(super) struct Data { shorteventid_authchain: Arc, - pub(super) auth_chain_cache: SyncMutex, Arc<[ShortEventId]>>>, + pub(super) auth_chain_cache: Mutex, Arc<[ShortEventId]>>>, } impl Data { @@ -20,7 +23,7 @@ impl Data { .expect("valid cache size"); Self { shorteventid_authchain: db["shorteventid_authchain"].clone(), - auth_chain_cache: SyncMutex::new(LruCache::new(cache_size)), + auth_chain_cache: Mutex::new(LruCache::new(cache_size)), } } @@ -31,7 +34,12 @@ impl Data { debug_assert!(!key.is_empty(), "auth_chain key must not be empty"); // Check RAM cache - if let Some(result) = self.auth_chain_cache.lock().get_mut(key) { + if let Some(result) = self + .auth_chain_cache + .lock() + .expect("cache locked") + .get_mut(key) + { return Ok(Arc::clone(result)); } @@ -55,6 +63,7 @@ impl Data { // Cache in RAM self.auth_chain_cache .lock() + .expect("cache locked") .insert(vec![key[0]], Arc::clone(&chain)); Ok(chain) @@ -75,6 +84,9 @@ impl Data { } // Cache in RAM - self.auth_chain_cache.lock().insert(key, auth_chain); + self.auth_chain_cache + .lock() + .expect("cache locked") + .insert(key, auth_chain); } } diff --git a/src/service/rooms/auth_chain/mod.rs b/src/service/rooms/auth_chain/mod.rs index 79d4d070..0903ea75 100644 --- a/src/service/rooms/auth_chain/mod.rs +++ b/src/service/rooms/auth_chain/mod.rs @@ -248,10 +248,10 @@ pub fn cache_auth_chain_vec(&self, key: Vec, auth_chain: &[ShortEventId]) { #[implement(Service)] pub fn get_cache_usage(&self) -> (usize, usize) { - let cache = self.db.auth_chain_cache.lock(); + let cache = self.db.auth_chain_cache.lock().expect("locked"); (cache.len(), cache.capacity()) } #[implement(Service)] -pub fn clear_cache(&self) { self.db.auth_chain_cache.lock().clear(); } +pub fn clear_cache(&self) { self.db.auth_chain_cache.lock().expect("locked").clear(); } diff --git a/src/service/rooms/event_handler/fetch_and_handle_outliers.rs b/src/service/rooms/event_handler/fetch_and_handle_outliers.rs index 59b768f2..44027e04 100644 --- a/src/service/rooms/event_handler/fetch_and_handle_outliers.rs +++ b/src/service/rooms/event_handler/fetch_and_handle_outliers.rs @@ -41,6 +41,7 @@ where .globals .bad_event_ratelimiter .write() + .expect("locked") .entry(id) { | hash_map::Entry::Vacant(e) => { @@ -75,6 +76,7 @@ where .globals .bad_event_ratelimiter .read() + .expect("locked") .get(&*next_id) { // Exponential backoff @@ -185,6 +187,7 @@ where .globals .bad_event_ratelimiter .read() + .expect("locked") .get(&*next_id) { // Exponential backoff diff --git a/src/service/rooms/event_handler/handle_incoming_pdu.rs b/src/service/rooms/event_handler/handle_incoming_pdu.rs index 5299e8d4..86a05e0a 100644 --- a/src/service/rooms/event_handler/handle_incoming_pdu.rs +++ b/src/service/rooms/event_handler/handle_incoming_pdu.rs @@ -160,6 +160,7 @@ pub async fn handle_incoming_pdu<'a>( .globals .bad_event_ratelimiter .write() + .expect("locked") .entry(prev_id.into()) { | hash_map::Entry::Vacant(e) => { @@ -180,11 +181,13 @@ pub async fn handle_incoming_pdu<'a>( let start_time = Instant::now(); self.federation_handletime .write() + .expect("locked") .insert(room_id.into(), (event_id.to_owned(), start_time)); defer! {{ self.federation_handletime .write() + .expect("locked") .remove(room_id); }}; diff --git a/src/service/rooms/event_handler/handle_prev_pdu.rs b/src/service/rooms/event_handler/handle_prev_pdu.rs index cb4978d9..cd46310a 100644 --- a/src/service/rooms/event_handler/handle_prev_pdu.rs +++ b/src/service/rooms/event_handler/handle_prev_pdu.rs @@ -42,6 +42,7 @@ where .globals .bad_event_ratelimiter .read() + .expect("locked") .get(prev_id) { // Exponential backoff @@ -69,11 +70,13 @@ where let start_time = Instant::now(); self.federation_handletime .write() + .expect("locked") .insert(room_id.into(), ((*prev_id).to_owned(), start_time)); defer! {{ self.federation_handletime .write() + .expect("locked") .remove(room_id); }}; diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index a0a1b20b..4e948e95 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -11,10 +11,15 @@ mod resolve_state; mod state_at_incoming; mod upgrade_outlier_pdu; -use std::{collections::HashMap, fmt::Write, sync::Arc, time::Instant}; +use std::{ + collections::HashMap, + fmt::Write, + sync::{Arc, RwLock as StdRwLock}, + time::Instant, +}; use async_trait::async_trait; -use conduwuit::{Err, Event, PduEvent, Result, RoomVersion, Server, SyncRwLock, utils::MutexMap}; +use conduwuit::{Err, Event, PduEvent, Result, RoomVersion, Server, utils::MutexMap}; use ruma::{ OwnedEventId, OwnedRoomId, RoomId, RoomVersionId, events::room::create::RoomCreateEventContent, @@ -24,7 +29,7 @@ use crate::{Dep, globals, rooms, sending, server_keys}; pub struct Service { pub mutex_federation: RoomMutexMap, - pub federation_handletime: SyncRwLock, + pub federation_handletime: StdRwLock, services: Services, } @@ -79,7 +84,11 @@ impl crate::Service for Service { let mutex_federation = self.mutex_federation.len(); writeln!(out, "federation_mutex: {mutex_federation}")?; - let federation_handletime = self.federation_handletime.read().len(); + let federation_handletime = self + .federation_handletime + .read() + .expect("locked for reading") + .len(); writeln!(out, "federation_handletime: {federation_handletime}")?; Ok(()) diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index e9845fbf..9429be79 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -1,10 +1,13 @@ mod update; mod via; -use std::{collections::HashMap, sync::Arc}; +use std::{ + collections::HashMap, + sync::{Arc, RwLock}, +}; use conduwuit::{ - Result, SyncRwLock, implement, + Result, implement, result::LogErr, utils::{ReadyExt, stream::TryIgnore}, warn, @@ -51,14 +54,14 @@ struct Data { userroomid_knockedstate: Arc, } -type AppServiceInRoomCache = SyncRwLock>>; +type AppServiceInRoomCache = RwLock>>; type StrippedStateEventItem = (OwnedRoomId, Vec>); type SyncStateEventItem = (OwnedRoomId, Vec>); impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - appservice_in_room_cache: SyncRwLock::new(HashMap::new()), + appservice_in_room_cache: RwLock::new(HashMap::new()), services: Services { account_data: args.depend::("account_data"), config: args.depend::("config"), @@ -96,6 +99,7 @@ pub async fn appservice_in_room(&self, room_id: &RoomId, appservice: &Registrati if let Some(cached) = self .appservice_in_room_cache .read() + .expect("locked") .get(room_id) .and_then(|map| map.get(&appservice.registration.id)) .copied() @@ -120,6 +124,7 @@ pub async fn appservice_in_room(&self, room_id: &RoomId, appservice: &Registrati self.appservice_in_room_cache .write() + .expect("locked") .entry(room_id.into()) .or_default() .insert(appservice.registration.id.clone(), in_room); @@ -129,14 +134,19 @@ pub async fn appservice_in_room(&self, room_id: &RoomId, appservice: &Registrati #[implement(Service)] pub fn get_appservice_in_room_cache_usage(&self) -> (usize, usize) { - let cache = self.appservice_in_room_cache.read(); + let cache = self.appservice_in_room_cache.read().expect("locked"); (cache.len(), cache.capacity()) } #[implement(Service)] #[tracing::instrument(level = "debug", skip_all)] -pub fn clear_appservice_in_room_cache(&self) { self.appservice_in_room_cache.write().clear(); } +pub fn clear_appservice_in_room_cache(&self) { + self.appservice_in_room_cache + .write() + .expect("locked") + .clear(); +} /// Returns an iterator of all servers participating in this room. #[implement(Service)] diff --git a/src/service/rooms/state_cache/update.rs b/src/service/rooms/state_cache/update.rs index 32c67947..02c6bec6 100644 --- a/src/service/rooms/state_cache/update.rs +++ b/src/service/rooms/state_cache/update.rs @@ -211,7 +211,10 @@ pub async fn update_joined_count(&self, room_id: &RoomId) { self.db.serverroomids.put_raw(serverroom_id, []); } - self.appservice_in_room_cache.write().remove(room_id); + self.appservice_in_room_cache + .write() + .expect("locked") + .remove(room_id); } /// Direct DB function to directly mark a user as joined. It is not diff --git a/src/service/rooms/state_compressor/mod.rs b/src/service/rooms/state_compressor/mod.rs index f7f7d043..a33fb342 100644 --- a/src/service/rooms/state_compressor/mod.rs +++ b/src/service/rooms/state_compressor/mod.rs @@ -2,12 +2,12 @@ use std::{ collections::{BTreeSet, HashMap}, fmt::{Debug, Write}, mem::size_of, - sync::Arc, + sync::{Arc, Mutex}, }; use async_trait::async_trait; use conduwuit::{ - Result, SyncMutex, + Result, arrayvec::ArrayVec, at, checked, err, expected, implement, utils, utils::{bytes, math::usize_from_f64, stream::IterStream}, @@ -23,7 +23,7 @@ use crate::{ }; pub struct Service { - pub stateinfo_cache: SyncMutex, + pub stateinfo_cache: Mutex, db: Data, services: Services, } @@ -86,7 +86,7 @@ impl crate::Service for Service { async fn memory_usage(&self, out: &mut (dyn Write + Send)) -> Result { let (cache_len, ents) = { - let cache = self.stateinfo_cache.lock(); + let cache = self.stateinfo_cache.lock().expect("locked"); let ents = cache.iter().map(at!(1)).flat_map(|vec| vec.iter()).fold( HashMap::new(), |mut ents, ssi| { @@ -110,7 +110,7 @@ impl crate::Service for Service { Ok(()) } - async fn clear_cache(&self) { self.stateinfo_cache.lock().clear(); } + async fn clear_cache(&self) { self.stateinfo_cache.lock().expect("locked").clear(); } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } @@ -123,7 +123,7 @@ pub async fn load_shortstatehash_info( &self, shortstatehash: ShortStateHash, ) -> Result { - if let Some(r) = self.stateinfo_cache.lock().get_mut(&shortstatehash) { + if let Some(r) = self.stateinfo_cache.lock()?.get_mut(&shortstatehash) { return Ok(r.clone()); } @@ -152,7 +152,7 @@ async fn cache_shortstatehash_info( shortstatehash: ShortStateHash, stack: ShortStateInfoVec, ) -> Result { - self.stateinfo_cache.lock().insert(shortstatehash, stack); + self.stateinfo_cache.lock()?.insert(shortstatehash, stack); Ok(()) } diff --git a/src/service/service.rs b/src/service/service.rs index 3bc61aeb..574efd8f 100644 --- a/src/service/service.rs +++ b/src/service/service.rs @@ -3,13 +3,11 @@ use std::{ collections::BTreeMap, fmt::Write, ops::Deref, - sync::{Arc, OnceLock, Weak}, + sync::{Arc, OnceLock, RwLock, Weak}, }; use async_trait::async_trait; -use conduwuit::{ - Err, Result, Server, SyncRwLock, err, error::inspect_log, utils::string::SplitInfallible, -}; +use conduwuit::{Err, Result, Server, err, error::inspect_log, utils::string::SplitInfallible}; use database::Database; /// Abstract interface for a Service @@ -64,7 +62,7 @@ pub(crate) struct Dep { name: &'static str, } -pub(crate) type Map = SyncRwLock; +pub(crate) type Map = RwLock; pub(crate) type MapType = BTreeMap; pub(crate) type MapVal = (Weak, Weak); pub(crate) type MapKey = String; @@ -145,12 +143,15 @@ pub(crate) fn get(map: &Map, name: &str) -> Option> where T: Any + Send + Sync + Sized, { - map.read().get(name).map(|(_, s)| { - s.upgrade().map(|s| { - s.downcast::() - .expect("Service must be correctly downcast.") - }) - })? + map.read() + .expect("locked for reading") + .get(name) + .map(|(_, s)| { + s.upgrade().map(|s| { + s.downcast::() + .expect("Service must be correctly downcast.") + }) + })? } /// Reference a Service by name. Returns Err if the Service does not exist or @@ -159,18 +160,21 @@ pub(crate) fn try_get(map: &Map, name: &str) -> Result> where T: Any + Send + Sync + Sized, { - map.read().get(name).map_or_else( - || Err!("Service {name:?} does not exist or has not been built yet."), - |(_, s)| { - s.upgrade().map_or_else( - || Err!("Service {name:?} no longer exists."), - |s| { - s.downcast::() - .map_err(|_| err!("Service {name:?} must be correctly downcast.")) - }, - ) - }, - ) + map.read() + .expect("locked for reading") + .get(name) + .map_or_else( + || Err!("Service {name:?} does not exist or has not been built yet."), + |(_, s)| { + s.upgrade().map_or_else( + || Err!("Service {name:?} no longer exists."), + |s| { + s.downcast::() + .map_err(|_| err!("Service {name:?} must be correctly downcast.")) + }, + ) + }, + ) } /// Utility for service implementations; see Service::name() in the trait. diff --git a/src/service/services.rs b/src/service/services.rs index 642f61c7..daece245 100644 --- a/src/service/services.rs +++ b/src/service/services.rs @@ -1,8 +1,10 @@ -use std::{any::Any, collections::BTreeMap, sync::Arc}; - -use conduwuit::{ - Result, Server, SyncRwLock, debug, debug_info, info, trace, utils::stream::IterStream, +use std::{ + any::Any, + collections::BTreeMap, + sync::{Arc, RwLock}, }; + +use conduwuit::{Result, Server, debug, debug_info, info, trace, utils::stream::IterStream}; use database::Database; use futures::{Stream, StreamExt, TryStreamExt}; use tokio::sync::Mutex; @@ -50,7 +52,7 @@ impl Services { #[allow(clippy::cognitive_complexity)] pub async fn build(server: Arc) -> Result> { let db = Database::open(&server).await?; - let service: Arc = Arc::new(SyncRwLock::new(BTreeMap::new())); + let service: Arc = Arc::new(RwLock::new(BTreeMap::new())); macro_rules! build { ($tyname:ty) => {{ let built = <$tyname>::build(Args { @@ -191,7 +193,7 @@ impl Services { fn interrupt(&self) { debug!("Interrupting services..."); - for (name, (service, ..)) in self.service.read().iter() { + for (name, (service, ..)) in self.service.read().expect("locked for reading").iter() { if let Some(service) = service.upgrade() { trace!("Interrupting {name}"); service.interrupt(); @@ -203,6 +205,7 @@ impl Services { fn services(&self) -> impl Stream> + Send { self.service .read() + .expect("locked for reading") .values() .filter_map(|val| val.0.upgrade()) .collect::>() @@ -230,9 +233,10 @@ impl Services { #[allow(clippy::needless_pass_by_value)] fn add_service(map: &Arc, s: Arc, a: Arc) { let name = s.name(); - let len = map.read().len(); + let len = map.read().expect("locked for reading").len(); trace!("built service #{len}: {name:?}"); map.write() + .expect("locked for writing") .insert(name.to_owned(), (Arc::downgrade(&s), Arc::downgrade(&a))); } diff --git a/src/service/sync/mod.rs b/src/service/sync/mod.rs index 6ac579f4..b095d2c1 100644 --- a/src/service/sync/mod.rs +++ b/src/service/sync/mod.rs @@ -2,10 +2,10 @@ mod watch; use std::{ collections::{BTreeMap, BTreeSet}, - sync::Arc, + sync::{Arc, Mutex, Mutex as StdMutex}, }; -use conduwuit::{Result, Server, SyncMutex}; +use conduwuit::{Result, Server}; use database::Map; use ruma::{ OwnedDeviceId, OwnedRoomId, OwnedUserId, @@ -62,11 +62,11 @@ struct SnakeSyncCache { extensions: v5::request::Extensions, } -type DbConnections = SyncMutex>; +type DbConnections = Mutex>; type DbConnectionsKey = (OwnedUserId, OwnedDeviceId, String); -type DbConnectionsVal = Arc>; +type DbConnectionsVal = Arc>; type SnakeConnectionsKey = (OwnedUserId, OwnedDeviceId, Option); -type SnakeConnectionsVal = Arc>; +type SnakeConnectionsVal = Arc>; impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { @@ -90,8 +90,8 @@ impl crate::Service for Service { state_cache: args.depend::("rooms::state_cache"), typing: args.depend::("rooms::typing"), }, - connections: SyncMutex::new(BTreeMap::new()), - snake_connections: SyncMutex::new(BTreeMap::new()), + connections: StdMutex::new(BTreeMap::new()), + snake_connections: StdMutex::new(BTreeMap::new()), })) } @@ -100,19 +100,22 @@ impl crate::Service for Service { impl Service { pub fn snake_connection_cached(&self, key: &SnakeConnectionsKey) -> bool { - self.snake_connections.lock().contains_key(key) + self.snake_connections + .lock() + .expect("locked") + .contains_key(key) } pub fn forget_snake_sync_connection(&self, key: &SnakeConnectionsKey) { - self.snake_connections.lock().remove(key); + self.snake_connections.lock().expect("locked").remove(key); } pub fn remembered(&self, key: &DbConnectionsKey) -> bool { - self.connections.lock().contains_key(key) + self.connections.lock().expect("locked").contains_key(key) } pub fn forget_sync_request_connection(&self, key: &DbConnectionsKey) { - self.connections.lock().remove(key); + self.connections.lock().expect("locked").remove(key); } pub fn update_snake_sync_request_with_cache( @@ -120,13 +123,13 @@ impl Service { snake_key: &SnakeConnectionsKey, request: &mut v5::Request, ) -> BTreeMap> { - let mut cache = self.snake_connections.lock(); + let mut cache = self.snake_connections.lock().expect("locked"); let cached = Arc::clone( cache .entry(snake_key.clone()) - .or_insert_with(|| Arc::new(SyncMutex::new(SnakeSyncCache::default()))), + .or_insert_with(|| Arc::new(Mutex::new(SnakeSyncCache::default()))), ); - let cached = &mut cached.lock(); + let cached = &mut cached.lock().expect("locked"); drop(cache); //v5::Request::try_from_http_request(req, path_args); @@ -229,16 +232,16 @@ impl Service { }; let key = into_db_key(key.0.clone(), key.1.clone(), conn_id); - let mut cache = self.connections.lock(); + let mut cache = self.connections.lock().expect("locked"); let cached = Arc::clone(cache.entry(key).or_insert_with(|| { - Arc::new(SyncMutex::new(SlidingSyncCache { + Arc::new(Mutex::new(SlidingSyncCache { lists: BTreeMap::new(), subscriptions: BTreeMap::new(), known_rooms: BTreeMap::new(), extensions: ExtensionsConfig::default(), })) })); - let cached = &mut cached.lock(); + let cached = &mut cached.lock().expect("locked"); drop(cache); for (list_id, list) in &mut request.lists { @@ -325,16 +328,16 @@ impl Service { key: &DbConnectionsKey, subscriptions: BTreeMap, ) { - let mut cache = self.connections.lock(); + let mut cache = self.connections.lock().expect("locked"); let cached = Arc::clone(cache.entry(key.clone()).or_insert_with(|| { - Arc::new(SyncMutex::new(SlidingSyncCache { + Arc::new(Mutex::new(SlidingSyncCache { lists: BTreeMap::new(), subscriptions: BTreeMap::new(), known_rooms: BTreeMap::new(), extensions: ExtensionsConfig::default(), })) })); - let cached = &mut cached.lock(); + let cached = &mut cached.lock().expect("locked"); drop(cache); cached.subscriptions = subscriptions; @@ -347,16 +350,16 @@ impl Service { new_cached_rooms: BTreeSet, globalsince: u64, ) { - let mut cache = self.connections.lock(); + let mut cache = self.connections.lock().expect("locked"); let cached = Arc::clone(cache.entry(key.clone()).or_insert_with(|| { - Arc::new(SyncMutex::new(SlidingSyncCache { + Arc::new(Mutex::new(SlidingSyncCache { lists: BTreeMap::new(), subscriptions: BTreeMap::new(), known_rooms: BTreeMap::new(), extensions: ExtensionsConfig::default(), })) })); - let cached = &mut cached.lock(); + let cached = &mut cached.lock().expect("locked"); drop(cache); for (room_id, lastsince) in cached @@ -383,13 +386,13 @@ impl Service { globalsince: u64, ) { assert!(key.2.is_some(), "Some(conn_id) required for this call"); - let mut cache = self.snake_connections.lock(); + let mut cache = self.snake_connections.lock().expect("locked"); let cached = Arc::clone( cache .entry(key.clone()) - .or_insert_with(|| Arc::new(SyncMutex::new(SnakeSyncCache::default()))), + .or_insert_with(|| Arc::new(Mutex::new(SnakeSyncCache::default()))), ); - let cached = &mut cached.lock(); + let cached = &mut cached.lock().expect("locked"); drop(cache); for (room_id, lastsince) in cached @@ -413,13 +416,13 @@ impl Service { key: &SnakeConnectionsKey, subscriptions: BTreeMap, ) { - let mut cache = self.snake_connections.lock(); + let mut cache = self.snake_connections.lock().expect("locked"); let cached = Arc::clone( cache .entry(key.clone()) - .or_insert_with(|| Arc::new(SyncMutex::new(SnakeSyncCache::default()))), + .or_insert_with(|| Arc::new(Mutex::new(SnakeSyncCache::default()))), ); - let cached = &mut cached.lock(); + let cached = &mut cached.lock().expect("locked"); drop(cache); cached.subscriptions = subscriptions; diff --git a/src/service/uiaa/mod.rs b/src/service/uiaa/mod.rs index acd3dd86..7735c87f 100644 --- a/src/service/uiaa/mod.rs +++ b/src/service/uiaa/mod.rs @@ -1,10 +1,10 @@ use std::{ collections::{BTreeMap, HashSet}, - sync::Arc, + sync::{Arc, RwLock}, }; use conduwuit::{ - Err, Error, Result, SyncRwLock, err, error, implement, utils, + Err, Error, Result, err, error, implement, utils, utils::{hash, string::EMPTY}, }; use database::{Deserialized, Json, Map}; @@ -19,7 +19,7 @@ use ruma::{ use crate::{Dep, config, globals, users}; pub struct Service { - userdevicesessionid_uiaarequest: SyncRwLock, + userdevicesessionid_uiaarequest: RwLock, db: Data, services: Services, } @@ -42,7 +42,7 @@ pub const SESSION_ID_LENGTH: usize = 32; impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - userdevicesessionid_uiaarequest: SyncRwLock::new(RequestMap::new()), + userdevicesessionid_uiaarequest: RwLock::new(RequestMap::new()), db: Data { userdevicesessionid_uiaainfo: args.db["userdevicesessionid_uiaainfo"].clone(), }, @@ -268,6 +268,7 @@ fn set_uiaa_request( let key = (user_id.to_owned(), device_id.to_owned(), session.to_owned()); self.userdevicesessionid_uiaarequest .write() + .expect("locked for writing") .insert(key, request.to_owned()); } @@ -286,6 +287,7 @@ pub fn get_uiaa_request( self.userdevicesessionid_uiaarequest .read() + .expect("locked for reading") .get(&key) .cloned() }