diff --git a/src/api/client_server/message.rs b/src/api/client_server/message.rs index fd4cef90..2aab4227 100644 --- a/src/api/client_server/message.rs +++ b/src/api/client_server/message.rs @@ -6,13 +6,12 @@ use std::{ use ruma::{ api::client::{ error::ErrorKind, - filter::{RoomEventFilter, UrlFilter}, message::{get_message_events, send_message_event}, }, events::{MessageLikeEventType, StateEventType}, RoomId, UserId, }; -use serde_json::{from_str, Value}; +use serde_json::from_str; use crate::{ service::{pdu::PduBuilder, rooms::timeline::PduCount}, @@ -176,7 +175,7 @@ pub(crate) async fn get_message_events_route( .timeline .pdus_after(sender_user, &body.room_id, from)? .filter_map(Result::ok) // Filter out buggy events - .filter(|(_, pdu)| contains_url_filter(pdu, &body.filter)) + .filter(|(_, pdu)| filter.pdu_event_allowed(pdu)) .filter(|(_, pdu)| visibility_filter(pdu, sender_user, &body.room_id)) .take_while(|&(k, _)| Some(k) != to) // Stop at `to` .take(limit) @@ -222,7 +221,7 @@ pub(crate) async fn get_message_events_route( .timeline .pdus_until(sender_user, &body.room_id, from)? .filter_map(Result::ok) // Filter out buggy events - .filter(|(_, pdu)| contains_url_filter(pdu, &body.filter)) + .filter(|(_, pdu)| filter.pdu_event_allowed(pdu)) .filter(|(_, pdu)| visibility_filter(pdu, sender_user, &body.room_id)) .take_while(|&(k, _)| Some(k) != to) // Stop at `to` .take(limit) @@ -291,16 +290,3 @@ fn visibility_filter(pdu: &PduEvent, user_id: &UserId, room_id: &RoomId) -> bool .user_can_see_event(user_id, room_id, &pdu.event_id) .unwrap_or(false) } - -fn contains_url_filter(pdu: &PduEvent, filter: &RoomEventFilter) -> bool { - if filter.url_filter.is_none() { - return true; - } - - let content: Value = from_str(pdu.content.get()).unwrap(); - match filter.url_filter { - Some(UrlFilter::EventsWithoutUrl) => !content["url"].is_string(), - Some(UrlFilter::EventsWithUrl) => content["url"].is_string(), - None => true, - } -} diff --git a/src/utils/filter.rs b/src/utils/filter.rs index 3e5b4131..c102e899 100644 --- a/src/utils/filter.rs +++ b/src/utils/filter.rs @@ -14,9 +14,12 @@ use std::{collections::HashSet, hash::Hash}; -use ruma::{api::client::filter::RoomEventFilter, RoomId}; +use ruma::{ + api::client::filter::{RoomEventFilter, UrlFilter}, + RoomId, +}; -use crate::Error; +use crate::{Error, PduEvent}; /// Structure for testing against an allowlist and a denylist with a single /// `HashSet` lookup. @@ -60,6 +63,7 @@ impl<'a, T: ?Sized + Hash + PartialEq + Eq> AllowDenyList<'a, T> { pub(crate) struct CompiledRoomEventFilter<'a> { rooms: AllowDenyList<'a, RoomId>, + url_filter: Option, } impl<'a> TryFrom<&'a RoomEventFilter> for CompiledRoomEventFilter<'a> { @@ -68,6 +72,7 @@ impl<'a> TryFrom<&'a RoomEventFilter> for CompiledRoomEventFilter<'a> { fn try_from(source: &'a RoomEventFilter) -> Result, Error> { Ok(CompiledRoomEventFilter { rooms: AllowDenyList::from_slices(source.rooms.as_deref(), &source.not_rooms), + url_filter: source.url_filter, }) } } @@ -81,4 +86,26 @@ impl CompiledRoomEventFilter<'_> { /// rejected by the top-level filter using /// [`CompiledRoomFilter::room_allowed`], if applicable. pub(crate) fn room_allowed(&self, room_id: &RoomId) -> bool { self.rooms.allowed(room_id) } + + /// Returns `true` if a PDU event is allowed by the filter. + /// + /// This tests against the `url_filter` field. + /// + /// This does *not* check whether the event's room is allowed. It is + /// expected that callers have already filtered out rejected rooms using + /// [`CompiledRoomEventFilter::room_allowed`] and + /// [`CompiledRoomFilter::room_allowed`]. + pub(crate) fn pdu_event_allowed(&self, pdu: &PduEvent) -> bool { self.allowed_by_url_filter(pdu) } + + fn allowed_by_url_filter(&self, pdu: &PduEvent) -> bool { + let Some(filter) = self.url_filter else { + return true; + }; + // TODO: is this unwrap okay? + let content: serde_json::Value = serde_json::from_str(pdu.content.get()).unwrap(); + match filter { + UrlFilter::EventsWithoutUrl => !content["url"].is_string(), + UrlFilter::EventsWithUrl => content["url"].is_string(), + } + } }