fmt: ran cargo format

This commit is contained in:
NinekoTheCat 2023-12-24 19:25:51 +01:00
parent cadc36700f
commit 0add1c7c52
No known key found for this signature in database
GPG key ID: 700DB3F678A4AB66
12 changed files with 189 additions and 140 deletions

View file

@ -156,7 +156,10 @@ where
debug!("Checking acl allowance for {}", destination); debug!("Checking acl allowance for {}", destination);
if !services().acl.is_federation_with_allowed_fedi_dest(&actual_destination) { if !services()
.acl
.is_federation_with_allowed_fedi_dest(&actual_destination)
{
debug!("blocked sending federation to {:?}", actual_destination); debug!("blocked sending federation to {:?}", actual_destination);
return Err(Error::ACLBlock(destination.to_owned())); return Err(Error::ACLBlock(destination.to_owned()));

View file

@ -1,11 +1,11 @@
use std::collections::HashSet;
use serde::Deserialize; use serde::Deserialize;
use std::collections::HashSet;
use url::Host; use url::Host;
#[derive(Deserialize,Debug, Default, Clone)] #[derive(Deserialize, Debug, Default, Clone)]
pub struct AccessControlListConfig { pub struct AccessControlListConfig {
/// setting this explicitly enables allowlists /// setting this explicitly enables allowlists
pub(crate)allow_list: Option<HashSet<Host<String>>>, pub(crate) allow_list: Option<HashSet<Host<String>>>,
#[serde(default)] #[serde(default)]
pub(crate)block_list: HashSet<Host<String>> pub(crate) block_list: HashSet<Host<String>>,
} }

View file

@ -2,7 +2,8 @@ use std::{
collections::BTreeMap, collections::BTreeMap,
fmt, fmt,
net::{IpAddr, Ipv4Addr}, net::{IpAddr, Ipv4Addr},
path::PathBuf, sync::Arc, path::PathBuf,
sync::Arc,
}; };
use figment::Figment; use figment::Figment;
@ -13,7 +14,7 @@ use tracing::{error, warn};
pub(crate) mod acl; pub(crate) mod acl;
mod proxy; mod proxy;
use self::{proxy::ProxyConfig, acl::AccessControlListConfig}; use self::{acl::AccessControlListConfig, proxy::ProxyConfig};
#[derive(Clone, Debug, Deserialize)] #[derive(Clone, Debug, Deserialize)]
pub struct Config { pub struct Config {

View file

@ -3,34 +3,43 @@ use std::collections::HashSet;
use tracing::warn; use tracing::warn;
use url::Host; use url::Host;
use crate::{service::acl::{Data, AclDatabaseEntry, AclMode}, KeyValueDatabase}; use crate::{
service::acl::{AclDatabaseEntry, AclMode, Data},
KeyValueDatabase,
};
impl Data for KeyValueDatabase { impl Data for KeyValueDatabase {
fn check_acl(&self,host: &Host<String> ) -> crate::Result<Option<AclMode>> { fn check_acl(&self, host: &Host<String>) -> crate::Result<Option<AclMode>> {
let thing = self.acl_list.get(host.to_string().as_bytes())?; let thing = self.acl_list.get(host.to_string().as_bytes())?;
if let Some(thing) = thing { if let Some(thing) = thing {
match thing.first() { match thing.first() {
Some(0x1) => Ok(Some(AclMode::Allow)), Some(0x1) => Ok(Some(AclMode::Allow)),
Some(0x0) => Ok(Some(AclMode::Block)), Some(0x0) => Ok(Some(AclMode::Block)),
Some(invalid) => { Some(invalid) => {
warn!("found invalid value for mode byte in value {}, probably db corruption", invalid); warn!(
"found invalid value for mode byte in value {}, probably db corruption",
invalid
);
Ok(None) Ok(None)
} }
None => Ok(None), None => Ok(None),
} }
}else { } else {
Ok(None) Ok(None)
} }
} }
fn add_acl(&self, acl: AclDatabaseEntry) -> crate::Result<()> { fn add_acl(&self, acl: AclDatabaseEntry) -> crate::Result<()> {
self.acl_list.insert(acl.hostname.to_string().as_bytes(), match acl.mode { self.acl_list.insert(
AclMode::Block => &[0x0], acl.hostname.to_string().as_bytes(),
AclMode::Allow => &[0x1], match acl.mode {
}) AclMode::Block => &[0x0],
AclMode::Allow => &[0x1],
},
)
} }
fn remove_acl(&self,host: Host<String>) -> crate::Result<()> { fn remove_acl(&self, host: Host<String>) -> crate::Result<()> {
self.acl_list.remove(host.to_string().as_bytes()) self.acl_list.remove(host.to_string().as_bytes())
} }
@ -49,12 +58,18 @@ impl Data for KeyValueDatabase {
Some(0x1) => AclMode::Allow, Some(0x1) => AclMode::Allow,
Some(0x0) => AclMode::Block, Some(0x0) => AclMode::Block,
Some(invalid) => { Some(invalid) => {
warn!("found invalid value for mode byte in value {}, probably db corruption", invalid); warn!(
"found invalid value for mode byte in value {}, probably db corruption",
invalid
);
return; return;
} }
None => return, None => return,
}; };
set.insert(AclDatabaseEntry { mode: mode, hostname: parsed_host }); set.insert(AclDatabaseEntry {
mode: mode,
hostname: parsed_host,
});
}); });
set set
} }

View file

@ -173,7 +173,7 @@ pub struct KeyValueDatabase {
pub(super) lasttimelinecount_cache: Mutex<HashMap<OwnedRoomId, PduCount>>, pub(super) lasttimelinecount_cache: Mutex<HashMap<OwnedRoomId, PduCount>>,
pub(super) presence_timer_sender: Arc<mpsc::UnboundedSender<(OwnedUserId, Duration)>>, pub(super) presence_timer_sender: Arc<mpsc::UnboundedSender<(OwnedUserId, Duration)>>,
pub(super) acl_list: Arc<dyn KvTree> pub(super) acl_list: Arc<dyn KvTree>,
} }
impl KeyValueDatabase { impl KeyValueDatabase {

View file

@ -1,40 +1,39 @@
use std::collections::HashSet; use std::collections::HashSet;
use clap::ValueEnum; use clap::ValueEnum;
use serde::{Serialize, Deserialize}; use serde::{Deserialize, Serialize};
use url::Host; use url::Host;
pub trait Data: Send + Sync { pub trait Data: Send + Sync {
/// check if given host exists in Acls, if so return it /// check if given host exists in Acls, if so return it
fn check_acl(&self,host: &Host<String> ) -> crate::Result<Option<AclMode>>; fn check_acl(&self, host: &Host<String>) -> crate::Result<Option<AclMode>>;
/// add a given Acl entry to the database /// add a given Acl entry to the database
fn add_acl(&self, acl: AclDatabaseEntry) -> crate::Result<()>; fn add_acl(&self, acl: AclDatabaseEntry) -> crate::Result<()>;
/// remove a given Acl entry from the database /// remove a given Acl entry from the database
fn remove_acl(&self,host: Host<String>) -> crate::Result<()>; fn remove_acl(&self, host: Host<String>) -> crate::Result<()>;
/// list all acls /// list all acls
fn get_all_acls(&self) -> HashSet<AclDatabaseEntry>; fn get_all_acls(&self) -> HashSet<AclDatabaseEntry>;
} }
#[derive(Serialize,Deserialize, Debug, Clone, Copy, Hash, Eq, PartialEq, ValueEnum)] #[derive(Serialize, Deserialize, Debug, Clone, Copy, Hash, Eq, PartialEq, ValueEnum)]
pub enum AclMode{ pub enum AclMode {
Block, Block,
Allow Allow,
} }
impl AclMode { impl AclMode {
pub fn to_emoji(&self)-> char { pub fn to_emoji(&self) -> char {
match self { match self {
AclMode::Block => '❎', AclMode::Block => '❎',
AclMode::Allow => '✅', AclMode::Allow => '✅',
} }
} }
} }
#[derive(Serialize,Deserialize, Debug, Clone, Hash, Eq,PartialEq)] #[derive(Serialize, Deserialize, Debug, Clone, Hash, Eq, PartialEq)]
pub struct AclDatabaseEntry { pub struct AclDatabaseEntry {
pub(crate) mode: AclMode, pub(crate) mode: AclMode,
pub(crate) hostname: Host pub(crate) hostname: Host,
} }

View file

@ -1,93 +1,107 @@
use std::sync::Arc; use std::sync::Arc;
use ruma::ServerName; use ruma::ServerName;
use tracing::{warn, debug, error}; use tracing::{debug, error, warn};
use url::Host; use url::Host;
use crate::{config::acl::AccessControlListConfig, api::server_server::FedDest}; use crate::{api::server_server::FedDest, config::acl::AccessControlListConfig};
pub use self::data::*; pub use self::data::*;
mod data; mod data;
pub struct Service { pub struct Service {
pub db: &'static dyn Data, pub db: &'static dyn Data,
pub acl_config: Arc<AccessControlListConfig> pub acl_config: Arc<AccessControlListConfig>,
} }
impl Service { impl Service {
pub fn list_acls(&self, filter: Option<AclMode>) -> Vec<AclDatabaseEntry> { pub fn list_acls(&self, filter: Option<AclMode>) -> Vec<AclDatabaseEntry> {
let mut set = self.db.get_all_acls(); let mut set = self.db.get_all_acls();
self.acl_config.allow_list.clone().unwrap_or_default().iter().for_each(|it| { self.acl_config
set.insert(AclDatabaseEntry { mode: AclMode::Allow, hostname: it.to_owned() }); .allow_list
.clone()
.unwrap_or_default()
.iter()
.for_each(|it| {
set.insert(AclDatabaseEntry {
mode: AclMode::Allow,
hostname: it.to_owned(),
});
}); });
self.acl_config.block_list.clone().iter().for_each(|it| { self.acl_config.block_list.clone().iter().for_each(|it| {
set.insert(AclDatabaseEntry { mode: AclMode::Block, hostname: it.to_owned() }); set.insert(AclDatabaseEntry {
mode: AclMode::Block,
hostname: it.to_owned(),
}); });
match filter { });
Some(filter) => set.into_iter().filter(|it| it.mode == filter).collect(), match filter {
None => set.into_iter().collect(), Some(filter) => set.into_iter().filter(|it| it.mode == filter).collect(),
} None => set.into_iter().collect(),
} }
pub fn remove_acl(&self, host: Host) -> crate::Result<()> { }
self.db.remove_acl(host) pub fn remove_acl(&self, host: Host) -> crate::Result<()> {
self.db.remove_acl(host)
}
pub fn add_acl(&self, host: Host, mode: AclMode) -> crate::Result<()> {
self.db.add_acl(AclDatabaseEntry {
mode: mode,
hostname: host,
})
}
/// same as federation_with_allowed however it can work with the fedi_dest type
pub fn is_federation_with_allowed_fedi_dest(&self, fedi_dest: &FedDest) -> bool {
let hostname = if let Ok(name) = Host::parse(&fedi_dest.hostname()) {
name
} else {
warn!(
"cannot deserialise hostname for server with name {:?}",
fedi_dest
);
return false;
};
return self.is_federation_with_allowed(hostname);
}
/// same as federation_with_allowed however it can work with the fedi_dest type
pub fn is_federation_with_allowed_server_name(&self, srv: &ServerName) -> bool {
let hostname = if let Ok(name) = Host::parse(srv.host()) {
name
} else {
warn!("cannot deserialise hostname for server with name {:?}", srv);
return false;
};
return self.is_federation_with_allowed(hostname);
}
/// is federation allowed with this particular server?
pub fn is_federation_with_allowed(&self, server_host_name: Host<String>) -> bool {
debug!("checking federation allowance for {}", server_host_name);
// check blocklist first
if self.acl_config.block_list.contains(&server_host_name) {
return false;
}
let mut allow_list_enabled = false;
// check allowlist
if let Some(list) = &self.acl_config.allow_list {
if list.contains(&server_host_name) {
return true;
}
allow_list_enabled = true;
} }
pub fn add_acl(&self, host: Host, mode: AclMode) -> crate::Result<()> { //check database
self.db.add_acl(AclDatabaseEntry { mode: mode, hostname: host }) match self.db.check_acl(&server_host_name) {
} Err(error) => {
/// same as federation_with_allowed however it can work with the fedi_dest type error!("database failed with {}", error);
pub fn is_federation_with_allowed_fedi_dest(&self,fedi_dest: &FedDest) -> bool { false
let hostname = if let Ok(name) = Host::parse(&fedi_dest.hostname()) {
name
} else {
warn!("cannot deserialise hostname for server with name {:?}",fedi_dest);
return false;
};
return self.is_federation_with_allowed(hostname);
}
/// same as federation_with_allowed however it can work with the fedi_dest type
pub fn is_federation_with_allowed_server_name(&self,srv: &ServerName) -> bool {
let hostname = if let Ok(name) = Host::parse(srv.host()) {
name
} else {
warn!("cannot deserialise hostname for server with name {:?}",srv);
return false;
};
return self.is_federation_with_allowed(hostname);
}
/// is federation allowed with this particular server?
pub fn is_federation_with_allowed(&self,server_host_name: Host<String>) -> bool {
debug!("checking federation allowance for {}", server_host_name);
// check blocklist first
if self.acl_config.block_list.contains(&server_host_name) {
return false;
} }
let mut allow_list_enabled = false; Ok(None) if allow_list_enabled => false,
// check allowlist Ok(None) => true,
if let Some(list) = &self.acl_config.allow_list { Ok(Some(data::AclMode::Block)) => false,
if list.contains(&server_host_name) { Ok(Some(data::AclMode::Allow)) if allow_list_enabled => true,
return true; Ok(Some(data::AclMode::Allow)) => {
} warn!("allowlist value found in database for {} but allow list is not enabled, denied request", server_host_name);
allow_list_enabled = true; false
}
//check database
match self.db.check_acl(&server_host_name) {
Err(error) => {
error!("database failed with {}",error);
false
}
Ok(None) if allow_list_enabled => false,
Ok(None) => true,
Ok(Some(data::AclMode::Block)) => false,
Ok(Some(data::AclMode::Allow)) if allow_list_enabled => true,
Ok(Some(data::AclMode::Allow)) => {
warn!("allowlist value found in database for {} but allow list is not enabled, denied request", server_host_name);
false
}
} }
} }
}
} }

View file

@ -40,7 +40,7 @@ use crate::{
Error, PduEvent, Result, Error, PduEvent, Result,
}; };
use super::{pdu::PduBuilder, acl::AclMode}; use super::{acl::AclMode, pdu::PduBuilder};
const PAGE_SIZE: usize = 100; const PAGE_SIZE: usize = 100;
@ -81,16 +81,9 @@ enum AdminCommand {
#[cfg_attr(test, derive(Debug))] #[cfg_attr(test, derive(Debug))]
#[derive(Subcommand)] #[derive(Subcommand)]
enum AclCommand { enum AclCommand {
Add { Add { mode: AclMode, hostname: String },
mode: AclMode, Remove { hostname: String },
hostname: String List { filter: Option<AclMode> },
},
Remove {
hostname: String
},
List{
filter: Option<AclMode>
}
} }
#[cfg_attr(test, derive(Debug))] #[cfg_attr(test, derive(Debug))]
@ -1279,42 +1272,60 @@ impl Service {
AdminCommand::Acl(AclCommand::Add { mode, hostname }) => { AdminCommand::Acl(AclCommand::Add { mode, hostname }) => {
let host = match Host::parse(&hostname) { let host = match Host::parse(&hostname) {
Ok(host) => host, Ok(host) => host,
Err(error) => return Ok(RoomMessageEventContent::text_plain(format!("failed to parse hostname with error {}",error))), Err(error) => {
return Ok(RoomMessageEventContent::text_plain(format!(
"failed to parse hostname with error {}",
error
)))
}
}; };
if let Err(error) = services().acl.add_acl(host.clone(), mode) { if let Err(error) = services().acl.add_acl(host.clone(), mode) {
error!("encountered {} while trying to add acl with host {} and mode {:?}",error,host,mode); error!(
"encountered {} while trying to add acl with host {} and mode {:?}",
error, host, mode
);
RoomMessageEventContent::text_plain("error, couldn't add acl") RoomMessageEventContent::text_plain("error, couldn't add acl")
} else { } else {
RoomMessageEventContent::text_plain("successfully added ACL") RoomMessageEventContent::text_plain("successfully added ACL")
} }
}
}, AdminCommand::Acl(AclCommand::Remove { hostname }) => {
AdminCommand::Acl(AclCommand::Remove { hostname }) => {
let host = match Host::parse(&hostname) { let host = match Host::parse(&hostname) {
Ok(host) => host, Ok(host) => host,
Err(error) => return Ok(RoomMessageEventContent::text_plain(format!("failed to parse hostname with error {}",error))), Err(error) => {
return Ok(RoomMessageEventContent::text_plain(format!(
"failed to parse hostname with error {}",
error
)))
}
}; };
if let Err(error) = services().acl.remove_acl(host.clone()) { if let Err(error) = services().acl.remove_acl(host.clone()) {
error!("encountered {} while trying to remove acl with host {}",error,host); error!(
"encountered {} while trying to remove acl with host {}",
error, host
);
RoomMessageEventContent::text_plain("error, couldn't remove acl") RoomMessageEventContent::text_plain("error, couldn't remove acl")
} else { } else {
RoomMessageEventContent::text_plain("successfully removed ACL") RoomMessageEventContent::text_plain("successfully removed ACL")
} }
}, }
AdminCommand::Acl(AclCommand::List { filter}) => { AdminCommand::Acl(AclCommand::List { filter }) => {
let results = services().acl.list_acls(filter); let results = services().acl.list_acls(filter);
let mut results_html = String::new(); let mut results_html = String::new();
results.iter().for_each(|it| { results.iter().for_each(|it| {
results_html.push_str(&format!("* {} | {}\n",it.hostname,it.mode.to_emoji())); results_html.push_str(&format!("* {} | {}\n", it.hostname, it.mode.to_emoji()));
}); });
RoomMessageEventContent::text_plain(format!(" RoomMessageEventContent::text_plain(format!(
"
List of services: \n List of services: \n
= blocked\n = blocked\n
= allowed\n = allowed\n
{} {}
",results_html)) ",
}, results_html
))
}
}; };
Ok(reply_message_content) Ok(reply_message_content)

View file

@ -6,8 +6,8 @@ use std::{
use lru_cache::LruCache; use lru_cache::LruCache;
use crate::{Config, Result}; use crate::{Config, Result};
pub mod acl;
pub mod account_data; pub mod account_data;
pub mod acl;
pub mod admin; pub mod admin;
pub mod appservice; pub mod appservice;
pub mod globals; pub mod globals;
@ -34,7 +34,7 @@ pub struct Services {
pub key_backups: key_backups::Service, pub key_backups: key_backups::Service,
pub media: media::Service, pub media: media::Service,
pub sending: Arc<sending::Service>, pub sending: Arc<sending::Service>,
pub acl: acl::Service pub acl: acl::Service,
} }
impl Services { impl Services {
@ -121,7 +121,10 @@ impl Services {
sending: sending::Service::build(db, &config), sending: sending::Service::build(db, &config),
globals: globals::Service::load(db, config)?, globals: globals::Service::load(db, config)?,
acl: acl::Service { db: db, acl_config: acl_conf }, acl: acl::Service {
db: db,
acl_config: acl_conf,
},
}) })
} }
fn memory_usage(&self) -> String { fn memory_usage(&self) -> String {

View file

@ -1645,7 +1645,10 @@ impl Service {
/// Returns Ok if the acl allows the server /// Returns Ok if the acl allows the server
pub fn acl_check(&self, server_name: &ServerName, room_id: &RoomId) -> Result<()> { pub fn acl_check(&self, server_name: &ServerName, room_id: &RoomId) -> Result<()> {
if !services().acl.is_federation_with_allowed_server_name(server_name) { if !services()
.acl
.is_federation_with_allowed_server_name(server_name)
{
info!( info!(
"Server {} was denied by server ACL in {}", "Server {} was denied by server ACL in {}",
server_name, room_id server_name, room_id

View file

@ -85,7 +85,7 @@ pub enum Error {
#[error("{0} in {1}")] #[error("{0} in {1}")]
InconsistentRoomState(&'static str, ruma::OwnedRoomId), InconsistentRoomState(&'static str, ruma::OwnedRoomId),
#[error("blocked {0}")] #[error("blocked {0}")]
ACLBlock(OwnedServerName) ACLBlock(OwnedServerName),
} }
impl Error { impl Error {