feat: added functions to remove, add and lsit acls (with filtering)

This commit is contained in:
NinekoTheCat 2023-12-24 16:39:25 +01:00
parent f1d725c842
commit 13e497936f
No known key found for this signature in database
GPG key ID: 700DB3F678A4AB66
3 changed files with 52 additions and 5 deletions

View file

@ -1,9 +1,12 @@
use std::collections::HashSet;
use tracing::warn; use tracing::warn;
use url::Host;
use crate::{service::acl::{Data, AclDatabaseEntry, AclMode}, KeyValueDatabase}; use crate::{service::acl::{Data, AclDatabaseEntry, AclMode}, KeyValueDatabase};
impl Data for KeyValueDatabase { impl Data for KeyValueDatabase {
fn check_acl(&self,host: &url::Host<String> ) -> crate::Result<Option<AclMode>> { fn check_acl(&self,host: &Host<String> ) -> crate::Result<Option<AclMode>> {
let thing = self.acl_list.get(host.to_string().as_bytes())?; let thing = self.acl_list.get(host.to_string().as_bytes())?;
if let Some(thing) = thing { if let Some(thing) = thing {
match thing.first() { match thing.first() {
@ -27,7 +30,32 @@ impl Data for KeyValueDatabase {
}) })
} }
fn remove_acl(&self,host: url::Host<String>) -> crate::Result<()> { fn remove_acl(&self,host: Host<String>) -> crate::Result<()> {
self.acl_list.remove(host.to_string().as_bytes()) self.acl_list.remove(host.to_string().as_bytes())
} }
fn get_all_acls(&self) -> HashSet<AclDatabaseEntry> {
let mut set = HashSet::new();
self.acl_list.iter().for_each(|it| {
let Ok(key) = String::from_utf8(it.0) else {
return;
};
let Ok(parsed_host) = Host::parse(&key) else {
warn!("failed to parse host {}", key);
return;
};
let mode = match it.1.first() {
Some(0x1) => AclMode::Allow,
Some(0x0) => AclMode::Block,
Some(invalid) => {
warn!("found invalid value for mode byte in value {}, probably db corruption", invalid);
return;
}
None => return,
};
set.insert(AclDatabaseEntry { mode: mode, hostname: parsed_host });
});
set
}
} }

View file

@ -1,3 +1,5 @@
use std::collections::HashSet;
use serde::{Serialize, Deserialize}; use serde::{Serialize, Deserialize};
use url::Host; use url::Host;
@ -10,14 +12,17 @@ pub trait Data: Send + Sync {
fn add_acl(&self, acl: AclDatabaseEntry) -> crate::Result<()>; fn add_acl(&self, acl: AclDatabaseEntry) -> crate::Result<()>;
/// remove a given Acl entry from the database /// remove a given Acl entry from the database
fn remove_acl(&self,host: Host<String>) -> crate::Result<()>; fn remove_acl(&self,host: Host<String>) -> crate::Result<()>;
/// list all acls
fn get_all_acls(&self) -> HashSet<AclDatabaseEntry>;
} }
#[derive(Serialize,Deserialize, Debug, Clone, Copy)] #[derive(Serialize,Deserialize, Debug, Clone, Copy, Hash, Eq, PartialEq)]
pub enum AclMode{ pub enum AclMode{
Block, Block,
Allow Allow
} }
#[derive(Serialize,Deserialize, Debug, Clone)] #[derive(Serialize,Deserialize, Debug, Clone, Hash, Eq,PartialEq)]
pub struct AclDatabaseEntry { pub struct AclDatabaseEntry {
pub(crate) mode: AclMode, pub(crate) mode: AclMode,

View file

@ -1,6 +1,6 @@
use std::sync::Arc; use std::{sync::Arc, collections::HashSet};
use ruma::ServerName; use ruma::ServerName;
use tracing::{warn, debug, error}; use tracing::{warn, debug, error};
@ -16,6 +16,20 @@ pub struct Service {
} }
impl Service { impl Service {
pub fn list_acls(&self, filter: Option<AclMode>) -> Vec<AclDatabaseEntry> {
let set = self.db.get_all_acls();
match filter {
Some(filter) => set.into_iter().filter(|it| it.mode == filter).collect(),
None => set.into_iter().collect(),
}
}
pub fn remove_acl(&self, host: Host) -> crate::Result<()> {
self.db.remove_acl(host)
}
pub fn add_acl(&self, host: Host, mode: AclMode) -> crate::Result<()> {
self.db.add_acl(AclDatabaseEntry { mode: mode, hostname: host })
}
/// same as federation_with_allowed however it can work with the fedi_dest type /// same as federation_with_allowed however it can work with the fedi_dest type
pub fn is_federation_with_allowed_fedi_dest(&self,fedi_dest: &FedDest) -> bool { pub fn is_federation_with_allowed_fedi_dest(&self,fedi_dest: &FedDest) -> bool {
let hostname = if let Ok(name) = Host::parse(&fedi_dest.hostname()) { let hostname = if let Ok(name) = Host::parse(&fedi_dest.hostname()) {