diff --git a/.envrc b/.envrc index 952ec2f8..bad73b75 100644 --- a/.envrc +++ b/.envrc @@ -2,6 +2,6 @@ dotenv_if_exists -use flake ".#${DIRENV_DEVSHELL:-default}" +# use flake ".#${DIRENV_DEVSHELL:-default}" PATH_add bin diff --git a/Cargo.lock b/Cargo.lock index 6f711007..5dce9c59 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -963,10 +963,12 @@ dependencies = [ "itertools 0.14.0", "libc", "libloading", + "lock_api", "log", "maplit", "nix", "num-traits", + "parking_lot", "rand 0.8.5", "regex", "reqwest", @@ -1657,6 +1659,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "fixedbitset" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" + [[package]] name = "flate2" version = "1.1.2" @@ -3218,10 +3226,13 @@ version = "0.9.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bc838d2a56b5b1a6c25f55575dfc605fabb63bb2365f6c2353ef9159aa69e4a5" dependencies = [ + "backtrace", "cfg-if", "libc", + "petgraph", "redox_syscall", "smallvec", + "thread-id", "windows-targets 0.52.6", ] @@ -3271,6 +3282,16 @@ version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" +[[package]] +name = "petgraph" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db" +dependencies = [ + "fixedbitset", + "indexmap 2.9.0", +] + [[package]] name = "phf" version = "0.11.3" @@ -4892,6 +4913,16 @@ dependencies = [ "syn", ] +[[package]] +name = "thread-id" +version = "4.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfe8f25bbdd100db7e1d34acf7fd2dc59c4bf8f7483f505eaa7d4f12f76cc0ea" +dependencies = [ + "libc", + "winapi", +] + [[package]] name = "thread_local" version = "1.1.9" diff --git a/Cargo.toml b/Cargo.toml index ef917332..ab6a9e8a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -515,6 +515,14 @@ version = "1.0" [workspace.dependencies.proc-macro2] version = "1.0" +[workspace.dependencies.parking_lot] +version = "0.12.4" +features = ["hardware-lock-elision", "deadlock_detection"] # TODO: Check if deadlock_detection has a perf impact, if it does only enable with debug_assertions + +# Use this when extending with_lock::WithLock to parking_lot +[workspace.dependencies.lock_api] +version = "0.4.13" + [workspace.dependencies.bytesize] version = "2.0" diff --git a/src/admin/federation/commands.rs b/src/admin/federation/commands.rs index 545dcbca..f77dadab 100644 --- a/src/admin/federation/commands.rs +++ b/src/admin/federation/commands.rs @@ -26,8 +26,7 @@ pub(super) async fn incoming_federation(&self) -> Result { .rooms .event_handler .federation_handletime - .read() - .expect("locked"); + .read(); let mut msg = format!("Handling {} incoming pdus:\n", map.len()); for (r, (e, i)) in map.iter() { diff --git a/src/admin/mod.rs b/src/admin/mod.rs index 732b8ce0..1d46590b 100644 --- a/src/admin/mod.rs +++ b/src/admin/mod.rs @@ -37,11 +37,7 @@ pub use crate::admin::AdminCommand; /// Install the admin command processor pub async fn init(admin_service: &service::admin::Service) { - _ = admin_service - .complete - .write() - .expect("locked for writing") - .insert(processor::complete); + _ = admin_service.complete.write().insert(processor::complete); _ = admin_service .handle .write() @@ -52,9 +48,5 @@ pub async fn init(admin_service: &service::admin::Service) { /// Uninstall the admin command handler pub async fn fini(admin_service: &service::admin::Service) { _ = admin_service.handle.write().await.take(); - _ = admin_service - .complete - .write() - .expect("locked for writing") - .take(); + _ = admin_service.complete.write().take(); } diff --git a/src/admin/processor.rs b/src/admin/processor.rs index e80000c1..2c91efe1 100644 --- a/src/admin/processor.rs +++ b/src/admin/processor.rs @@ -1,14 +1,8 @@ -use std::{ - fmt::Write, - mem::take, - panic::AssertUnwindSafe, - sync::{Arc, Mutex}, - time::SystemTime, -}; +use std::{fmt::Write, mem::take, panic::AssertUnwindSafe, sync::Arc, time::SystemTime}; use clap::{CommandFactory, Parser}; use conduwuit::{ - Error, Result, debug, error, + Error, Result, SyncMutex, debug, error, log::{ capture, capture::Capture, @@ -123,7 +117,7 @@ async fn process( let mut output = String::new(); // Prepend the logs only if any were captured - let logs = logs.lock().expect("locked"); + let logs = logs.lock(); if logs.lines().count() > 2 { writeln!(&mut output, "{logs}").expect("failed to format logs to command output"); } @@ -132,7 +126,7 @@ async fn process( (result, output) } -fn capture_create(context: &Context<'_>) -> (Arc, Arc>) { +fn capture_create(context: &Context<'_>) -> (Arc, Arc>) { let env_config = &context.services.server.config.admin_log_capture; let env_filter = EnvFilter::try_new(env_config).unwrap_or_else(|e| { warn!("admin_log_capture filter invalid: {e:?}"); @@ -152,7 +146,7 @@ fn capture_create(context: &Context<'_>) -> (Arc, Arc>) { data.level() <= log_level && data.our_modules() && data.scope.contains(&"admin") }; - let logs = Arc::new(Mutex::new( + let logs = Arc::new(SyncMutex::new( collect_stream(|s| markdown_table_head(s)).expect("markdown table header"), )); diff --git a/src/core/Cargo.toml b/src/core/Cargo.toml index 0c33c590..462b8e54 100644 --- a/src/core/Cargo.toml +++ b/src/core/Cargo.toml @@ -110,6 +110,8 @@ tracing-core.workspace = true tracing-subscriber.workspace = true tracing.workspace = true url.workspace = true +parking_lot.workspace = true +lock_api.workspace = true [target.'cfg(unix)'.dependencies] nix.workspace = true diff --git a/src/core/alloc/je.rs b/src/core/alloc/je.rs index e138233e..77deebc5 100644 --- a/src/core/alloc/je.rs +++ b/src/core/alloc/je.rs @@ -4,7 +4,6 @@ use std::{ cell::OnceCell, ffi::{CStr, c_char, c_void}, fmt::Debug, - sync::RwLock, }; use arrayvec::ArrayVec; @@ -13,7 +12,7 @@ use tikv_jemalloc_sys as ffi; use tikv_jemallocator as jemalloc; use crate::{ - Result, err, is_equal_to, is_nonzero, + Result, SyncRwLock, err, is_equal_to, is_nonzero, utils::{math, math::Tried}, }; @@ -40,7 +39,7 @@ const MALLOC_CONF_PROF: &str = ""; #[global_allocator] static JEMALLOC: jemalloc::Jemalloc = jemalloc::Jemalloc; -static CONTROL: RwLock<()> = RwLock::new(()); +static CONTROL: SyncRwLock<()> = SyncRwLock::new(()); type Name = ArrayVec; type Key = ArrayVec; @@ -332,7 +331,7 @@ fn set(key: &Key, val: T) -> Result where T: Copy + Debug, { - let _lock = CONTROL.write()?; + let _lock = CONTROL.write(); let res = xchg(key, val)?; inc_epoch()?; diff --git a/src/core/info/rustc.rs b/src/core/info/rustc.rs index 048c0cd5..60156301 100644 --- a/src/core/info/rustc.rs +++ b/src/core/info/rustc.rs @@ -3,18 +3,15 @@ //! several crates, lower-level information is supplied from each crate during //! static initialization. -use std::{ - collections::BTreeMap, - sync::{Mutex, OnceLock}, -}; +use std::{collections::BTreeMap, sync::OnceLock}; -use crate::utils::exchange; +use crate::{SyncMutex, utils::exchange}; /// Raw capture of rustc flags used to build each crate in the project. Informed /// by rustc_flags_capture macro (one in each crate's mod.rs). This is /// done during static initialization which is why it's mutex-protected and pub. /// Should not be written to by anything other than our macro. -pub static FLAGS: Mutex> = Mutex::new(BTreeMap::new()); +pub static FLAGS: SyncMutex> = SyncMutex::new(BTreeMap::new()); /// Processed list of enabled features across all project crates. This is /// generated from the data in FLAGS. @@ -27,7 +24,6 @@ fn init_features() -> Vec<&'static str> { let mut features = Vec::new(); FLAGS .lock() - .expect("locked") .iter() .for_each(|(_, flags)| append_features(&mut features, flags)); diff --git a/src/core/log/capture/layer.rs b/src/core/log/capture/layer.rs index 381a652f..e3fe66df 100644 --- a/src/core/log/capture/layer.rs +++ b/src/core/log/capture/layer.rs @@ -40,7 +40,6 @@ where self.state .active .read() - .expect("shared lock") .iter() .filter(|capture| filter(self, capture, event, &ctx)) .for_each(|capture| handle(self, capture, event, &ctx)); @@ -55,7 +54,7 @@ where let mut visitor = Visitor { values: Values::new() }; event.record(&mut visitor); - let mut closure = capture.closure.lock().expect("exclusive lock"); + let mut closure = capture.closure.lock(); closure(Data { layer, event, diff --git a/src/core/log/capture/mod.rs b/src/core/log/capture/mod.rs index 20f70091..b7e5d2b5 100644 --- a/src/core/log/capture/mod.rs +++ b/src/core/log/capture/mod.rs @@ -4,7 +4,7 @@ pub mod layer; pub mod state; pub mod util; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; pub use data::Data; use guard::Guard; @@ -12,6 +12,8 @@ pub use layer::{Layer, Value}; pub use state::State; pub use util::*; +use crate::SyncMutex; + pub type Filter = dyn Fn(Data<'_>) -> bool + Send + Sync + 'static; pub type Closure = dyn FnMut(Data<'_>) + Send + Sync + 'static; @@ -19,7 +21,7 @@ pub type Closure = dyn FnMut(Data<'_>) + Send + Sync + 'static; pub struct Capture { state: Arc, filter: Option>, - closure: Mutex>, + closure: SyncMutex>, } impl Capture { @@ -34,7 +36,7 @@ impl Capture { Arc::new(Self { state: state.clone(), filter: filter.map(|p| -> Box { Box::new(p) }), - closure: Mutex::new(Box::new(closure)), + closure: SyncMutex::new(Box::new(closure)), }) } diff --git a/src/core/log/capture/state.rs b/src/core/log/capture/state.rs index dad6c8d8..92a1608f 100644 --- a/src/core/log/capture/state.rs +++ b/src/core/log/capture/state.rs @@ -1,10 +1,11 @@ -use std::sync::{Arc, RwLock}; +use std::sync::Arc; use super::Capture; +use crate::SyncRwLock; /// Capture layer state. pub struct State { - pub(super) active: RwLock>>, + pub(super) active: SyncRwLock>>, } impl Default for State { @@ -13,17 +14,14 @@ impl Default for State { impl State { #[must_use] - pub fn new() -> Self { Self { active: RwLock::new(Vec::new()) } } + pub fn new() -> Self { Self { active: SyncRwLock::new(Vec::new()) } } pub(super) fn add(&self, capture: &Arc) { - self.active - .write() - .expect("locked for writing") - .push(capture.clone()); + self.active.write().push(capture.clone()); } pub(super) fn del(&self, capture: &Arc) { - let mut vec = self.active.write().expect("locked for writing"); + let mut vec = self.active.write(); if let Some(pos) = vec.iter().position(|v| Arc::ptr_eq(v, capture)) { vec.swap_remove(pos); } diff --git a/src/core/log/capture/util.rs b/src/core/log/capture/util.rs index 65524be5..21a416a9 100644 --- a/src/core/log/capture/util.rs +++ b/src/core/log/capture/util.rs @@ -1,31 +1,31 @@ -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use super::{ super::{Level, fmt}, Closure, Data, }; -use crate::Result; +use crate::{Result, SyncMutex}; -pub fn fmt_html(out: Arc>) -> Box +pub fn fmt_html(out: Arc>) -> Box where S: std::fmt::Write + Send + 'static, { fmt(fmt::html, out) } -pub fn fmt_markdown(out: Arc>) -> Box +pub fn fmt_markdown(out: Arc>) -> Box where S: std::fmt::Write + Send + 'static, { fmt(fmt::markdown, out) } -pub fn fmt(fun: F, out: Arc>) -> Box +pub fn fmt(fun: F, out: Arc>) -> Box where F: Fn(&mut S, &Level, &str, &str) -> Result<()> + Send + Sync + Copy + 'static, S: std::fmt::Write + Send + 'static, { - Box::new(move |data| call(fun, &mut *out.lock().expect("locked"), &data)) + Box::new(move |data| call(fun, &mut *out.lock(), &data)) } fn call(fun: F, out: &mut S, data: &Data<'_>) diff --git a/src/core/log/reload.rs b/src/core/log/reload.rs index f72fde47..356ee9f2 100644 --- a/src/core/log/reload.rs +++ b/src/core/log/reload.rs @@ -1,11 +1,8 @@ -use std::{ - collections::HashMap, - sync::{Arc, Mutex}, -}; +use std::{collections::HashMap, sync::Arc}; use tracing_subscriber::{EnvFilter, reload}; -use crate::{Result, error}; +use crate::{Result, SyncMutex, error}; /// We need to store a reload::Handle value, but can't name it's type explicitly /// because the S type parameter depends on the subscriber's previous layers. In @@ -35,7 +32,7 @@ impl ReloadHandle for reload::Handle { #[derive(Clone)] pub struct LogLevelReloadHandles { - handles: Arc>, + handles: Arc>, } type HandleMap = HashMap; @@ -43,16 +40,12 @@ type Handle = Box + Send + Sync>; impl LogLevelReloadHandles { pub fn add(&self, name: &str, handle: Handle) { - self.handles - .lock() - .expect("locked") - .insert(name.into(), handle); + self.handles.lock().insert(name.into(), handle); } pub fn reload(&self, new_value: &EnvFilter, names: Option<&[&str]>) -> Result<()> { self.handles .lock() - .expect("locked") .iter() .filter(|(name, _)| names.is_some_and(|names| names.contains(&name.as_str()))) .for_each(|(_, handle)| { @@ -66,7 +59,6 @@ impl LogLevelReloadHandles { pub fn current(&self, name: &str) -> Option { self.handles .lock() - .expect("locked") .get(name) .map(|handle| handle.current())? } diff --git a/src/core/mod.rs b/src/core/mod.rs index d99139be..363fece8 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -28,6 +28,7 @@ pub use info::{ pub use matrix::{ Event, EventTypeExt, Pdu, PduCount, PduEvent, PduId, RoomVersion, pdu, state_res, }; +pub use parking_lot::{Mutex as SyncMutex, RwLock as SyncRwLock}; pub use server::Server; pub use utils::{ctor, dtor, implement, result, result::Result}; diff --git a/src/core/utils/mutex_map.rs b/src/core/utils/mutex_map.rs index 01504ce6..ddb361a4 100644 --- a/src/core/utils/mutex_map.rs +++ b/src/core/utils/mutex_map.rs @@ -1,12 +1,8 @@ -use std::{ - fmt::Debug, - hash::Hash, - sync::{Arc, TryLockError::WouldBlock}, -}; +use std::{fmt::Debug, hash::Hash, sync::Arc}; use tokio::sync::OwnedMutexGuard as Omg; -use crate::{Result, err}; +use crate::{Result, SyncMutex, err}; /// Map of Mutexes pub struct MutexMap { @@ -19,7 +15,7 @@ pub struct Guard { } type Map = Arc>; -type MapMutex = std::sync::Mutex>; +type MapMutex = SyncMutex>; type HashMap = std::collections::HashMap>; type Value = Arc>; @@ -45,7 +41,6 @@ where let val = self .map .lock() - .expect("locked") .entry(k.try_into().expect("failed to construct key")) .or_default() .clone(); @@ -66,7 +61,6 @@ where let val = self .map .lock() - .expect("locked") .entry(k.try_into().expect("failed to construct key")) .or_default() .clone(); @@ -87,10 +81,7 @@ where let val = self .map .try_lock() - .map_err(|e| match e { - | WouldBlock => err!("would block"), - | _ => panic!("{e:?}"), - })? + .ok_or_else(|| err!("would block"))? .entry(k.try_into().expect("failed to construct key")) .or_default() .clone(); @@ -102,13 +93,13 @@ where } #[must_use] - pub fn contains(&self, k: &Key) -> bool { self.map.lock().expect("locked").contains_key(k) } + pub fn contains(&self, k: &Key) -> bool { self.map.lock().contains_key(k) } #[must_use] - pub fn is_empty(&self) -> bool { self.map.lock().expect("locked").is_empty() } + pub fn is_empty(&self) -> bool { self.map.lock().is_empty() } #[must_use] - pub fn len(&self) -> usize { self.map.lock().expect("locked").len() } + pub fn len(&self) -> usize { self.map.lock().len() } } impl Default for MutexMap @@ -123,7 +114,7 @@ impl Drop for Guard { #[tracing::instrument(name = "unlock", level = "trace", skip_all)] fn drop(&mut self) { if Arc::strong_count(Omg::mutex(&self.val)) <= 2 { - self.map.lock().expect("locked").retain(|_, val| { + self.map.lock().retain(|_, val| { !Arc::ptr_eq(val, Omg::mutex(&self.val)) || Arc::strong_count(val) > 2 }); } diff --git a/src/core/utils/with_lock.rs b/src/core/utils/with_lock.rs index 76f014d1..91e8e8d1 100644 --- a/src/core/utils/with_lock.rs +++ b/src/core/utils/with_lock.rs @@ -1,65 +1,212 @@ //! Traits for explicitly scoping the lifetime of locks. -use std::sync::{Arc, Mutex}; +use std::{ + future::Future, + sync::{Arc, Mutex}, +}; -pub trait WithLock { - /// Acquires a lock and executes the given closure with the locked data. - fn with_lock(&self, f: F) +pub trait WithLock { + /// Acquires a lock and executes the given closure with the locked data, + /// returning the result. + fn with_lock(&self, f: F) -> R where - F: FnMut(&mut T); + F: FnMut(&mut T) -> R; } impl WithLock for Mutex { - fn with_lock(&self, mut f: F) + fn with_lock(&self, mut f: F) -> R where - F: FnMut(&mut T), + F: FnMut(&mut T) -> R, { // The locking and unlocking logic is hidden inside this function. let mut data_guard = self.lock().unwrap(); - f(&mut data_guard); + f(&mut data_guard) // Lock is released here when `data_guard` goes out of scope. } } impl WithLock for Arc> { - fn with_lock(&self, mut f: F) + fn with_lock(&self, mut f: F) -> R where - F: FnMut(&mut T), + F: FnMut(&mut T) -> R, { // The locking and unlocking logic is hidden inside this function. let mut data_guard = self.lock().unwrap(); - f(&mut data_guard); + f(&mut data_guard) + // Lock is released here when `data_guard` goes out of scope. + } +} + +impl WithLock for lock_api::Mutex { + fn with_lock(&self, mut f: F) -> Ret + where + F: FnMut(&mut T) -> Ret, + { + // The locking and unlocking logic is hidden inside this function. + let mut data_guard = self.lock(); + f(&mut data_guard) + // Lock is released here when `data_guard` goes out of scope. + } +} + +impl WithLock for Arc> { + fn with_lock(&self, mut f: F) -> Ret + where + F: FnMut(&mut T) -> Ret, + { + // The locking and unlocking logic is hidden inside this function. + let mut data_guard = self.lock(); + f(&mut data_guard) // Lock is released here when `data_guard` goes out of scope. } } pub trait WithLockAsync { - /// Acquires a lock and executes the given closure with the locked data. - fn with_lock(&self, f: F) -> impl Future + /// Acquires a lock and executes the given closure with the locked data, + /// returning the result. + fn with_lock(&self, f: F) -> impl Future where - F: FnMut(&mut T); + F: FnMut(&mut T) -> R; + + /// Acquires a lock and executes the given async closure with the locked + /// data. + fn with_lock_async(&self, f: F) -> impl std::future::Future + where + F: AsyncFnMut(&mut T) -> R; } impl WithLockAsync for futures::lock::Mutex { - async fn with_lock(&self, mut f: F) + async fn with_lock(&self, mut f: F) -> R where - F: FnMut(&mut T), + F: FnMut(&mut T) -> R, { // The locking and unlocking logic is hidden inside this function. let mut data_guard = self.lock().await; - f(&mut data_guard); + f(&mut data_guard) + // Lock is released here when `data_guard` goes out of scope. + } + + async fn with_lock_async(&self, mut f: F) -> R + where + F: AsyncFnMut(&mut T) -> R, + { + // The locking and unlocking logic is hidden inside this function. + let mut data_guard = self.lock().await; + f(&mut data_guard).await // Lock is released here when `data_guard` goes out of scope. } } impl WithLockAsync for Arc> { - async fn with_lock(&self, mut f: F) + async fn with_lock(&self, mut f: F) -> R where - F: FnMut(&mut T), + F: FnMut(&mut T) -> R, { // The locking and unlocking logic is hidden inside this function. let mut data_guard = self.lock().await; - f(&mut data_guard); + f(&mut data_guard) + // Lock is released here when `data_guard` goes out of scope. + } + + async fn with_lock_async(&self, mut f: F) -> R + where + F: AsyncFnMut(&mut T) -> R, + { + // The locking and unlocking logic is hidden inside this function. + let mut data_guard = self.lock().await; + f(&mut data_guard).await // Lock is released here when `data_guard` goes out of scope. } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_with_lock_return_value() { + let mutex = Mutex::new(5); + let result = mutex.with_lock(|v| { + *v += 1; + *v * 2 + }); + assert_eq!(result, 12); + let value = mutex.lock().unwrap(); + assert_eq!(*value, 6); + } + + #[test] + fn test_with_lock_unit_return() { + let mutex = Mutex::new(10); + mutex.with_lock(|v| { + *v += 2; + }); + let value = mutex.lock().unwrap(); + assert_eq!(*value, 12); + } + + #[test] + fn test_with_lock_arc_mutex() { + let mutex = Arc::new(Mutex::new(1)); + let result = mutex.with_lock(|v| { + *v *= 10; + *v + }); + assert_eq!(result, 10); + assert_eq!(*mutex.lock().unwrap(), 10); + } + + #[tokio::test] + async fn test_with_lock_async_return_value() { + use futures::lock::Mutex as AsyncMutex; + let mutex = AsyncMutex::new(7); + let result = mutex + .with_lock(|v| { + *v += 3; + *v * 2 + }) + .await; + assert_eq!(result, 20); + let value = mutex.lock().await; + assert_eq!(*value, 10); + } + + #[tokio::test] + async fn test_with_lock_async_unit_return() { + use futures::lock::Mutex as AsyncMutex; + let mutex = AsyncMutex::new(100); + mutex + .with_lock(|v| { + *v -= 50; + }) + .await; + let value = mutex.lock().await; + assert_eq!(*value, 50); + } + + #[tokio::test] + async fn test_with_lock_async_closure() { + use futures::lock::Mutex as AsyncMutex; + let mutex = AsyncMutex::new(1); + mutex + .with_lock_async(async |v| { + *v += 9; + }) + .await; + let value = mutex.lock().await; + assert_eq!(*value, 10); + } + + #[tokio::test] + async fn test_with_lock_async_arc_mutex() { + use futures::lock::Mutex as AsyncMutex; + let mutex = Arc::new(AsyncMutex::new(2)); + mutex + .with_lock_async(async |v: &mut i32| { + *v *= 5; + }) + .await; + let value = mutex.lock().await; + assert_eq!(*value, 10); + } +} diff --git a/src/database/engine/backup.rs b/src/database/engine/backup.rs index ac72e6d4..4cdb6172 100644 --- a/src/database/engine/backup.rs +++ b/src/database/engine/backup.rs @@ -71,7 +71,7 @@ pub fn backup_count(&self) -> Result { fn backup_engine(&self) -> Result { let path = self.backup_path()?; let options = BackupEngineOptions::new(path).map_err(map_err)?; - BackupEngine::open(&options, &*self.ctx.env.lock()?).map_err(map_err) + BackupEngine::open(&options, &self.ctx.env.lock()).map_err(map_err) } #[implement(Engine)] diff --git a/src/database/engine/cf_opts.rs b/src/database/engine/cf_opts.rs index cbbd1012..58358f02 100644 --- a/src/database/engine/cf_opts.rs +++ b/src/database/engine/cf_opts.rs @@ -232,7 +232,7 @@ fn get_cache(ctx: &Context, desc: &Descriptor) -> Option { cache_opts.set_num_shard_bits(shard_bits); cache_opts.set_capacity(size); - let mut caches = ctx.col_cache.lock().expect("locked"); + let mut caches = ctx.col_cache.lock(); match desc.cache_disp { | CacheDisp::Unique if desc.cache_size == 0 => None, | CacheDisp::Unique => { diff --git a/src/database/engine/context.rs b/src/database/engine/context.rs index 380e37af..3b9238bd 100644 --- a/src/database/engine/context.rs +++ b/src/database/engine/context.rs @@ -1,9 +1,6 @@ -use std::{ - collections::BTreeMap, - sync::{Arc, Mutex}, -}; +use std::{collections::BTreeMap, sync::Arc}; -use conduwuit::{Result, Server, debug, utils::math::usize_from_f64}; +use conduwuit::{Result, Server, SyncMutex, debug, utils::math::usize_from_f64}; use rocksdb::{Cache, Env, LruCacheOptions}; use crate::{or_else, pool::Pool}; @@ -14,9 +11,9 @@ use crate::{or_else, pool::Pool}; /// These assets are housed in the shared Context. pub(crate) struct Context { pub(crate) pool: Arc, - pub(crate) col_cache: Mutex>, - pub(crate) row_cache: Mutex, - pub(crate) env: Mutex, + pub(crate) col_cache: SyncMutex>, + pub(crate) row_cache: SyncMutex, + pub(crate) env: SyncMutex, pub(crate) server: Arc, } @@ -68,7 +65,7 @@ impl Drop for Context { debug!("Closing frontend pool"); self.pool.close(); - let mut env = self.env.lock().expect("locked"); + let mut env = self.env.lock(); debug!("Shutting down background threads"); env.set_high_priority_background_threads(0); diff --git a/src/database/engine/memory_usage.rs b/src/database/engine/memory_usage.rs index 9bb5c535..21af35c8 100644 --- a/src/database/engine/memory_usage.rs +++ b/src/database/engine/memory_usage.rs @@ -9,7 +9,7 @@ use crate::or_else; #[implement(Engine)] pub fn memory_usage(&self) -> Result { let mut res = String::new(); - let stats = get_memory_usage_stats(Some(&[&self.db]), Some(&[&*self.ctx.row_cache.lock()?])) + let stats = get_memory_usage_stats(Some(&[&self.db]), Some(&[&*self.ctx.row_cache.lock()])) .or_else(or_else)?; let mibs = |input| f64::from(u32::try_from(input / 1024).unwrap_or(0)) / 1024.0; writeln!( @@ -19,10 +19,10 @@ pub fn memory_usage(&self) -> Result { mibs(stats.mem_table_total), mibs(stats.mem_table_unflushed), mibs(stats.mem_table_readers_total), - mibs(u64::try_from(self.ctx.row_cache.lock()?.get_usage())?), + mibs(u64::try_from(self.ctx.row_cache.lock().get_usage())?), )?; - for (name, cache) in &*self.ctx.col_cache.lock()? { + for (name, cache) in &*self.ctx.col_cache.lock() { writeln!(res, "{name} cache: {:.2} MiB", mibs(u64::try_from(cache.get_usage())?))?; } diff --git a/src/database/engine/open.rs b/src/database/engine/open.rs index 84e59a6a..7b9d93c2 100644 --- a/src/database/engine/open.rs +++ b/src/database/engine/open.rs @@ -23,11 +23,7 @@ pub(crate) async fn open(ctx: Arc, desc: &[Descriptor]) -> Result, queues: Vec>, - workers: Mutex>>, + workers: SyncMutex>>, topology: Vec, busy: AtomicUsize, queued_max: AtomicUsize, @@ -115,7 +115,7 @@ impl Drop for Pool { #[implement(Pool)] #[tracing::instrument(skip_all)] pub(crate) fn close(&self) { - let workers = take(&mut *self.workers.lock().expect("locked")); + let workers = take(&mut *self.workers.lock()); let senders = self.queues.iter().map(Sender::sender_count).sum::(); @@ -154,7 +154,7 @@ pub(crate) fn close(&self) { #[implement(Pool)] fn spawn_until(self: &Arc, recv: &[Receiver], count: usize) -> Result { - let mut workers = self.workers.lock().expect("locked"); + let mut workers = self.workers.lock(); while workers.len() < count { self.clone().spawn_one(&mut workers, recv)?; } diff --git a/src/database/watchers.rs b/src/database/watchers.rs index efb939d7..0e911c82 100644 --- a/src/database/watchers.rs +++ b/src/database/watchers.rs @@ -2,12 +2,12 @@ use std::{ collections::{HashMap, hash_map}, future::Future, pin::Pin, - sync::RwLock, }; +use conduwuit::SyncRwLock; use tokio::sync::watch; -type Watcher = RwLock, (watch::Sender<()>, watch::Receiver<()>)>>; +type Watcher = SyncRwLock, (watch::Sender<()>, watch::Receiver<()>)>>; #[derive(Default)] pub(crate) struct Watchers { @@ -19,7 +19,7 @@ impl Watchers { &'a self, prefix: &[u8], ) -> Pin + Send + 'a>> { - let mut rx = match self.watchers.write().unwrap().entry(prefix.to_vec()) { + let mut rx = match self.watchers.write().entry(prefix.to_vec()) { | hash_map::Entry::Occupied(o) => o.get().1.clone(), | hash_map::Entry::Vacant(v) => { let (tx, rx) = watch::channel(()); @@ -35,7 +35,7 @@ impl Watchers { } pub(crate) fn wake(&self, key: &[u8]) { - let watchers = self.watchers.read().unwrap(); + let watchers = self.watchers.read(); let mut triggered = Vec::new(); for length in 0..=key.len() { if watchers.contains_key(&key[..length]) { @@ -46,7 +46,7 @@ impl Watchers { drop(watchers); if !triggered.is_empty() { - let mut watchers = self.watchers.write().unwrap(); + let mut watchers = self.watchers.write(); for prefix in triggered { if let Some(tx) = watchers.remove(prefix) { tx.0.send(()).expect("channel should still be open"); diff --git a/src/macros/rustc.rs b/src/macros/rustc.rs index 1220c8d4..cf935fe5 100644 --- a/src/macros/rustc.rs +++ b/src/macros/rustc.rs @@ -15,13 +15,13 @@ pub(super) fn flags_capture(args: TokenStream) -> TokenStream { #[conduwuit_core::ctor] fn _set_rustc_flags() { - conduwuit_core::info::rustc::FLAGS.lock().expect("locked").insert(#crate_name, &RUSTC_FLAGS); + conduwuit_core::info::rustc::FLAGS.lock().insert(#crate_name, &RUSTC_FLAGS); } // static strings have to be yanked on module unload #[conduwuit_core::dtor] fn _unset_rustc_flags() { - conduwuit_core::info::rustc::FLAGS.lock().expect("locked").remove(#crate_name); + conduwuit_core::info::rustc::FLAGS.lock().remove(#crate_name); } }; diff --git a/src/service/admin/console.rs b/src/service/admin/console.rs index 02f41303..931bb719 100644 --- a/src/service/admin/console.rs +++ b/src/service/admin/console.rs @@ -1,11 +1,8 @@ #![cfg(feature = "console")] -use std::{ - collections::VecDeque, - sync::{Arc, Mutex}, -}; +use std::{collections::VecDeque, sync::Arc}; -use conduwuit::{Server, debug, defer, error, log, log::is_systemd_mode}; +use conduwuit::{Server, SyncMutex, debug, defer, error, log, log::is_systemd_mode}; use futures::future::{AbortHandle, Abortable}; use ruma::events::room::message::RoomMessageEventContent; use rustyline_async::{Readline, ReadlineError, ReadlineEvent}; @@ -17,10 +14,10 @@ use crate::{Dep, admin}; pub struct Console { server: Arc, admin: Dep, - worker_join: Mutex>>, - input_abort: Mutex>, - command_abort: Mutex>, - history: Mutex>, + worker_join: SyncMutex>>, + input_abort: SyncMutex>, + command_abort: SyncMutex>, + history: SyncMutex>, output: MadSkin, } @@ -50,7 +47,7 @@ impl Console { } pub async fn start(self: &Arc) { - let mut worker_join = self.worker_join.lock().expect("locked"); + let mut worker_join = self.worker_join.lock(); if worker_join.is_none() { let self_ = Arc::clone(self); _ = worker_join.insert(self.server.runtime().spawn(self_.worker())); @@ -60,7 +57,7 @@ impl Console { pub async fn close(self: &Arc) { self.interrupt(); - let Some(worker_join) = self.worker_join.lock().expect("locked").take() else { + let Some(worker_join) = self.worker_join.lock().take() else { return; }; @@ -70,22 +67,18 @@ impl Console { pub fn interrupt(self: &Arc) { self.interrupt_command(); self.interrupt_readline(); - self.worker_join - .lock() - .expect("locked") - .as_ref() - .map(JoinHandle::abort); + self.worker_join.lock().as_ref().map(JoinHandle::abort); } pub fn interrupt_readline(self: &Arc) { - if let Some(input_abort) = self.input_abort.lock().expect("locked").take() { + if let Some(input_abort) = self.input_abort.lock().take() { debug!("Interrupting console readline..."); input_abort.abort(); } } pub fn interrupt_command(self: &Arc) { - if let Some(command_abort) = self.command_abort.lock().expect("locked").take() { + if let Some(command_abort) = self.command_abort.lock().take() { debug!("Interrupting console command..."); command_abort.abort(); } @@ -120,7 +113,7 @@ impl Console { } debug!("session ending"); - self.worker_join.lock().expect("locked").take(); + self.worker_join.lock().take(); } async fn readline(self: &Arc) -> Result { @@ -135,9 +128,9 @@ impl Console { let (abort, abort_reg) = AbortHandle::new_pair(); let future = Abortable::new(future, abort_reg); - _ = self.input_abort.lock().expect("locked").insert(abort); + _ = self.input_abort.lock().insert(abort); defer! {{ - _ = self.input_abort.lock().expect("locked").take(); + _ = self.input_abort.lock().take(); }} let Ok(result) = future.await else { @@ -158,9 +151,9 @@ impl Console { let (abort, abort_reg) = AbortHandle::new_pair(); let future = Abortable::new(future, abort_reg); - _ = self.command_abort.lock().expect("locked").insert(abort); + _ = self.command_abort.lock().insert(abort); defer! {{ - _ = self.command_abort.lock().expect("locked").take(); + _ = self.command_abort.lock().take(); }} _ = future.await; @@ -184,20 +177,15 @@ impl Console { } fn set_history(&self, readline: &mut Readline) { - self.history - .lock() - .expect("locked") - .iter() - .rev() - .for_each(|entry| { - readline - .add_history_entry(entry.clone()) - .expect("added history entry"); - }); + self.history.lock().iter().rev().for_each(|entry| { + readline + .add_history_entry(entry.clone()) + .expect("added history entry"); + }); } fn add_history(&self, line: String) { - let mut history = self.history.lock().expect("locked"); + let mut history = self.history.lock(); history.push_front(line); history.truncate(HISTORY_LIMIT); } diff --git a/src/service/admin/mod.rs b/src/service/admin/mod.rs index f496c414..c052198c 100644 --- a/src/service/admin/mod.rs +++ b/src/service/admin/mod.rs @@ -5,11 +5,11 @@ mod grant; use std::{ pin::Pin, - sync::{Arc, RwLock as StdRwLock, Weak}, + sync::{Arc, Weak}, }; use async_trait::async_trait; -use conduwuit::{Err, utils}; +use conduwuit::{Err, SyncRwLock, utils}; use conduwuit_core::{ Error, Event, Result, Server, debug, err, error, error::default_log, pdu::PduBuilder, }; @@ -36,7 +36,7 @@ pub struct Service { services: Services, channel: (Sender, Receiver), pub handle: RwLock>, - pub complete: StdRwLock>, + pub complete: SyncRwLock>, #[cfg(feature = "console")] pub console: Arc, } @@ -50,7 +50,7 @@ struct Services { state_cache: Dep, state_accessor: Dep, account_data: Dep, - services: StdRwLock>>, + services: SyncRwLock>>, media: Dep, } @@ -105,7 +105,7 @@ impl crate::Service for Service { }, channel: loole::bounded(COMMAND_QUEUE_LIMIT), handle: RwLock::new(None), - complete: StdRwLock::new(None), + complete: SyncRwLock::new(None), #[cfg(feature = "console")] console: console::Console::new(&args), })) @@ -312,10 +312,7 @@ impl Service { /// Invokes the tab-completer to complete the command. When unavailable, /// None is returned. pub fn complete_command(&self, command: &str) -> Option { - self.complete - .read() - .expect("locked for reading") - .map(|complete| complete(command)) + self.complete.read().map(|complete| complete(command)) } async fn handle_signal(&self, sig: &'static str) { @@ -338,17 +335,13 @@ impl Service { } async fn process_command(&self, command: CommandInput) -> ProcessorResult { - let handle = &self - .handle - .read() - .await - .expect("Admin module is not loaded"); + let handle_guard = self.handle.read().await; + let handle = handle_guard.as_ref().expect("Admin module is not loaded"); let services = self .services .services .read() - .expect("locked") .as_ref() .and_then(Weak::upgrade) .expect("Services self-reference not initialized."); @@ -523,7 +516,7 @@ impl Service { /// Sets the self-reference to crate::Services which will provide context to /// the admin commands. pub(super) fn set_services(&self, services: Option<&Arc>) { - let receiver = &mut *self.services.services.write().expect("locked for writing"); + let receiver = &mut *self.services.services.write(); let weak = services.map(Arc::downgrade); *receiver = weak; } diff --git a/src/service/globals/data.rs b/src/service/globals/data.rs index 21c09252..07f1de5c 100644 --- a/src/service/globals/data.rs +++ b/src/service/globals/data.rs @@ -1,11 +1,11 @@ -use std::sync::{Arc, RwLock}; +use std::sync::Arc; -use conduwuit::{Result, utils}; +use conduwuit::{Result, SyncRwLock, utils}; use database::{Database, Deserialized, Map}; pub struct Data { global: Arc, - counter: RwLock, + counter: SyncRwLock, pub(super) db: Arc, } @@ -16,25 +16,21 @@ impl Data { let db = &args.db; Self { global: db["global"].clone(), - counter: RwLock::new( - Self::stored_count(&db["global"]).expect("initialized global counter"), - ), + counter: SyncRwLock::new(Self::stored_count(&db["global"]).unwrap_or_default()), db: args.db.clone(), } } pub fn next_count(&self) -> Result { let _cork = self.db.cork(); - let mut lock = self.counter.write().expect("locked"); + let mut lock = self.counter.write(); let counter: &mut u64 = &mut lock; debug_assert!( - *counter == Self::stored_count(&self.global).expect("database failure"), + *counter == Self::stored_count(&self.global).unwrap_or_default(), "counter mismatch" ); - *counter = counter - .checked_add(1) - .expect("counter must not overflow u64"); + *counter = counter.checked_add(1).unwrap_or(*counter); self.global.insert(COUNTER, counter.to_be_bytes()); @@ -43,10 +39,10 @@ impl Data { #[inline] pub fn current_count(&self) -> u64 { - let lock = self.counter.read().expect("locked"); + let lock = self.counter.read(); let counter: &u64 = &lock; debug_assert!( - *counter == Self::stored_count(&self.global).expect("database failure"), + *counter == Self::stored_count(&self.global).unwrap_or_default(), "counter mismatch" ); diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index a23a4c21..12f2ec78 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -1,14 +1,9 @@ mod data; -use std::{ - collections::HashMap, - fmt::Write, - sync::{Arc, RwLock}, - time::Instant, -}; +use std::{collections::HashMap, fmt::Write, sync::Arc, time::Instant}; use async_trait::async_trait; -use conduwuit::{Result, Server, error, utils::bytes::pretty}; +use conduwuit::{Result, Server, SyncRwLock, error, utils::bytes::pretty}; use data::Data; use regex::RegexSet; use ruma::{OwnedEventId, OwnedRoomAliasId, OwnedServerName, OwnedUserId, ServerName, UserId}; @@ -19,7 +14,7 @@ pub struct Service { pub db: Data, server: Arc, - pub bad_event_ratelimiter: Arc>>, + pub bad_event_ratelimiter: Arc>>, pub server_user: OwnedUserId, pub admin_alias: OwnedRoomAliasId, pub turn_secret: String, @@ -62,7 +57,7 @@ impl crate::Service for Service { Ok(Arc::new(Self { db, server: args.server.clone(), - bad_event_ratelimiter: Arc::new(RwLock::new(HashMap::new())), + bad_event_ratelimiter: Arc::new(SyncRwLock::new(HashMap::new())), admin_alias: OwnedRoomAliasId::try_from(format!("#admins:{}", &args.server.name)) .expect("#admins:server_name is valid alias name"), server_user: UserId::parse_with_server_name( @@ -76,7 +71,7 @@ impl crate::Service for Service { } async fn memory_usage(&self, out: &mut (dyn Write + Send)) -> Result { - let (ber_count, ber_bytes) = self.bad_event_ratelimiter.read()?.iter().fold( + let (ber_count, ber_bytes) = self.bad_event_ratelimiter.read().iter().fold( (0_usize, 0_usize), |(mut count, mut bytes), (event_id, _)| { bytes = bytes.saturating_add(event_id.capacity()); @@ -91,12 +86,7 @@ impl crate::Service for Service { Ok(()) } - async fn clear_cache(&self) { - self.bad_event_ratelimiter - .write() - .expect("locked for writing") - .clear(); - } + async fn clear_cache(&self) { self.bad_event_ratelimiter.write().clear(); } fn name(&self) -> &str { service::make_name(std::module_path!()) } } diff --git a/src/service/manager.rs b/src/service/manager.rs index 3cdf5945..7a2e50d5 100644 --- a/src/service/manager.rs +++ b/src/service/manager.rs @@ -58,7 +58,6 @@ impl Manager { let services: Vec> = self .service .read() - .expect("locked for reading") .values() .map(|val| val.0.upgrade()) .map(|arc| arc.expect("services available for manager startup")) diff --git a/src/service/rooms/auth_chain/data.rs b/src/service/rooms/auth_chain/data.rs index 8c3588cc..e9e40979 100644 --- a/src/service/rooms/auth_chain/data.rs +++ b/src/service/rooms/auth_chain/data.rs @@ -1,9 +1,6 @@ -use std::{ - mem::size_of, - sync::{Arc, Mutex}, -}; +use std::{mem::size_of, sync::Arc}; -use conduwuit::{Err, Result, err, utils, utils::math::usize_from_f64}; +use conduwuit::{Err, Result, SyncMutex, err, utils, utils::math::usize_from_f64}; use database::Map; use lru_cache::LruCache; @@ -11,7 +8,7 @@ use crate::rooms::short::ShortEventId; pub(super) struct Data { shorteventid_authchain: Arc, - pub(super) auth_chain_cache: Mutex, Arc<[ShortEventId]>>>, + pub(super) auth_chain_cache: SyncMutex, Arc<[ShortEventId]>>>, } impl Data { @@ -23,7 +20,7 @@ impl Data { .expect("valid cache size"); Self { shorteventid_authchain: db["shorteventid_authchain"].clone(), - auth_chain_cache: Mutex::new(LruCache::new(cache_size)), + auth_chain_cache: SyncMutex::new(LruCache::new(cache_size)), } } @@ -34,12 +31,7 @@ impl Data { debug_assert!(!key.is_empty(), "auth_chain key must not be empty"); // Check RAM cache - if let Some(result) = self - .auth_chain_cache - .lock() - .expect("cache locked") - .get_mut(key) - { + if let Some(result) = self.auth_chain_cache.lock().get_mut(key) { return Ok(Arc::clone(result)); } @@ -63,7 +55,6 @@ impl Data { // Cache in RAM self.auth_chain_cache .lock() - .expect("cache locked") .insert(vec![key[0]], Arc::clone(&chain)); Ok(chain) @@ -84,9 +75,6 @@ impl Data { } // Cache in RAM - self.auth_chain_cache - .lock() - .expect("cache locked") - .insert(key, auth_chain); + self.auth_chain_cache.lock().insert(key, auth_chain); } } diff --git a/src/service/rooms/auth_chain/mod.rs b/src/service/rooms/auth_chain/mod.rs index 0903ea75..79d4d070 100644 --- a/src/service/rooms/auth_chain/mod.rs +++ b/src/service/rooms/auth_chain/mod.rs @@ -248,10 +248,10 @@ pub fn cache_auth_chain_vec(&self, key: Vec, auth_chain: &[ShortEventId]) { #[implement(Service)] pub fn get_cache_usage(&self) -> (usize, usize) { - let cache = self.db.auth_chain_cache.lock().expect("locked"); + let cache = self.db.auth_chain_cache.lock(); (cache.len(), cache.capacity()) } #[implement(Service)] -pub fn clear_cache(&self) { self.db.auth_chain_cache.lock().expect("locked").clear(); } +pub fn clear_cache(&self) { self.db.auth_chain_cache.lock().clear(); } diff --git a/src/service/rooms/event_handler/fetch_and_handle_outliers.rs b/src/service/rooms/event_handler/fetch_and_handle_outliers.rs index 44027e04..59b768f2 100644 --- a/src/service/rooms/event_handler/fetch_and_handle_outliers.rs +++ b/src/service/rooms/event_handler/fetch_and_handle_outliers.rs @@ -41,7 +41,6 @@ where .globals .bad_event_ratelimiter .write() - .expect("locked") .entry(id) { | hash_map::Entry::Vacant(e) => { @@ -76,7 +75,6 @@ where .globals .bad_event_ratelimiter .read() - .expect("locked") .get(&*next_id) { // Exponential backoff @@ -187,7 +185,6 @@ where .globals .bad_event_ratelimiter .read() - .expect("locked") .get(&*next_id) { // Exponential backoff diff --git a/src/service/rooms/event_handler/handle_incoming_pdu.rs b/src/service/rooms/event_handler/handle_incoming_pdu.rs index 86a05e0a..5299e8d4 100644 --- a/src/service/rooms/event_handler/handle_incoming_pdu.rs +++ b/src/service/rooms/event_handler/handle_incoming_pdu.rs @@ -160,7 +160,6 @@ pub async fn handle_incoming_pdu<'a>( .globals .bad_event_ratelimiter .write() - .expect("locked") .entry(prev_id.into()) { | hash_map::Entry::Vacant(e) => { @@ -181,13 +180,11 @@ pub async fn handle_incoming_pdu<'a>( let start_time = Instant::now(); self.federation_handletime .write() - .expect("locked") .insert(room_id.into(), (event_id.to_owned(), start_time)); defer! {{ self.federation_handletime .write() - .expect("locked") .remove(room_id); }}; diff --git a/src/service/rooms/event_handler/handle_prev_pdu.rs b/src/service/rooms/event_handler/handle_prev_pdu.rs index cd46310a..cb4978d9 100644 --- a/src/service/rooms/event_handler/handle_prev_pdu.rs +++ b/src/service/rooms/event_handler/handle_prev_pdu.rs @@ -42,7 +42,6 @@ where .globals .bad_event_ratelimiter .read() - .expect("locked") .get(prev_id) { // Exponential backoff @@ -70,13 +69,11 @@ where let start_time = Instant::now(); self.federation_handletime .write() - .expect("locked") .insert(room_id.into(), ((*prev_id).to_owned(), start_time)); defer! {{ self.federation_handletime .write() - .expect("locked") .remove(room_id); }}; diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index aed38e1e..4e59c207 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -10,15 +10,10 @@ mod resolve_state; mod state_at_incoming; mod upgrade_outlier_pdu; -use std::{ - collections::HashMap, - fmt::Write, - sync::{Arc, RwLock as StdRwLock}, - time::Instant, -}; +use std::{collections::HashMap, fmt::Write, sync::Arc, time::Instant}; use async_trait::async_trait; -use conduwuit::{Err, Event, PduEvent, Result, RoomVersion, Server, utils::MutexMap}; +use conduwuit::{Err, Event, PduEvent, Result, RoomVersion, Server, SyncRwLock, utils::MutexMap}; use ruma::{ OwnedEventId, OwnedRoomId, RoomId, RoomVersionId, events::room::create::RoomCreateEventContent, @@ -28,7 +23,7 @@ use crate::{Dep, globals, rooms, sending, server_keys}; pub struct Service { pub mutex_federation: RoomMutexMap, - pub federation_handletime: StdRwLock, + pub federation_handletime: SyncRwLock, services: Services, } @@ -81,11 +76,7 @@ impl crate::Service for Service { let mutex_federation = self.mutex_federation.len(); writeln!(out, "federation_mutex: {mutex_federation}")?; - let federation_handletime = self - .federation_handletime - .read() - .expect("locked for reading") - .len(); + let federation_handletime = self.federation_handletime.read().len(); writeln!(out, "federation_handletime: {federation_handletime}")?; Ok(()) diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index 9429be79..e9845fbf 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -1,13 +1,10 @@ mod update; mod via; -use std::{ - collections::HashMap, - sync::{Arc, RwLock}, -}; +use std::{collections::HashMap, sync::Arc}; use conduwuit::{ - Result, implement, + Result, SyncRwLock, implement, result::LogErr, utils::{ReadyExt, stream::TryIgnore}, warn, @@ -54,14 +51,14 @@ struct Data { userroomid_knockedstate: Arc, } -type AppServiceInRoomCache = RwLock>>; +type AppServiceInRoomCache = SyncRwLock>>; type StrippedStateEventItem = (OwnedRoomId, Vec>); type SyncStateEventItem = (OwnedRoomId, Vec>); impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - appservice_in_room_cache: RwLock::new(HashMap::new()), + appservice_in_room_cache: SyncRwLock::new(HashMap::new()), services: Services { account_data: args.depend::("account_data"), config: args.depend::("config"), @@ -99,7 +96,6 @@ pub async fn appservice_in_room(&self, room_id: &RoomId, appservice: &Registrati if let Some(cached) = self .appservice_in_room_cache .read() - .expect("locked") .get(room_id) .and_then(|map| map.get(&appservice.registration.id)) .copied() @@ -124,7 +120,6 @@ pub async fn appservice_in_room(&self, room_id: &RoomId, appservice: &Registrati self.appservice_in_room_cache .write() - .expect("locked") .entry(room_id.into()) .or_default() .insert(appservice.registration.id.clone(), in_room); @@ -134,19 +129,14 @@ pub async fn appservice_in_room(&self, room_id: &RoomId, appservice: &Registrati #[implement(Service)] pub fn get_appservice_in_room_cache_usage(&self) -> (usize, usize) { - let cache = self.appservice_in_room_cache.read().expect("locked"); + let cache = self.appservice_in_room_cache.read(); (cache.len(), cache.capacity()) } #[implement(Service)] #[tracing::instrument(level = "debug", skip_all)] -pub fn clear_appservice_in_room_cache(&self) { - self.appservice_in_room_cache - .write() - .expect("locked") - .clear(); -} +pub fn clear_appservice_in_room_cache(&self) { self.appservice_in_room_cache.write().clear(); } /// Returns an iterator of all servers participating in this room. #[implement(Service)] diff --git a/src/service/rooms/state_cache/update.rs b/src/service/rooms/state_cache/update.rs index 02c6bec6..32c67947 100644 --- a/src/service/rooms/state_cache/update.rs +++ b/src/service/rooms/state_cache/update.rs @@ -211,10 +211,7 @@ pub async fn update_joined_count(&self, room_id: &RoomId) { self.db.serverroomids.put_raw(serverroom_id, []); } - self.appservice_in_room_cache - .write() - .expect("locked") - .remove(room_id); + self.appservice_in_room_cache.write().remove(room_id); } /// Direct DB function to directly mark a user as joined. It is not diff --git a/src/service/rooms/state_compressor/mod.rs b/src/service/rooms/state_compressor/mod.rs index a33fb342..f7f7d043 100644 --- a/src/service/rooms/state_compressor/mod.rs +++ b/src/service/rooms/state_compressor/mod.rs @@ -2,12 +2,12 @@ use std::{ collections::{BTreeSet, HashMap}, fmt::{Debug, Write}, mem::size_of, - sync::{Arc, Mutex}, + sync::Arc, }; use async_trait::async_trait; use conduwuit::{ - Result, + Result, SyncMutex, arrayvec::ArrayVec, at, checked, err, expected, implement, utils, utils::{bytes, math::usize_from_f64, stream::IterStream}, @@ -23,7 +23,7 @@ use crate::{ }; pub struct Service { - pub stateinfo_cache: Mutex, + pub stateinfo_cache: SyncMutex, db: Data, services: Services, } @@ -86,7 +86,7 @@ impl crate::Service for Service { async fn memory_usage(&self, out: &mut (dyn Write + Send)) -> Result { let (cache_len, ents) = { - let cache = self.stateinfo_cache.lock().expect("locked"); + let cache = self.stateinfo_cache.lock(); let ents = cache.iter().map(at!(1)).flat_map(|vec| vec.iter()).fold( HashMap::new(), |mut ents, ssi| { @@ -110,7 +110,7 @@ impl crate::Service for Service { Ok(()) } - async fn clear_cache(&self) { self.stateinfo_cache.lock().expect("locked").clear(); } + async fn clear_cache(&self) { self.stateinfo_cache.lock().clear(); } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } @@ -123,7 +123,7 @@ pub async fn load_shortstatehash_info( &self, shortstatehash: ShortStateHash, ) -> Result { - if let Some(r) = self.stateinfo_cache.lock()?.get_mut(&shortstatehash) { + if let Some(r) = self.stateinfo_cache.lock().get_mut(&shortstatehash) { return Ok(r.clone()); } @@ -152,7 +152,7 @@ async fn cache_shortstatehash_info( shortstatehash: ShortStateHash, stack: ShortStateInfoVec, ) -> Result { - self.stateinfo_cache.lock()?.insert(shortstatehash, stack); + self.stateinfo_cache.lock().insert(shortstatehash, stack); Ok(()) } diff --git a/src/service/service.rs b/src/service/service.rs index 574efd8f..3bc61aeb 100644 --- a/src/service/service.rs +++ b/src/service/service.rs @@ -3,11 +3,13 @@ use std::{ collections::BTreeMap, fmt::Write, ops::Deref, - sync::{Arc, OnceLock, RwLock, Weak}, + sync::{Arc, OnceLock, Weak}, }; use async_trait::async_trait; -use conduwuit::{Err, Result, Server, err, error::inspect_log, utils::string::SplitInfallible}; +use conduwuit::{ + Err, Result, Server, SyncRwLock, err, error::inspect_log, utils::string::SplitInfallible, +}; use database::Database; /// Abstract interface for a Service @@ -62,7 +64,7 @@ pub(crate) struct Dep { name: &'static str, } -pub(crate) type Map = RwLock; +pub(crate) type Map = SyncRwLock; pub(crate) type MapType = BTreeMap; pub(crate) type MapVal = (Weak, Weak); pub(crate) type MapKey = String; @@ -143,15 +145,12 @@ pub(crate) fn get(map: &Map, name: &str) -> Option> where T: Any + Send + Sync + Sized, { - map.read() - .expect("locked for reading") - .get(name) - .map(|(_, s)| { - s.upgrade().map(|s| { - s.downcast::() - .expect("Service must be correctly downcast.") - }) - })? + map.read().get(name).map(|(_, s)| { + s.upgrade().map(|s| { + s.downcast::() + .expect("Service must be correctly downcast.") + }) + })? } /// Reference a Service by name. Returns Err if the Service does not exist or @@ -160,21 +159,18 @@ pub(crate) fn try_get(map: &Map, name: &str) -> Result> where T: Any + Send + Sync + Sized, { - map.read() - .expect("locked for reading") - .get(name) - .map_or_else( - || Err!("Service {name:?} does not exist or has not been built yet."), - |(_, s)| { - s.upgrade().map_or_else( - || Err!("Service {name:?} no longer exists."), - |s| { - s.downcast::() - .map_err(|_| err!("Service {name:?} must be correctly downcast.")) - }, - ) - }, - ) + map.read().get(name).map_or_else( + || Err!("Service {name:?} does not exist or has not been built yet."), + |(_, s)| { + s.upgrade().map_or_else( + || Err!("Service {name:?} no longer exists."), + |s| { + s.downcast::() + .map_err(|_| err!("Service {name:?} must be correctly downcast.")) + }, + ) + }, + ) } /// Utility for service implementations; see Service::name() in the trait. diff --git a/src/service/services.rs b/src/service/services.rs index daece245..642f61c7 100644 --- a/src/service/services.rs +++ b/src/service/services.rs @@ -1,10 +1,8 @@ -use std::{ - any::Any, - collections::BTreeMap, - sync::{Arc, RwLock}, -}; +use std::{any::Any, collections::BTreeMap, sync::Arc}; -use conduwuit::{Result, Server, debug, debug_info, info, trace, utils::stream::IterStream}; +use conduwuit::{ + Result, Server, SyncRwLock, debug, debug_info, info, trace, utils::stream::IterStream, +}; use database::Database; use futures::{Stream, StreamExt, TryStreamExt}; use tokio::sync::Mutex; @@ -52,7 +50,7 @@ impl Services { #[allow(clippy::cognitive_complexity)] pub async fn build(server: Arc) -> Result> { let db = Database::open(&server).await?; - let service: Arc = Arc::new(RwLock::new(BTreeMap::new())); + let service: Arc = Arc::new(SyncRwLock::new(BTreeMap::new())); macro_rules! build { ($tyname:ty) => {{ let built = <$tyname>::build(Args { @@ -193,7 +191,7 @@ impl Services { fn interrupt(&self) { debug!("Interrupting services..."); - for (name, (service, ..)) in self.service.read().expect("locked for reading").iter() { + for (name, (service, ..)) in self.service.read().iter() { if let Some(service) = service.upgrade() { trace!("Interrupting {name}"); service.interrupt(); @@ -205,7 +203,6 @@ impl Services { fn services(&self) -> impl Stream> + Send { self.service .read() - .expect("locked for reading") .values() .filter_map(|val| val.0.upgrade()) .collect::>() @@ -233,10 +230,9 @@ impl Services { #[allow(clippy::needless_pass_by_value)] fn add_service(map: &Arc, s: Arc, a: Arc) { let name = s.name(); - let len = map.read().expect("locked for reading").len(); + let len = map.read().len(); trace!("built service #{len}: {name:?}"); map.write() - .expect("locked for writing") .insert(name.to_owned(), (Arc::downgrade(&s), Arc::downgrade(&a))); } diff --git a/src/service/sync/mod.rs b/src/service/sync/mod.rs index b095d2c1..6ac579f4 100644 --- a/src/service/sync/mod.rs +++ b/src/service/sync/mod.rs @@ -2,10 +2,10 @@ mod watch; use std::{ collections::{BTreeMap, BTreeSet}, - sync::{Arc, Mutex, Mutex as StdMutex}, + sync::Arc, }; -use conduwuit::{Result, Server}; +use conduwuit::{Result, Server, SyncMutex}; use database::Map; use ruma::{ OwnedDeviceId, OwnedRoomId, OwnedUserId, @@ -62,11 +62,11 @@ struct SnakeSyncCache { extensions: v5::request::Extensions, } -type DbConnections = Mutex>; +type DbConnections = SyncMutex>; type DbConnectionsKey = (OwnedUserId, OwnedDeviceId, String); -type DbConnectionsVal = Arc>; +type DbConnectionsVal = Arc>; type SnakeConnectionsKey = (OwnedUserId, OwnedDeviceId, Option); -type SnakeConnectionsVal = Arc>; +type SnakeConnectionsVal = Arc>; impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { @@ -90,8 +90,8 @@ impl crate::Service for Service { state_cache: args.depend::("rooms::state_cache"), typing: args.depend::("rooms::typing"), }, - connections: StdMutex::new(BTreeMap::new()), - snake_connections: StdMutex::new(BTreeMap::new()), + connections: SyncMutex::new(BTreeMap::new()), + snake_connections: SyncMutex::new(BTreeMap::new()), })) } @@ -100,22 +100,19 @@ impl crate::Service for Service { impl Service { pub fn snake_connection_cached(&self, key: &SnakeConnectionsKey) -> bool { - self.snake_connections - .lock() - .expect("locked") - .contains_key(key) + self.snake_connections.lock().contains_key(key) } pub fn forget_snake_sync_connection(&self, key: &SnakeConnectionsKey) { - self.snake_connections.lock().expect("locked").remove(key); + self.snake_connections.lock().remove(key); } pub fn remembered(&self, key: &DbConnectionsKey) -> bool { - self.connections.lock().expect("locked").contains_key(key) + self.connections.lock().contains_key(key) } pub fn forget_sync_request_connection(&self, key: &DbConnectionsKey) { - self.connections.lock().expect("locked").remove(key); + self.connections.lock().remove(key); } pub fn update_snake_sync_request_with_cache( @@ -123,13 +120,13 @@ impl Service { snake_key: &SnakeConnectionsKey, request: &mut v5::Request, ) -> BTreeMap> { - let mut cache = self.snake_connections.lock().expect("locked"); + let mut cache = self.snake_connections.lock(); let cached = Arc::clone( cache .entry(snake_key.clone()) - .or_insert_with(|| Arc::new(Mutex::new(SnakeSyncCache::default()))), + .or_insert_with(|| Arc::new(SyncMutex::new(SnakeSyncCache::default()))), ); - let cached = &mut cached.lock().expect("locked"); + let cached = &mut cached.lock(); drop(cache); //v5::Request::try_from_http_request(req, path_args); @@ -232,16 +229,16 @@ impl Service { }; let key = into_db_key(key.0.clone(), key.1.clone(), conn_id); - let mut cache = self.connections.lock().expect("locked"); + let mut cache = self.connections.lock(); let cached = Arc::clone(cache.entry(key).or_insert_with(|| { - Arc::new(Mutex::new(SlidingSyncCache { + Arc::new(SyncMutex::new(SlidingSyncCache { lists: BTreeMap::new(), subscriptions: BTreeMap::new(), known_rooms: BTreeMap::new(), extensions: ExtensionsConfig::default(), })) })); - let cached = &mut cached.lock().expect("locked"); + let cached = &mut cached.lock(); drop(cache); for (list_id, list) in &mut request.lists { @@ -328,16 +325,16 @@ impl Service { key: &DbConnectionsKey, subscriptions: BTreeMap, ) { - let mut cache = self.connections.lock().expect("locked"); + let mut cache = self.connections.lock(); let cached = Arc::clone(cache.entry(key.clone()).or_insert_with(|| { - Arc::new(Mutex::new(SlidingSyncCache { + Arc::new(SyncMutex::new(SlidingSyncCache { lists: BTreeMap::new(), subscriptions: BTreeMap::new(), known_rooms: BTreeMap::new(), extensions: ExtensionsConfig::default(), })) })); - let cached = &mut cached.lock().expect("locked"); + let cached = &mut cached.lock(); drop(cache); cached.subscriptions = subscriptions; @@ -350,16 +347,16 @@ impl Service { new_cached_rooms: BTreeSet, globalsince: u64, ) { - let mut cache = self.connections.lock().expect("locked"); + let mut cache = self.connections.lock(); let cached = Arc::clone(cache.entry(key.clone()).or_insert_with(|| { - Arc::new(Mutex::new(SlidingSyncCache { + Arc::new(SyncMutex::new(SlidingSyncCache { lists: BTreeMap::new(), subscriptions: BTreeMap::new(), known_rooms: BTreeMap::new(), extensions: ExtensionsConfig::default(), })) })); - let cached = &mut cached.lock().expect("locked"); + let cached = &mut cached.lock(); drop(cache); for (room_id, lastsince) in cached @@ -386,13 +383,13 @@ impl Service { globalsince: u64, ) { assert!(key.2.is_some(), "Some(conn_id) required for this call"); - let mut cache = self.snake_connections.lock().expect("locked"); + let mut cache = self.snake_connections.lock(); let cached = Arc::clone( cache .entry(key.clone()) - .or_insert_with(|| Arc::new(Mutex::new(SnakeSyncCache::default()))), + .or_insert_with(|| Arc::new(SyncMutex::new(SnakeSyncCache::default()))), ); - let cached = &mut cached.lock().expect("locked"); + let cached = &mut cached.lock(); drop(cache); for (room_id, lastsince) in cached @@ -416,13 +413,13 @@ impl Service { key: &SnakeConnectionsKey, subscriptions: BTreeMap, ) { - let mut cache = self.snake_connections.lock().expect("locked"); + let mut cache = self.snake_connections.lock(); let cached = Arc::clone( cache .entry(key.clone()) - .or_insert_with(|| Arc::new(Mutex::new(SnakeSyncCache::default()))), + .or_insert_with(|| Arc::new(SyncMutex::new(SnakeSyncCache::default()))), ); - let cached = &mut cached.lock().expect("locked"); + let cached = &mut cached.lock(); drop(cache); cached.subscriptions = subscriptions; diff --git a/src/service/uiaa/mod.rs b/src/service/uiaa/mod.rs index 7735c87f..acd3dd86 100644 --- a/src/service/uiaa/mod.rs +++ b/src/service/uiaa/mod.rs @@ -1,10 +1,10 @@ use std::{ collections::{BTreeMap, HashSet}, - sync::{Arc, RwLock}, + sync::Arc, }; use conduwuit::{ - Err, Error, Result, err, error, implement, utils, + Err, Error, Result, SyncRwLock, err, error, implement, utils, utils::{hash, string::EMPTY}, }; use database::{Deserialized, Json, Map}; @@ -19,7 +19,7 @@ use ruma::{ use crate::{Dep, config, globals, users}; pub struct Service { - userdevicesessionid_uiaarequest: RwLock, + userdevicesessionid_uiaarequest: SyncRwLock, db: Data, services: Services, } @@ -42,7 +42,7 @@ pub const SESSION_ID_LENGTH: usize = 32; impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - userdevicesessionid_uiaarequest: RwLock::new(RequestMap::new()), + userdevicesessionid_uiaarequest: SyncRwLock::new(RequestMap::new()), db: Data { userdevicesessionid_uiaainfo: args.db["userdevicesessionid_uiaainfo"].clone(), }, @@ -268,7 +268,6 @@ fn set_uiaa_request( let key = (user_id.to_owned(), device_id.to_owned(), session.to_owned()); self.userdevicesessionid_uiaarequest .write() - .expect("locked for writing") .insert(key, request.to_owned()); } @@ -287,7 +286,6 @@ pub fn get_uiaa_request( self.userdevicesessionid_uiaarequest .read() - .expect("locked for reading") .get(&key) .cloned() }