mirror of
https://forgejo.ellis.link/continuwuation/continuwuity.git
synced 2026-05-26 20:49:55 +00:00
1768 lines
49 KiB
Rust
1768 lines
49 KiB
Rust
pub(super) mod dehydrated_device;
|
|
|
|
use std::{
|
|
collections::BTreeMap,
|
|
mem,
|
|
net::IpAddr,
|
|
sync::Arc,
|
|
time::{Duration, SystemTime},
|
|
};
|
|
|
|
use conduwuit::{
|
|
Err, Error, Result, debug_error, debug_warn, err, info, trace,
|
|
utils::{self, ReadyExt, stream::TryIgnore, string::Unquoted},
|
|
warn,
|
|
};
|
|
use database::{Deserialized, Ignore, Interfix, Json, Map};
|
|
use futures::{FutureExt, Stream, StreamExt, TryFutureExt};
|
|
use lettre::Address;
|
|
use ruma::{
|
|
DeviceId, MilliSecondsSinceUnixEpoch, OneTimeKeyAlgorithm, OneTimeKeyId, OneTimeKeyName,
|
|
OwnedDeviceId, OwnedKeyId, OwnedMxcUri, OwnedOneTimeKeyId, OwnedUserId, RoomId, UInt, UserId,
|
|
api::{
|
|
client::{device::Device, filter::FilterDefinition},
|
|
error::ErrorKind,
|
|
},
|
|
encryption::{CrossSigningKey, DeviceKeys, OneTimeKey},
|
|
events::{
|
|
AnyToDeviceEvent, GlobalAccountDataEventType, ignored_user_list::IgnoredUserListEvent,
|
|
push_rules::PushRulesEvent, room::message::RoomMessageEventContent,
|
|
},
|
|
push::Ruleset,
|
|
serde::Raw,
|
|
uint,
|
|
};
|
|
use ruminuwuity::invite_permission_config::{FilterLevel, InvitePermissionConfigEvent};
|
|
use serde::{Deserialize, Serialize};
|
|
use serde_json::json;
|
|
use tracing::error;
|
|
|
|
use crate::{
|
|
Dep, account_data, admin,
|
|
appservice::{self, RegistrationInfo},
|
|
config, firstrun, globals, oauth,
|
|
rooms::{self, alias, membership},
|
|
threepid,
|
|
};
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct UserSuspension {
|
|
/// Whether the user is currently suspended
|
|
pub suspended: bool,
|
|
/// When the user was suspended (Unix timestamp in milliseconds)
|
|
pub suspended_at: u64,
|
|
/// User ID of who suspended this user
|
|
pub suspended_by: String,
|
|
}
|
|
|
|
/// A password hash. This is only for use when setting a user's password,
|
|
/// if the hash needs to be kept around for a while without keeping the password
|
|
/// in memory.
|
|
#[derive(Serialize, Deserialize)]
|
|
pub struct HashedPassword(String);
|
|
|
|
impl HashedPassword {
|
|
pub fn new(password: &str) -> Result<Self> {
|
|
Ok(Self(utils::hash::password(password).map_err(|e| {
|
|
err!(Request(InvalidParam("Password does not meet the requirements: {e}")))
|
|
})?))
|
|
}
|
|
}
|
|
|
|
/// The status of an access token.
|
|
pub enum AccessTokenStatus {
|
|
Valid,
|
|
Expired,
|
|
}
|
|
|
|
pub struct Service {
|
|
services: Services,
|
|
db: Data,
|
|
}
|
|
|
|
struct Services {
|
|
account_data: Dep<account_data::Service>,
|
|
admin: Dep<admin::Service>,
|
|
alias: Dep<alias::Service>,
|
|
appservice: Dep<appservice::Service>,
|
|
config: Dep<config::Service>,
|
|
firstrun: Dep<firstrun::Service>,
|
|
globals: Dep<globals::Service>,
|
|
membership: Dep<membership::Service>,
|
|
oauth: Dep<oauth::Service>,
|
|
state_accessor: Dep<rooms::state_accessor::Service>,
|
|
state_cache: Dep<rooms::state_cache::Service>,
|
|
threepid: Dep<threepid::Service>,
|
|
}
|
|
|
|
struct Data {
|
|
keychangeid_userid: Arc<Map>,
|
|
keyid_key: Arc<Map>,
|
|
onetimekeyid_onetimekeys: Arc<Map>,
|
|
fallbackkeyid_fallbackkey: Arc<Map>,
|
|
openidtoken_expiresatuserid: Arc<Map>,
|
|
logintoken_expiresatuserid: Arc<Map>,
|
|
todeviceid_events: Arc<Map>,
|
|
token_userdeviceid: Arc<Map>,
|
|
userdeviceid_tokenexpires: Arc<Map>,
|
|
userdeviceid_metadata: Arc<Map>,
|
|
userdeviceid_token: Arc<Map>,
|
|
userfilterid_filter: Arc<Map>,
|
|
userid_avatarurl: Arc<Map>,
|
|
userid_dehydrateddevice: Arc<Map>,
|
|
userid_devicelistversion: Arc<Map>,
|
|
userid_displayname: Arc<Map>,
|
|
userid_lastonetimekeyupdate: Arc<Map>,
|
|
userid_masterkeyid: Arc<Map>,
|
|
userid_password: Arc<Map>,
|
|
userid_suspension: Arc<Map>,
|
|
userid_lock: Arc<Map>,
|
|
userid_logindisabled: Arc<Map>,
|
|
userid_selfsigningkeyid: Arc<Map>,
|
|
userid_usersigningkeyid: Arc<Map>,
|
|
useridprofilekey_value: Arc<Map>,
|
|
}
|
|
|
|
impl crate::Service for Service {
|
|
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
|
|
Ok(Arc::new(Self {
|
|
services: Services {
|
|
account_data: args.depend::<account_data::Service>("account_data"),
|
|
admin: args.depend::<admin::Service>("admin"),
|
|
alias: args.depend::<alias::Service>("alias"),
|
|
appservice: args.depend::<appservice::Service>("appservice"),
|
|
config: args.depend::<config::Service>("config"),
|
|
firstrun: args.depend::<firstrun::Service>("firstrun"),
|
|
globals: args.depend::<globals::Service>("globals"),
|
|
membership: args.depend::<membership::Service>("membership"),
|
|
oauth: args.depend::<oauth::Service>("oauth"),
|
|
state_accessor: args
|
|
.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
|
|
state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"),
|
|
threepid: args.depend::<threepid::Service>("threepid"),
|
|
},
|
|
db: Data {
|
|
keychangeid_userid: args.db["keychangeid_userid"].clone(),
|
|
keyid_key: args.db["keyid_key"].clone(),
|
|
onetimekeyid_onetimekeys: args.db["onetimekeyid_onetimekeys"].clone(),
|
|
fallbackkeyid_fallbackkey: args.db["fallbackkeyid_fallbackkey"].clone(),
|
|
openidtoken_expiresatuserid: args.db["openidtoken_expiresatuserid"].clone(),
|
|
logintoken_expiresatuserid: args.db["logintoken_expiresatuserid"].clone(),
|
|
todeviceid_events: args.db["todeviceid_events"].clone(),
|
|
token_userdeviceid: args.db["token_userdeviceid"].clone(),
|
|
userdeviceid_metadata: args.db["userdeviceid_metadata"].clone(),
|
|
userdeviceid_token: args.db["userdeviceid_token"].clone(),
|
|
userfilterid_filter: args.db["userfilterid_filter"].clone(),
|
|
userid_avatarurl: args.db["userid_avatarurl"].clone(),
|
|
userid_dehydrateddevice: args.db["userid_dehydrateddevice"].clone(),
|
|
userid_devicelistversion: args.db["userid_devicelistversion"].clone(),
|
|
userid_displayname: args.db["userid_displayname"].clone(),
|
|
userid_lastonetimekeyupdate: args.db["userid_lastonetimekeyupdate"].clone(),
|
|
userid_masterkeyid: args.db["userid_masterkeyid"].clone(),
|
|
userid_password: args.db["userid_password"].clone(),
|
|
userid_suspension: args.db["userid_suspension"].clone(),
|
|
userid_lock: args.db["userid_lock"].clone(),
|
|
userid_logindisabled: args.db["userid_logindisabled"].clone(),
|
|
userid_selfsigningkeyid: args.db["userid_selfsigningkeyid"].clone(),
|
|
userid_usersigningkeyid: args.db["userid_usersigningkeyid"].clone(),
|
|
useridprofilekey_value: args.db["useridprofilekey_value"].clone(),
|
|
userdeviceid_tokenexpires: args.db["userdeviceid_tokenexpires"].clone(),
|
|
},
|
|
}))
|
|
}
|
|
|
|
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
|
|
}
|
|
|
|
impl Service {
|
|
/// Returns true/false based on whether the recipient/receiving user has
|
|
/// blocked the sender
|
|
pub async fn user_is_ignored(&self, sender_user: &UserId, recipient_user: &UserId) -> bool {
|
|
self.services
|
|
.account_data
|
|
.get_global(recipient_user, GlobalAccountDataEventType::IgnoredUserList)
|
|
.await
|
|
.is_ok_and(|ignored: IgnoredUserListEvent| {
|
|
ignored
|
|
.content
|
|
.ignored_users
|
|
.keys()
|
|
.any(|blocked_user| blocked_user == sender_user)
|
|
})
|
|
}
|
|
|
|
/// Returns the recipient's filter level for an invite from the sender.
|
|
pub async fn invite_filter_level(
|
|
&self,
|
|
sender_user: &UserId,
|
|
recipient_user: &UserId,
|
|
) -> FilterLevel {
|
|
if self.user_is_ignored(sender_user, recipient_user).await {
|
|
FilterLevel::Ignore
|
|
} else {
|
|
self.services
|
|
.account_data
|
|
.get_global(recipient_user, GlobalAccountDataEventType::InvitePermissionConfig)
|
|
.await
|
|
.map_or(FilterLevel::Allow, |config: InvitePermissionConfigEvent| {
|
|
config.content.user_filter_level(sender_user)
|
|
})
|
|
}
|
|
}
|
|
|
|
/// Check if a user is an admin
|
|
#[inline]
|
|
pub async fn is_admin(&self, user_id: &UserId) -> bool {
|
|
self.services.admin.user_is_admin(user_id).await
|
|
}
|
|
|
|
/// Create a new user account on this homeserver.
|
|
#[inline]
|
|
pub async fn create(&self, user_id: &UserId, password: Option<HashedPassword>) -> Result<()> {
|
|
if !self.services.globals.user_is_local(user_id) && password.is_some() {
|
|
return Err!("Cannot create a nonlocal user with a set password");
|
|
}
|
|
|
|
self.set_password(user_id, password);
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Create a new account for a local human or bot user.
|
|
pub async fn create_local_account(
|
|
&self,
|
|
user_id: &UserId,
|
|
password: HashedPassword,
|
|
email: Option<Address>,
|
|
) {
|
|
self.create(user_id, Some(password))
|
|
.await
|
|
.expect("should be able to save a new local user. what happened?");
|
|
|
|
// Set an initial display name
|
|
{
|
|
let mut displayname = user_id.localpart().to_owned();
|
|
|
|
let suffix = &self.services.config.new_user_displayname_suffix;
|
|
if !suffix.is_empty() {
|
|
displayname.push_str(suffix);
|
|
}
|
|
|
|
self.set_displayname(user_id, Some(displayname));
|
|
};
|
|
|
|
// Set default push rules
|
|
self.services
|
|
.account_data
|
|
.update(
|
|
None,
|
|
user_id,
|
|
GlobalAccountDataEventType::PushRules.to_string().into(),
|
|
&serde_json::to_value(PushRulesEvent::new(
|
|
Ruleset::server_default(user_id).into(),
|
|
))
|
|
.expect("should be able to serialize push rules"),
|
|
)
|
|
.await
|
|
.expect("should be able to update account data");
|
|
|
|
// If the user registered with an email, associate it with their account.
|
|
if let Some(email) = email {
|
|
// This may fail if the email is already in use, but we should have already
|
|
// checked that when we sent the validation email, so ignoring the error is
|
|
// acceptable here in the rare case that an email is sniped by another user
|
|
// between the validation email being sent and the account being created.
|
|
let _ = self
|
|
.services
|
|
.threepid
|
|
.associate_localpart_email(user_id.localpart(), &email)
|
|
.await;
|
|
}
|
|
|
|
// Attempt to empower the first user and disable first-run mode.
|
|
let was_first_user = self.services.firstrun.empower_first_user(user_id).await;
|
|
|
|
// If the registering user was not the first and we're suspending users on
|
|
// register, suspend them.
|
|
if !was_first_user && self.services.config.suspend_on_register {
|
|
// Note that we can still do auto joins for suspended users
|
|
self.suspend_account(user_id, &self.services.globals.server_user)
|
|
.await;
|
|
|
|
// And send an @room notice to the admin room, to prompt admins to review the
|
|
// new user and ideally unsuspend them if deemed appropriate.
|
|
if self.services.config.admin_room_notices {
|
|
self.services
|
|
.admin
|
|
.send_loud_message(RoomMessageEventContent::text_plain(format!(
|
|
"User {user_id} has been suspended as they are not the first user on \
|
|
this server. Please review and unsuspend them if appropriate."
|
|
)))
|
|
.await
|
|
.ok();
|
|
}
|
|
}
|
|
|
|
// Autojoin the user to the configured autojoin rooms
|
|
for room in &self.services.config.auto_join_rooms {
|
|
let Ok(room_id) = self.services.alias.resolve(room).await else {
|
|
error!(
|
|
"Failed to resolve room alias to room ID when attempting to auto join \
|
|
{room}, skipping"
|
|
);
|
|
continue;
|
|
};
|
|
|
|
if !self
|
|
.services
|
|
.state_cache
|
|
.server_in_room(self.services.globals.server_name(), &room_id)
|
|
.await
|
|
{
|
|
warn!(
|
|
"Skipping room {room} to automatically join as we have never joined before."
|
|
);
|
|
continue;
|
|
}
|
|
|
|
if let Some(room_server_name) = room.server_name() {
|
|
match self
|
|
.services
|
|
.membership
|
|
.join_room(
|
|
user_id,
|
|
&room_id,
|
|
Some("Automatically joining this room upon registration".to_owned()),
|
|
&[
|
|
self.services.globals.server_name().to_owned(),
|
|
room_server_name.to_owned(),
|
|
],
|
|
)
|
|
.boxed()
|
|
.await
|
|
{
|
|
| Err(e) => {
|
|
// don't return this error so we don't fail registrations
|
|
error!(
|
|
"Failed to automatically join room {room} for user {user_id}: {e}"
|
|
);
|
|
},
|
|
| _ => {
|
|
info!("Automatically joined room {room} for user {user_id}");
|
|
},
|
|
}
|
|
}
|
|
}
|
|
|
|
info!("Created new user account for {user_id}");
|
|
}
|
|
|
|
pub async fn determine_registration_user_id(
|
|
&self,
|
|
supplied_username: Option<String>,
|
|
email: Option<&Address>,
|
|
appservice_info: Option<&RegistrationInfo>,
|
|
) -> Result<OwnedUserId> {
|
|
const RANDOM_USER_ID_LENGTH: usize = 10;
|
|
|
|
let emergency_mode_enabled = self.services.config.emergency_password.is_some();
|
|
|
|
let supplied_username = supplied_username.or_else(|| {
|
|
// If the user didn't supply a username but did supply an email, use
|
|
// the email's user part to avoid falling back to a random username
|
|
email.map(|address| address.user().to_owned())
|
|
});
|
|
|
|
if let Some(supplied_username) = supplied_username {
|
|
// The user gets to pick their username. Do some validation to make sure it's
|
|
// acceptable.
|
|
|
|
// Don't allow registration with forbidden usernames.
|
|
if self
|
|
.services
|
|
.globals
|
|
.forbidden_usernames()
|
|
.is_match(&supplied_username)
|
|
&& !emergency_mode_enabled
|
|
{
|
|
return Err!(Request(Forbidden("Username is forbidden")));
|
|
}
|
|
|
|
// Create and validate the user ID
|
|
let user_id = match UserId::parse_with_server_name(
|
|
&supplied_username,
|
|
self.services.globals.server_name(),
|
|
) {
|
|
| Ok(user_id) => {
|
|
if let Err(e) = user_id.validate_strict() {
|
|
// Unless we are in emergency mode, we should follow synapse's behaviour
|
|
// on not allowing things like spaces and UTF-8 characters in
|
|
// usernames
|
|
if !emergency_mode_enabled {
|
|
return Err!(Request(InvalidUsername(debug_warn!(
|
|
"Username {supplied_username} contains disallowed characters or \
|
|
spaces: {e}"
|
|
))));
|
|
}
|
|
}
|
|
|
|
// Don't allow registration with user IDs that aren't local
|
|
if !self.services.globals.user_is_local(&user_id) {
|
|
return Err!(Request(InvalidUsername(
|
|
"Username {supplied_username} is not local to this server"
|
|
)));
|
|
}
|
|
|
|
user_id
|
|
},
|
|
| Err(e) => {
|
|
return Err!(Request(InvalidUsername(debug_warn!(
|
|
"Username {supplied_username} is not valid: {e}"
|
|
))));
|
|
},
|
|
};
|
|
|
|
if self.exists(&user_id).await {
|
|
return Err!(Request(UserInUse("User ID is not available.")));
|
|
}
|
|
|
|
// Check that the user ID is/is not in an appservice's namespace
|
|
if let Some(appservice_info) = appservice_info {
|
|
if !appservice_info.is_user_match(&user_id) && !emergency_mode_enabled {
|
|
return Err!(Request(Exclusive(
|
|
"Username is not in this appservice's namespace."
|
|
)));
|
|
}
|
|
} else if self
|
|
.services
|
|
.appservice
|
|
.is_exclusive_user_id(&user_id)
|
|
.await && !emergency_mode_enabled
|
|
{
|
|
return Err!(Request(Exclusive("Username is reserved by an appservice.")));
|
|
}
|
|
|
|
Ok(user_id)
|
|
} else {
|
|
// The user didn't specify a username. Generate a username for
|
|
// them.
|
|
|
|
loop {
|
|
let user_id = UserId::parse_with_server_name(
|
|
utils::random_string(RANDOM_USER_ID_LENGTH).to_lowercase(),
|
|
self.services.globals.server_name(),
|
|
)
|
|
.unwrap();
|
|
|
|
if !self.exists(&user_id).await {
|
|
break Ok(user_id);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Deactivate account
|
|
pub async fn deactivate_account(&self, user_id: &UserId) -> Result<()> {
|
|
// Remove all associated devices
|
|
self.all_device_ids(user_id)
|
|
.for_each(async |device_id| self.remove_device(user_id, &device_id).await)
|
|
.await;
|
|
|
|
// Set the password to "" to indicate a deactivated account. Hashes will never
|
|
// result in an empty string, so the user will not be able to log in again.
|
|
// Systems like changing the password without logging in should check if the
|
|
// account is deactivated.
|
|
self.set_password(user_id, None);
|
|
|
|
// TODO: Unhook 3PID
|
|
Ok(())
|
|
}
|
|
|
|
/// Suspend account, placing it in a read-only state
|
|
pub async fn suspend_account(&self, user_id: &UserId, suspending_user: &UserId) {
|
|
self.db.userid_suspension.raw_put(
|
|
user_id,
|
|
Json(UserSuspension {
|
|
suspended: true,
|
|
suspended_at: MilliSecondsSinceUnixEpoch::now().get().into(),
|
|
suspended_by: suspending_user.to_string(),
|
|
}),
|
|
);
|
|
}
|
|
|
|
/// Unsuspend account, placing it in a read-write state
|
|
pub async fn unsuspend_account(&self, user_id: &UserId) {
|
|
self.db.userid_suspension.remove(user_id);
|
|
}
|
|
|
|
pub async fn lock_account(&self, user_id: &UserId, locking_user: &UserId) {
|
|
// NOTE: Locking is basically just suspension with a more severe effect,
|
|
// so we'll just re-use the suspension data structure to store the lock state.
|
|
let suspension = self
|
|
.db
|
|
.userid_lock
|
|
.get(user_id)
|
|
.await
|
|
.deserialized::<UserSuspension>()
|
|
.unwrap_or_else(|_| UserSuspension {
|
|
suspended: true,
|
|
suspended_at: MilliSecondsSinceUnixEpoch::now().get().into(),
|
|
suspended_by: locking_user.to_string(),
|
|
});
|
|
|
|
self.db.userid_lock.raw_put(user_id, Json(suspension));
|
|
}
|
|
|
|
pub async fn unlock_account(&self, user_id: &UserId) { self.db.userid_lock.remove(user_id); }
|
|
|
|
/// Check if the provided user ID belongs to an existing (possibly
|
|
/// deactivated) account on this homeserver.
|
|
#[inline]
|
|
pub async fn exists(&self, user_id: &UserId) -> bool {
|
|
self.services.globals.user_is_local(user_id)
|
|
&& self.db.userid_password.get(user_id).await.is_ok()
|
|
}
|
|
|
|
/// Check if account is deactivated
|
|
pub async fn is_deactivated(&self, user_id: &UserId) -> Result<bool> {
|
|
self.db
|
|
.userid_password
|
|
.get(user_id)
|
|
.map_ok(|val| val.is_empty())
|
|
.map_err(|_| err!(Request(NotFound("User does not exist."))))
|
|
.await
|
|
}
|
|
|
|
/// Check if account is suspended
|
|
pub async fn is_suspended(&self, user_id: &UserId) -> Result<bool> {
|
|
match self
|
|
.db
|
|
.userid_suspension
|
|
.get(user_id)
|
|
.await
|
|
.deserialized::<UserSuspension>()
|
|
{
|
|
| Ok(s) => Ok(s.suspended),
|
|
| Err(e) =>
|
|
if e.is_not_found() {
|
|
Ok(false)
|
|
} else {
|
|
Err(e)
|
|
},
|
|
}
|
|
}
|
|
|
|
pub async fn is_locked(&self, user_id: &UserId) -> Result<bool> {
|
|
match self
|
|
.db
|
|
.userid_lock
|
|
.get(user_id)
|
|
.await
|
|
.deserialized::<UserSuspension>()
|
|
{
|
|
| Ok(s) => Ok(s.suspended),
|
|
| Err(e) =>
|
|
if e.is_not_found() {
|
|
Ok(false)
|
|
} else {
|
|
Err(e)
|
|
},
|
|
}
|
|
}
|
|
|
|
pub fn disable_login(&self, user_id: &UserId) {
|
|
self.db.userid_logindisabled.insert(user_id, "");
|
|
}
|
|
|
|
pub fn enable_login(&self, user_id: &UserId) { self.db.userid_logindisabled.remove(user_id); }
|
|
|
|
pub async fn is_login_disabled(&self, user_id: &UserId) -> bool {
|
|
self.db
|
|
.userid_logindisabled
|
|
.exists(user_id.as_str())
|
|
.await
|
|
.is_ok()
|
|
}
|
|
|
|
/// Check if account is active, infallible
|
|
pub async fn is_active(&self, user_id: &UserId) -> bool {
|
|
!self.is_deactivated(user_id).await.unwrap_or(true)
|
|
}
|
|
|
|
/// Check if account is active, infallible
|
|
pub async fn is_active_local(&self, user_id: &UserId) -> bool {
|
|
self.services.globals.user_is_local(user_id) && self.is_active(user_id).await
|
|
}
|
|
|
|
/// Returns the number of users registered on this server.
|
|
#[inline]
|
|
pub async fn count(&self) -> usize { self.db.userid_password.count().await }
|
|
|
|
/// Find out which user an access token belongs to.
|
|
pub async fn find_from_token(
|
|
&self,
|
|
token: &str,
|
|
) -> Option<(OwnedUserId, OwnedDeviceId, AccessTokenStatus)> {
|
|
let user = self
|
|
.db
|
|
.token_userdeviceid
|
|
.get(token)
|
|
.await
|
|
.deserialized()
|
|
.ok();
|
|
|
|
// Check if the token has expired
|
|
if let Some((user_id, device_id)) = user {
|
|
if let Some(expires) = self
|
|
.db
|
|
.userdeviceid_tokenexpires
|
|
.qry(&(&user_id, &device_id))
|
|
.await
|
|
.deserialized::<u64>()
|
|
.ok()
|
|
.map(Duration::from_secs)
|
|
{
|
|
let expires_at = SystemTime::UNIX_EPOCH
|
|
.checked_add(expires)
|
|
.expect("expiry time should not overflow SystemTime");
|
|
|
|
if SystemTime::now() > expires_at {
|
|
return Some((user_id, device_id, AccessTokenStatus::Expired));
|
|
}
|
|
}
|
|
|
|
Some((user_id, device_id, AccessTokenStatus::Valid))
|
|
} else {
|
|
None
|
|
}
|
|
}
|
|
|
|
/// Returns an iterator over all users on this homeserver.
|
|
pub fn stream(&self) -> impl Stream<Item = OwnedUserId> + Send {
|
|
self.db.userid_password.keys().ignore_err()
|
|
}
|
|
|
|
/// Returns a list of local users as list of usernames.
|
|
///
|
|
/// A user account is considered `local` if the length of it's password is
|
|
/// greater then zero.
|
|
pub fn list_local_users(&self) -> impl Stream<Item = OwnedUserId> + Send + '_ {
|
|
self.db
|
|
.userid_password
|
|
.stream()
|
|
.ignore_err()
|
|
.ready_filter_map(|(u, p): (OwnedUserId, &[u8])| (!p.is_empty()).then_some(u))
|
|
}
|
|
|
|
/// Set a user's password.
|
|
pub fn set_password(&self, user_id: &UserId, password: Option<HashedPassword>) {
|
|
if let Some(hash) = password {
|
|
self.db.userid_password.insert(user_id, hash.0);
|
|
} else {
|
|
self.db.userid_password.insert(user_id, b"");
|
|
}
|
|
}
|
|
|
|
/// Check a user's password.
|
|
pub async fn check_password(&self, user_id: &UserId, password: &str) -> Result<OwnedUserId> {
|
|
let (hash, user_id): (String, OwnedUserId) = if let Ok(hash) =
|
|
self.db.userid_password.get(user_id).await.deserialized()
|
|
{
|
|
(hash, user_id.to_owned())
|
|
} else {
|
|
// We also check the lowercased version of the user ID to handle legacy user IDs
|
|
// better
|
|
let lowercase_user_id = UserId::parse(user_id.as_str().to_lowercase()).unwrap();
|
|
|
|
if let Ok(hash) = self.db.userid_password.get(user_id).await.deserialized() {
|
|
(hash, lowercase_user_id)
|
|
} else {
|
|
return Err!(Request(InvalidParam("This user cannot log in with a password.")));
|
|
}
|
|
};
|
|
|
|
if hash.is_empty() {
|
|
return Err!(Request(UserDeactivated("This user is deactivated")));
|
|
}
|
|
|
|
utils::hash::verify_password(password, &hash)
|
|
.inspect_err(|e| debug_error!("{e}"))
|
|
.map_err(|_| err!(Request(Forbidden("Invalid identifier or password."))))?;
|
|
|
|
Ok(user_id)
|
|
}
|
|
|
|
/// Returns the displayname of a user on this homeserver.
|
|
pub async fn displayname(&self, user_id: &UserId) -> Result<String> {
|
|
self.db.userid_displayname.get(user_id).await.deserialized()
|
|
}
|
|
|
|
/// Sets a new displayname or removes it if displayname is None. You still
|
|
/// need to notify all rooms of this change.
|
|
pub fn set_displayname(&self, user_id: &UserId, displayname: Option<String>) {
|
|
if let Some(displayname) = displayname {
|
|
self.db.userid_displayname.insert(user_id, displayname);
|
|
} else {
|
|
self.db.userid_displayname.remove(user_id);
|
|
}
|
|
}
|
|
|
|
/// Get the `avatar_url` of a user.
|
|
pub async fn avatar_url(&self, user_id: &UserId) -> Result<OwnedMxcUri> {
|
|
self.db.userid_avatarurl.get(user_id).await.deserialized()
|
|
}
|
|
|
|
/// Sets a new avatar_url or removes it if avatar_url is None.
|
|
pub fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option<OwnedMxcUri>) {
|
|
match avatar_url {
|
|
| Some(avatar_url) => {
|
|
self.db.userid_avatarurl.insert(user_id, &avatar_url);
|
|
},
|
|
| _ => {
|
|
self.db.userid_avatarurl.remove(user_id);
|
|
},
|
|
}
|
|
}
|
|
|
|
/// Adds a new device to a user.
|
|
pub async fn create_device(
|
|
&self,
|
|
user_id: &UserId,
|
|
device_id: &DeviceId,
|
|
token: &str,
|
|
token_max_age: Option<Duration>,
|
|
initial_device_display_name: Option<String>,
|
|
client_ip: Option<String>,
|
|
) -> Result<()> {
|
|
if !self.exists(user_id).await {
|
|
return Err!(Request(InvalidParam(error!(
|
|
"Called create_device for non-existent user {user_id}"
|
|
))));
|
|
}
|
|
|
|
let key = (user_id, device_id);
|
|
let mut device = Device::new(device_id.into());
|
|
device.display_name = initial_device_display_name;
|
|
device.last_seen_ip = client_ip;
|
|
device.last_seen_ts = Some(MilliSecondsSinceUnixEpoch::now());
|
|
|
|
increment(&self.db.userid_devicelistversion, user_id.as_bytes());
|
|
self.db.userdeviceid_metadata.put(key, Json(device));
|
|
self.set_token(user_id, device_id, token, token_max_age)
|
|
.await
|
|
}
|
|
|
|
/// Removes a device from a user.
|
|
pub async fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) {
|
|
// Remove dehydrated device if this is the dehydrated device
|
|
let _: Result<_> = self
|
|
.remove_dehydrated_device(user_id, Some(device_id))
|
|
.await;
|
|
|
|
let userdeviceid = (user_id, device_id);
|
|
|
|
// Remove tokens
|
|
if let Ok(old_token) = self.db.userdeviceid_token.qry(&userdeviceid).await {
|
|
self.db.userdeviceid_token.del(userdeviceid);
|
|
self.db.token_userdeviceid.remove(&old_token);
|
|
self.db.userdeviceid_tokenexpires.del(userdeviceid);
|
|
}
|
|
|
|
// Remove todevice events
|
|
let prefix = (user_id, device_id, Interfix);
|
|
self.db
|
|
.todeviceid_events
|
|
.keys_prefix_raw(&prefix)
|
|
.ignore_err()
|
|
.ready_for_each(|key| self.db.todeviceid_events.remove(key))
|
|
.await;
|
|
|
|
// TODO: Remove onetimekeys
|
|
|
|
// Remove OAuth session information
|
|
self.services.oauth.remove_session(user_id, device_id).await;
|
|
|
|
increment(&self.db.userid_devicelistversion, user_id.as_bytes());
|
|
|
|
self.db.userdeviceid_metadata.del(userdeviceid);
|
|
self.mark_device_key_update(user_id).await;
|
|
}
|
|
|
|
/// Returns an iterator over all device ids of this user.
|
|
pub fn all_device_ids<'a>(
|
|
&'a self,
|
|
user_id: &'a UserId,
|
|
) -> impl Stream<Item = OwnedDeviceId> + Send + 'a {
|
|
let prefix = (user_id, Interfix);
|
|
self.db
|
|
.userdeviceid_metadata
|
|
.keys_prefix(&prefix)
|
|
.ignore_err()
|
|
.map(|(_, device_id): (Ignore, OwnedDeviceId)| device_id)
|
|
}
|
|
|
|
pub async fn get_token(&self, user_id: &UserId, device_id: &DeviceId) -> Result<String> {
|
|
let key = (user_id, device_id);
|
|
self.db.userdeviceid_token.qry(&key).await.deserialized()
|
|
}
|
|
|
|
/// Generate a unique access token that doesn't collide with existing tokens
|
|
pub async fn generate_unique_token(&self) -> String {
|
|
loop {
|
|
let token = utils::random_string(32);
|
|
|
|
// Check for collision with appservice tokens
|
|
if self
|
|
.services
|
|
.appservice
|
|
.find_from_token(&token)
|
|
.await
|
|
.is_ok()
|
|
{
|
|
continue;
|
|
}
|
|
|
|
// Check for collision with user tokens
|
|
if self.db.token_userdeviceid.get(&token).await.is_ok() {
|
|
continue;
|
|
}
|
|
|
|
return token;
|
|
}
|
|
}
|
|
|
|
/// Replaces the access token of one device.
|
|
pub async fn set_token(
|
|
&self,
|
|
user_id: &UserId,
|
|
device_id: &DeviceId,
|
|
token: &str,
|
|
token_max_age: Option<Duration>,
|
|
) -> Result<()> {
|
|
let key = (user_id, device_id);
|
|
if self.db.userdeviceid_metadata.qry(&key).await.is_err() {
|
|
return Err!(Database(error!(
|
|
%user_id,
|
|
%device_id,
|
|
"User does not exist or device has no metadata."
|
|
)));
|
|
}
|
|
|
|
// Check for token collision with appservices
|
|
if self
|
|
.services
|
|
.appservice
|
|
.find_from_token(token)
|
|
.await
|
|
.is_ok()
|
|
{
|
|
return Err!(Request(InvalidParam(
|
|
"Token conflicts with an existing appservice token"
|
|
)));
|
|
}
|
|
|
|
// Remove old token
|
|
if let Ok(old_token) = self.db.userdeviceid_token.qry(&key).await {
|
|
self.db.token_userdeviceid.remove(&old_token);
|
|
self.db.userdeviceid_tokenexpires.remove(&old_token);
|
|
// It will be removed from userdeviceid_token by the insert later
|
|
}
|
|
|
|
// Assign token to user device combination
|
|
self.db.userdeviceid_token.put_raw(key, token);
|
|
self.db.token_userdeviceid.raw_put(token, key);
|
|
|
|
if let Some(max_age) = token_max_age {
|
|
let expires = SystemTime::now()
|
|
.duration_since(SystemTime::UNIX_EPOCH)
|
|
.expect("system time should not be before the epoch")
|
|
.saturating_add(max_age)
|
|
.as_secs();
|
|
|
|
self.db.userdeviceid_tokenexpires.put(key, expires);
|
|
} else {
|
|
self.db.userdeviceid_tokenexpires.del(key);
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn add_one_time_key(
|
|
&self,
|
|
user_id: &UserId,
|
|
device_id: &DeviceId,
|
|
one_time_key_key: &OneTimeKeyId,
|
|
one_time_key_value: &Raw<OneTimeKey>,
|
|
) -> Result {
|
|
// All devices have metadata
|
|
// Only existing devices should be able to call this, but we shouldn't assert
|
|
// either...
|
|
let key = (user_id, device_id);
|
|
if self.db.userdeviceid_metadata.qry(&key).await.is_err() {
|
|
return Err!(Database(error!(
|
|
%user_id,
|
|
%device_id,
|
|
"User does not exist or device has no metadata."
|
|
)));
|
|
}
|
|
|
|
let mut key = user_id.as_bytes().to_vec();
|
|
key.push(0xFF);
|
|
key.extend_from_slice(device_id.as_bytes());
|
|
key.push(0xFF);
|
|
// TODO: Use DeviceKeyId::to_string when it's available (and update everything,
|
|
// because there are no wrapping quotation marks anymore)
|
|
key.extend_from_slice(
|
|
serde_json::to_string(one_time_key_key)
|
|
.expect("DeviceKeyId::to_string always works")
|
|
.as_bytes(),
|
|
);
|
|
|
|
self.db
|
|
.onetimekeyid_onetimekeys
|
|
.raw_put(key, Json(one_time_key_value));
|
|
|
|
let count = self.services.globals.next_count().unwrap();
|
|
self.db.userid_lastonetimekeyupdate.raw_put(user_id, count);
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Save a fallback key for the given user, device, and algorithm
|
|
/// This key will replace an existing fallback key
|
|
pub async fn add_fallback_key(
|
|
&self,
|
|
user_id: &UserId,
|
|
device_id: &DeviceId,
|
|
fallback_key_id: &OneTimeKeyId,
|
|
fallback_key: &Raw<OneTimeKey>,
|
|
used: bool,
|
|
) -> Result {
|
|
// All devices have metadata
|
|
// Only existing devices should be able to call this, but we shouldn't assert
|
|
// either...
|
|
let key = (user_id, device_id);
|
|
if self.db.userdeviceid_metadata.qry(&key).await.is_err() {
|
|
return Err!(Database(error!(
|
|
%user_id,
|
|
%device_id,
|
|
"User does not exist or device has no metadata."
|
|
)));
|
|
}
|
|
|
|
// There is one fallback key slot per user, per device, per algorithm
|
|
// Therefore we use this as the DB key for this column
|
|
let db_key = (user_id, device_id, fallback_key_id.algorithm());
|
|
|
|
self.db
|
|
.fallbackkeyid_fallbackkey
|
|
.put(db_key, (used, fallback_key_id.as_str(), Json(fallback_key)));
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn last_one_time_keys_update(&self, user_id: &UserId) -> u64 {
|
|
self.db
|
|
.userid_lastonetimekeyupdate
|
|
.get(user_id)
|
|
.await
|
|
.deserialized()
|
|
.unwrap_or(0)
|
|
}
|
|
|
|
pub async fn take_one_time_key(
|
|
&self,
|
|
user_id: &UserId,
|
|
device_id: &DeviceId,
|
|
key_algorithm: &OneTimeKeyAlgorithm,
|
|
) -> Result<(OwnedKeyId<OneTimeKeyAlgorithm, OneTimeKeyName>, Raw<OneTimeKey>)> {
|
|
let count = self.services.globals.next_count()?.to_be_bytes();
|
|
self.db.userid_lastonetimekeyupdate.insert(user_id, count);
|
|
|
|
let mut prefix = user_id.as_bytes().to_vec();
|
|
prefix.push(0xFF);
|
|
prefix.extend_from_slice(device_id.as_bytes());
|
|
prefix.push(0xFF);
|
|
prefix.push(b'"'); // Annoying quotation mark
|
|
prefix.extend_from_slice(key_algorithm.as_ref().as_bytes());
|
|
prefix.push(b':');
|
|
|
|
let one_time_key = self
|
|
.db
|
|
.onetimekeyid_onetimekeys
|
|
.raw_stream_prefix(&prefix)
|
|
.ignore_err()
|
|
.next()
|
|
.await
|
|
.map(|(key, val)| {
|
|
self.db.onetimekeyid_onetimekeys.remove(key);
|
|
|
|
let key = key
|
|
.rsplit(|&b| b == 0xFF)
|
|
.next()
|
|
.ok_or_else(|| err!(Database("OneTimeKeyId in db is invalid.")))
|
|
.unwrap();
|
|
|
|
let key = serde_json::from_slice(key)
|
|
.map_err(|e| err!(Database("OneTimeKeyId in db is invalid. {e}")))
|
|
.unwrap();
|
|
|
|
let val = serde_json::from_slice(val)
|
|
.map_err(|e| err!(Database("OneTimeKeys in db are invalid. {e}")))
|
|
.unwrap();
|
|
|
|
(key, val)
|
|
});
|
|
|
|
if let Some(result) = one_time_key {
|
|
return Ok(result);
|
|
}
|
|
|
|
// No one-time key has been found. Look for a fallback key.
|
|
|
|
let db_key = (user_id, device_id, key_algorithm);
|
|
|
|
let fallback_key = self
|
|
.db
|
|
.fallbackkeyid_fallbackkey
|
|
.qry(&db_key)
|
|
.await
|
|
.ok()
|
|
.and_then(|handle| {
|
|
handle
|
|
.deserialized::<(bool, OwnedOneTimeKeyId, Raw<OneTimeKey>)>()
|
|
.ok()
|
|
});
|
|
|
|
if let Some((used, fallback_key_id, fallback_key_value)) = fallback_key {
|
|
if !used {
|
|
// write the key to the database again to mark it as used
|
|
self.add_fallback_key(
|
|
user_id,
|
|
device_id,
|
|
&fallback_key_id,
|
|
&fallback_key_value,
|
|
true,
|
|
)
|
|
.await?;
|
|
}
|
|
return Ok((fallback_key_id, fallback_key_value));
|
|
}
|
|
|
|
Err(err!(Request(NotFound("No one-time key or fallback key found"))))
|
|
}
|
|
|
|
pub async fn count_one_time_keys(
|
|
&self,
|
|
user_id: &UserId,
|
|
device_id: &DeviceId,
|
|
) -> BTreeMap<OneTimeKeyAlgorithm, UInt> {
|
|
type KeyVal<'a> = ((Ignore, Ignore, &'a Unquoted), Ignore);
|
|
|
|
let mut algorithm_counts = BTreeMap::<OneTimeKeyAlgorithm, _>::new();
|
|
let query = (user_id, device_id);
|
|
self.db
|
|
.onetimekeyid_onetimekeys
|
|
.stream_prefix(&query)
|
|
.ignore_err()
|
|
.ready_for_each(|((Ignore, Ignore, device_key_id), Ignore): KeyVal<'_>| {
|
|
let one_time_key_id: &OneTimeKeyId = device_key_id
|
|
.as_str()
|
|
.try_into()
|
|
.expect("Invalid DeviceKeyID in database");
|
|
|
|
let count: &mut UInt = algorithm_counts
|
|
.entry(one_time_key_id.algorithm())
|
|
.or_default();
|
|
|
|
*count = count.saturating_add(1_u32.into());
|
|
})
|
|
.await;
|
|
|
|
algorithm_counts
|
|
}
|
|
|
|
pub async fn list_unused_fallback_key_types(
|
|
&self,
|
|
user_id: &UserId,
|
|
device_id: &DeviceId,
|
|
) -> Vec<OneTimeKeyAlgorithm> {
|
|
type KeyVal = ((String, String, OneTimeKeyAlgorithm), (bool, String, Ignore));
|
|
|
|
let mut query = user_id.as_bytes().to_vec();
|
|
query.push(0xFF);
|
|
query.extend_from_slice(device_id.as_bytes());
|
|
query.push(0xFF);
|
|
|
|
let mut unused_algorithms = Vec::new();
|
|
|
|
self.db
|
|
.fallbackkeyid_fallbackkey
|
|
.stream_prefix(&query)
|
|
.ignore_err()
|
|
.ready_for_each(|((_, _, fallback_key_algorithm), (used, ..)): KeyVal| {
|
|
if !used {
|
|
unused_algorithms.push(fallback_key_algorithm);
|
|
}
|
|
})
|
|
.await;
|
|
|
|
unused_algorithms
|
|
}
|
|
|
|
pub async fn add_device_keys(
|
|
&self,
|
|
user_id: &UserId,
|
|
device_id: &DeviceId,
|
|
device_keys: &Raw<DeviceKeys>,
|
|
) {
|
|
let key = (user_id, device_id);
|
|
|
|
self.db.keyid_key.put(key, Json(device_keys));
|
|
self.mark_device_key_update(user_id).await;
|
|
}
|
|
|
|
pub async fn add_cross_signing_keys(
|
|
&self,
|
|
user_id: &UserId,
|
|
master_key: &Option<Raw<CrossSigningKey>>,
|
|
self_signing_key: &Option<Raw<CrossSigningKey>>,
|
|
user_signing_key: &Option<Raw<CrossSigningKey>>,
|
|
notify: bool,
|
|
) -> Result<()> {
|
|
// TODO: Check signatures
|
|
let mut prefix = user_id.as_bytes().to_vec();
|
|
prefix.push(0xFF);
|
|
|
|
if let Some(master_key) = master_key {
|
|
let (master_key_key, _) = parse_master_key(user_id, master_key)?;
|
|
|
|
self.db
|
|
.keyid_key
|
|
.insert(&master_key_key, master_key.json().get().as_bytes());
|
|
|
|
self.db
|
|
.userid_masterkeyid
|
|
.insert(user_id.as_bytes(), &master_key_key);
|
|
}
|
|
|
|
// Self-signing key
|
|
if let Some(self_signing_key) = self_signing_key {
|
|
let mut self_signing_key_ids = self_signing_key
|
|
.deserialize()
|
|
.map_err(|e| err!(Request(InvalidParam("Invalid self signing key: {e:?}"))))?
|
|
.keys
|
|
.into_values();
|
|
|
|
let self_signing_key_id = self_signing_key_ids.next().ok_or(Error::BadRequest(
|
|
ErrorKind::InvalidParam,
|
|
"Self signing key contained no key.",
|
|
))?;
|
|
|
|
if self_signing_key_ids.next().is_some() {
|
|
return Err(Error::BadRequest(
|
|
ErrorKind::InvalidParam,
|
|
"Self signing key contained more than one key.",
|
|
));
|
|
}
|
|
|
|
let mut self_signing_key_key = prefix.clone();
|
|
self_signing_key_key.extend_from_slice(self_signing_key_id.as_bytes());
|
|
|
|
self.db
|
|
.keyid_key
|
|
.insert(&self_signing_key_key, self_signing_key.json().get().as_bytes());
|
|
|
|
self.db
|
|
.userid_selfsigningkeyid
|
|
.insert(user_id.as_bytes(), &self_signing_key_key);
|
|
}
|
|
|
|
// User-signing key
|
|
if let Some(user_signing_key) = user_signing_key {
|
|
let user_signing_key_id = parse_user_signing_key(user_signing_key)?;
|
|
|
|
let user_signing_key_key = (user_id, &user_signing_key_id);
|
|
self.db
|
|
.keyid_key
|
|
.put_raw(user_signing_key_key, user_signing_key.json().get().as_bytes());
|
|
|
|
self.db
|
|
.userid_usersigningkeyid
|
|
.raw_put(user_id, user_signing_key_key);
|
|
}
|
|
|
|
if notify {
|
|
self.mark_device_key_update(user_id).await;
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn sign_key(
|
|
&self,
|
|
target_id: &UserId,
|
|
key_id: &str,
|
|
signature: (String, String),
|
|
sender_id: &UserId,
|
|
) -> Result {
|
|
let key = (target_id, key_id);
|
|
|
|
let mut cross_signing_key: serde_json::Value = self
|
|
.db
|
|
.keyid_key
|
|
.qry(&key)
|
|
.await
|
|
.map_err(|_| err!(Request(InvalidParam("Tried to sign nonexistent key"))))?
|
|
.deserialized()
|
|
.map_err(|e| err!(Database(debug_warn!("key in keyid_key is invalid: {e:?}"))))?;
|
|
|
|
let signatures = cross_signing_key
|
|
.get_mut("signatures")
|
|
.ok_or_else(|| {
|
|
err!(Database(debug_warn!("key in keyid_key has no signatures field")))
|
|
})?
|
|
.as_object_mut()
|
|
.ok_or_else(|| {
|
|
err!(Database(debug_warn!("key in keyid_key has invalid signatures field.")))
|
|
})?
|
|
.entry(sender_id.to_string())
|
|
.or_insert_with(|| serde_json::Map::new().into());
|
|
|
|
signatures
|
|
.as_object_mut()
|
|
.ok_or_else(|| {
|
|
err!(Database(debug_warn!("signatures in keyid_key for a user is invalid.")))
|
|
})?
|
|
.insert(signature.0, signature.1.into());
|
|
|
|
let key = (target_id, key_id);
|
|
self.db.keyid_key.put(key, Json(cross_signing_key));
|
|
|
|
self.mark_device_key_update(target_id).await;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[inline]
|
|
pub fn keys_changed<'a>(
|
|
&'a self,
|
|
user_id: &'a UserId,
|
|
from: Option<u64>,
|
|
to: Option<u64>,
|
|
) -> impl Stream<Item = OwnedUserId> + Send + 'a {
|
|
self.keys_changed_user_or_room(user_id.as_str(), from, to)
|
|
.map(|(user_id, ..)| user_id)
|
|
}
|
|
|
|
#[inline]
|
|
pub fn room_keys_changed<'a>(
|
|
&'a self,
|
|
room_id: &'a RoomId,
|
|
from: Option<u64>,
|
|
to: Option<u64>,
|
|
) -> impl Stream<Item = (OwnedUserId, u64)> + Send + 'a {
|
|
self.keys_changed_user_or_room(room_id.as_str(), from, to)
|
|
}
|
|
|
|
fn keys_changed_user_or_room<'a>(
|
|
&'a self,
|
|
user_or_room_id: &'a str,
|
|
from: Option<u64>,
|
|
to: Option<u64>,
|
|
) -> impl Stream<Item = (OwnedUserId, u64)> + Send + 'a {
|
|
type KeyVal<'a> = ((&'a str, u64), OwnedUserId);
|
|
|
|
let from = from.unwrap_or(0);
|
|
let to = to.unwrap_or(u64::MAX);
|
|
let start = (user_or_room_id, from.saturating_add(1));
|
|
self.db
|
|
.keychangeid_userid
|
|
.stream_from(&start)
|
|
.ignore_err()
|
|
.ready_take_while(move |((prefix, count), _): &KeyVal<'_>| {
|
|
*prefix == user_or_room_id && *count <= to
|
|
})
|
|
.map(|((_, count), user_id): KeyVal<'_>| (user_id, count))
|
|
}
|
|
|
|
pub async fn mark_device_key_update(&self, user_id: &UserId) {
|
|
let count = self.services.globals.next_count().unwrap();
|
|
|
|
self.services
|
|
.state_cache
|
|
.rooms_joined(user_id)
|
|
// Don't send key updates to unencrypted rooms
|
|
.filter_map(async |room_id| {
|
|
if self.services.state_accessor.is_encrypted_room(&room_id).await {
|
|
Some(room_id)
|
|
} else {
|
|
None
|
|
}
|
|
})
|
|
.ready_for_each(|room_id| {
|
|
let key = (room_id, count);
|
|
self.db.keychangeid_userid.put_raw(key, user_id);
|
|
})
|
|
.await;
|
|
|
|
let key = (user_id, count);
|
|
self.db.keychangeid_userid.put_raw(key, user_id);
|
|
}
|
|
|
|
pub async fn get_device_keys<'a>(
|
|
&'a self,
|
|
user_id: &'a UserId,
|
|
device_id: &DeviceId,
|
|
) -> Result<Raw<DeviceKeys>> {
|
|
let key_id = (user_id, device_id);
|
|
self.db.keyid_key.qry(&key_id).await.deserialized()
|
|
}
|
|
|
|
pub async fn get_key<F>(
|
|
&self,
|
|
key_id: &[u8],
|
|
sender_user: Option<&UserId>,
|
|
user_id: &UserId,
|
|
allowed_signatures: &F,
|
|
) -> Result<Raw<CrossSigningKey>>
|
|
where
|
|
F: Fn(&UserId) -> bool + Send + Sync,
|
|
{
|
|
let key: serde_json::Value = self.db.keyid_key.get(key_id).await.deserialized()?;
|
|
|
|
let cleaned = clean_signatures(key, sender_user, user_id, allowed_signatures)?;
|
|
let raw_value = serde_json::value::to_raw_value(&cleaned)?;
|
|
Ok(Raw::from_json(raw_value))
|
|
}
|
|
|
|
pub async fn get_master_key<F>(
|
|
&self,
|
|
sender_user: Option<&UserId>,
|
|
user_id: &UserId,
|
|
allowed_signatures: &F,
|
|
) -> Result<Raw<CrossSigningKey>>
|
|
where
|
|
F: Fn(&UserId) -> bool + Send + Sync,
|
|
{
|
|
let key_id = self.db.userid_masterkeyid.get(user_id).await?;
|
|
|
|
self.get_key(&key_id, sender_user, user_id, allowed_signatures)
|
|
.await
|
|
}
|
|
|
|
pub async fn get_self_signing_key<F>(
|
|
&self,
|
|
sender_user: Option<&UserId>,
|
|
user_id: &UserId,
|
|
allowed_signatures: &F,
|
|
) -> Result<Raw<CrossSigningKey>>
|
|
where
|
|
F: Fn(&UserId) -> bool + Send + Sync,
|
|
{
|
|
let key_id = self.db.userid_selfsigningkeyid.get(user_id).await?;
|
|
|
|
self.get_key(&key_id, sender_user, user_id, allowed_signatures)
|
|
.await
|
|
}
|
|
|
|
pub async fn get_user_signing_key(&self, user_id: &UserId) -> Result<Raw<CrossSigningKey>> {
|
|
self.db
|
|
.userid_usersigningkeyid
|
|
.get(user_id)
|
|
.and_then(|key_id| self.db.keyid_key.get(&*key_id))
|
|
.await
|
|
.deserialized()
|
|
}
|
|
|
|
pub async fn add_to_device_event(
|
|
&self,
|
|
sender: &UserId,
|
|
target_user_id: &UserId,
|
|
target_device_id: &DeviceId,
|
|
event_type: &str,
|
|
content: serde_json::Value,
|
|
) {
|
|
let count = self.services.globals.next_count().unwrap();
|
|
|
|
let key = (target_user_id, target_device_id, count);
|
|
self.db.todeviceid_events.put(
|
|
key,
|
|
Json(json!({
|
|
"type": event_type,
|
|
"sender": sender,
|
|
"content": content,
|
|
})),
|
|
);
|
|
}
|
|
|
|
pub fn get_to_device_events<'a>(
|
|
&'a self,
|
|
user_id: &'a UserId,
|
|
device_id: &'a DeviceId,
|
|
since: Option<u64>,
|
|
to: Option<u64>,
|
|
) -> impl Stream<Item = (u64, Raw<AnyToDeviceEvent>)> + Send + 'a {
|
|
type Key = (OwnedUserId, OwnedDeviceId, u64);
|
|
|
|
let from = (user_id, device_id, since.map_or(0, |since| since.saturating_add(1)));
|
|
|
|
self.db
|
|
.todeviceid_events
|
|
.stream_from(&from)
|
|
.ignore_err()
|
|
.ready_take_while(move |((user_id_, device_id_, count), _): &(Key, _)| {
|
|
user_id == *user_id_
|
|
&& device_id == *device_id_
|
|
&& to.is_none_or(|to| *count <= to)
|
|
})
|
|
.map(|((_, _, count), event)| (count, event))
|
|
}
|
|
|
|
pub async fn remove_to_device_events<Until>(
|
|
&self,
|
|
user_id: &UserId,
|
|
device_id: &DeviceId,
|
|
until: Until,
|
|
) where
|
|
Until: Into<Option<u64>> + Send,
|
|
{
|
|
type Key = (OwnedUserId, OwnedDeviceId, u64);
|
|
|
|
let until = until.into().unwrap_or(u64::MAX);
|
|
let from = (user_id, device_id, until);
|
|
self.db
|
|
.todeviceid_events
|
|
.rev_keys_from(&from)
|
|
.ignore_err()
|
|
.ready_take_while(move |(user_id_, device_id_, _): &Key| {
|
|
user_id == *user_id_ && device_id == *device_id_
|
|
})
|
|
.ready_for_each(|key: Key| {
|
|
self.db.todeviceid_events.del(key);
|
|
})
|
|
.await;
|
|
}
|
|
|
|
/// Updates device metadata and increments the device list version.
|
|
pub async fn update_device_metadata(
|
|
&self,
|
|
user_id: &UserId,
|
|
device_id: &DeviceId,
|
|
device: &Device,
|
|
) -> Result<()> {
|
|
increment(&self.db.userid_devicelistversion, user_id.as_bytes());
|
|
self.update_device_metadata_no_increment(user_id, device_id, device)
|
|
.await
|
|
}
|
|
|
|
// Updates device metadata without incrementing the device list version.
|
|
// This is namely used for updating the last_seen_ip and last_seen_ts values,
|
|
// as those do not need a device list version bump due to them not being
|
|
// relevant to other consumers.
|
|
pub async fn update_device_metadata_no_increment(
|
|
&self,
|
|
user_id: &UserId,
|
|
device_id: &DeviceId,
|
|
device: &Device,
|
|
) -> Result<()> {
|
|
let key = (user_id, device_id);
|
|
self.db.userdeviceid_metadata.put(key, Json(device));
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn update_device_last_seen(
|
|
&self,
|
|
user_id: &UserId,
|
|
device_id: Option<&DeviceId>,
|
|
ip: IpAddr,
|
|
) {
|
|
let now = MilliSecondsSinceUnixEpoch::now();
|
|
if let Some(device_id) = device_id {
|
|
if let Ok(mut device) = self.get_device_metadata(user_id, device_id).await {
|
|
device.last_seen_ip = Some(ip.to_string());
|
|
// If the last update was less than 10 seconds ago, don't update the timestamp
|
|
if let Some(prev) = device.last_seen_ts {
|
|
if now.get().saturating_sub(prev.get()) < uint!(10_000) {
|
|
return;
|
|
}
|
|
}
|
|
device.last_seen_ts = Some(now);
|
|
|
|
self.update_device_metadata_no_increment(user_id, device_id, &device)
|
|
.await
|
|
.ok();
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Get device metadata.
|
|
pub async fn get_device_metadata(
|
|
&self,
|
|
user_id: &UserId,
|
|
device_id: &DeviceId,
|
|
) -> Result<Device> {
|
|
self.db
|
|
.userdeviceid_metadata
|
|
.qry(&(user_id, device_id))
|
|
.await
|
|
.deserialized()
|
|
}
|
|
|
|
pub async fn get_devicelist_version(&self, user_id: &UserId) -> Result<u64> {
|
|
self.db
|
|
.userid_devicelistversion
|
|
.get(user_id)
|
|
.await
|
|
.deserialized()
|
|
}
|
|
|
|
pub fn all_devices_metadata<'a>(
|
|
&'a self,
|
|
user_id: &'a UserId,
|
|
) -> impl Stream<Item = Device> + Send + 'a {
|
|
let key = (user_id, Interfix);
|
|
self.db
|
|
.userdeviceid_metadata
|
|
.stream_prefix(&key)
|
|
.ignore_err()
|
|
.map(|(_, val): (Ignore, Device)| val)
|
|
}
|
|
|
|
/// Creates a new sync filter. Returns the filter id.
|
|
pub fn create_filter(&self, user_id: &UserId, filter: &FilterDefinition) -> String {
|
|
let filter_id = utils::random_string(4);
|
|
|
|
let key = (user_id, &filter_id);
|
|
self.db.userfilterid_filter.put(key, Json(filter));
|
|
|
|
filter_id
|
|
}
|
|
|
|
pub async fn get_filter(
|
|
&self,
|
|
user_id: &UserId,
|
|
filter_id: &str,
|
|
) -> Result<FilterDefinition> {
|
|
let key = (user_id, filter_id);
|
|
self.db.userfilterid_filter.qry(&key).await.deserialized()
|
|
}
|
|
|
|
/// Creates an OpenID token, which can be used to prove that a user has
|
|
/// access to an account (primarily for integrations)
|
|
pub fn create_openid_token(&self, user_id: &UserId, token: &str) -> Result<u64> {
|
|
use std::num::Saturating as Sat;
|
|
|
|
let expires_in = self.services.config.openid_token_ttl;
|
|
let expires_at = Sat(utils::millis_since_unix_epoch()) + Sat(expires_in) * Sat(1000);
|
|
|
|
let mut value = expires_at.0.to_be_bytes().to_vec();
|
|
value.extend_from_slice(user_id.as_bytes());
|
|
|
|
self.db
|
|
.openidtoken_expiresatuserid
|
|
.insert(token.as_bytes(), value.as_slice());
|
|
|
|
Ok(expires_in)
|
|
}
|
|
|
|
/// Find out which user an OpenID access token belongs to.
|
|
pub async fn find_from_openid_token(&self, token: &str) -> Result<OwnedUserId> {
|
|
let Ok(value) = self.db.openidtoken_expiresatuserid.get(token).await else {
|
|
return Err!(Request(Unauthorized("OpenID token is unrecognised")));
|
|
};
|
|
|
|
let (expires_at_bytes, user_bytes) = value.split_at(0_u64.to_be_bytes().len());
|
|
let expires_at =
|
|
u64::from_be_bytes(expires_at_bytes.try_into().map_err(|e| {
|
|
err!(Database("expires_at in openid_userid is invalid u64. {e}"))
|
|
})?);
|
|
|
|
if expires_at < utils::millis_since_unix_epoch() {
|
|
debug_warn!("OpenID token is expired, removing");
|
|
self.db.openidtoken_expiresatuserid.remove(token.as_bytes());
|
|
|
|
return Err!(Request(Unauthorized("OpenID token is expired")));
|
|
}
|
|
|
|
let user_string = utils::string_from_bytes(user_bytes)
|
|
.map_err(|e| err!(Database("User ID in openid_userid is invalid unicode. {e}")))?;
|
|
|
|
OwnedUserId::try_from(user_string)
|
|
.map_err(|e| err!(Database("User ID in openid_userid is invalid. {e}")))
|
|
}
|
|
|
|
/// Creates a short-lived login token, which can be used to log in using the
|
|
/// `m.login.token` mechanism.
|
|
pub fn create_login_token(&self, user_id: &UserId, token: &str) -> u64 {
|
|
use std::num::Saturating as Sat;
|
|
|
|
let expires_in = self.services.config.login_token_ttl;
|
|
let expires_at = Sat(utils::millis_since_unix_epoch()) + Sat(expires_in);
|
|
|
|
let value = (expires_at.0, user_id);
|
|
self.db.logintoken_expiresatuserid.raw_put(token, value);
|
|
|
|
expires_in
|
|
}
|
|
|
|
/// Find out which user a login token belongs to.
|
|
/// Removes the token to prevent double-use attacks.
|
|
pub async fn find_from_login_token(&self, token: &str) -> Result<OwnedUserId> {
|
|
let Ok(value) = self.db.logintoken_expiresatuserid.get(token).await else {
|
|
return Err!(Request(Forbidden("Login token is unrecognised")));
|
|
};
|
|
let (expires_at, user_id): (u64, OwnedUserId) = value.deserialized()?;
|
|
|
|
if expires_at < utils::millis_since_unix_epoch() {
|
|
trace!(%user_id, ?token, "Removing expired login token");
|
|
|
|
self.db.logintoken_expiresatuserid.remove(token);
|
|
|
|
return Err!(Request(Forbidden("Login token is expired")));
|
|
}
|
|
|
|
self.db.logintoken_expiresatuserid.remove(token);
|
|
|
|
Ok(user_id)
|
|
}
|
|
|
|
/// Gets a specific user profile key
|
|
pub async fn profile_key(
|
|
&self,
|
|
user_id: &UserId,
|
|
profile_key: &str,
|
|
) -> Result<serde_json::Value> {
|
|
let key = (user_id, profile_key);
|
|
self.db
|
|
.useridprofilekey_value
|
|
.qry(&key)
|
|
.await
|
|
.and_then(|handle| serde_json::from_slice(&handle).map_err(Into::into))
|
|
}
|
|
|
|
/// Gets all the user's profile keys and values in an iterator
|
|
pub fn all_profile_keys<'a>(
|
|
&'a self,
|
|
user_id: &'a UserId,
|
|
) -> impl Stream<Item = (String, serde_json::Value)> + 'a + Send {
|
|
type KeyVal<'a> = ((Ignore, String), &'a [u8]);
|
|
|
|
let prefix = (user_id, Interfix);
|
|
self.db
|
|
.useridprofilekey_value
|
|
.stream_prefix(&prefix)
|
|
.ignore_err()
|
|
.map(|((_, key), value): KeyVal<'_>| Ok((key, serde_json::from_slice(value)?)))
|
|
.ignore_err()
|
|
}
|
|
|
|
/// Sets a new profile key value, removes the key if value is None
|
|
pub fn set_profile_key(
|
|
&self,
|
|
user_id: &UserId,
|
|
profile_key: &str,
|
|
profile_key_value: Option<serde_json::Value>,
|
|
) {
|
|
let key = (user_id, profile_key);
|
|
|
|
if let Some(value) = profile_key_value {
|
|
self.db.useridprofilekey_value.put(key, Json(value));
|
|
} else {
|
|
self.db.useridprofilekey_value.del(key);
|
|
}
|
|
}
|
|
|
|
/// Clears all profile data for a user, including display name and avatar
|
|
/// url.
|
|
pub async fn clear_profile(&self, user_id: &UserId) {
|
|
self.set_displayname(user_id, None);
|
|
self.set_avatar_url(user_id, None);
|
|
self.all_profile_keys(user_id)
|
|
.ready_for_each(|(key, _)| self.set_profile_key(user_id, &key, None))
|
|
.await;
|
|
}
|
|
}
|
|
|
|
pub fn parse_master_key(
|
|
user_id: &UserId,
|
|
master_key: &Raw<CrossSigningKey>,
|
|
) -> Result<(Vec<u8>, CrossSigningKey)> {
|
|
let mut prefix = user_id.as_bytes().to_vec();
|
|
prefix.push(0xFF);
|
|
|
|
let master_key = master_key
|
|
.deserialize()
|
|
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid master key"))?;
|
|
let mut master_key_ids = master_key.keys.values();
|
|
let master_key_id = master_key_ids
|
|
.next()
|
|
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Master key contained no key."))?;
|
|
if master_key_ids.next().is_some() {
|
|
return Err(Error::BadRequest(
|
|
ErrorKind::InvalidParam,
|
|
"Master key contained more than one key.",
|
|
));
|
|
}
|
|
let mut master_key_key = prefix.clone();
|
|
master_key_key.extend_from_slice(master_key_id.as_bytes());
|
|
Ok((master_key_key, master_key))
|
|
}
|
|
|
|
pub fn parse_user_signing_key(user_signing_key: &Raw<CrossSigningKey>) -> Result<String> {
|
|
let mut user_signing_key_ids = user_signing_key
|
|
.deserialize()
|
|
.map_err(|_| err!(Request(InvalidParam("Invalid user signing key"))))?
|
|
.keys
|
|
.into_values();
|
|
|
|
let user_signing_key_id = user_signing_key_ids
|
|
.next()
|
|
.ok_or(err!(Request(InvalidParam("User signing key contained no key."))))?;
|
|
|
|
if user_signing_key_ids.next().is_some() {
|
|
return Err!(Request(InvalidParam("User signing key contained more than one key.")));
|
|
}
|
|
|
|
Ok(user_signing_key_id)
|
|
}
|
|
|
|
/// Ensure that a user only sees signatures from themselves and the target user
|
|
fn clean_signatures<F>(
|
|
mut cross_signing_key: serde_json::Value,
|
|
sender_user: Option<&UserId>,
|
|
user_id: &UserId,
|
|
allowed_signatures: &F,
|
|
) -> Result<serde_json::Value>
|
|
where
|
|
F: Fn(&UserId) -> bool + Send + Sync,
|
|
{
|
|
if let Some(signatures) = cross_signing_key
|
|
.get_mut("signatures")
|
|
.and_then(|v| v.as_object_mut())
|
|
{
|
|
// Don't allocate for the full size of the current signatures, but require
|
|
// at most one resize if nothing is dropped
|
|
let new_capacity = signatures.len() / 2;
|
|
for (user, signature) in
|
|
mem::replace(signatures, serde_json::Map::with_capacity(new_capacity))
|
|
{
|
|
let sid = <&UserId>::try_from(user.as_str())
|
|
.map_err(|_| Error::bad_database("Invalid user ID in database."))?;
|
|
if sender_user == Some(user_id) || sid == user_id || allowed_signatures(sid) {
|
|
signatures.insert(user, signature);
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(cross_signing_key)
|
|
}
|
|
|
|
//TODO: this is an ABA
|
|
fn increment(db: &Arc<Map>, key: &[u8]) {
|
|
let old = db.get_blocking(key);
|
|
let new = utils::increment(old.ok().as_deref());
|
|
db.insert(key, new);
|
|
}
|