use std::{ cell::RefCell, future::Future, path::{Path, PathBuf}, pin::Pin, sync::Arc, }; use conduit::{Config, Result}; use parking_lot::{Mutex, MutexGuard}; use rusqlite::{Connection, DatabaseName::Main, OptionalExtension}; use thread_local::ThreadLocal; use tracing::debug; use super::{watchers::Watchers, KeyValueDatabaseEngine, KvTree}; thread_local! { static READ_CONNECTION: RefCell> = const { RefCell::new(None) }; static READ_CONNECTION_ITERATOR: RefCell> = const { RefCell::new(None) }; } struct PreparedStatementIterator<'a> { iterator: Box + 'a>, _statement_ref: AliasableBox>, } impl Iterator for PreparedStatementIterator<'_> { type Item = TupleOfBytes; fn next(&mut self) -> Option { self.iterator.next() } } struct AliasableBox(*mut T); impl Drop for AliasableBox { fn drop(&mut self) { // SAFETY: This is cursed and relies on non-local reasoning. // // In order for this to be safe: // // * All aliased references to this value must have been dropped first, for // example by coming after its referrers in struct fields, because struct // fields are automatically dropped in order from top to bottom in the absence // of an explicit Drop impl. Otherwise, the referrers may read into // deallocated memory. // * This type must not be copyable or cloneable. Otherwise, double-free can // occur. // // These conditions are met, but again, note that changing safe code in // this module can result in unsoundness if any of these constraints are // violated. unsafe { drop(Box::from_raw(self.0)) } } } pub(crate) struct Engine { writer: Mutex, read_conn_tls: ThreadLocal, read_iterator_conn_tls: ThreadLocal, path: PathBuf, cache_size_per_thread: u32, } impl Engine { fn prepare_conn(path: &Path, cache_size_kb: u32) -> Result { let conn = Connection::open(path)?; conn.pragma_update(Some(Main), "page_size", 2048)?; conn.pragma_update(Some(Main), "journal_mode", "WAL")?; conn.pragma_update(Some(Main), "synchronous", "NORMAL")?; conn.pragma_update(Some(Main), "cache_size", -i64::from(cache_size_kb))?; conn.pragma_update(Some(Main), "wal_autocheckpoint", 0)?; Ok(conn) } fn write_lock(&self) -> MutexGuard<'_, Connection> { self.writer.lock() } fn read_lock(&self) -> &Connection { self.read_conn_tls .get_or(|| Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap()) } fn read_lock_iterator(&self) -> &Connection { self.read_iterator_conn_tls .get_or(|| Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap()) } fn flush_wal(self: &Arc) -> Result<()> { self.write_lock() .pragma_update(Some(Main), "wal_checkpoint", "RESTART")?; Ok(()) } } impl KeyValueDatabaseEngine for Arc { fn open(config: &Config) -> Result { let path = Path::new(&config.database_path).join("conduit.db"); // calculates cache-size per permanent connection // 1. convert MB to KiB // 2. divide by permanent connections + permanent iter connections + write // connection // 3. round down to nearest integer #[allow( clippy::as_conversions, clippy::cast_possible_truncation, clippy::cast_precision_loss, clippy::cast_sign_loss )] let cache_size_per_thread = ((config.db_cache_capacity_mb * 1024.0) / (conduit::utils::available_parallelism() as f64).mul_add(2.0, 1.0)) as u32; let writer = Mutex::new(Engine::prepare_conn(&path, cache_size_per_thread)?); let arc = Self::new(Engine { writer, read_conn_tls: ThreadLocal::new(), read_iterator_conn_tls: ThreadLocal::new(), path, cache_size_per_thread, }); Ok(arc) } fn open_tree(&self, name: &str) -> Result> { self.write_lock().execute( &format!("CREATE TABLE IF NOT EXISTS {name} ( \"key\" BLOB PRIMARY KEY, \"value\" BLOB NOT NULL )"), [], )?; Ok(Arc::new(SqliteTable { engine: Self::clone(self), name: name.to_owned(), watchers: Watchers::default(), })) } fn flush(&self) -> Result<()> { // we enabled PRAGMA synchronous=normal, so this should not be necessary Ok(()) } fn cleanup(&self) -> Result<()> { self.flush_wal() } } struct SqliteTable { engine: Arc, name: String, watchers: Watchers, } type TupleOfBytes = (Vec, Vec); impl SqliteTable { fn get_with_guard(&self, guard: &Connection, key: &[u8]) -> Result>> { Ok(guard .prepare(format!("SELECT value FROM {} WHERE key = ?", self.name).as_str())? .query_row([key], |row| row.get(0)) .optional()?) } fn insert_with_guard(&self, guard: &Connection, key: &[u8], value: &[u8]) -> Result<()> { guard.execute( format!("INSERT OR REPLACE INTO {} (key, value) VALUES (?, ?)", self.name).as_str(), [key, value], )?; Ok(()) } fn iter_with_guard<'a>(&'a self, guard: &'a Connection) -> Box + 'a> { let statement = Box::leak(Box::new( guard .prepare(&format!("SELECT key, value FROM {} ORDER BY key ASC", &self.name)) .unwrap(), )); let statement_ref = AliasableBox(statement); //let name = self.name.clone(); let iterator = Box::new( statement .query_map([], |row| Ok((row.get_unwrap(0), row.get_unwrap(1)))) .unwrap() .map(Result::unwrap), ); Box::new(PreparedStatementIterator { iterator, _statement_ref: statement_ref, }) } } impl KvTree for SqliteTable { fn get(&self, key: &[u8]) -> Result>> { self.get_with_guard(self.engine.read_lock(), key) } fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> { let guard = self.engine.write_lock(); self.insert_with_guard(&guard, key, value)?; drop(guard); self.watchers.wake(key); Ok(()) } fn insert_batch(&self, iter: &mut dyn Iterator, Vec)>) -> Result<()> { let guard = self.engine.write_lock(); guard.execute("BEGIN", [])?; for (key, value) in iter { self.insert_with_guard(&guard, &key, &value)?; } guard.execute("COMMIT", [])?; drop(guard); Ok(()) } fn increment_batch(&self, iter: &mut dyn Iterator>) -> Result<()> { let guard = self.engine.write_lock(); guard.execute("BEGIN", [])?; for key in iter { let old = self.get_with_guard(&guard, &key)?; let new = conduit::utils::increment(old.as_deref()); self.insert_with_guard(&guard, &key, &new)?; } guard.execute("COMMIT", [])?; drop(guard); Ok(()) } fn remove(&self, key: &[u8]) -> Result<()> { let guard = self.engine.write_lock(); guard.execute(format!("DELETE FROM {} WHERE key = ?", self.name).as_str(), [key])?; Ok(()) } fn iter<'a>(&'a self) -> Box + 'a> { let guard = self.engine.read_lock_iterator(); self.iter_with_guard(guard) } fn iter_from<'a>(&'a self, from: &[u8], backwards: bool) -> Box + 'a> { let guard = self.engine.read_lock_iterator(); let from = from.to_vec(); // TODO change interface? //let name = self.name.clone(); if backwards { let statement = Box::leak(Box::new( guard .prepare(&format!( "SELECT key, value FROM {} WHERE key <= ? ORDER BY key DESC", &self.name )) .unwrap(), )); let statement_ref = AliasableBox(statement); let iterator = Box::new( statement .query_map([from], |row| Ok((row.get_unwrap(0), row.get_unwrap(1)))) .unwrap() .map(Result::unwrap), ); Box::new(PreparedStatementIterator { iterator, _statement_ref: statement_ref, }) } else { let statement = Box::leak(Box::new( guard .prepare(&format!( "SELECT key, value FROM {} WHERE key >= ? ORDER BY key ASC", &self.name )) .unwrap(), )); let statement_ref = AliasableBox(statement); let iterator = Box::new( statement .query_map([from], |row| Ok((row.get_unwrap(0), row.get_unwrap(1)))) .unwrap() .map(Result::unwrap), ); Box::new(PreparedStatementIterator { iterator, _statement_ref: statement_ref, }) } } fn increment(&self, key: &[u8]) -> Result> { let guard = self.engine.write_lock(); let old = self.get_with_guard(&guard, key)?; let new = conduit::utils::increment(old.as_deref()); self.insert_with_guard(&guard, key, &new)?; Ok(new) } fn scan_prefix<'a>(&'a self, prefix: Vec) -> Box + 'a> { Box::new( self.iter_from(&prefix, false) .take_while(move |(key, _)| key.starts_with(&prefix)), ) } fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin + Send + 'a>> { self.watchers.watch(prefix) } fn clear(&self) -> Result<()> { debug!("clear: running"); self.engine .write_lock() .execute(format!("DELETE FROM {}", self.name).as_str(), [])?; debug!("clear: ran"); Ok(()) } }