diff --git a/src/database/key_value/acl.rs b/src/database/key_value/acl.rs index 6316522c..17c560f3 100644 --- a/src/database/key_value/acl.rs +++ b/src/database/key_value/acl.rs @@ -1,9 +1,12 @@ +use std::collections::HashSet; + use tracing::warn; +use url::Host; use crate::{service::acl::{Data, AclDatabaseEntry, AclMode}, KeyValueDatabase}; impl Data for KeyValueDatabase { - fn check_acl(&self,host: &url::Host ) -> crate::Result> { + fn check_acl(&self,host: &Host ) -> crate::Result> { let thing = self.acl_list.get(host.to_string().as_bytes())?; if let Some(thing) = thing { match thing.first() { @@ -27,7 +30,32 @@ impl Data for KeyValueDatabase { }) } - fn remove_acl(&self,host: url::Host) -> crate::Result<()> { + fn remove_acl(&self,host: Host) -> crate::Result<()> { self.acl_list.remove(host.to_string().as_bytes()) } + + fn get_all_acls(&self) -> HashSet { + let mut set = HashSet::new(); + + self.acl_list.iter().for_each(|it| { + let Ok(key) = String::from_utf8(it.0) else { + return; + }; + let Ok(parsed_host) = Host::parse(&key) else { + warn!("failed to parse host {}", key); + return; + }; + let mode = match it.1.first() { + Some(0x1) => AclMode::Allow, + Some(0x0) => AclMode::Block, + Some(invalid) => { + warn!("found invalid value for mode byte in value {}, probably db corruption", invalid); + return; + } + None => return, + }; + set.insert(AclDatabaseEntry { mode: mode, hostname: parsed_host }); + }); + set + } } \ No newline at end of file diff --git a/src/service/acl/data.rs b/src/service/acl/data.rs index c14575b7..30fd5442 100644 --- a/src/service/acl/data.rs +++ b/src/service/acl/data.rs @@ -1,3 +1,5 @@ +use std::collections::HashSet; + use serde::{Serialize, Deserialize}; use url::Host; @@ -10,14 +12,17 @@ pub trait Data: Send + Sync { fn add_acl(&self, acl: AclDatabaseEntry) -> crate::Result<()>; /// remove a given Acl entry from the database fn remove_acl(&self,host: Host) -> crate::Result<()>; + + /// list all acls + fn get_all_acls(&self) -> HashSet; } -#[derive(Serialize,Deserialize, Debug, Clone, Copy)] +#[derive(Serialize,Deserialize, Debug, Clone, Copy, Hash, Eq, PartialEq)] pub enum AclMode{ Block, Allow } -#[derive(Serialize,Deserialize, Debug, Clone)] +#[derive(Serialize,Deserialize, Debug, Clone, Hash, Eq,PartialEq)] pub struct AclDatabaseEntry { pub(crate) mode: AclMode, diff --git a/src/service/acl/mod.rs b/src/service/acl/mod.rs index b9352e2e..4178181c 100644 --- a/src/service/acl/mod.rs +++ b/src/service/acl/mod.rs @@ -1,6 +1,6 @@ -use std::sync::Arc; +use std::{sync::Arc, collections::HashSet}; use ruma::ServerName; use tracing::{warn, debug, error}; @@ -16,6 +16,20 @@ pub struct Service { } impl Service { + pub fn list_acls(&self, filter: Option) -> Vec { + let set = self.db.get_all_acls(); + match filter { + Some(filter) => set.into_iter().filter(|it| it.mode == filter).collect(), + None => set.into_iter().collect(), + } + } + pub fn remove_acl(&self, host: Host) -> crate::Result<()> { + self.db.remove_acl(host) + } + + pub fn add_acl(&self, host: Host, mode: AclMode) -> crate::Result<()> { + self.db.add_acl(AclDatabaseEntry { mode: mode, hostname: host }) + } /// same as federation_with_allowed however it can work with the fedi_dest type pub fn is_federation_with_allowed_fedi_dest(&self,fedi_dest: &FedDest) -> bool { let hostname = if let Ok(name) = Host::parse(&fedi_dest.hostname()) {