diff --git a/src/api/client/keys.rs b/src/api/client/keys.rs index 8dde57cbf..bc05ee9db 100644 --- a/src/api/client/keys.rs +++ b/src/api/client/keys.rs @@ -64,6 +64,27 @@ pub(crate) async fn upload_keys_route( .await?; } + for (key_id, fallback_key) in &body.fallback_keys { + if fallback_key + .deserialize() + .inspect_err(|e| { + debug_warn!( + %key_id, + ?fallback_key, + "Invalid one time key JSON submitted by client, skipping: {e}" + ); + }) + .is_err() + { + continue; + } + + services + .users + .add_fallback_key(sender_user, sender_device, key_id, fallback_key, false) + .await?; + } + if let Some(device_keys) = &body.device_keys { let deser_device_keys = device_keys.deserialize().map_err(|e| { err!(Request(BadJson(debug_warn!( diff --git a/src/api/client/sync/v3/mod.rs b/src/api/client/sync/v3/mod.rs index 0d4644230..1c3612b97 100644 --- a/src/api/client/sync/v3/mod.rs +++ b/src/api/client/sync/v3/mod.rs @@ -395,6 +395,10 @@ pub(crate) async fn build_sync_events( .users .count_one_time_keys(syncing_user, syncing_device); + let unused_fallback_key_types = services + .users + .list_unused_fallback_key_types(syncing_user, syncing_device); + let ( (joined_rooms, mut device_list_updates), left_rooms, @@ -405,6 +409,7 @@ pub(crate) async fn build_sync_events( to_device_events, keys_changed, device_one_time_keys_count, + unused_fallback_key_types, ) = async { futures::join!( joined_rooms, @@ -415,7 +420,8 @@ pub(crate) async fn build_sync_events( account_data, to_device_events, keys_changed, - device_one_time_keys_count + device_one_time_keys_count, + unused_fallback_key_types, ) } .boxed() @@ -433,8 +439,7 @@ pub(crate) async fn build_sync_events( account_data: assign!(GlobalAccountData::new(), { events: account_data }), device_lists: device_list_updates.into(), device_one_time_keys_count, - // Fallback keys are not yet supported - device_unused_fallback_key_types: None, + device_unused_fallback_key_types: Some(unused_fallback_key_types), presence: assign!(Presence::new(), { events: presence_updates .into_iter() diff --git a/src/database/maps.rs b/src/database/maps.rs index cbde6223e..2bb42ef59 100644 --- a/src/database/maps.rs +++ b/src/database/maps.rs @@ -120,6 +120,10 @@ pub(super) static MAPS: &[Descriptor] = &[ name: "onetimekeyid_onetimekeys", ..descriptor::RANDOM_SMALL }, + Descriptor { + name: "fallbackkeyid_fallbackkey", + ..descriptor::RANDOM_SMALL + }, Descriptor { name: "passwordresettoken_info", ..descriptor::RANDOM_SMALL diff --git a/src/service/users/mod.rs b/src/service/users/mod.rs index 0a95d6108..cea0693f1 100644 --- a/src/service/users/mod.rs +++ b/src/service/users/mod.rs @@ -9,8 +9,8 @@ use conduwuit::{ use database::{Deserialized, Ignore, Interfix, Json, Map}; use futures::{Stream, StreamExt, TryFutureExt}; use ruma::{ - DeviceId, KeyId, MilliSecondsSinceUnixEpoch, OneTimeKeyAlgorithm, OneTimeKeyId, - OneTimeKeyName, OwnedDeviceId, OwnedKeyId, OwnedMxcUri, OwnedUserId, RoomId, UInt, UserId, + DeviceId, MilliSecondsSinceUnixEpoch, OneTimeKeyAlgorithm, OneTimeKeyId, OneTimeKeyName, + OwnedDeviceId, OwnedKeyId, OwnedMxcUri, OwnedOneTimeKeyId, OwnedUserId, RoomId, UInt, UserId, api::{ client::{device::Device, filter::FilterDefinition}, error::ErrorKind, @@ -57,6 +57,7 @@ struct Data { keychangeid_userid: Arc, keyid_key: Arc, onetimekeyid_onetimekeys: Arc, + fallbackkeyid_fallbackkey: Arc, openidtoken_expiresatuserid: Arc, logintoken_expiresatuserid: Arc, todeviceid_events: Arc, @@ -97,6 +98,7 @@ impl crate::Service for Service { keychangeid_userid: args.db["keychangeid_userid"].clone(), keyid_key: args.db["keyid_key"].clone(), onetimekeyid_onetimekeys: args.db["onetimekeyid_onetimekeys"].clone(), + fallbackkeyid_fallbackkey: args.db["fallbackkeyid_fallbackkey"].clone(), openidtoken_expiresatuserid: args.db["openidtoken_expiresatuserid"].clone(), logintoken_expiresatuserid: args.db["logintoken_expiresatuserid"].clone(), todeviceid_events: args.db["todeviceid_events"].clone(), @@ -550,7 +552,7 @@ impl Service { &self, user_id: &UserId, device_id: &DeviceId, - one_time_key_key: &KeyId, + one_time_key_key: &OneTimeKeyId, one_time_key_value: &Raw, ) -> Result { // All devices have metadata @@ -587,6 +589,39 @@ impl Service { Ok(()) } + /// Save a fallback key for the given user, device, and algorithm + /// This key will replace an existing fallback key + pub async fn add_fallback_key( + &self, + user_id: &UserId, + device_id: &DeviceId, + fallback_key_id: &OneTimeKeyId, + fallback_key: &Raw, + used: bool, + ) -> Result { + // All devices have metadata + // Only existing devices should be able to call this, but we shouldn't assert + // either... + let key = (user_id, device_id); + if self.db.userdeviceid_metadata.qry(&key).await.is_err() { + return Err!(Database(error!( + %user_id, + %device_id, + "User does not exist or device has no metadata." + ))); + } + + // There is one fallback key slot per user, per device, per algorithm + // Therefore we use this as the DB key for this column + let db_key = (user_id, device_id, fallback_key_id.algorithm()); + + self.db + .fallbackkeyid_fallbackkey + .put(db_key, (used, fallback_key_id.as_str(), Json(fallback_key))); + + Ok(()) + } + pub async fn last_one_time_keys_update(&self, user_id: &UserId) -> u64 { self.db .userid_lastonetimekeyupdate @@ -618,6 +653,8 @@ impl Service { .onetimekeyid_onetimekeys .raw_stream_prefix(&prefix) .ignore_err() + .next() + .await .map(|(key, val)| { self.db.onetimekeyid_onetimekeys.remove(key); @@ -636,11 +673,44 @@ impl Service { .unwrap(); (key, val) - }) - .next() - .await; + }); - one_time_key.ok_or_else(|| err!(Request(NotFound("No one-time-key found")))) + if let Some(result) = one_time_key { + return Ok(result); + } + + // No one-time key has been found. Look for a fallback key. + + let db_key = (user_id, device_id, key_algorithm); + + let fallback_key = self + .db + .fallbackkeyid_fallbackkey + .qry(&db_key) + .await + .ok() + .and_then(|handle| { + handle + .deserialized::<(bool, OwnedOneTimeKeyId, Raw)>() + .ok() + }); + + if let Some((used, fallback_key_id, fallback_key_value)) = fallback_key { + if !used { + // write the key to the database again to mark it as used + self.add_fallback_key( + user_id, + device_id, + &fallback_key_id, + &fallback_key_value, + true, + ) + .await?; + } + return Ok((fallback_key_id, fallback_key_value)); + } + + Err(err!(Request(NotFound("No one-time key or fallback key found")))) } pub async fn count_one_time_keys( @@ -673,6 +743,34 @@ impl Service { algorithm_counts } + pub async fn list_unused_fallback_key_types( + &self, + user_id: &UserId, + device_id: &DeviceId, + ) -> Vec { + type KeyVal = ((String, String, OneTimeKeyAlgorithm), (bool, String, Ignore)); + + let mut query = user_id.as_bytes().to_vec(); + query.push(0xFF); + query.extend_from_slice(device_id.as_bytes()); + query.push(0xFF); + + let mut unused_algorithms = Vec::new(); + + self.db + .fallbackkeyid_fallbackkey + .stream_prefix(&query) + .ignore_err() + .ready_for_each(|((_, _, fallback_key_algorithm), (used, ..)): KeyVal| { + if !used { + unused_algorithms.push(fallback_key_algorithm); + } + }) + .await; + + unused_algorithms + } + pub async fn add_device_keys( &self, user_id: &UserId,