diff --git a/src/api/client/backup.rs b/src/api/client/backup.rs index 2ad37cf3..a3038f26 100644 --- a/src/api/client/backup.rs +++ b/src/api/client/backup.rs @@ -2,8 +2,10 @@ use std::cmp::Ordering; use axum::extract::State; use conduwuit::{Err, Result, err}; +use conduwuit_service::Services; +use futures::{FutureExt, future::try_join}; use ruma::{ - UInt, + UInt, UserId, api::client::backup::{ add_backup_keys, add_backup_keys_for_room, add_backup_keys_for_session, create_backup_version, delete_backup_keys, delete_backup_keys_for_room, @@ -58,21 +60,9 @@ pub(crate) async fn get_latest_backup_info_route( .await .map_err(|_| err!(Request(NotFound("Key backup does not exist."))))?; - Ok(get_latest_backup_info::v3::Response { - algorithm, - count: (UInt::try_from( - services - .key_backups - .count_keys(body.sender_user(), &version) - .await, - ) - .expect("user backup keys count should not be that high")), - etag: services - .key_backups - .get_etag(body.sender_user(), &version) - .await, - version, - }) + let (count, etag) = get_count_etag(&services, body.sender_user(), &version).await?; + + Ok(get_latest_backup_info::v3::Response { algorithm, count, etag, version }) } /// # `GET /_matrix/client/v3/room_keys/version/{version}` @@ -90,17 +80,12 @@ pub(crate) async fn get_backup_info_route( err!(Request(NotFound("Key backup does not exist at version {:?}", body.version))) })?; + let (count, etag) = get_count_etag(&services, body.sender_user(), &body.version).await?; + Ok(get_backup_info::v3::Response { algorithm, - count: services - .key_backups - .count_keys(body.sender_user(), &body.version) - .await - .try_into()?, - etag: services - .key_backups - .get_etag(body.sender_user(), &body.version) - .await, + count, + etag, version: body.version.clone(), }) } @@ -155,17 +140,9 @@ pub(crate) async fn add_backup_keys_route( } } - Ok(add_backup_keys::v3::Response { - count: services - .key_backups - .count_keys(body.sender_user(), &body.version) - .await - .try_into()?, - etag: services - .key_backups - .get_etag(body.sender_user(), &body.version) - .await, - }) + let (count, etag) = get_count_etag(&services, body.sender_user(), &body.version).await?; + + Ok(add_backup_keys::v3::Response { count, etag }) } /// # `PUT /_matrix/client/r0/room_keys/keys/{roomId}` @@ -198,17 +175,9 @@ pub(crate) async fn add_backup_keys_for_room_route( .await?; } - Ok(add_backup_keys_for_room::v3::Response { - count: services - .key_backups - .count_keys(body.sender_user(), &body.version) - .await - .try_into()?, - etag: services - .key_backups - .get_etag(body.sender_user(), &body.version) - .await, - }) + let (count, etag) = get_count_etag(&services, body.sender_user(), &body.version).await?; + + Ok(add_backup_keys_for_room::v3::Response { count, etag }) } /// # `PUT /_matrix/client/r0/room_keys/keys/{roomId}/{sessionId}` @@ -306,17 +275,9 @@ pub(crate) async fn add_backup_keys_for_session_route( .await?; } - Ok(add_backup_keys_for_session::v3::Response { - count: services - .key_backups - .count_keys(body.sender_user(), &body.version) - .await - .try_into()?, - etag: services - .key_backups - .get_etag(body.sender_user(), &body.version) - .await, - }) + let (count, etag) = get_count_etag(&services, body.sender_user(), &body.version).await?; + + Ok(add_backup_keys_for_session::v3::Response { count, etag }) } /// # `GET /_matrix/client/r0/room_keys/keys` @@ -379,17 +340,9 @@ pub(crate) async fn delete_backup_keys_route( .delete_all_keys(body.sender_user(), &body.version) .await; - Ok(delete_backup_keys::v3::Response { - count: services - .key_backups - .count_keys(body.sender_user(), &body.version) - .await - .try_into()?, - etag: services - .key_backups - .get_etag(body.sender_user(), &body.version) - .await, - }) + let (count, etag) = get_count_etag(&services, body.sender_user(), &body.version).await?; + + Ok(delete_backup_keys::v3::Response { count, etag }) } /// # `DELETE /_matrix/client/r0/room_keys/keys/{roomId}` @@ -404,17 +357,9 @@ pub(crate) async fn delete_backup_keys_for_room_route( .delete_room_keys(body.sender_user(), &body.version, &body.room_id) .await; - Ok(delete_backup_keys_for_room::v3::Response { - count: services - .key_backups - .count_keys(body.sender_user(), &body.version) - .await - .try_into()?, - etag: services - .key_backups - .get_etag(body.sender_user(), &body.version) - .await, - }) + let (count, etag) = get_count_etag(&services, body.sender_user(), &body.version).await?; + + Ok(delete_backup_keys_for_room::v3::Response { count, etag }) } /// # `DELETE /_matrix/client/r0/room_keys/keys/{roomId}/{sessionId}` @@ -429,15 +374,22 @@ pub(crate) async fn delete_backup_keys_for_session_route( .delete_room_key(body.sender_user(), &body.version, &body.room_id, &body.session_id) .await; - Ok(delete_backup_keys_for_session::v3::Response { - count: services - .key_backups - .count_keys(body.sender_user(), &body.version) - .await - .try_into()?, - etag: services - .key_backups - .get_etag(body.sender_user(), &body.version) - .await, - }) + let (count, etag) = get_count_etag(&services, body.sender_user(), &body.version).await?; + + Ok(delete_backup_keys_for_session::v3::Response { count, etag }) +} + +async fn get_count_etag( + services: &Services, + sender_user: &UserId, + version: &str, +) -> Result<(UInt, String)> { + let count = services + .key_backups + .count_keys(sender_user, version) + .map(TryInto::try_into); + + let etag = services.key_backups.get_etag(sender_user, version).map(Ok); + + Ok(try_join(count, etag).await?) }