feat: added ACL commands

This commit is contained in:
NinekoTheCat 2023-12-24 17:46:24 +01:00
parent 13e497936f
commit 90232b894d
No known key found for this signature in database
GPG key ID: 700DB3F678A4AB66
3 changed files with 71 additions and 3 deletions

View file

@ -1,5 +1,6 @@
use std::collections::HashSet; use std::collections::HashSet;
use clap::ValueEnum;
use serde::{Serialize, Deserialize}; use serde::{Serialize, Deserialize};
use url::Host; use url::Host;
@ -17,11 +18,20 @@ pub trait Data: Send + Sync {
fn get_all_acls(&self) -> HashSet<AclDatabaseEntry>; fn get_all_acls(&self) -> HashSet<AclDatabaseEntry>;
} }
#[derive(Serialize,Deserialize, Debug, Clone, Copy, Hash, Eq, PartialEq)] #[derive(Serialize,Deserialize, Debug, Clone, Copy, Hash, Eq, PartialEq, ValueEnum)]
pub enum AclMode{ pub enum AclMode{
Block, Block,
Allow Allow
} }
impl AclMode {
pub fn to_emoji(&self)-> char {
match self {
AclMode::Block => '❎',
AclMode::Allow => '✅',
}
}
}
#[derive(Serialize,Deserialize, Debug, Clone, Hash, Eq,PartialEq)] #[derive(Serialize,Deserialize, Debug, Clone, Hash, Eq,PartialEq)]
pub struct AclDatabaseEntry { pub struct AclDatabaseEntry {

View file

@ -1,6 +1,6 @@
use std::{sync::Arc, collections::HashSet}; use std::sync::Arc;
use ruma::ServerName; use ruma::ServerName;
use tracing::{warn, debug, error}; use tracing::{warn, debug, error};

View file

@ -30,6 +30,8 @@ use ruma::{
}; };
use serde_json::value::to_raw_value; use serde_json::value::to_raw_value;
use tokio::sync::{mpsc, Mutex}; use tokio::sync::{mpsc, Mutex};
use tracing::{warn, error};
use url::Host;
use crate::{ use crate::{
api::client_server::{leave_all_rooms, AUTO_GEN_PASSWORD_LENGTH}, api::client_server::{leave_all_rooms, AUTO_GEN_PASSWORD_LENGTH},
@ -38,7 +40,7 @@ use crate::{
Error, PduEvent, Result, Error, PduEvent, Result,
}; };
use super::pdu::PduBuilder; use super::{pdu::PduBuilder, acl::AclMode};
const PAGE_SIZE: usize = 100; const PAGE_SIZE: usize = 100;
@ -72,6 +74,23 @@ enum AdminCommand {
// this is more like a "miscellaneous" category than a debug one // this is more like a "miscellaneous" category than a debug one
/// Commands for debugging things /// Commands for debugging things
Debug(DebugCommand), Debug(DebugCommand),
/// commands for manging ACL
#[command(subcommand)]
Acl(AclCommand),
}
#[cfg_attr(test, derive(Debug))]
#[derive(Subcommand)]
enum AclCommand {
Add {
mode: AclMode,
hostname: String
},
Remove {
hostname: String
},
List{
filter: Option<AclMode>
}
} }
#[cfg_attr(test, derive(Debug))] #[cfg_attr(test, derive(Debug))]
@ -1257,6 +1276,45 @@ impl Service {
) )
} }
}, },
AdminCommand::Acl(AclCommand::Add { mode, hostname }) => {
let host = match Host::parse(&hostname) {
Ok(host) => host,
Err(error) => return Ok(RoomMessageEventContent::text_plain(format!("failed to parse hostname with error {}",error))),
};
if let Err(error) = services().acl.add_acl(host.clone(), mode) {
error!("encountered {} while trying to add acl with host {} and mode {:?}",error,host,mode);
RoomMessageEventContent::text_plain("error, couldn't add acl")
} else {
RoomMessageEventContent::text_plain("successfully added ACL")
}
},
AdminCommand::Acl(AclCommand::Remove { hostname }) => {
let host = match Host::parse(&hostname) {
Ok(host) => host,
Err(error) => return Ok(RoomMessageEventContent::text_plain(format!("failed to parse hostname with error {}",error))),
};
if let Err(error) = services().acl.remove_acl(host.clone()) {
error!("encountered {} while trying to remove acl with host {}",error,host);
RoomMessageEventContent::text_plain("error, couldn't remove acl")
} else {
RoomMessageEventContent::text_plain("successfully removed ACL")
}
},
AdminCommand::Acl(AclCommand::List { filter}) => {
let results = services().acl.list_acls(filter);
let mut results_html = String::new();
results.iter().for_each(|it| {
results_html.push_str(&format!("* {} | {}\n",it.hostname,it.mode.to_emoji()));
});
RoomMessageEventContent::text_plain(format!("
List of services: \n
= blocked\n
= allowed\n
{}
",results_html))
},
}; };
Ok(reply_message_content) Ok(reply_message_content)