pub(super) mod dehydrated_device; use std::{ collections::BTreeMap, mem, net::IpAddr, sync::Arc, time::{Duration, SystemTime}, }; use conduwuit::{ Err, Error, Result, debug_error, debug_warn, err, info, trace, utils::{self, ReadyExt, stream::TryIgnore, string::Unquoted}, warn, }; use database::{Deserialized, Ignore, Interfix, Json, Map}; use futures::{FutureExt, Stream, StreamExt, TryFutureExt}; use lettre::Address; use ruma::{ DeviceId, MilliSecondsSinceUnixEpoch, OneTimeKeyAlgorithm, OneTimeKeyId, OneTimeKeyName, OwnedDeviceId, OwnedKeyId, OwnedMxcUri, OwnedOneTimeKeyId, OwnedUserId, RoomId, UInt, UserId, api::{ client::{device::Device, filter::FilterDefinition}, error::ErrorKind, }, encryption::{CrossSigningKey, DeviceKeys, OneTimeKey}, events::{ AnyToDeviceEvent, GlobalAccountDataEventType, ignored_user_list::IgnoredUserListEvent, push_rules::PushRulesEvent, room::message::RoomMessageEventContent, }, push::Ruleset, serde::Raw, uint, }; use ruminuwuity::invite_permission_config::{FilterLevel, InvitePermissionConfigEvent}; use serde::{Deserialize, Serialize}; use serde_json::json; use tracing::error; use crate::{ Dep, account_data, admin, appservice::{self, RegistrationInfo}, config, firstrun, globals, oauth, rooms::{self, alias, membership}, threepid, }; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct UserSuspension { /// Whether the user is currently suspended pub suspended: bool, /// When the user was suspended (Unix timestamp in milliseconds) pub suspended_at: u64, /// User ID of who suspended this user pub suspended_by: String, } /// A password hash. This is only for use when setting a user's password, /// if the hash needs to be kept around for a while without keeping the password /// in memory. #[derive(Serialize, Deserialize)] pub struct HashedPassword(String); impl HashedPassword { pub fn new(password: &str) -> Result { Ok(Self(utils::hash::password(password).map_err(|e| { err!(Request(InvalidParam("Password does not meet the requirements: {e}"))) })?)) } } /// The status of an access token. pub enum AccessTokenStatus { Valid, Expired, } pub struct Service { services: Services, db: Data, } struct Services { account_data: Dep, admin: Dep, alias: Dep, appservice: Dep, config: Dep, firstrun: Dep, globals: Dep, membership: Dep, oauth: Dep, state_accessor: Dep, state_cache: Dep, threepid: Dep, } struct Data { keychangeid_userid: Arc, keyid_key: Arc, onetimekeyid_onetimekeys: Arc, fallbackkeyid_fallbackkey: Arc, openidtoken_expiresatuserid: Arc, logintoken_expiresatuserid: Arc, todeviceid_events: Arc, token_userdeviceid: Arc, userdeviceid_tokenexpires: Arc, userdeviceid_metadata: Arc, userdeviceid_token: Arc, userfilterid_filter: Arc, userid_avatarurl: Arc, userid_dehydrateddevice: Arc, userid_devicelistversion: Arc, userid_displayname: Arc, userid_lastonetimekeyupdate: Arc, userid_masterkeyid: Arc, userid_password: Arc, userid_suspension: Arc, userid_lock: Arc, userid_logindisabled: Arc, userid_selfsigningkeyid: Arc, userid_usersigningkeyid: Arc, useridprofilekey_value: Arc, } impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { services: Services { account_data: args.depend::("account_data"), admin: args.depend::("admin"), alias: args.depend::("alias"), appservice: args.depend::("appservice"), config: args.depend::("config"), firstrun: args.depend::("firstrun"), globals: args.depend::("globals"), membership: args.depend::("membership"), oauth: args.depend::("oauth"), state_accessor: args .depend::("rooms::state_accessor"), state_cache: args.depend::("rooms::state_cache"), threepid: args.depend::("threepid"), }, db: Data { keychangeid_userid: args.db["keychangeid_userid"].clone(), keyid_key: args.db["keyid_key"].clone(), onetimekeyid_onetimekeys: args.db["onetimekeyid_onetimekeys"].clone(), fallbackkeyid_fallbackkey: args.db["fallbackkeyid_fallbackkey"].clone(), openidtoken_expiresatuserid: args.db["openidtoken_expiresatuserid"].clone(), logintoken_expiresatuserid: args.db["logintoken_expiresatuserid"].clone(), todeviceid_events: args.db["todeviceid_events"].clone(), token_userdeviceid: args.db["token_userdeviceid"].clone(), userdeviceid_metadata: args.db["userdeviceid_metadata"].clone(), userdeviceid_token: args.db["userdeviceid_token"].clone(), userfilterid_filter: args.db["userfilterid_filter"].clone(), userid_avatarurl: args.db["userid_avatarurl"].clone(), userid_dehydrateddevice: args.db["userid_dehydrateddevice"].clone(), userid_devicelistversion: args.db["userid_devicelistversion"].clone(), userid_displayname: args.db["userid_displayname"].clone(), userid_lastonetimekeyupdate: args.db["userid_lastonetimekeyupdate"].clone(), userid_masterkeyid: args.db["userid_masterkeyid"].clone(), userid_password: args.db["userid_password"].clone(), userid_suspension: args.db["userid_suspension"].clone(), userid_lock: args.db["userid_lock"].clone(), userid_logindisabled: args.db["userid_logindisabled"].clone(), userid_selfsigningkeyid: args.db["userid_selfsigningkeyid"].clone(), userid_usersigningkeyid: args.db["userid_usersigningkeyid"].clone(), useridprofilekey_value: args.db["useridprofilekey_value"].clone(), userdeviceid_tokenexpires: args.db["userdeviceid_tokenexpires"].clone(), }, })) } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } impl Service { /// Returns true/false based on whether the recipient/receiving user has /// blocked the sender pub async fn user_is_ignored(&self, sender_user: &UserId, recipient_user: &UserId) -> bool { self.services .account_data .get_global(recipient_user, GlobalAccountDataEventType::IgnoredUserList) .await .is_ok_and(|ignored: IgnoredUserListEvent| { ignored .content .ignored_users .keys() .any(|blocked_user| blocked_user == sender_user) }) } /// Returns the recipient's filter level for an invite from the sender. pub async fn invite_filter_level( &self, sender_user: &UserId, recipient_user: &UserId, ) -> FilterLevel { if self.user_is_ignored(sender_user, recipient_user).await { FilterLevel::Ignore } else { self.services .account_data .get_global(recipient_user, GlobalAccountDataEventType::InvitePermissionConfig) .await .map_or(FilterLevel::Allow, |config: InvitePermissionConfigEvent| { config.content.user_filter_level(sender_user) }) } } /// Check if a user is an admin #[inline] pub async fn is_admin(&self, user_id: &UserId) -> bool { self.services.admin.user_is_admin(user_id).await } /// Create a new user account on this homeserver. #[inline] pub async fn create(&self, user_id: &UserId, password: Option) -> Result<()> { if !self.services.globals.user_is_local(user_id) && password.is_some() { return Err!("Cannot create a nonlocal user with a set password"); } self.set_password(user_id, password); Ok(()) } /// Create a new account for a local human or bot user. pub async fn create_local_account( &self, user_id: &UserId, password: HashedPassword, email: Option
, ) { self.create(user_id, Some(password)) .await .expect("should be able to save a new local user. what happened?"); // Set an initial display name { let mut displayname = user_id.localpart().to_owned(); let suffix = &self.services.config.new_user_displayname_suffix; if !suffix.is_empty() { displayname.push_str(suffix); } self.set_displayname(user_id, Some(displayname)); }; // Set default push rules self.services .account_data .update( None, user_id, GlobalAccountDataEventType::PushRules.to_string().into(), &serde_json::to_value(PushRulesEvent::new( Ruleset::server_default(user_id).into(), )) .expect("should be able to serialize push rules"), ) .await .expect("should be able to update account data"); // If the user registered with an email, associate it with their account. if let Some(email) = email { // This may fail if the email is already in use, but we should have already // checked that when we sent the validation email, so ignoring the error is // acceptable here in the rare case that an email is sniped by another user // between the validation email being sent and the account being created. let _ = self .services .threepid .associate_localpart_email(user_id.localpart(), &email) .await; } // Attempt to empower the first user and disable first-run mode. let was_first_user = self.services.firstrun.empower_first_user(user_id).await; // If the registering user was not the first and we're suspending users on // register, suspend them. if !was_first_user && self.services.config.suspend_on_register { // Note that we can still do auto joins for suspended users self.suspend_account(user_id, &self.services.globals.server_user) .await; // And send an @room notice to the admin room, to prompt admins to review the // new user and ideally unsuspend them if deemed appropriate. if self.services.config.admin_room_notices { self.services .admin .send_loud_message(RoomMessageEventContent::text_plain(format!( "User {user_id} has been suspended as they are not the first user on \ this server. Please review and unsuspend them if appropriate." ))) .await .ok(); } } // Autojoin the user to the configured autojoin rooms for room in &self.services.config.auto_join_rooms { let Ok(room_id) = self.services.alias.resolve(room).await else { error!( "Failed to resolve room alias to room ID when attempting to auto join \ {room}, skipping" ); continue; }; if !self .services .state_cache .server_in_room(self.services.globals.server_name(), &room_id) .await { warn!( "Skipping room {room} to automatically join as we have never joined before." ); continue; } if let Some(room_server_name) = room.server_name() { match self .services .membership .join_room( user_id, &room_id, Some("Automatically joining this room upon registration".to_owned()), &[ self.services.globals.server_name().to_owned(), room_server_name.to_owned(), ], ) .boxed() .await { | Err(e) => { // don't return this error so we don't fail registrations error!( "Failed to automatically join room {room} for user {user_id}: {e}" ); }, | _ => { info!("Automatically joined room {room} for user {user_id}"); }, } } } info!("Created new user account for {user_id}"); } pub async fn determine_registration_user_id( &self, supplied_username: Option, email: Option<&Address>, appservice_info: Option<&RegistrationInfo>, ) -> Result { const RANDOM_USER_ID_LENGTH: usize = 10; let emergency_mode_enabled = self.services.config.emergency_password.is_some(); let supplied_username = supplied_username.or_else(|| { // If the user didn't supply a username but did supply an email, use // the email's user part to avoid falling back to a random username email.map(|address| address.user().to_owned()) }); if let Some(supplied_username) = supplied_username { // The user gets to pick their username. Do some validation to make sure it's // acceptable. // Don't allow registration with forbidden usernames. if self .services .globals .forbidden_usernames() .is_match(&supplied_username) && !emergency_mode_enabled { return Err!(Request(Forbidden("Username is forbidden"))); } // Create and validate the user ID let user_id = match UserId::parse_with_server_name( &supplied_username, self.services.globals.server_name(), ) { | Ok(user_id) => { if let Err(e) = user_id.validate_strict() { // Unless we are in emergency mode, we should follow synapse's behaviour // on not allowing things like spaces and UTF-8 characters in // usernames if !emergency_mode_enabled { return Err!(Request(InvalidUsername(debug_warn!( "Username {supplied_username} contains disallowed characters or \ spaces: {e}" )))); } } // Don't allow registration with user IDs that aren't local if !self.services.globals.user_is_local(&user_id) { return Err!(Request(InvalidUsername( "Username {supplied_username} is not local to this server" ))); } user_id }, | Err(e) => { return Err!(Request(InvalidUsername(debug_warn!( "Username {supplied_username} is not valid: {e}" )))); }, }; if self.exists(&user_id).await { return Err!(Request(UserInUse("User ID is not available."))); } // Check that the user ID is/is not in an appservice's namespace if let Some(appservice_info) = appservice_info { if !appservice_info.is_user_match(&user_id) && !emergency_mode_enabled { return Err!(Request(Exclusive( "Username is not in this appservice's namespace." ))); } } else if self .services .appservice .is_exclusive_user_id(&user_id) .await && !emergency_mode_enabled { return Err!(Request(Exclusive("Username is reserved by an appservice."))); } Ok(user_id) } else { // The user didn't specify a username. Generate a username for // them. loop { let user_id = UserId::parse_with_server_name( utils::random_string(RANDOM_USER_ID_LENGTH).to_lowercase(), self.services.globals.server_name(), ) .unwrap(); if !self.exists(&user_id).await { break Ok(user_id); } } } } /// Deactivate account pub async fn deactivate_account(&self, user_id: &UserId) -> Result<()> { // Remove all associated devices self.all_device_ids(user_id) .for_each(async |device_id| self.remove_device(user_id, &device_id).await) .await; // Set the password to "" to indicate a deactivated account. Hashes will never // result in an empty string, so the user will not be able to log in again. // Systems like changing the password without logging in should check if the // account is deactivated. self.set_password(user_id, None); // TODO: Unhook 3PID Ok(()) } /// Suspend account, placing it in a read-only state pub async fn suspend_account(&self, user_id: &UserId, suspending_user: &UserId) { self.db.userid_suspension.raw_put( user_id, Json(UserSuspension { suspended: true, suspended_at: MilliSecondsSinceUnixEpoch::now().get().into(), suspended_by: suspending_user.to_string(), }), ); } /// Unsuspend account, placing it in a read-write state pub async fn unsuspend_account(&self, user_id: &UserId) { self.db.userid_suspension.remove(user_id); } pub async fn lock_account(&self, user_id: &UserId, locking_user: &UserId) { // NOTE: Locking is basically just suspension with a more severe effect, // so we'll just re-use the suspension data structure to store the lock state. let suspension = self .db .userid_lock .get(user_id) .await .deserialized::() .unwrap_or_else(|_| UserSuspension { suspended: true, suspended_at: MilliSecondsSinceUnixEpoch::now().get().into(), suspended_by: locking_user.to_string(), }); self.db.userid_lock.raw_put(user_id, Json(suspension)); } pub async fn unlock_account(&self, user_id: &UserId) { self.db.userid_lock.remove(user_id); } /// Check if the provided user ID belongs to an existing (possibly /// deactivated) account on this homeserver. #[inline] pub async fn exists(&self, user_id: &UserId) -> bool { self.services.globals.user_is_local(user_id) && self.db.userid_password.get(user_id).await.is_ok() } /// Check if account is deactivated pub async fn is_deactivated(&self, user_id: &UserId) -> Result { self.db .userid_password .get(user_id) .map_ok(|val| val.is_empty()) .map_err(|_| err!(Request(NotFound("User does not exist.")))) .await } /// Check if account is suspended pub async fn is_suspended(&self, user_id: &UserId) -> Result { match self .db .userid_suspension .get(user_id) .await .deserialized::() { | Ok(s) => Ok(s.suspended), | Err(e) => if e.is_not_found() { Ok(false) } else { Err(e) }, } } pub async fn is_locked(&self, user_id: &UserId) -> Result { match self .db .userid_lock .get(user_id) .await .deserialized::() { | Ok(s) => Ok(s.suspended), | Err(e) => if e.is_not_found() { Ok(false) } else { Err(e) }, } } pub fn disable_login(&self, user_id: &UserId) { self.db.userid_logindisabled.insert(user_id, ""); } pub fn enable_login(&self, user_id: &UserId) { self.db.userid_logindisabled.remove(user_id); } pub async fn is_login_disabled(&self, user_id: &UserId) -> bool { self.db .userid_logindisabled .exists(user_id.as_str()) .await .is_ok() } /// Check if account is active, infallible pub async fn is_active(&self, user_id: &UserId) -> bool { !self.is_deactivated(user_id).await.unwrap_or(true) } /// Check if account is active, infallible pub async fn is_active_local(&self, user_id: &UserId) -> bool { self.services.globals.user_is_local(user_id) && self.is_active(user_id).await } /// Returns the number of users registered on this server. #[inline] pub async fn count(&self) -> usize { self.db.userid_password.count().await } /// Find out which user an access token belongs to. pub async fn find_from_token( &self, token: &str, ) -> Option<(OwnedUserId, OwnedDeviceId, AccessTokenStatus)> { let user = self .db .token_userdeviceid .get(token) .await .deserialized() .ok(); // Check if the token has expired if let Some((user_id, device_id)) = user { if let Some(expires) = self .db .userdeviceid_tokenexpires .qry(&(&user_id, &device_id)) .await .deserialized::() .ok() .map(Duration::from_secs) { let expires_at = SystemTime::UNIX_EPOCH .checked_add(expires) .expect("expiry time should not overflow SystemTime"); if SystemTime::now() > expires_at { return Some((user_id, device_id, AccessTokenStatus::Expired)); } } Some((user_id, device_id, AccessTokenStatus::Valid)) } else { None } } /// Returns an iterator over all users on this homeserver. pub fn stream(&self) -> impl Stream + Send { self.db.userid_password.keys().ignore_err() } /// Returns a list of local users as list of usernames. /// /// A user account is considered `local` if the length of it's password is /// greater then zero. pub fn list_local_users(&self) -> impl Stream + Send + '_ { self.db .userid_password .stream() .ignore_err() .ready_filter_map(|(u, p): (OwnedUserId, &[u8])| (!p.is_empty()).then_some(u)) } /// Set a user's password. pub fn set_password(&self, user_id: &UserId, password: Option) { if let Some(hash) = password { self.db.userid_password.insert(user_id, hash.0); } else { self.db.userid_password.insert(user_id, b""); } } /// Check a user's password. pub async fn check_password(&self, user_id: &UserId, password: &str) -> Result { let (hash, user_id): (String, OwnedUserId) = if let Ok(hash) = self.db.userid_password.get(user_id).await.deserialized() { (hash, user_id.to_owned()) } else { // We also check the lowercased version of the user ID to handle legacy user IDs // better let lowercase_user_id = UserId::parse(user_id.as_str().to_lowercase()).unwrap(); if let Ok(hash) = self.db.userid_password.get(user_id).await.deserialized() { (hash, lowercase_user_id) } else { return Err!(Request(InvalidParam("This user cannot log in with a password."))); } }; if hash.is_empty() { return Err!(Request(UserDeactivated("This user is deactivated"))); } utils::hash::verify_password(password, &hash) .inspect_err(|e| debug_error!("{e}")) .map_err(|_| err!(Request(Forbidden("Invalid identifier or password."))))?; Ok(user_id) } /// Returns the displayname of a user on this homeserver. pub async fn displayname(&self, user_id: &UserId) -> Result { self.db.userid_displayname.get(user_id).await.deserialized() } /// Sets a new displayname or removes it if displayname is None. You still /// need to notify all rooms of this change. pub fn set_displayname(&self, user_id: &UserId, displayname: Option) { if let Some(displayname) = displayname { self.db.userid_displayname.insert(user_id, displayname); } else { self.db.userid_displayname.remove(user_id); } } /// Get the `avatar_url` of a user. pub async fn avatar_url(&self, user_id: &UserId) -> Result { self.db.userid_avatarurl.get(user_id).await.deserialized() } /// Sets a new avatar_url or removes it if avatar_url is None. pub fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option) { match avatar_url { | Some(avatar_url) => { self.db.userid_avatarurl.insert(user_id, &avatar_url); }, | _ => { self.db.userid_avatarurl.remove(user_id); }, } } /// Adds a new device to a user. pub async fn create_device( &self, user_id: &UserId, device_id: &DeviceId, token: &str, token_max_age: Option, initial_device_display_name: Option, client_ip: Option, ) -> Result<()> { if !self.exists(user_id).await { return Err!(Request(InvalidParam(error!( "Called create_device for non-existent user {user_id}" )))); } let key = (user_id, device_id); let mut device = Device::new(device_id.into()); device.display_name = initial_device_display_name; device.last_seen_ip = client_ip; device.last_seen_ts = Some(MilliSecondsSinceUnixEpoch::now()); increment(&self.db.userid_devicelistversion, user_id.as_bytes()); self.db.userdeviceid_metadata.put(key, Json(device)); self.set_token(user_id, device_id, token, token_max_age) .await } /// Removes a device from a user. pub async fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) { // Remove dehydrated device if this is the dehydrated device let _: Result<_> = self .remove_dehydrated_device(user_id, Some(device_id)) .await; let userdeviceid = (user_id, device_id); // Remove tokens if let Ok(old_token) = self.db.userdeviceid_token.qry(&userdeviceid).await { self.db.userdeviceid_token.del(userdeviceid); self.db.token_userdeviceid.remove(&old_token); self.db.userdeviceid_tokenexpires.del(userdeviceid); } // Remove todevice events let prefix = (user_id, device_id, Interfix); self.db .todeviceid_events .keys_prefix_raw(&prefix) .ignore_err() .ready_for_each(|key| self.db.todeviceid_events.remove(key)) .await; // TODO: Remove onetimekeys // Remove OAuth session information self.services.oauth.remove_session(user_id, device_id).await; increment(&self.db.userid_devicelistversion, user_id.as_bytes()); self.db.userdeviceid_metadata.del(userdeviceid); self.mark_device_key_update(user_id).await; } /// Returns an iterator over all device ids of this user. pub fn all_device_ids<'a>( &'a self, user_id: &'a UserId, ) -> impl Stream + Send + 'a { let prefix = (user_id, Interfix); self.db .userdeviceid_metadata .keys_prefix(&prefix) .ignore_err() .map(|(_, device_id): (Ignore, OwnedDeviceId)| device_id) } pub async fn get_token(&self, user_id: &UserId, device_id: &DeviceId) -> Result { let key = (user_id, device_id); self.db.userdeviceid_token.qry(&key).await.deserialized() } /// Generate a unique access token that doesn't collide with existing tokens pub async fn generate_unique_token(&self) -> String { loop { let token = utils::random_string(32); // Check for collision with appservice tokens if self .services .appservice .find_from_token(&token) .await .is_ok() { continue; } // Check for collision with user tokens if self.db.token_userdeviceid.get(&token).await.is_ok() { continue; } return token; } } /// Replaces the access token of one device. pub async fn set_token( &self, user_id: &UserId, device_id: &DeviceId, token: &str, token_max_age: Option, ) -> Result<()> { let key = (user_id, device_id); if self.db.userdeviceid_metadata.qry(&key).await.is_err() { return Err!(Database(error!( %user_id, %device_id, "User does not exist or device has no metadata." ))); } // Check for token collision with appservices if self .services .appservice .find_from_token(token) .await .is_ok() { return Err!(Request(InvalidParam( "Token conflicts with an existing appservice token" ))); } // Remove old token if let Ok(old_token) = self.db.userdeviceid_token.qry(&key).await { self.db.token_userdeviceid.remove(&old_token); self.db.userdeviceid_tokenexpires.remove(&old_token); // It will be removed from userdeviceid_token by the insert later } // Assign token to user device combination self.db.userdeviceid_token.put_raw(key, token); self.db.token_userdeviceid.raw_put(token, key); if let Some(max_age) = token_max_age { let expires = SystemTime::now() .duration_since(SystemTime::UNIX_EPOCH) .expect("system time should not be before the epoch") .saturating_add(max_age) .as_secs(); self.db.userdeviceid_tokenexpires.put(key, expires); } else { self.db.userdeviceid_tokenexpires.del(key); } Ok(()) } pub async fn add_one_time_key( &self, user_id: &UserId, device_id: &DeviceId, one_time_key_key: &OneTimeKeyId, one_time_key_value: &Raw, ) -> Result { // All devices have metadata // Only existing devices should be able to call this, but we shouldn't assert // either... let key = (user_id, device_id); if self.db.userdeviceid_metadata.qry(&key).await.is_err() { return Err!(Database(error!( %user_id, %device_id, "User does not exist or device has no metadata." ))); } let mut key = user_id.as_bytes().to_vec(); key.push(0xFF); key.extend_from_slice(device_id.as_bytes()); key.push(0xFF); // TODO: Use DeviceKeyId::to_string when it's available (and update everything, // because there are no wrapping quotation marks anymore) key.extend_from_slice( serde_json::to_string(one_time_key_key) .expect("DeviceKeyId::to_string always works") .as_bytes(), ); self.db .onetimekeyid_onetimekeys .raw_put(key, Json(one_time_key_value)); let count = self.services.globals.next_count().unwrap(); self.db.userid_lastonetimekeyupdate.raw_put(user_id, count); Ok(()) } /// Save a fallback key for the given user, device, and algorithm /// This key will replace an existing fallback key pub async fn add_fallback_key( &self, user_id: &UserId, device_id: &DeviceId, fallback_key_id: &OneTimeKeyId, fallback_key: &Raw, used: bool, ) -> Result { // All devices have metadata // Only existing devices should be able to call this, but we shouldn't assert // either... let key = (user_id, device_id); if self.db.userdeviceid_metadata.qry(&key).await.is_err() { return Err!(Database(error!( %user_id, %device_id, "User does not exist or device has no metadata." ))); } // There is one fallback key slot per user, per device, per algorithm // Therefore we use this as the DB key for this column let db_key = (user_id, device_id, fallback_key_id.algorithm()); self.db .fallbackkeyid_fallbackkey .put(db_key, (used, fallback_key_id.as_str(), Json(fallback_key))); Ok(()) } pub async fn last_one_time_keys_update(&self, user_id: &UserId) -> u64 { self.db .userid_lastonetimekeyupdate .get(user_id) .await .deserialized() .unwrap_or(0) } pub async fn take_one_time_key( &self, user_id: &UserId, device_id: &DeviceId, key_algorithm: &OneTimeKeyAlgorithm, ) -> Result<(OwnedKeyId, Raw)> { let count = self.services.globals.next_count()?.to_be_bytes(); self.db.userid_lastonetimekeyupdate.insert(user_id, count); let mut prefix = user_id.as_bytes().to_vec(); prefix.push(0xFF); prefix.extend_from_slice(device_id.as_bytes()); prefix.push(0xFF); prefix.push(b'"'); // Annoying quotation mark prefix.extend_from_slice(key_algorithm.as_ref().as_bytes()); prefix.push(b':'); let one_time_key = self .db .onetimekeyid_onetimekeys .raw_stream_prefix(&prefix) .ignore_err() .next() .await .map(|(key, val)| { self.db.onetimekeyid_onetimekeys.remove(key); let key = key .rsplit(|&b| b == 0xFF) .next() .ok_or_else(|| err!(Database("OneTimeKeyId in db is invalid."))) .unwrap(); let key = serde_json::from_slice(key) .map_err(|e| err!(Database("OneTimeKeyId in db is invalid. {e}"))) .unwrap(); let val = serde_json::from_slice(val) .map_err(|e| err!(Database("OneTimeKeys in db are invalid. {e}"))) .unwrap(); (key, val) }); if let Some(result) = one_time_key { return Ok(result); } // No one-time key has been found. Look for a fallback key. let db_key = (user_id, device_id, key_algorithm); let fallback_key = self .db .fallbackkeyid_fallbackkey .qry(&db_key) .await .ok() .and_then(|handle| { handle .deserialized::<(bool, OwnedOneTimeKeyId, Raw)>() .ok() }); if let Some((used, fallback_key_id, fallback_key_value)) = fallback_key { if !used { // write the key to the database again to mark it as used self.add_fallback_key( user_id, device_id, &fallback_key_id, &fallback_key_value, true, ) .await?; } return Ok((fallback_key_id, fallback_key_value)); } Err(err!(Request(NotFound("No one-time key or fallback key found")))) } pub async fn count_one_time_keys( &self, user_id: &UserId, device_id: &DeviceId, ) -> BTreeMap { type KeyVal<'a> = ((Ignore, Ignore, &'a Unquoted), Ignore); let mut algorithm_counts = BTreeMap::::new(); let query = (user_id, device_id); self.db .onetimekeyid_onetimekeys .stream_prefix(&query) .ignore_err() .ready_for_each(|((Ignore, Ignore, device_key_id), Ignore): KeyVal<'_>| { let one_time_key_id: &OneTimeKeyId = device_key_id .as_str() .try_into() .expect("Invalid DeviceKeyID in database"); let count: &mut UInt = algorithm_counts .entry(one_time_key_id.algorithm()) .or_default(); *count = count.saturating_add(1_u32.into()); }) .await; algorithm_counts } pub async fn list_unused_fallback_key_types( &self, user_id: &UserId, device_id: &DeviceId, ) -> Vec { type KeyVal = ((String, String, OneTimeKeyAlgorithm), (bool, String, Ignore)); let mut query = user_id.as_bytes().to_vec(); query.push(0xFF); query.extend_from_slice(device_id.as_bytes()); query.push(0xFF); let mut unused_algorithms = Vec::new(); self.db .fallbackkeyid_fallbackkey .stream_prefix(&query) .ignore_err() .ready_for_each(|((_, _, fallback_key_algorithm), (used, ..)): KeyVal| { if !used { unused_algorithms.push(fallback_key_algorithm); } }) .await; unused_algorithms } pub async fn add_device_keys( &self, user_id: &UserId, device_id: &DeviceId, device_keys: &Raw, ) { let key = (user_id, device_id); self.db.keyid_key.put(key, Json(device_keys)); self.mark_device_key_update(user_id).await; } pub async fn add_cross_signing_keys( &self, user_id: &UserId, master_key: &Option>, self_signing_key: &Option>, user_signing_key: &Option>, notify: bool, ) -> Result<()> { // TODO: Check signatures let mut prefix = user_id.as_bytes().to_vec(); prefix.push(0xFF); if let Some(master_key) = master_key { let (master_key_key, _) = parse_master_key(user_id, master_key)?; self.db .keyid_key .insert(&master_key_key, master_key.json().get().as_bytes()); self.db .userid_masterkeyid .insert(user_id.as_bytes(), &master_key_key); } // Self-signing key if let Some(self_signing_key) = self_signing_key { let mut self_signing_key_ids = self_signing_key .deserialize() .map_err(|e| err!(Request(InvalidParam("Invalid self signing key: {e:?}"))))? .keys .into_values(); let self_signing_key_id = self_signing_key_ids.next().ok_or(Error::BadRequest( ErrorKind::InvalidParam, "Self signing key contained no key.", ))?; if self_signing_key_ids.next().is_some() { return Err(Error::BadRequest( ErrorKind::InvalidParam, "Self signing key contained more than one key.", )); } let mut self_signing_key_key = prefix.clone(); self_signing_key_key.extend_from_slice(self_signing_key_id.as_bytes()); self.db .keyid_key .insert(&self_signing_key_key, self_signing_key.json().get().as_bytes()); self.db .userid_selfsigningkeyid .insert(user_id.as_bytes(), &self_signing_key_key); } // User-signing key if let Some(user_signing_key) = user_signing_key { let user_signing_key_id = parse_user_signing_key(user_signing_key)?; let user_signing_key_key = (user_id, &user_signing_key_id); self.db .keyid_key .put_raw(user_signing_key_key, user_signing_key.json().get().as_bytes()); self.db .userid_usersigningkeyid .raw_put(user_id, user_signing_key_key); } if notify { self.mark_device_key_update(user_id).await; } Ok(()) } pub async fn sign_key( &self, target_id: &UserId, key_id: &str, signature: (String, String), sender_id: &UserId, ) -> Result { let key = (target_id, key_id); let mut cross_signing_key: serde_json::Value = self .db .keyid_key .qry(&key) .await .map_err(|_| err!(Request(InvalidParam("Tried to sign nonexistent key"))))? .deserialized() .map_err(|e| err!(Database(debug_warn!("key in keyid_key is invalid: {e:?}"))))?; let signatures = cross_signing_key .get_mut("signatures") .ok_or_else(|| { err!(Database(debug_warn!("key in keyid_key has no signatures field"))) })? .as_object_mut() .ok_or_else(|| { err!(Database(debug_warn!("key in keyid_key has invalid signatures field."))) })? .entry(sender_id.to_string()) .or_insert_with(|| serde_json::Map::new().into()); signatures .as_object_mut() .ok_or_else(|| { err!(Database(debug_warn!("signatures in keyid_key for a user is invalid."))) })? .insert(signature.0, signature.1.into()); let key = (target_id, key_id); self.db.keyid_key.put(key, Json(cross_signing_key)); self.mark_device_key_update(target_id).await; Ok(()) } #[inline] pub fn keys_changed<'a>( &'a self, user_id: &'a UserId, from: Option, to: Option, ) -> impl Stream + Send + 'a { self.keys_changed_user_or_room(user_id.as_str(), from, to) .map(|(user_id, ..)| user_id) } #[inline] pub fn room_keys_changed<'a>( &'a self, room_id: &'a RoomId, from: Option, to: Option, ) -> impl Stream + Send + 'a { self.keys_changed_user_or_room(room_id.as_str(), from, to) } fn keys_changed_user_or_room<'a>( &'a self, user_or_room_id: &'a str, from: Option, to: Option, ) -> impl Stream + Send + 'a { type KeyVal<'a> = ((&'a str, u64), OwnedUserId); let from = from.unwrap_or(0); let to = to.unwrap_or(u64::MAX); let start = (user_or_room_id, from.saturating_add(1)); self.db .keychangeid_userid .stream_from(&start) .ignore_err() .ready_take_while(move |((prefix, count), _): &KeyVal<'_>| { *prefix == user_or_room_id && *count <= to }) .map(|((_, count), user_id): KeyVal<'_>| (user_id, count)) } pub async fn mark_device_key_update(&self, user_id: &UserId) { let count = self.services.globals.next_count().unwrap(); self.services .state_cache .rooms_joined(user_id) // Don't send key updates to unencrypted rooms .filter_map(async |room_id| { if self.services.state_accessor.is_encrypted_room(&room_id).await { Some(room_id) } else { None } }) .ready_for_each(|room_id| { let key = (room_id, count); self.db.keychangeid_userid.put_raw(key, user_id); }) .await; let key = (user_id, count); self.db.keychangeid_userid.put_raw(key, user_id); } pub async fn get_device_keys<'a>( &'a self, user_id: &'a UserId, device_id: &DeviceId, ) -> Result> { let key_id = (user_id, device_id); self.db.keyid_key.qry(&key_id).await.deserialized() } pub async fn get_key( &self, key_id: &[u8], sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &F, ) -> Result> where F: Fn(&UserId) -> bool + Send + Sync, { let key: serde_json::Value = self.db.keyid_key.get(key_id).await.deserialized()?; let cleaned = clean_signatures(key, sender_user, user_id, allowed_signatures)?; let raw_value = serde_json::value::to_raw_value(&cleaned)?; Ok(Raw::from_json(raw_value)) } pub async fn get_master_key( &self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &F, ) -> Result> where F: Fn(&UserId) -> bool + Send + Sync, { let key_id = self.db.userid_masterkeyid.get(user_id).await?; self.get_key(&key_id, sender_user, user_id, allowed_signatures) .await } pub async fn get_self_signing_key( &self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &F, ) -> Result> where F: Fn(&UserId) -> bool + Send + Sync, { let key_id = self.db.userid_selfsigningkeyid.get(user_id).await?; self.get_key(&key_id, sender_user, user_id, allowed_signatures) .await } pub async fn get_user_signing_key(&self, user_id: &UserId) -> Result> { self.db .userid_usersigningkeyid .get(user_id) .and_then(|key_id| self.db.keyid_key.get(&*key_id)) .await .deserialized() } pub async fn add_to_device_event( &self, sender: &UserId, target_user_id: &UserId, target_device_id: &DeviceId, event_type: &str, content: serde_json::Value, ) { let count = self.services.globals.next_count().unwrap(); let key = (target_user_id, target_device_id, count); self.db.todeviceid_events.put( key, Json(json!({ "type": event_type, "sender": sender, "content": content, })), ); } pub fn get_to_device_events<'a>( &'a self, user_id: &'a UserId, device_id: &'a DeviceId, since: Option, to: Option, ) -> impl Stream)> + Send + 'a { type Key = (OwnedUserId, OwnedDeviceId, u64); let from = (user_id, device_id, since.map_or(0, |since| since.saturating_add(1))); self.db .todeviceid_events .stream_from(&from) .ignore_err() .ready_take_while(move |((user_id_, device_id_, count), _): &(Key, _)| { user_id == *user_id_ && device_id == *device_id_ && to.is_none_or(|to| *count <= to) }) .map(|((_, _, count), event)| (count, event)) } pub async fn remove_to_device_events( &self, user_id: &UserId, device_id: &DeviceId, until: Until, ) where Until: Into> + Send, { type Key = (OwnedUserId, OwnedDeviceId, u64); let until = until.into().unwrap_or(u64::MAX); let from = (user_id, device_id, until); self.db .todeviceid_events .rev_keys_from(&from) .ignore_err() .ready_take_while(move |(user_id_, device_id_, _): &Key| { user_id == *user_id_ && device_id == *device_id_ }) .ready_for_each(|key: Key| { self.db.todeviceid_events.del(key); }) .await; } /// Updates device metadata and increments the device list version. pub async fn update_device_metadata( &self, user_id: &UserId, device_id: &DeviceId, device: &Device, ) -> Result<()> { increment(&self.db.userid_devicelistversion, user_id.as_bytes()); self.update_device_metadata_no_increment(user_id, device_id, device) .await } // Updates device metadata without incrementing the device list version. // This is namely used for updating the last_seen_ip and last_seen_ts values, // as those do not need a device list version bump due to them not being // relevant to other consumers. pub async fn update_device_metadata_no_increment( &self, user_id: &UserId, device_id: &DeviceId, device: &Device, ) -> Result<()> { let key = (user_id, device_id); self.db.userdeviceid_metadata.put(key, Json(device)); Ok(()) } pub async fn update_device_last_seen( &self, user_id: &UserId, device_id: Option<&DeviceId>, ip: IpAddr, ) { let now = MilliSecondsSinceUnixEpoch::now(); if let Some(device_id) = device_id { if let Ok(mut device) = self.get_device_metadata(user_id, device_id).await { device.last_seen_ip = Some(ip.to_string()); // If the last update was less than 10 seconds ago, don't update the timestamp if let Some(prev) = device.last_seen_ts { if now.get().saturating_sub(prev.get()) < uint!(10_000) { return; } } device.last_seen_ts = Some(now); self.update_device_metadata_no_increment(user_id, device_id, &device) .await .ok(); } } } /// Get device metadata. pub async fn get_device_metadata( &self, user_id: &UserId, device_id: &DeviceId, ) -> Result { self.db .userdeviceid_metadata .qry(&(user_id, device_id)) .await .deserialized() } pub async fn get_devicelist_version(&self, user_id: &UserId) -> Result { self.db .userid_devicelistversion .get(user_id) .await .deserialized() } pub fn all_devices_metadata<'a>( &'a self, user_id: &'a UserId, ) -> impl Stream + Send + 'a { let key = (user_id, Interfix); self.db .userdeviceid_metadata .stream_prefix(&key) .ignore_err() .map(|(_, val): (Ignore, Device)| val) } /// Creates a new sync filter. Returns the filter id. pub fn create_filter(&self, user_id: &UserId, filter: &FilterDefinition) -> String { let filter_id = utils::random_string(4); let key = (user_id, &filter_id); self.db.userfilterid_filter.put(key, Json(filter)); filter_id } pub async fn get_filter( &self, user_id: &UserId, filter_id: &str, ) -> Result { let key = (user_id, filter_id); self.db.userfilterid_filter.qry(&key).await.deserialized() } /// Creates an OpenID token, which can be used to prove that a user has /// access to an account (primarily for integrations) pub fn create_openid_token(&self, user_id: &UserId, token: &str) -> Result { use std::num::Saturating as Sat; let expires_in = self.services.config.openid_token_ttl; let expires_at = Sat(utils::millis_since_unix_epoch()) + Sat(expires_in) * Sat(1000); let mut value = expires_at.0.to_be_bytes().to_vec(); value.extend_from_slice(user_id.as_bytes()); self.db .openidtoken_expiresatuserid .insert(token.as_bytes(), value.as_slice()); Ok(expires_in) } /// Find out which user an OpenID access token belongs to. pub async fn find_from_openid_token(&self, token: &str) -> Result { let Ok(value) = self.db.openidtoken_expiresatuserid.get(token).await else { return Err!(Request(Unauthorized("OpenID token is unrecognised"))); }; let (expires_at_bytes, user_bytes) = value.split_at(0_u64.to_be_bytes().len()); let expires_at = u64::from_be_bytes(expires_at_bytes.try_into().map_err(|e| { err!(Database("expires_at in openid_userid is invalid u64. {e}")) })?); if expires_at < utils::millis_since_unix_epoch() { debug_warn!("OpenID token is expired, removing"); self.db.openidtoken_expiresatuserid.remove(token.as_bytes()); return Err!(Request(Unauthorized("OpenID token is expired"))); } let user_string = utils::string_from_bytes(user_bytes) .map_err(|e| err!(Database("User ID in openid_userid is invalid unicode. {e}")))?; OwnedUserId::try_from(user_string) .map_err(|e| err!(Database("User ID in openid_userid is invalid. {e}"))) } /// Creates a short-lived login token, which can be used to log in using the /// `m.login.token` mechanism. pub fn create_login_token(&self, user_id: &UserId, token: &str) -> u64 { use std::num::Saturating as Sat; let expires_in = self.services.config.login_token_ttl; let expires_at = Sat(utils::millis_since_unix_epoch()) + Sat(expires_in); let value = (expires_at.0, user_id); self.db.logintoken_expiresatuserid.raw_put(token, value); expires_in } /// Find out which user a login token belongs to. /// Removes the token to prevent double-use attacks. pub async fn find_from_login_token(&self, token: &str) -> Result { let Ok(value) = self.db.logintoken_expiresatuserid.get(token).await else { return Err!(Request(Forbidden("Login token is unrecognised"))); }; let (expires_at, user_id): (u64, OwnedUserId) = value.deserialized()?; if expires_at < utils::millis_since_unix_epoch() { trace!(%user_id, ?token, "Removing expired login token"); self.db.logintoken_expiresatuserid.remove(token); return Err!(Request(Forbidden("Login token is expired"))); } self.db.logintoken_expiresatuserid.remove(token); Ok(user_id) } /// Gets a specific user profile key pub async fn profile_key( &self, user_id: &UserId, profile_key: &str, ) -> Result { let key = (user_id, profile_key); self.db .useridprofilekey_value .qry(&key) .await .and_then(|handle| serde_json::from_slice(&handle).map_err(Into::into)) } /// Gets all the user's profile keys and values in an iterator pub fn all_profile_keys<'a>( &'a self, user_id: &'a UserId, ) -> impl Stream + 'a + Send { type KeyVal<'a> = ((Ignore, String), &'a [u8]); let prefix = (user_id, Interfix); self.db .useridprofilekey_value .stream_prefix(&prefix) .ignore_err() .map(|((_, key), value): KeyVal<'_>| Ok((key, serde_json::from_slice(value)?))) .ignore_err() } /// Sets a new profile key value, removes the key if value is None pub fn set_profile_key( &self, user_id: &UserId, profile_key: &str, profile_key_value: Option, ) { let key = (user_id, profile_key); if let Some(value) = profile_key_value { self.db.useridprofilekey_value.put(key, Json(value)); } else { self.db.useridprofilekey_value.del(key); } } /// Clears all profile data for a user, including display name and avatar /// url. pub async fn clear_profile(&self, user_id: &UserId) { self.set_displayname(user_id, None); self.set_avatar_url(user_id, None); self.all_profile_keys(user_id) .ready_for_each(|(key, _)| self.set_profile_key(user_id, &key, None)) .await; } } pub fn parse_master_key( user_id: &UserId, master_key: &Raw, ) -> Result<(Vec, CrossSigningKey)> { let mut prefix = user_id.as_bytes().to_vec(); prefix.push(0xFF); let master_key = master_key .deserialize() .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid master key"))?; let mut master_key_ids = master_key.keys.values(); let master_key_id = master_key_ids .next() .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Master key contained no key."))?; if master_key_ids.next().is_some() { return Err(Error::BadRequest( ErrorKind::InvalidParam, "Master key contained more than one key.", )); } let mut master_key_key = prefix.clone(); master_key_key.extend_from_slice(master_key_id.as_bytes()); Ok((master_key_key, master_key)) } pub fn parse_user_signing_key(user_signing_key: &Raw) -> Result { let mut user_signing_key_ids = user_signing_key .deserialize() .map_err(|_| err!(Request(InvalidParam("Invalid user signing key"))))? .keys .into_values(); let user_signing_key_id = user_signing_key_ids .next() .ok_or(err!(Request(InvalidParam("User signing key contained no key."))))?; if user_signing_key_ids.next().is_some() { return Err!(Request(InvalidParam("User signing key contained more than one key."))); } Ok(user_signing_key_id) } /// Ensure that a user only sees signatures from themselves and the target user fn clean_signatures( mut cross_signing_key: serde_json::Value, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &F, ) -> Result where F: Fn(&UserId) -> bool + Send + Sync, { if let Some(signatures) = cross_signing_key .get_mut("signatures") .and_then(|v| v.as_object_mut()) { // Don't allocate for the full size of the current signatures, but require // at most one resize if nothing is dropped let new_capacity = signatures.len() / 2; for (user, signature) in mem::replace(signatures, serde_json::Map::with_capacity(new_capacity)) { let sid = <&UserId>::try_from(user.as_str()) .map_err(|_| Error::bad_database("Invalid user ID in database."))?; if sender_user == Some(user_id) || sid == user_id || allowed_signatures(sid) { signatures.insert(user, signature); } } } Ok(cross_signing_key) } //TODO: this is an ABA fn increment(db: &Arc, key: &[u8]) { let old = db.get_blocking(key); let new = utils::increment(old.ok().as_deref()); db.insert(key, new); }