diff --git a/Cargo.lock b/Cargo.lock index f332321c..ff11e4c2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -451,6 +451,7 @@ dependencies = [ "tracing-opentelemetry", "tracing-subscriber", "trust-dns-resolver", + "url", ] [[package]] @@ -3274,6 +3275,7 @@ dependencies = [ "form_urlencoded", "idna", "percent-encoding", + "serde", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index e9dc01cd..ee72700f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -58,6 +58,8 @@ rand = "0.8.5" # Used to hash passwords rust-argon2 = { git = "https://github.com/sru-systems/rust-argon2", rev = "e6cb5bf99643e565f4f0d103960d655dac9f3097" } reqwest = { version = "0.11.22", default-features = false, features = ["rustls-tls-native-roots", "socks"] } +# Used to validate hostnames, already included in reqwest however we need access to it +url = {version = "^2", features = ["serde"]} # Used for conduit::Error type thiserror = "1.0.51" # Used to generate thumbnails for images diff --git a/src/api/server_server.rs b/src/api/server_server.rs index ee71c2b1..969ed223 100644 --- a/src/api/server_server.rs +++ b/src/api/server_server.rs @@ -99,7 +99,7 @@ impl FedDest { } } - fn hostname(&self) -> String { + pub(crate) fn hostname(&self) -> String { match &self { Self::Literal(addr) => addr.ip().to_string(), Self::Named(host, _) => host.clone(), @@ -154,6 +154,14 @@ where (result.0, result.1.into_uri_string()) }; + debug!("Checking acl allowance for {}", destination); + + if !services().acl.is_federation_with_allowed_fedi_dest(&actual_destination) { + debug!("blocked sending federation to {:?}", actual_destination); + + return Err(Error::ACLBlock(destination.to_owned())); + } + let actual_destination_str = actual_destination.clone().into_https_string(); let mut http_request = request diff --git a/src/config/acl.rs b/src/config/acl.rs new file mode 100644 index 00000000..d580d5a5 --- /dev/null +++ b/src/config/acl.rs @@ -0,0 +1,11 @@ +use std::collections::HashSet; +use serde::Deserialize; +use url::Host; +#[derive(Deserialize,Debug, Default, Clone)] +pub struct AccessControlListConfig { + /// setting this explicitly enables allowlists + pub(crate)allow_list: Option>>, + + #[serde(default)] + pub(crate)block_list: HashSet> +} \ No newline at end of file diff --git a/src/config/mod.rs b/src/config/mod.rs index 21cf6bf9..af0638d4 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -2,7 +2,7 @@ use std::{ collections::BTreeMap, fmt, net::{IpAddr, Ipv4Addr}, - path::PathBuf, + path::PathBuf, sync::Arc, }; use figment::Figment; @@ -10,9 +10,10 @@ use ruma::{OwnedServerName, RoomVersionId}; use serde::{de::IgnoredAny, Deserialize}; use tracing::{error, warn}; +pub(crate) mod acl; mod proxy; -use self::proxy::ProxyConfig; +use self::{proxy::ProxyConfig, acl::AccessControlListConfig}; #[derive(Clone, Debug, Deserialize)] pub struct Config { @@ -122,6 +123,9 @@ pub struct Config { #[serde(default = "false_fn")] pub allow_guest_registration: bool, + #[serde(default)] + pub acl: Arc, + #[serde(flatten)] pub catchall: BTreeMap, } diff --git a/src/database/key_value/acl.rs b/src/database/key_value/acl.rs new file mode 100644 index 00000000..6316522c --- /dev/null +++ b/src/database/key_value/acl.rs @@ -0,0 +1,33 @@ +use tracing::warn; + +use crate::{service::acl::{Data, AclDatabaseEntry, AclMode}, KeyValueDatabase}; + +impl Data for KeyValueDatabase { + fn check_acl(&self,host: &url::Host ) -> crate::Result> { + let thing = self.acl_list.get(host.to_string().as_bytes())?; + if let Some(thing) = thing { + match thing.first() { + Some(0x1) => Ok(Some(AclMode::Allow)), + Some(0x0) => Ok(Some(AclMode::Block)), + Some(invalid) => { + warn!("found invalid value for mode byte in value {}, probably db corruption", invalid); + Ok(None) + } + None => Ok(None), + } + }else { + Ok(None) + } + } + + fn add_acl(&self, acl: AclDatabaseEntry) -> crate::Result<()> { + self.acl_list.insert(acl.hostname.to_string().as_bytes(), match acl.mode { + AclMode::Block => &[0x0], + AclMode::Allow => &[0x1], + }) + } + + fn remove_acl(&self,host: url::Host) -> crate::Result<()> { + self.acl_list.remove(host.to_string().as_bytes()) + } +} \ No newline at end of file diff --git a/src/database/key_value/mod.rs b/src/database/key_value/mod.rs index c4496af8..d8581e4a 100644 --- a/src/database/key_value/mod.rs +++ b/src/database/key_value/mod.rs @@ -11,3 +11,5 @@ mod sending; mod transaction_ids; mod uiaa; mod users; + +mod acl; \ No newline at end of file diff --git a/src/database/mod.rs b/src/database/mod.rs index 79503f1b..8ce53f8d 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -172,6 +172,8 @@ pub struct KeyValueDatabase { pub(super) appservice_in_room_cache: RwLock>>, pub(super) lasttimelinecount_cache: Mutex>, pub(super) presence_timer_sender: Arc>, + + pub(super) acl_list: Arc } impl KeyValueDatabase { @@ -281,6 +283,7 @@ impl KeyValueDatabase { let db_raw = Box::new(Self { _db: builder.clone(), + acl_list: builder.open_tree("acl")?, userid_password: builder.open_tree("userid_password")?, userid_displayname: builder.open_tree("userid_displayname")?, userid_avatarurl: builder.open_tree("userid_avatarurl")?, diff --git a/src/service/acl/data.rs b/src/service/acl/data.rs new file mode 100644 index 00000000..c14575b7 --- /dev/null +++ b/src/service/acl/data.rs @@ -0,0 +1,25 @@ +use serde::{Serialize, Deserialize}; +use url::Host; + + +pub trait Data: Send + Sync { + /// check if given host exists in Acls, if so return it + fn check_acl(&self,host: &Host ) -> crate::Result>; + + /// add a given Acl entry to the database + fn add_acl(&self, acl: AclDatabaseEntry) -> crate::Result<()>; + /// remove a given Acl entry from the database + fn remove_acl(&self,host: Host) -> crate::Result<()>; +} + +#[derive(Serialize,Deserialize, Debug, Clone, Copy)] +pub enum AclMode{ + Block, + Allow +} +#[derive(Serialize,Deserialize, Debug, Clone)] + +pub struct AclDatabaseEntry { + pub(crate) mode: AclMode, + pub(crate) hostname: Host +} \ No newline at end of file diff --git a/src/service/acl/mod.rs b/src/service/acl/mod.rs new file mode 100644 index 00000000..c3a72b3e --- /dev/null +++ b/src/service/acl/mod.rs @@ -0,0 +1,72 @@ + + +use std::sync::Arc; + +use ruma::ServerName; +use tracing::{warn, debug, error}; +use url::Host; + +use crate::{config::acl::AccessControlListConfig, api::server_server::FedDest}; + +pub use self::data::*; +mod data; +pub struct Service { + pub db: &'static dyn Data, + pub acl_config: Arc +} + +impl Service { + /// same as federation_with_allowed however it can work with the fedi_dest type + pub fn is_federation_with_allowed_fedi_dest(&self,fedi_dest: &FedDest) -> bool { + let hostname = if let Ok(name) = Host::parse(&fedi_dest.hostname()) { + name + } else { + warn!("cannot deserialise hostname for server with name {:?}",fedi_dest); + return false; + }; + return self.is_federation_with_allowed(hostname); + } + + /// same as federation_with_allowed however it can work with the fedi_dest type + pub fn is_federation_with_allowed_server_name(&self,srv: &ServerName) -> bool { + let hostname = if let Ok(name) = Host::parse(srv.host()) { + name + } else { + warn!("cannot deserialise hostname for server with name {:?}",srv); + return false; + }; + return self.is_federation_with_allowed(hostname); + } + /// is federation allowed with this particular server? + pub fn is_federation_with_allowed(&self,server_host_name: Host) -> bool { + debug!("checking federation allowance for {}", server_host_name); + // check blocklist first + if self.acl_config.block_list.contains(&server_host_name) { + return false; + } + let mut allow_list_enabled = false; + // check allowlist + if let Some(list) = &self.acl_config.allow_list { + if list.contains(&server_host_name) { + return true; + } + allow_list_enabled = true; + } + + //check database + match self.db.check_acl(&server_host_name) { + Err(error) => { + error!("database failed with {}",error); + false + } + Ok(None) => false, + Ok(Some(data::AclMode::Block)) => false, + Ok(Some(data::AclMode::Allow)) if allow_list_enabled => true, + Ok(Some(data::AclMode::Allow)) => { + warn!("allowlist value found in database for {} but allow list is not enabled, denied request", server_host_name); + false + } + + } + } +} \ No newline at end of file diff --git a/src/service/mod.rs b/src/service/mod.rs index 74f120f9..a2d71e95 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -6,7 +6,7 @@ use std::{ use lru_cache::LruCache; use crate::{Config, Result}; - +pub mod acl; pub mod account_data; pub mod admin; pub mod appservice; @@ -34,6 +34,7 @@ pub struct Services { pub key_backups: key_backups::Service, pub media: media::Service, pub sending: Arc, + pub acl: acl::Service } impl Services { @@ -49,11 +50,13 @@ impl Services { + key_backups::Data + media::Data + sending::Data + + acl::Data + 'static, >( db: &'static D, config: Config, ) -> Result { + let acl_conf = config.acl.clone(); Ok(Self { appservice: appservice::Service { db }, pusher: pusher::Service { db }, @@ -118,6 +121,7 @@ impl Services { sending: sending::Service::build(db, &config), globals: globals::Service::load(db, config)?, + acl: acl::Service { db: db, acl_config: acl_conf }, }) } fn memory_usage(&self) -> String { diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index 729a7a2b..90605aa1 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -1645,6 +1645,17 @@ impl Service { /// Returns Ok if the acl allows the server pub fn acl_check(&self, server_name: &ServerName, room_id: &RoomId) -> Result<()> { + if !services().acl.is_federation_with_allowed_server_name(server_name) { + info!( + "Server {} was denied by server ACL in {}", + server_name, room_id + ); + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "Server was denied by Server ACL", + )); + } + let acl_event = match services().rooms.state_accessor.room_state_get( room_id, &StateEventType::RoomServerAcl, diff --git a/src/utils/error.rs b/src/utils/error.rs index d821fe66..6beb1b32 100644 --- a/src/utils/error.rs +++ b/src/utils/error.rs @@ -84,6 +84,8 @@ pub enum Error { RedactionError(OwnedServerName, ruma::canonical_json::RedactionError), #[error("{0} in {1}")] InconsistentRoomState(&'static str, ruma::OwnedRoomId), + #[error("blocked {0}")] + ACLBlock(OwnedServerName) } impl Error {