Compare commits

...

9 Commits

Author SHA1 Message Date
timedout 1ad0bd5d0d fix: Don't be so aggressive when validating policy server signatures 2026-05-26 08:20:06 -07:00
Jacob Taylor 8bea04b1ed feat: Merge ginger/oauth 2026-05-26 08:14:45 -07:00
Jacob Taylor 2e08ffe646 feat: Merge ginger/kill-sync-tokens 2026-05-26 08:14:16 -07:00
Jacob Taylor 3f83989c83 fix: Pre-Commit Lint Compliance Maneuver 2026-05-26 08:14:16 -07:00
Jacob Taylor e898d147ce feat: Bump one cache a bit 2026-05-26 08:14:16 -07:00
Jacob Taylor 4379d662b8 upgrade some logs to info 2026-05-26 08:14:16 -07:00
Jacob Taylor 5958c6c2dd exponential backoff is now just bees. did you want bees? no? well you have them now. congrats 2026-05-26 08:14:16 -07:00
Jacob Taylor 0b4c91ca2b enable converged 6g at the edge in continuwuity
sender_workers scaling. this time, with feeling!
2026-05-26 08:14:16 -07:00
Jacob Taylor b92f693717 bump the number of allowed immutable memtables by 1, to allow for greater flood protection
this should probably not be applied if you have rocksdb_atomic_flush = false (the default)
2026-05-26 08:14:16 -07:00
132 changed files with 6483 additions and 1631 deletions
Generated
+67
View File
@@ -1088,6 +1088,7 @@ dependencies = [
"serde",
"serde-saphyr",
"serde_json",
"serde_urlencoded",
"sha2 0.11.0",
"termimad",
"tokio",
@@ -1107,18 +1108,29 @@ dependencies = [
"axum",
"axum-extra",
"base64 0.22.1",
"conduwuit_api",
"conduwuit_build_metadata",
"conduwuit_core",
"conduwuit_database",
"conduwuit_service",
"form_urlencoded",
"futures",
"lettre",
"memory-serve",
"rand 0.10.1",
"recaptcha-verify",
"reqwest 0.12.28",
"ruma",
"serde",
"serde_json",
"serde_urlencoded",
"thiserror",
"tower-http",
"tower-sec-fetch",
"tower-sessions",
"tower-sessions-core",
"tracing",
"url",
"validator",
]
@@ -1526,6 +1538,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7cd812cc2bc1d69d4764bd80df88b4317eaef9e773c75226407d9bc0876b211c"
dependencies = [
"powerfmt",
"serde_core",
]
[[package]]
@@ -5543,6 +5556,22 @@ dependencies = [
"tracing",
]
[[package]]
name = "tower-cookies"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "151b5a3e3c45df17466454bb74e9ecedecc955269bdedbf4d150dfa393b55a36"
dependencies = [
"axum-core",
"cookie",
"futures-util",
"http",
"parking_lot",
"pin-project-lite",
"tower-layer",
"tower-service",
]
[[package]]
name = "tower-http"
version = "0.6.11"
@@ -5591,6 +5620,44 @@ version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3"
[[package]]
name = "tower-sessions"
version = "0.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "518dca34b74a17cadfcee06e616a09d2bd0c3984eff1769e1e76d58df978fc78"
dependencies = [
"async-trait",
"http",
"time",
"tokio",
"tower-cookies",
"tower-layer",
"tower-service",
"tower-sessions-core",
"tracing",
]
[[package]]
name = "tower-sessions-core"
version = "0.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "568531ec3dfcf3ffe493de1958ae5662a0284ac5d767476ecdb6a34ff8c6b06c"
dependencies = [
"async-trait",
"axum-core",
"base64 0.22.1",
"futures",
"http",
"parking_lot",
"rand 0.9.4",
"serde",
"serde_json",
"thiserror",
"time",
"tokio",
"tracing",
]
[[package]]
name = "tracing"
version = "0.1.44"
+3
View File
@@ -559,6 +559,9 @@ features = ["std"]
[workspace.dependencies.nonzero_ext]
version = "0.3.0"
[workspace.dependencies.serde_urlencoded]
version = "0.7.1"
#
# Patches
#
+1
View File
@@ -0,0 +1 @@
Users may now be forbidden from deactivating their own accounts with the new `allow_deactivation` config option. Contributed by @ginger.
+1
View File
@@ -0,0 +1 @@
Added support for Matrix 1.16's `state_after` feature, allowing clients which understand it to sync room state changes more reliably. Contributed by @ginger.
+1
View File
@@ -0,0 +1 @@
Added support for authenticating clients using the new OAuth 2.0 login API. Contributed by @ginger.
+1
View File
@@ -0,0 +1 @@
Adjusted legacy sync logic to no longer use the `roomsynctoken_shortstatehash` database column. Once this change has been confirmed to be stable and reliable, a future update will remove it entirely, significantly decreasing database sizes. Contributed by @ginger.
+39 -13
View File
@@ -521,17 +521,15 @@
#
#recaptcha_private_site_key =
# Policy documents, such as terms and conditions or a privacy policy,
# which users must agree to when registering an account.
# Controls whether users are allowed to deactivate their own accounts
# through the account management panel or their Matrix clients. Server
# admins can always deactivate users using the relevant admin commands.
#
# Example:
# ```ignore
# [global.registration_terms.privacy_policy]
# en = { name = "Privacy Policy", url = "https://homeserver.example/en/privacy_policy.html" }
# es = { name = "Política de Privacidad", url = "https://homeserver.example/es/privacy_policy.html" }
# ```
# Note that, in some jurisdictions, you may be legally required to honor
# users who request to deactivate their accounts if you set this option
# to `false`.
#
#registration_terms = {}
#allow_deactivation = true
# Controls whether encrypted rooms and events are allowed.
#
@@ -1795,11 +1793,9 @@
#stream_amplification = 1024
# Number of sender task workers; determines sender parallelism. Default is
# '0' which means the value is determined internally, likely matching the
# number of tokio worker-threads or number of cores, etc. Override by
# setting a non-zero value.
# core count. Override by setting a different value.
#
#sender_workers = 0
#sender_workers = core count
# Enables listener sockets; can be set to false to disable listening. This
# option is intended for developer/diagnostic purposes only.
@@ -1987,3 +1983,33 @@
# `require_email_for_registration`.
#
#require_email_for_token_registration = false
#[global.registration_terms]
# The language code to provide to clients along with the policy documents.
#
#language = "en"
# Policy documents, such as terms and conditions or a privacy policy,
# which users must agree to when registering an account.
#
# Example:
# ```ignore
# [global.registration_terms.documents]
# privacy_policy = { name = "Privacy Policy", url = "https://homeserver.example/en/privacy_policy.html" }
# ```
#
#documents = {}
#[global.oauth]
# The compatibility mode to use for OAuth.
#
# - "disabled": OAuth will be unavailable. Users will only be able to log
# in using legacy authentication.
# - "hybrid": OAuth and legacy authentication will both be available. Some
# clients may only use one or the other.
# - "exclusive": Only OAuth will be available. Clients which require
# legacy authentication will be unable to log in.
#
#compatibility_mode = "hybrid"
+1 -1
View File
@@ -16,7 +16,7 @@ use crate::{
};
#[derive(Debug, Parser)]
#[command(name = conduwuit_core::name(), version = conduwuit_core::version())]
#[command(name = conduwuit_core::BRANDING, version = conduwuit_core::version())]
pub enum AdminCommand {
#[command(subcommand)]
/// Commands for managing appservices
+25 -2
View File
@@ -30,14 +30,37 @@ pub(super) async fn issue_token(&self, expires: super::TokenExpires) -> Result {
.issue_token(self.sender_or_service_user().into(), expires);
self.write_str(&format!(
"New registration token issued: `{token}`. {}.",
"New registration token issued: `{token}` . {}.",
if let Some(expires) = info.expires {
format!("{expires}")
} else {
"Never expires".to_owned()
}
))
.await
.await?;
if self
.services
.config
.oauth
.compatibility_mode
.oauth_available()
{
self.write_str(&format!(
"\nInvite link using this token: {}",
self.services
.config
.get_client_domain()
.join(&format!(
"{}/account/register/?flow=trusted&token={token}",
conduwuit::ROUTE_PREFIX
))
.unwrap()
))
.await?;
}
Ok(())
}
#[admin_command]
+13 -149
View File
@@ -1,13 +1,10 @@
use std::{
collections::{BTreeMap, HashSet},
fmt::Write as _,
};
use std::collections::{BTreeMap, HashSet};
use api::client::{
full_user_deactivate, leave_room, recreate_push_rules_and_return, remote_leave_room,
};
use conduwuit::{
Err, Result, debug_warn, error, info,
Err, Result, debug_warn, info,
matrix::{Event, pdu::PartialPdu},
utils::{self, ReadyExt},
warn,
@@ -53,130 +50,22 @@ pub(super) async fn list_users(&self) -> Result {
#[admin_command]
pub(super) async fn create_user(&self, username: String, password: Option<String>) -> Result {
// Validate user id
let user_id = parse_local_user_id(self.services, &username)?;
if let Err(e) = user_id.validate_strict() {
if self.services.config.emergency_password.is_none() {
return Err!("Username {user_id} contains disallowed characters or spaces: {e}");
}
}
if self.services.users.exists(&user_id).await {
return Err!("User {user_id} already exists");
}
let password = password.unwrap_or_else(|| utils::random_string(AUTO_GEN_PASSWORD_LENGTH));
// Create user
self.services
.users
.create(&user_id, Some(HashedPassword::new(&password)?))
.await?;
// Default to pretty displayname
let mut displayname = user_id.localpart().to_owned();
// If `new_user_displayname_suffix` is set, registration will push whatever
// content is set to the user's display name with a space before it
if !self
let user_id = self
.services
.server
.config
.new_user_displayname_suffix
.is_empty()
{
write!(displayname, " {}", self.services.server.config.new_user_displayname_suffix)?;
}
.users
.determine_registration_user_id(Some(username), None, None)
.await?;
let password = HashedPassword::new(
&password.unwrap_or_else(|| utils::random_string(AUTO_GEN_PASSWORD_LENGTH)),
)?;
self.services
.users
.set_displayname(&user_id, Some(displayname));
.create_local_account(&user_id, password, None)
.await;
// Initial account data
self.services
.account_data
.update(
None,
&user_id,
ruma::events::GlobalAccountDataEventType::PushRules
.to_string()
.into(),
&serde_json::to_value(ruma::events::push_rules::PushRulesEvent::new(
ruma::events::push_rules::PushRulesEventContent::new(
ruma::push::Ruleset::server_default(&user_id),
),
))
.unwrap(),
)
.await?;
if !self.services.server.config.auto_join_rooms.is_empty() {
for room in &self.services.server.config.auto_join_rooms {
let Ok(room_id) = self.services.rooms.alias.resolve(room).await else {
error!(
%user_id,
"Failed to resolve room alias to room ID when attempting to auto join {room}, skipping"
);
continue;
};
if !self
.services
.rooms
.state_cache
.server_in_room(self.services.globals.server_name(), &room_id)
.await
{
warn!(
"Skipping room {room} to automatically join as we have never joined before."
);
continue;
}
if let Some(room_server_name) = room.server_name() {
match self
.services
.rooms
.membership
.join_room(
&user_id,
&room_id,
Some("Automatically joining this room upon registration".to_owned()),
&[
self.services.globals.server_name().to_owned(),
room_server_name.to_owned(),
],
)
.await
{
| Ok(_response) => {
info!("Automatically joined room {room} for user {user_id}");
},
| Err(e) => {
// don't return this error so we don't fail registrations
error!(
"Failed to automatically join room {room} for user {user_id}: {e}"
);
self.services
.admin
.send_text(&format!(
"Failed to automatically join room {room} for user {user_id}: \
{e}"
))
.await;
},
}
}
}
}
// we dont add a device since we're not the user, just the creator
// Make the first user to register an administrator and disable first-run mode.
self.services.firstrun.empower_first_user(&user_id).await?;
self.write_str(&format!("Created user with user_id: {user_id} and password: `{password}`"))
.await
self.write_str(&format!("Created user {user_id}")).await
}
#[admin_command]
@@ -302,31 +191,6 @@ pub(super) async fn reset_password(
Ok(())
}
#[admin_command]
pub(super) async fn issue_password_reset_link(&self, username: String) -> Result {
use conduwuit_service::password_reset::{PASSWORD_RESET_PATH, RESET_TOKEN_QUERY_PARAM};
self.bail_restricted()?;
let mut reset_url = self
.services
.config
.get_client_domain()
.join(PASSWORD_RESET_PATH)
.unwrap();
let user_id = parse_local_user_id(self.services, &username)?;
let token = self.services.password_reset.issue_token(user_id).await?;
reset_url
.query_pairs_mut()
.append_pair(RESET_TOKEN_QUERY_PARAM, &token.token);
self.write_str(&format!("Password reset link issued for {username}: {reset_url}"))
.await?;
Ok(())
}
#[admin_command]
pub(super) async fn deactivate_all(&self, no_leave_rooms: bool, force: bool) -> Result {
if self.body.len() < 2
-6
View File
@@ -29,12 +29,6 @@ pub enum UserCommand {
password: Option<String>,
},
/// Issue a self-service password reset link for a user.
IssuePasswordResetLink {
/// Username of the user who may use the link
username: String,
},
/// Get a user's associated email address.
GetEmail {
user_id: String,
+10 -3
View File
@@ -24,7 +24,7 @@ use ruma::{
power_levels::RoomPowerLevelsEventContent,
},
};
use service::{mailer::messages, uiaa::Identity, users::HashedPassword};
use service::{mailer::messages, uiaa::UiaaInitiator, users::HashedPassword};
use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH};
use crate::Ruma;
@@ -121,7 +121,7 @@ pub(crate) async fn change_password_route(
&body.auth,
vec![AuthFlow::new(vec![AuthType::Password])],
Box::default(),
Some(Identity::from_user_id(user_id)),
Some(UiaaInitiator::new(user_id, body.sender_device())),
)
.await?
} else {
@@ -270,10 +270,17 @@ pub(crate) async fn deactivate_route(
.as_ref()
.ok_or_else(|| err!(Request(MissingToken("Missing access token."))))?;
if !services.config.allow_deactivation {
return Err!(Request(Unauthorized(
"You may not deactivate your own account. Contact your server's administrator for \
assistance."
)));
}
// Prompt the user to confirm with their password using UIAA
let _ = services
.uiaa
.authenticate_password(&body.auth, Some(Identity::from_user_id(sender_user)))
.authenticate_password(&body.auth, sender_user, body.sender_device(), None)
.await?;
// Remove profile pictures and display name
+56 -294
View File
@@ -1,17 +1,15 @@
use std::{collections::HashMap, fmt::Write};
use std::collections::HashMap;
use axum::extract::State;
use axum_client_ip::ClientIp;
use conduwuit::{
Err, Result, debug_info, error, info,
Err, Result, debug_info, info,
utils::{self},
warn,
};
use conduwuit_service::Services;
use futures::{FutureExt, StreamExt};
use futures::StreamExt;
use lettre::{Address, message::Mailbox};
use ruma::{
OwnedUserId, UserId,
api::client::{
account::{
register::{self, LoginType, RegistrationKind},
@@ -20,11 +18,6 @@ use ruma::{
uiaa::{AuthFlow, AuthType},
},
assign,
events::{
GlobalAccountDataEventType, push_rules::PushRulesEvent,
room::message::RoomMessageEventContent,
},
push,
};
use serde_json::value::RawValue;
use service::{mailer::messages, users::HashedPassword};
@@ -32,8 +25,6 @@ use service::{mailer::messages, users::HashedPassword};
use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH};
use crate::Ruma;
const RANDOM_USER_ID_LENGTH: usize = 10;
/// # `POST /_matrix/client/v3/register`
///
/// Register an account on this homeserver.
@@ -52,8 +43,6 @@ pub(crate) async fn register_route(
return Err!(Request(GuestAccessForbidden("Guests may not register on this server.")));
}
let emergency_mode_enabled = services.config.emergency_password.is_some();
// Allow registration if it's enabled in the config file or if this is the first
// run (so the first user account can be created)
let allow_registration =
@@ -71,101 +60,59 @@ pub(crate) async fn register_route(
)));
}
let identity = if body.appservice_info.is_some() {
// Appservices can skip auth
None
let user_id = if body.body.login_type == Some(LoginType::ApplicationService) {
let Some(appservice_info) = &body.appservice_info else {
return Err!(Request(Forbidden(
"Only appservices can use the appservice login type."
)));
};
let user_id = services
.users
.determine_registration_user_id(body.username.clone(), None, Some(appservice_info))
.await?;
services.users.create(&user_id, None).await?;
user_id
} else {
// Perform UIAA to determine the user's identity
let (flows, params) = create_registration_uiaa_session(&services).await?;
Some(
services
.uiaa
.authenticate(&body.auth, flows, params, None)
.await?,
)
};
// If the user didn't supply a username but did supply an email, use
// the email's user as their initial localpart to avoid falling back to
// a randomly generated localpart
let supplied_username = body.username.clone().or_else(|| {
if let Some(identity) = &identity
&& let Some(email) = &identity.email
{
Some(email.user().to_owned())
} else {
None
}
});
let user_id =
determine_registration_user_id(&services, supplied_username, emergency_mode_enabled)
let identity = services
.uiaa
.authenticate(&body.auth, flows, params, None)
.await?;
if body.body.login_type == Some(LoginType::ApplicationService) {
// For appservice logins, make sure that the user ID is in the appservice's
// namespace
let password = if let Some(password) = &body.password {
HashedPassword::new(password)?
} else {
return Err!(Request(InvalidParam("A password must be provided.")));
};
match body.appservice_info {
| Some(ref info) =>
if !info.is_user_match(&user_id) && !emergency_mode_enabled {
return Err!(Request(Exclusive(
"Username is not in an appservice namespace."
)));
},
| _ => {
return Err!(Request(MissingToken("Missing appservice token.")));
},
}
} else if services.appservice.is_exclusive_user_id(&user_id).await && !emergency_mode_enabled
{
// For non-appservice logins, ban user IDs which are in an appservice's
// namespace (unless emergency mode is enabled)
return Err!(Request(Exclusive("Username is reserved by an appservice.")));
}
let user_id = services
.users
.determine_registration_user_id(body.username.clone(), identity.email.as_ref(), None)
.await?;
let password = if body.appservice_info.is_some() {
None
} else if let Some(password) = body.password.as_deref() {
Some(HashedPassword::new(password)?)
} else {
return Err!(Request(InvalidParam("A password must be provided")));
services
.users
.create_local_account(&user_id, password, identity.email)
.await;
user_id
};
// Create user
services.users.create(&user_id, password).await?;
// Set an initial display name
let mut displayname = user_id.localpart().to_owned();
// Apply the new user displayname suffix, if it's set
if !services.globals.new_user_displayname_suffix().is_empty()
&& body.appservice_info.is_none()
{
write!(displayname, " {}", services.server.config.new_user_displayname_suffix)?;
}
services
.users
.set_displayname(&user_id, Some(displayname.clone()));
// Initial account data
services
.account_data
.update(
None,
&user_id,
GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(PushRulesEvent::new(
push::Ruleset::server_default(&user_id).into(),
))
.expect("should be able to serialize push rules"),
)
.await?;
// Generate new device id if the user didn't specify one
let (token, device) = if !body.inhibit_login {
// If UIAA is disabled, we can't create a device. In that case only appservices
// can reach this point in the first place, so we return an error for them.
if !services.config.oauth.compatibility_mode.uiaa_available() {
return Err!(Request(AppserviceLoginUnsupported(
"User-interactive appservice registration is not available on this server."
)));
}
// Generate new device id if the user didn't specify one
let device_id = body
.device_id
.clone()
@@ -181,6 +128,7 @@ pub(crate) async fn register_route(
&user_id,
&device_id,
&new_token,
None,
body.initial_device_display_name.clone(),
Some(client.to_string()),
)
@@ -191,118 +139,7 @@ pub(crate) async fn register_route(
(None, None)
};
debug_info!(%user_id, ?device, "User account was created");
// If the user registered with an email, associate it with their account.
if let Some(identity) = identity
&& let Some(email) = identity.email
{
// This may fail if the email is already in use, but we already check for that
// in `/requestToken`, so ignoring the error is acceptable here in the rare case
// that an email is sniped by another user between the `/requestToken` request
// and the `/register` request.
let _ = services
.threepid
.associate_localpart_email(user_id.localpart(), &email)
.await;
}
let device_display_name = body.initial_device_display_name.as_deref().unwrap_or("");
if body.appservice_info.is_none() {
if !device_display_name.is_empty() {
let notice = format!(
"New user \"{user_id}\" registered on this server from IP {client} and device \
display name \"{device_display_name}\""
);
info!("{notice}");
if services.server.config.admin_room_notices {
services.admin.notice(&notice).await;
}
} else {
let notice = format!("New user \"{user_id}\" registered on this server.");
info!("{notice}");
if services.server.config.admin_room_notices {
services.admin.notice(&notice).await;
}
}
}
// Make the first user to register an administrator and disable first-run mode.
let was_first_user = services.firstrun.empower_first_user(&user_id).await?;
// If the registering user was not the first and we're suspending users on
// register, suspend them.
if !was_first_user && services.config.suspend_on_register {
// Note that we can still do auto joins for suspended users
services
.users
.suspend_account(&user_id, &services.globals.server_user)
.await;
// And send an @room notice to the admin room, to prompt admins to review the
// new user and ideally unsuspend them if deemed appropriate.
if services.server.config.admin_room_notices {
services
.admin
.send_loud_message(RoomMessageEventContent::text_plain(format!(
"User {user_id} has been suspended as they are not the first user on this \
server. Please review and unsuspend them if appropriate."
)))
.await
.ok();
}
}
if body.appservice_info.is_none() && !services.server.config.auto_join_rooms.is_empty() {
for room in &services.server.config.auto_join_rooms {
let Ok(room_id) = services.rooms.alias.resolve(room).await else {
error!(
"Failed to resolve room alias to room ID when attempting to auto join \
{room}, skipping"
);
continue;
};
if !services
.rooms
.state_cache
.server_in_room(services.globals.server_name(), &room_id)
.await
{
warn!(
"Skipping room {room} to automatically join as we have never joined before."
);
continue;
}
if let Some(room_server_name) = room.server_name() {
match services
.rooms
.membership
.join_room(
&user_id,
&room_id,
Some("Automatically joining this room upon registration".to_owned()),
&[services.globals.server_name().to_owned(), room_server_name.to_owned()],
)
.boxed()
.await
{
| Err(e) => {
// don't return this error so we don't fail registrations
error!(
"Failed to automatically join room {room} for user {user_id}: {e}"
);
},
| _ => {
info!("Automatically joined room {room} for user {user_id}");
},
}
}
}
}
debug_info!(%user_id, ?device, "New account created via legacy registration");
Ok(assign!(register::v3::Response::new(user_id), {
access_token: token,
@@ -374,21 +211,21 @@ async fn create_registration_uiaa_session(
// Require all users to agree to the terms and conditions, if configured
let terms = &services.config.registration_terms;
if !terms.is_empty() {
let mut terms =
serde_json::to_value(terms.clone()).expect("failed to serialize terms");
if !terms.documents.is_empty() {
let mut terms_map = HashMap::new();
// Insert a dummy `version` field
for (_, documents) in terms.as_object_mut().unwrap() {
let documents = documents.as_object_mut().unwrap();
documents.insert("version".to_owned(), "latest".into());
for (id, document) in &terms.documents {
terms_map.insert(id.to_owned(), serde_json::json!({
terms.language.clone(): serde_json::to_value(document).expect("should be able to serialize document")
}));
}
terms_map.insert("version".to_owned(), "latest".into());
params.insert(
AuthType::Terms.as_str().to_owned(),
serde_json::json!({
"policies": terms,
"policies": terms_map,
}),
);
@@ -421,81 +258,6 @@ async fn create_registration_uiaa_session(
Ok((flows, params))
}
async fn determine_registration_user_id(
services: &Services,
supplied_username: Option<String>,
emergency_mode_enabled: bool,
) -> Result<OwnedUserId> {
if let Some(supplied_username) = supplied_username {
// The user gets to pick their username. Do some validation to make sure it's
// acceptable.
// Don't allow registration with forbidden usernames.
if services
.globals
.forbidden_usernames()
.is_match(&supplied_username)
&& !emergency_mode_enabled
{
return Err!(Request(Forbidden("Username is forbidden")));
}
// Create and validate the user ID
let user_id = match UserId::parse_with_server_name(
&supplied_username,
services.globals.server_name(),
) {
| Ok(user_id) => {
if let Err(e) = user_id.validate_strict() {
// Unless we are in emergency mode, we should follow synapse's behaviour on
// not allowing things like spaces and UTF-8 characters in usernames
if !emergency_mode_enabled {
return Err!(Request(InvalidUsername(debug_warn!(
"Username {supplied_username} contains disallowed characters or \
spaces: {e}"
))));
}
}
// Don't allow registration with user IDs that aren't local
if !services.globals.user_is_local(&user_id) {
return Err!(Request(InvalidUsername(
"Username {supplied_username} is not local to this server"
)));
}
user_id
},
| Err(e) => {
return Err!(Request(InvalidUsername(debug_warn!(
"Username {supplied_username} is not valid: {e}"
))));
},
};
if services.users.exists(&user_id).await {
return Err!(Request(UserInUse("User ID is not available.")));
}
Ok(user_id)
} else {
// The user didn't specify a username. Generate a username for
// them.
loop {
let user_id = UserId::parse_with_server_name(
utils::random_string(RANDOM_USER_ID_LENGTH).to_lowercase(),
services.globals.server_name(),
)
.unwrap();
if !services.users.exists(&user_id).await {
break Ok(user_id);
}
}
}
}
/// # `POST /_matrix/client/v3/register/email/requestToken`
///
/// Requests a validation email for the purpose of registering a new account.
+5 -4
View File
@@ -11,7 +11,7 @@ use ruma::{
},
thirdparty::{Medium, ThirdPartyIdentifierInit},
};
use service::{mailer::messages, uiaa::Identity};
use service::mailer::messages;
use crate::Ruma;
@@ -116,14 +116,15 @@ pub(crate) async fn add_3pid_route(
// Require password auth to add an email
let _ = services
.uiaa
.authenticate_password(&body.auth, Some(Identity::from_user_id(sender_user)))
.authenticate_password(&body.auth, sender_user, body.sender_device(), None)
.await?;
let email = services
.threepid
.consume_valid_session(&body.sid, &body.client_secret)
.get_valid_session(&body.sid, &body.client_secret)
.await
.map_err(|message| err!(Request(ThreepidAuthFailed("{message}"))))?;
.map_err(|message| err!(Request(ThreepidAuthFailed("{message}"))))?
.consume();
services
.threepid
+3 -3
View File
@@ -8,7 +8,6 @@ use ruma::{
self, delete_device, delete_devices, get_device, get_devices, update_device,
},
};
use service::uiaa::Identity;
use crate::{Ruma, client::DEVICE_ID_LENGTH};
@@ -95,6 +94,7 @@ pub(crate) async fn update_device_route(
&device_id,
&appservice.registration.as_token,
None,
None,
Some(client.to_string()),
)
.await?;
@@ -126,7 +126,7 @@ pub(crate) async fn delete_device_route(
// Prompt the user to confirm with their password using UIAA
let _ = services
.uiaa
.authenticate_password(&body.auth, Some(Identity::from_user_id(sender_user)))
.authenticate_password(&body.auth, sender_user, body.sender_device(), None)
.await?;
}
@@ -162,7 +162,7 @@ pub(crate) async fn delete_devices_route(
// Prompt the user to confirm with their password using UIAA
let _ = services
.uiaa
.authenticate_password(&body.auth, Some(Identity::from_user_id(sender_user)))
.authenticate_password(&body.auth, sender_user, body.sender_device(), None)
.await?;
}
+7 -2
View File
@@ -26,7 +26,7 @@ use ruma::{
serde::Raw,
};
use serde_json::json;
use service::uiaa::Identity;
use service::oauth::OAuthTicket;
use crate::Ruma;
@@ -204,7 +204,12 @@ pub(crate) async fn upload_signing_keys_route(
{
let _ = services
.uiaa
.authenticate_password(&body.auth, Some(Identity::from_user_id(sender_user)))
.authenticate_password(
&body.auth,
sender_user,
body.sender_device(),
Some(OAuthTicket::CrossSigningReset),
)
.await?;
}
+3
View File
@@ -16,6 +16,7 @@ pub(super) mod media_legacy;
pub(super) mod membership;
pub(super) mod message;
pub(super) mod mutual_rooms;
pub(super) mod oauth;
pub(super) mod openid;
pub(super) mod presence;
pub(super) mod profile;
@@ -61,6 +62,7 @@ pub(super) use membership::*;
pub use membership::{leave_all_rooms, leave_room, remote_leave_room};
pub(super) use message::*;
pub(super) use mutual_rooms::*;
pub(super) use oauth::*;
pub(super) use openid::*;
pub(super) use presence::*;
pub(super) use profile::*;
@@ -73,6 +75,7 @@ pub(super) use report::*;
pub(super) use room::*;
pub(super) use search::*;
pub(super) use send::*;
pub use session::handle_login;
pub(super) use session::*;
pub(super) use space::*;
pub(super) use state::*;
+56
View File
@@ -0,0 +1,56 @@
use axum::{
Json, Router,
extract::{Request, State},
middleware::{self, Next},
response::{IntoResponse, Response},
routing::method_routing::{get, post},
};
use const_str::concat;
use http::StatusCode;
use serde_json::json;
pub(crate) use server_metadata::*;
mod register_client;
mod server_metadata;
mod token;
const BASE_PATH: &str = concat!(conduwuit_core::ROUTE_PREFIX, "/oauth2/");
const AUTH_CODE_PATH: &str = "grant/authorization_code";
const JWKS_URI_PATH: &str = "client/keys.json";
const CLIENT_REGISTER_PATH: &str = "client/register";
const TOKEN_REVOKE_PATH: &str = "client/revoke";
const TOKEN_PATH: &str = "grant/token";
const ACCOUNT_MANAGEMENT_PATH: &str = concat!(conduwuit_core::ROUTE_PREFIX, "/account/deeplink");
pub(crate) fn router(state: crate::State) -> Router<crate::State> {
Router::new()
.nest(BASE_PATH, oauth_router())
.route(
"/.well-known/openid-configuration",
get(
// TODO(unspecced): used by old versions of the matrix-js-sdk
async |State(services): State<crate::State>| {
Json(authorization_server_metadata(&services).await)
},
),
)
.layer(middleware::from_fn_with_state(
state,
async |State(state): State<crate::State>, request: Request, next: Next| -> Response {
if state.config.oauth.compatibility_mode.oauth_available() {
next.run(request).await
} else {
(StatusCode::NOT_FOUND, "OAuth is unavailable on this server").into_response()
}
},
))
}
fn oauth_router() -> Router<crate::State> {
Router::new()
.route(concat!("/", CLIENT_REGISTER_PATH), post(register_client::register_client_route))
// TODO(unspecced): used by old versions of the matrix-js-sdk
.route(concat!("/", JWKS_URI_PATH), get(async || Json(json!({"keys": []}))))
.route(concat!("/", TOKEN_PATH), post(token::token_route))
.route(concat!("/", TOKEN_REVOKE_PATH), post(token::revoke_token_route))
}
+28
View File
@@ -0,0 +1,28 @@
use axum::{
Json,
extract::State,
response::{IntoResponse, Response},
};
use http::StatusCode;
use serde::Serialize;
use service::oauth::client_metadata::ClientMetadata;
#[derive(Serialize)]
struct RegisteredClient {
client_id: String,
#[serde(flatten)]
metadata: ClientMetadata,
}
pub(crate) async fn register_client_route(
State(services): State<crate::State>,
Json(metadata): Json<ClientMetadata>,
) -> Result<Response, Response> {
let client_id = services
.oauth
.register_client(&metadata)
.await
.map_err(|err| (StatusCode::BAD_REQUEST, err.to_owned()).into_response())?;
Ok(Json(RegisteredClient { client_id, metadata }).into_response())
}
+62
View File
@@ -0,0 +1,62 @@
use axum::extract::State;
use conduwuit::{Err, Result};
use ruma::{
api::client::discovery::get_authorization_server_metadata::{
self, v1::AccountManagementAction,
},
serde::Raw,
};
use serde_json::{Value, json};
use service::Services;
use crate::{
Ruma,
client::oauth::{
ACCOUNT_MANAGEMENT_PATH, AUTH_CODE_PATH, CLIENT_REGISTER_PATH, JWKS_URI_PATH, TOKEN_PATH,
TOKEN_REVOKE_PATH,
},
};
pub(crate) async fn get_authorization_server_metadata_route(
State(services): State<crate::State>,
_body: Ruma<get_authorization_server_metadata::v1::Request>,
) -> Result<get_authorization_server_metadata::v1::Response> {
if !services.config.oauth.compatibility_mode.oauth_available() {
return Err!(Request(Unrecognized("OAuth is unavailable on this server")));
}
let metadata = Raw::new(&authorization_server_metadata(&services).await).unwrap();
Ok(get_authorization_server_metadata::v1::Response::new(metadata.cast_unchecked()))
}
pub(crate) async fn authorization_server_metadata(services: &Services) -> Value {
let endpoint_base = services
.config
.get_client_domain()
.join(super::BASE_PATH)
.unwrap();
json!({
"account_management_uri": endpoint_base.join(ACCOUNT_MANAGEMENT_PATH).unwrap(),
"account_management_actions_supported": [
AccountManagementAction::AccountDeactivate,
AccountManagementAction::CrossSigningReset,
AccountManagementAction::DeviceDelete,
AccountManagementAction::DeviceView,
AccountManagementAction::DevicesList,
AccountManagementAction::Profile,
],
"authorization_endpoint": endpoint_base.join(AUTH_CODE_PATH).unwrap(),
"code_challenge_methods_supported": ["S256"],
"grant_types_supported": ["authorization_code", "refresh_token"],
"issuer": services.config.get_client_domain(),
"jwks_uri": endpoint_base.join(JWKS_URI_PATH).unwrap(),
"prompt_values_supported": ["create"],
"registration_endpoint": endpoint_base.join(CLIENT_REGISTER_PATH).unwrap(),
"response_modes_supported": ["query", "fragment"],
"response_types_supported": ["code"],
"revocation_endpoint": endpoint_base.join(TOKEN_REVOKE_PATH).unwrap(),
"token_endpoint": endpoint_base.join(TOKEN_PATH).unwrap(),
})
}
+23
View File
@@ -0,0 +1,23 @@
use axum::{Form, Json, extract::State, response::IntoResponse};
use http::StatusCode;
use service::oauth::grant::{RevokeTokenRequest, TokenRequest};
pub(crate) async fn token_route(
State(services): State<crate::State>,
Form(request): Form<TokenRequest>,
) -> impl IntoResponse {
match services.oauth.issue_token(request).await {
| Ok(response) => Ok(Json(response)),
| Err(err) => Err((StatusCode::BAD_REQUEST, err.message())),
}
}
pub(crate) async fn revoke_token_route(
State(services): State<crate::State>,
Form(request): Form<RevokeTokenRequest>,
) -> impl IntoResponse {
match services.oauth.revoke_token(request.token).await {
| Ok(()) => Ok(StatusCode::OK),
| Err(err) => Err((StatusCode::BAD_REQUEST, err.message())),
}
}
+31 -20
View File
@@ -21,7 +21,7 @@ use ruma::{
},
login::{
self,
v3::{DiscoveryInfo, HomeserverInfo},
v3::{DiscoveryInfo, HomeserverInfo, LoginInfo},
},
logout, logout_all,
},
@@ -29,7 +29,6 @@ use ruma::{
},
assign,
};
use service::uiaa::Identity;
use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH};
use crate::Ruma;
@@ -44,6 +43,12 @@ pub(crate) async fn get_login_types_route(
ClientIp(client): ClientIp,
_body: Ruma<get_login_types::v3::Request>,
) -> Result<get_login_types::v3::Response> {
if !services.config.oauth.compatibility_mode.uiaa_available() {
return Err!(Request(Unrecognized(
"User-interactive authentication is not available on this server."
)));
}
Ok(get_login_types::v3::Response::new(vec![
get_login_types::v3::LoginType::Password(PasswordLoginType::default()),
get_login_types::v3::LoginType::ApplicationService(ApplicationServiceLoginType::default()),
@@ -53,7 +58,7 @@ pub(crate) async fn get_login_types_route(
]))
}
pub(crate) async fn handle_login(
pub async fn handle_login(
services: &Services,
identifier: Option<&UserIdentifier>,
password: &str,
@@ -87,10 +92,6 @@ pub(crate) async fn handle_login(
return Err!(Request(InvalidParam("User ID does not belong to this homeserver")));
}
if services.users.is_locked(&user_id).await? {
return Err!(Request(UserLocked("This account has been locked.")));
}
if services.users.is_login_disabled(&user_id).await {
warn!(%user_id, "user attempted to log in with a login-disabled account");
return Err!(Request(Forbidden("This account is not permitted to log in.")));
@@ -119,19 +120,29 @@ pub(crate) async fn login_route(
ClientIp(client): ClientIp,
body: Ruma<login::v3::Request>,
) -> Result<login::v3::Response> {
if !services.config.oauth.compatibility_mode.uiaa_available() {
return match body.login_info {
| LoginInfo::ApplicationService(_) => {
Err!(Request(AppserviceLoginUnsupported(
"User-interactive appservice login is not available on this server."
)))
},
| _ => {
Err!(Request(Unrecognized(
"User-interactive authentication is not available on this server."
)))
},
};
}
let emergency_mode_enabled = services.config.emergency_password.is_some();
// Validate login method
// TODO: Other login methods
let user_id = match &body.login_info {
#[allow(deprecated)]
| login::v3::LoginInfo::Password(login::v3::Password {
identifier,
password,
user,
..
}) => handle_login(&services, identifier.as_ref(), password, user.as_ref()).await?,
| login::v3::LoginInfo::Token(login::v3::Token { token, .. }) => {
| LoginInfo::Password(login::v3::Password { identifier, password, user, .. }) =>
handle_login(&services, identifier.as_ref(), password, user.as_ref()).await?,
| LoginInfo::Token(login::v3::Token { token, .. }) => {
debug!("Got token login type");
if !services.server.config.login_via_existing_session {
return Err!(Request(Unknown("Token login is not enabled.")));
@@ -139,7 +150,7 @@ pub(crate) async fn login_route(
services.users.find_from_login_token(token).await?
},
#[allow(deprecated)]
| login::v3::LoginInfo::ApplicationService(login::v3::ApplicationService {
| LoginInfo::ApplicationService(login::v3::ApplicationService {
identifier,
user,
..
@@ -173,7 +184,6 @@ pub(crate) async fn login_route(
user_id
},
| _ => {
debug!("/login json_body: {:?}", &body.json_body);
return Err!(Request(Unknown(
debug_warn!(?body.login_info, "Invalid or unsupported login type")
)));
@@ -203,7 +213,7 @@ pub(crate) async fn login_route(
if device_exists {
services
.users
.set_token(&user_id, &device_id, &token)
.set_token(&user_id, &device_id, &token, None)
.await?;
} else {
services
@@ -212,6 +222,7 @@ pub(crate) async fn login_route(
&user_id,
&device_id,
&token,
None,
body.initial_device_display_name.clone(),
Some(client.to_string()),
)
@@ -250,7 +261,7 @@ pub(crate) async fn login_token_route(
ClientIp(client): ClientIp,
body: Ruma<get_login_token::v1::Request>,
) -> Result<get_login_token::v1::Response> {
if !services.server.config.login_via_existing_session {
if !services.config.login_via_existing_session {
return Err!(Request(Forbidden("Login via an existing session is not enabled")));
}
@@ -259,7 +270,7 @@ pub(crate) async fn login_token_route(
// Prompt the user to confirm with their password using UIAA
let _ = services
.uiaa
.authenticate_password(&body.auth, Some(Identity::from_user_id(sender_user)))
.authenticate_password(&body.auth, sender_user, body.sender_device(), None)
.await?;
let login_token = utils::random_string(TOKEN_LENGTH);
+7
View File
@@ -48,6 +48,13 @@ async fn load_timeline(
ending_count: Option<PduCount>,
limit: usize,
) -> Result<TimelinePdus> {
if let (Some(starting_count), Some(ending_count)) = (starting_count, ending_count) {
debug_assert!(
starting_count <= ending_count,
"starting count {starting_count} > ending count {ending_count}"
);
}
let mut pdu_stream = match starting_count {
| Some(starting_count) => {
let last_timeline_count = services
+76 -73
View File
@@ -38,6 +38,7 @@ use ruma::{
uint,
};
use service::{account_data::AnyRawAccountDataEvent, rooms::short::ShortStateHash};
use tokio::pin;
use super::{load_timeline, share_encrypted_room};
use crate::client::{
@@ -96,12 +97,19 @@ pub(super) async fn load_joined_room(
);
}
let state_events =
StateEvents::with_events(state_events.into_iter().map(Event::into_format).collect());
let joined_room = assign!(JoinedRoom::new(), {
account_data,
summary: summary.unwrap_or_default(),
unread_notifications: notification_counts.unwrap_or_default(),
timeline,
state: RoomState::Before(StateEvents::with_events(state_events.into_iter().map(Event::into_format).collect())),
state: if sync_context.use_state_after {
RoomState::After(state_events)
} else {
RoomState::Before(state_events)
},
ephemeral,
unread_thread_notifications: BTreeMap::new(),
});
@@ -344,7 +352,7 @@ struct ShortStateHashes {
#[tracing::instrument(level = "debug", skip_all)]
async fn fetch_shortstatehashes(
services: &Services,
SyncContext { last_sync_end_count, current_count, .. }: SyncContext<'_>,
SyncContext { last_sync_end_count, .. }: SyncContext<'_>,
room_id: &RoomId,
) -> Result<ShortStateHashes> {
// the room state currently.
@@ -354,46 +362,41 @@ async fn fetch_shortstatehashes(
.rooms
.state
.get_room_shortstatehash(room_id)
.map_err(|_| err!(Database(error!("Room {room_id} has no state"))));
.map_err(|_| err!(Database(error!("Room {room_id} has no state"))))
.await?;
// the room state as of the end of the last sync.
// this will be None if we are doing an initial sync or if we just joined this
// room.
// The room state as of the end of the last sync.
// This will be None if we are doing an initial sync.
let last_sync_end_shortstatehash =
OptionFuture::from(last_sync_end_count.map(|last_sync_end_count| {
// look up the shortstatehash saved by the last sync's call to
// `associate_token_shortstatehash`
services
.rooms
.user
.get_token_shortstatehash(room_id, last_sync_end_count)
.inspect_err(move |_| {
debug_warn!(
token = last_sync_end_count,
"Room has no shortstatehash for this token"
);
})
.ok()
OptionFuture::from(last_sync_end_count.map(async |last_sync_end_count| {
pin! {
let pdus = services
.rooms
.timeline
.pdus(room_id, Some(PduCount::Normal(last_sync_end_count)))
.ignore_err();
}
match pdus.next().await {
| Some((_, pdu_after_last_sync_end)) => {
trace!(?pdu_after_last_sync_end.event_id, "pdu at last sync end");
services
.rooms
.state_accessor
.pdu_shortstatehash(&pdu_after_last_sync_end.event_id)
.await
.map_err(|err| err!("Last sync end PDU has no shortstatehash: {err}"))
},
| None => {
// No events have been sent since the last sync, or we just joined this room,
// so the state then is the same as the state now
Ok(current_shortstatehash)
},
}
}))
.map(Option::flatten)
.map(Ok);
let (current_shortstatehash, last_sync_end_shortstatehash) =
try_join(current_shortstatehash, last_sync_end_shortstatehash).await?;
/*
associate the `current_count` with the `current_shortstatehash`, so we can
use it on the next sync as the `last_sync_end_shortstatehash`.
TODO: the table written to by this call grows extremely fast, gaining one new entry for each
joined room on _every single sync request_. we need to find a better way to remember the shortstatehash
between syncs.
*/
services
.rooms
.user
.associate_token_shortstatehash(room_id, current_count, current_shortstatehash)
.await;
.await
.transpose()?;
Ok(ShortStateHashes {
current_shortstatehash,
@@ -452,6 +455,7 @@ async fn build_state_events(
syncing_user,
last_sync_end_count,
full_state,
use_state_after,
..
} = sync_context;
@@ -460,32 +464,28 @@ async fn build_state_events(
last_sync_end_shortstatehash,
} = shortstatehashes;
// the spec states that the `state` property only includes state events up to
// the beginning of the timeline, so we determine the state of the syncing room
// as of the first timeline event. NOTE: this explanation is not entirely
// accurate; see the implementation of `build_state_incremental`.
let timeline_start_shortstatehash = async {
if let Some((_, pdu)) = timeline.pdus.front() {
if let Ok(shortstatehash) = services
let timeline_start_shortstatehash = if let Some((count, pdu)) = timeline.pdus.front() {
if matches!(count, PduCount::Backfilled(_)) {
// We don't have shortstatehashes for backfilled PDUs, the best we can
// do is to use the current state
current_shortstatehash
} else {
services
.rooms
.state_accessor
.pdu_shortstatehash(&pdu.event_id)
.await
{
return shortstatehash;
}
.map_err(|err| err!("Timeline start has no shortstatehash: {err}"))?
}
current_shortstatehash
} else {
// if the timeline is empty there can't possibly be any changes to the state
return Ok(vec![]);
};
// the user IDs of members whose membership needs to be sent to the client, if
// lazy-loading is enabled.
let lazily_loaded_members =
prepare_lazily_loaded_members(services, sync_context, room_id, timeline.senders());
let (timeline_start_shortstatehash, lazily_loaded_members) =
join(timeline_start_shortstatehash, lazily_loaded_members).await;
prepare_lazily_loaded_members(services, sync_context, room_id, timeline.senders()).await;
// compute the state delta between the previous sync and this sync.
match (last_sync_end_count, last_sync_end_shortstatehash) {
@@ -494,16 +494,15 @@ async fn build_state_events(
is Some (meaning the syncing user didn't just join this room for the first time ever), and `full_state` is false,
then use `build_state_incremental`.
*/
| (Some(last_sync_end_count), Some(last_sync_end_shortstatehash)) if !full_state =>
| (Some(_), Some(last_sync_end_shortstatehash)) if !full_state =>
build_state_incremental(
services,
syncing_user,
room_id,
PduCount::Normal(last_sync_end_count),
last_sync_end_shortstatehash,
timeline_start_shortstatehash,
current_shortstatehash,
timeline,
use_state_after,
lazily_loaded_members.as_ref(),
)
.boxed()
@@ -518,6 +517,8 @@ async fn build_state_events(
services,
syncing_user,
timeline_start_shortstatehash,
current_shortstatehash,
use_state_after,
lazily_loaded_members.as_ref(),
)
.boxed()
@@ -598,23 +599,25 @@ async fn check_joined_since_last_sync(
ShortStateHashes { last_sync_end_shortstatehash, .. }: ShortStateHashes,
SyncContext { syncing_user, .. }: SyncContext<'_>,
) -> Result<bool> {
// fetch the syncing user's membership event during the last sync.
// this will be None if `previous_sync_end_shortstatehash` is None.
let membership_during_previous_sync = match last_sync_end_shortstatehash {
| Some(last_sync_end_shortstatehash) => services
.rooms
.state_accessor
.state_get_content(
last_sync_end_shortstatehash,
&StateEventType::RoomMember,
syncing_user.as_str(),
)
.await
.inspect_err(|_| debug_warn!("User has no previous membership"))
.ok(),
| None => None,
let Some(last_sync_end_shortstatehash) = last_sync_end_shortstatehash else {
// For initial syncs always return false, since there's no "last sync" for the
// user to have joined since.
return Ok(false);
};
// Fetch the syncing user's membership event during the last sync.
let membership_during_previous_sync = services
.rooms
.state_accessor
.state_get_content(
last_sync_end_shortstatehash,
&StateEventType::RoomMember,
syncing_user.as_str(),
)
.await
.inspect_err(|_| debug_warn!("User has no previous membership"))
.ok();
// TODO: If the requesting user got state-reset out of the room, this
// will be `true` when it shouldn't be. this function should never be called
// in that situation, but it may be if the membership cache didn't get updated.
+10 -1
View File
@@ -181,6 +181,9 @@ pub(super) async fn load_left_room(
.collect::<Vec<_>>()
.await;
let state_events =
StateEvents::with_events(state_events.into_iter().map(Event::into_format).collect());
Ok(Some(assign!(LeftRoom::new(), {
account_data: RoomAccountData::new(),
timeline: assign!(Timeline::new(), {
@@ -188,7 +191,11 @@ pub(super) async fn load_left_room(
prev_batch: Some(current_count.to_string()),
events: raw_timeline_pdus,
}),
state: State::Before(StateEvents::with_events(state_events.into_iter().map(Event::into_format).collect())),
state: if sync_context.use_state_after {
State::After(state_events)
} else {
State::Before(state_events)
},
})))
}
@@ -264,6 +271,8 @@ async fn build_left_state_and_timeline(
services,
syncing_user,
timeline_start_shortstatehash,
leave_shortstatehash,
sync_context.use_state_after,
lazily_loaded_members.as_ref(),
)
.await?;
+7 -4
View File
@@ -11,12 +11,11 @@ use std::{
use axum::extract::State;
use axum_client_ip::ClientIp;
use conduwuit::{
Err, Result, at, extract_variant,
Err, Result, at, error, extract_variant,
utils::{
ReadyExt, TryFutureExtExt,
stream::{BroadbandExt, Tools, WidebandExt},
},
warn,
};
use conduwuit_service::Services;
use futures::{FutureExt, StreamExt, TryFutureExt, future::OptionFuture};
@@ -110,6 +109,9 @@ struct SyncContext<'a> {
/// The sync filter, which the client uses to specify what data should be
/// included in the sync response.
filter: &'a FilterDefinition,
/// Whether the state at the end of the timeline should be used when
/// calculating state diffs for sync.
use_state_after: bool,
}
impl<'a> SyncContext<'a> {
@@ -261,6 +263,7 @@ pub(crate) async fn build_sync_events(
current_count,
full_state,
filter: &filter,
use_state_after: body.use_state_after,
};
let joined_rooms = services
@@ -273,7 +276,7 @@ pub(crate) async fn build_sync_events(
match joined_room {
| Ok((room, updates)) => Some((room_id, room, updates)),
| Err(err) => {
warn!(?err, %room_id, "error loading joined room");
error!(?err, %room_id, "error loading joined room");
None
},
}
@@ -302,7 +305,7 @@ pub(crate) async fn build_sync_events(
| Ok(Some(left_room)) => Some((room_id, left_room)),
| Ok(None) => None,
| Err(err) => {
warn!(?err, %room_id, "error loading joined room");
error!(?err, %room_id, "error loading joined room");
None
},
}
+60 -143
View File
@@ -1,11 +1,8 @@
use std::{collections::BTreeSet, ops::ControlFlow};
use std::collections::HashSet;
use conduwuit::{
Result, at, is_equal_to,
matrix::{
Event,
pdu::{PduCount, PduEvent},
},
Result, at,
matrix::{Event, pdu::PduEvent},
utils::{
BoolExt, IterStream, ReadyExt, TryFutureExtExt,
stream::{BroadbandExt, TryIgnore},
@@ -16,9 +13,7 @@ use conduwuit_service::{
rooms::{lazy_loading::MemberSet, short::ShortStateHash},
};
use futures::{FutureExt, StreamExt};
use itertools::Itertools;
use ruma::{OwnedEventId, RoomId, UserId, events::StateEventType};
use service::rooms::short::ShortEventId;
use ruma::{OwnedEventId, UserId, events::StateEventType};
use tracing::trace;
use crate::client::TimelinePdus;
@@ -39,13 +34,19 @@ pub(super) async fn build_state_initial(
services: &Services,
sender_user: &UserId,
timeline_start_shortstatehash: ShortStateHash,
timeline_end_shortstatehash: ShortStateHash,
use_state_after: bool,
lazily_loaded_members: Option<&MemberSet>,
) -> Result<Vec<PduEvent>> {
// load the keys and event IDs of the state events at the start of the timeline
let (shortstatekeys, event_ids): (Vec<_>, Vec<_>) = services
.rooms
.state_accessor
.state_full_ids(timeline_start_shortstatehash)
.state_full_ids(if use_state_after {
timeline_end_shortstatehash
} else {
timeline_start_shortstatehash
})
.unzip()
.await;
@@ -92,82 +93,34 @@ pub(super) async fn build_state_initial(
pub(super) async fn build_state_incremental<'a>(
services: &Services,
sender_user: &'a UserId,
room_id: &RoomId,
last_sync_end_count: PduCount,
last_sync_end_shortstatehash: ShortStateHash,
timeline_start_shortstatehash: ShortStateHash,
timeline_end_shortstatehash: ShortStateHash,
timeline: &TimelinePdus,
use_state_after: bool,
lazily_loaded_members: Option<&'a MemberSet>,
) -> Result<Vec<PduEvent>> {
/*
NB: a limited sync is one where `timeline.limited == true`. Synapse calls this a "gappy" sync internally.
let mut state_event_ids: HashSet<OwnedEventId> = HashSet::new();
The algorithm implemented in this function is, currently, quite different from the algorithm vaguely described
by the Matrix specification. This is because the specification's description of the `state` property does not accurately
reflect how Synapse behaves, and therefore how client SDKs behave. Notable differences include:
1. We do not compute the delta using the naive approach of "every state event from the end of the last sync
up to the start of this sync's timeline". see below for details.
2. If lazy-loading is enabled, we include lazily-loaded membership events. The specific users to include are determined
elsewhere and supplied to this function in the `lazily_loaded_members` parameter.
*/
trace!(
%use_state_after,
%last_sync_end_shortstatehash,
%timeline_start_shortstatehash,
%timeline_end_shortstatehash,
"computing state for incremental sync"
);
/*
the `state` property of an incremental sync which isn't limited are _usually_ empty.
(note: the specification says that the `state` property is _always_ empty for limited syncs, which is incorrect.)
however, if an event in the timeline (`timeline.pdus`) merges a split in the room's DAG (i.e. has multiple `prev_events`),
the state at the _end_ of the timeline may include state events which were merged in and don't exist in the state
at the _start_ of the timeline. because this is uncommon, we check here to see if any events in the timeline
merged a split in the DAG.
// Fetch lazy-loaded membership events if lazy-loading is enabled
if let Some(lazily_loaded_members) = lazily_loaded_members
&& !lazily_loaded_members.is_empty()
{
trace!("including lazy membership events for members: {:?}", lazily_loaded_members);
see: https://github.com/element-hq/synapse/issues/16941
*/
let timeline_is_linear = timeline.pdus.is_empty() || {
let last_pdu_of_last_sync = services
services
.rooms
.timeline
.pdus_rev(room_id, Some(last_sync_end_count.saturating_add(1)))
.boxed()
.next()
.await
.transpose()
.expect("last sync should have had some PDUs")
.map(at!(1));
// make sure the prev_events of each pdu in the timeline refer only to the
// previous pdu
timeline
.pdus
.iter()
.try_fold(last_pdu_of_last_sync.map(|pdu| pdu.event_id), |prev_event_id, (_, pdu)| {
if let Ok(pdu_prev_event_id) = pdu.prev_events.iter().exactly_one() {
if prev_event_id
.as_ref()
.is_none_or(is_equal_to!(pdu_prev_event_id))
{
return ControlFlow::Continue(Some(pdu_prev_event_id.to_owned()));
}
}
trace!(
"pdu {:?} has split prev_events (expected {:?}): {:?}",
pdu.event_id, prev_event_id, pdu.prev_events
);
ControlFlow::Break(())
})
.is_continue()
};
if timeline_is_linear && !timeline.limited {
// if there are no splits in the DAG and the timeline isn't limited, then
// `state` will always be empty unless lazy loading is enabled.
if let Some(lazily_loaded_members) = lazily_loaded_members {
if !timeline.pdus.is_empty() {
// lazy loading is enabled, so we return the membership events which were
// requested by the caller.
let lazy_membership_events: Vec<_> = lazily_loaded_members
.short
.multi_get_eventid_from_short::<'_, OwnedEventId, _>(
lazily_loaded_members
.iter()
.stream()
.broad_filter_map(|user_id| async move {
@@ -178,71 +131,24 @@ pub(super) async fn build_state_incremental<'a>(
services
.rooms
.state_accessor
.state_get(
.state_get_shortid(
timeline_start_shortstatehash,
&StateEventType::RoomMember,
user_id.as_str(),
)
.ok()
.await
})
.collect()
.await;
if !lazy_membership_events.is_empty() {
trace!(
"syncing lazy membership events for members: {:?}",
lazy_membership_events
.iter()
.map(|pdu| pdu.state_key().unwrap())
.collect::<Vec<_>>()
);
}
return Ok(lazy_membership_events);
}
}
// lazy loading is disabled, `state` is empty.
return Ok(vec![]);
}),
)
.ignore_err()
.ready_for_each(|event_id| {
state_event_ids.insert(event_id);
})
.await;
}
/*
at this point, either the timeline is `limited` or the DAG has a split in it. this necessitates
computing the incremental state (which may be empty).
NOTE: this code path does not use the `lazy_membership_events` parameter. any changes to membership will be included
in the incremental state. therefore, the incremental state may include "redundant" membership events,
which we do not filter out because A. the spec forbids lazy-load filtering if the timeline is `limited`,
and B. DAG splits which require sending extra membership state events are (probably) uncommon enough that
the performance penalty is acceptable.
*/
trace!(%timeline_is_linear, %timeline.limited, "computing state for incremental sync");
// fetch the shorteventids of state events in the timeline
let state_events_in_timeline: BTreeSet<ShortEventId> = services
.rooms
.short
.multi_get_or_create_shorteventid(timeline.pdus.iter().filter_map(|(_, pdu)| {
if pdu.state_key().is_some() {
Some(pdu.event_id.as_ref())
} else {
None
}
}))
.collect()
.await;
trace!("{} state events in timeline", state_events_in_timeline.len());
/*
fetch the state events which were added since the last sync.
specifically we fetch the difference between the state at the last sync and the state at the _end_
of the timeline, and then we filter out state events in the timeline itself using the shorteventids we fetched.
this is necessary to account for splits in the DAG, as explained above.
*/
let state_diff = services
// Fetch the state events added since the last sync.
services
.rooms
.short
.multi_get_eventid_from_short::<'_, OwnedEventId, _>(
@@ -252,18 +158,29 @@ pub(super) async fn build_state_incremental<'a>(
.state_added((last_sync_end_shortstatehash, timeline_end_shortstatehash))
.await?
.stream()
.ready_filter_map(|(_, shorteventid)| {
if state_events_in_timeline.contains(&shorteventid) {
None
} else {
Some(shorteventid)
}
}),
.map(at!(1)),
)
.ignore_err();
.ignore_err()
.ready_for_each(|event_id| {
state_event_ids.insert(event_id);
})
.await;
// finally, fetch the PDU contents and collect them into a vec
let state_diff_pdus = state_diff
if !use_state_after {
// If state_after isn't enabled, filter out state events which also exist
// in the timeline. If splits exist in the DAG, this may not be exactly the same
// thing as the state diff ending at the start of the timeline, but Synapse
// also does this and it's technically more useful behavior anyway.
// See: https://github.com/element-hq/synapse/issues/16941
for (_, pdu) in &timeline.pdus {
state_event_ids.remove(pdu.event_id());
}
}
// Finally, fetch the PDU contents and collect them into a vec
let state_diff_pdus = state_event_ids
.stream()
.broad_filter_map(|event_id| async move {
services
.rooms
+23 -8
View File
@@ -15,7 +15,7 @@ use conduwuit::{
BoolExt, FutureBoolExt, IterStream, ReadyExt, TryFutureExtExt,
future::ReadyEqExt,
math::{ruma_from_usize, usize_from_ruma},
stream::WidebandExt,
stream::{TryIgnore, WidebandExt},
},
warn,
};
@@ -41,6 +41,7 @@ use ruma::{
uint,
};
use service::account_data::AnyRawAccountDataEvent;
use tokio::pin;
use super::share_encrypted_room;
use crate::{
@@ -69,7 +70,6 @@ pub(crate) async fn sync_events_v5_route(
ClientIp(client_ip): ClientIp,
body: Ruma<sync_events::v5::Request>,
) -> Result<sync_events::v5::Response> {
debug_assert!(DEFAULT_BUMP_TYPES.is_sorted(), "DEFAULT_BUMP_TYPES is not sorted");
let ref sender_user = body.sender_user().to_owned();
let ref sender_device = body.sender_device().to_owned();
@@ -858,12 +858,27 @@ where
continue;
};
let since_shortstatehash = services
.rooms
.user
.get_token_shortstatehash(room_id, globalsince)
.await
.ok();
let since_shortstatehash = async {
pin! {
let pdus_rev = services
.rooms
.timeline
.pdus_rev(room_id, Some(PduCount::Normal(globalsince.saturating_sub(1))))
.ignore_err();
}
let (_, pdu_at_last_sync_end) = pdus_rev.next().await?;
Some(
services
.rooms
.state_accessor
.pdu_shortstatehash(&pdu_at_last_sync_end.event_id)
.await
.expect("pdu should have a shortstatehash"),
)
}
.await;
let encrypted_room = services
.rooms
+2 -2
View File
@@ -35,8 +35,8 @@ pub(crate) async fn get_supported_versions_route(
/// `/_matrix/federation/v1/version`
pub(crate) async fn conduwuit_server_version() -> Result<impl IntoResponse> {
Ok(Json(serde_json::json!({
"name": conduwuit::version::name(),
"version": conduwuit::version::version(),
"name": conduwuit::BRANDING,
"version": conduwuit::version(),
})))
}
+2 -42
View File
@@ -3,8 +3,7 @@ use conduwuit::{Err, Result};
use ruma::{
api::client::discovery::{
discover_homeserver::{self, HomeserverInfo},
discover_policy_server,
discover_support::{self, Contact, ContactRole},
discover_policy_server, discover_support,
},
assign,
};
@@ -67,46 +66,7 @@ pub(crate) async fn well_known_support(
.as_ref()
.map(ToString::to_string);
let email_address = services.config.well_known.support_email.clone();
let matrix_id = services.config.well_known.support_mxid.clone();
let pgp_key = services.config.well_known.support_pgp_key.clone();
// TODO: support defining multiple contacts in the config
let mut contacts: Vec<Contact> = vec![];
let role = services
.config
.well_known
.support_role
.clone()
.unwrap_or(ContactRole::Admin);
// Add configured contact if at least one contact method is specified
let configured_contact = match (matrix_id, email_address) {
| (Some(matrix_id), email_address) =>
Some(assign!(Contact::with_matrix_id(role, matrix_id), { email_address })),
| (None, Some(email_address)) => Some(Contact::with_email_address(role, email_address)),
| (None, None) => None,
};
if let Some(mut configured_contact) = configured_contact {
configured_contact.pgp_key = pgp_key;
contacts.push(configured_contact);
}
// Try to add admin users as contacts if no contacts are configured
if contacts.is_empty() {
let admin_users = services.admin.get_admins().await;
for user_id in &admin_users {
if *user_id == services.globals.server_user {
continue;
}
contacts.push(Contact::with_matrix_id(ContactRole::Admin, user_id.to_owned()));
}
}
let contacts = services.admin.get_support_contacts().await;
if contacts.is_empty() && support_page.is_none() {
// No admin room, no configured contacts, and no support page
+1
View File
@@ -1,4 +1,5 @@
#![type_length_limit = "16384"] //TODO: reduce me
#![recursion_limit = "256"] // My Giant Async Function
#![allow(clippy::toplevel_ref_arg)]
extern crate conduwuit_core as conduwuit;
+5 -3
View File
@@ -10,7 +10,7 @@ use axum::{
response::{IntoResponse, Redirect},
routing::{any, get, post},
};
use conduwuit::{Server, err};
use conduwuit::err;
pub(super) use conduwuit_service::state::State;
use http::{Uri, uri};
@@ -18,8 +18,8 @@ use self::handler::RouterExt;
pub(super) use self::{args::Args as Ruma, response::RumaResponse};
use crate::{admin, client, server};
pub fn build(router: Router<State>, server: &Server) -> Router<State> {
let config = &server.config;
pub fn build(router: Router<State>, state: State) -> Router<State> {
let config = &state.server.config;
let mut router = router
.ruma_route(&client::appservice_ping)
.ruma_route(&client::get_supported_versions_route)
@@ -186,6 +186,8 @@ pub fn build(router: Router<State>, server: &Server) -> Router<State> {
.ruma_route(&client::well_known_policy_server)
.ruma_route(&client::get_rtc_transports)
.ruma_route(&client::room_initial_sync_route)
.ruma_route(&client::get_authorization_server_metadata_route)
.merge(client::oauth::router(state))
.route("/_conduwuit/server_version", get(client::conduwuit_server_version))
.route("/_continuwuity/server_version", get(client::conduwuit_server_version))
.ruma_route(&admin::rooms::ban::ban_room)
+23 -6
View File
@@ -1,6 +1,7 @@
use std::any::{Any, TypeId};
use conduwuit::{Err, Result, err};
use conduwuit::{Err, Error, Result, err};
use http::StatusCode;
use ruma::{
OwnedDeviceId, OwnedServerName, OwnedUserId, UserId,
api::{
@@ -10,12 +11,15 @@ use ruma::{
AuthScheme, NoAccessToken, NoAuthentication,
},
client,
error::{ErrorKind, UnknownTokenErrorData},
federation::authentication::ServerSignatures,
},
assign,
};
use service::{
Services,
server_keys::{PubKeyMap, PubKeys},
users::AccessTokenStatus,
};
use crate::{router::args::AuthQueryParams, service::appservice::RegistrationInfo};
@@ -115,12 +119,21 @@ impl CheckAuth for AccessToken {
query: AuthQueryParams,
route: TypeId,
) -> Result<Auth> {
// Check for appservice tokens first
let (sender_user, sender_device, appservice_info) = {
if let Ok((sender_user, sender_device)) =
if let Some((sender_user, sender_device, status)) =
services.users.find_from_token(&output).await
{
// If the token is expired we return a soft logout
if matches!(status, AccessTokenStatus::Expired) {
return Err(Error::Request(
ErrorKind::UnknownToken(
assign!(UnknownTokenErrorData::new(), { soft_logout: true }),
),
"This token has expired".into(),
StatusCode::UNAUTHORIZED,
));
}
// Locked users can only use /logout and /logout/all
if services
.users
@@ -131,7 +144,7 @@ impl CheckAuth for AccessToken {
if !(route == TypeId::of::<client::session::logout::v3::Request>()
|| route == TypeId::of::<client::session::logout_all::v3::Request>())
{
return Err!(Request(Unauthorized("Your account is locked.")));
return Err!(Request(UserLocked("Your account is locked.")));
}
}
@@ -179,7 +192,11 @@ impl CheckAuth for AccessToken {
(Some(sender_user), sender_device, Some(appservice_info))
} else {
return Err!(Request(Unauthorized("Invalid access token.")));
return Err(Error::Request(
ErrorKind::UnknownToken(UnknownTokenErrorData::new()),
"Invalid token".into(),
StatusCode::UNAUTHORIZED,
));
}
};
+2 -2
View File
@@ -11,8 +11,8 @@ pub(crate) async fn get_server_version_route(
) -> Result<get_server_version::v1::Response> {
Ok(assign!(get_server_version::v1::Response::new(), {
server: Some(assign!(get_server_version::v1::Server::new(), {
name: Some(conduwuit::version::name().into()),
version: Some(conduwuit::version::version().into()),
name: Some(conduwuit::BRANDING.into()),
version: Some(conduwuit::version().into()),
})),
}))
}
+117 -37
View File
@@ -4,7 +4,7 @@ pub mod manager;
pub mod proxy;
use std::{
collections::{BTreeMap, BTreeSet, HashMap},
collections::{BTreeMap, BTreeSet},
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
path::PathBuf,
};
@@ -655,19 +655,25 @@ pub struct Config {
/// even if `recaptcha_site_key` is set.
pub recaptcha_private_site_key: Option<String>,
/// Policy documents, such as terms and conditions or a privacy policy,
/// which users must agree to when registering an account.
///
/// Example:
/// ```ignore
/// [global.registration_terms.privacy_policy]
/// en = { name = "Privacy Policy", url = "https://homeserver.example/en/privacy_policy.html" }
/// es = { name = "Política de Privacidad", url = "https://homeserver.example/es/privacy_policy.html" }
/// ```
///
/// default: {}
/// display: nested
#[serde(default)]
pub registration_terms: HashMap<String, HashMap<String, TermsDocument>>,
pub registration_terms: RegistrationTerms,
/// display: nested
#[serde(default)]
pub oauth: OauthConfig,
/// Controls whether users are allowed to deactivate their own accounts
/// through the account management panel or their Matrix clients. Server
/// admins can always deactivate users using the relevant admin commands.
///
/// Note that, in some jurisdictions, you may be legally required to honor
/// users who request to deactivate their accounts if you set this option
/// to `false`.
///
/// default: true
#[serde(default = "true_fn")]
pub allow_deactivation: bool,
/// Controls whether encrypted rooms and events are allowed.
#[serde(default = "true_fn")]
@@ -2071,12 +2077,10 @@ pub struct Config {
pub stream_amplification: usize,
/// Number of sender task workers; determines sender parallelism. Default is
/// '0' which means the value is determined internally, likely matching the
/// number of tokio worker-threads or number of cores, etc. Override by
/// setting a non-zero value.
/// core count. Override by setting a different value.
///
/// default: 0
#[serde(default)]
/// default: core count
#[serde(default = "default_sender_workers")]
pub sender_workers: usize,
/// Enables listener sockets; can be set to false to disable listening. This
@@ -2351,6 +2355,30 @@ pub struct SmtpConfig {
pub require_email_for_token_registration: bool,
}
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
#[config_example_generator(
filename = "conduwuit-example.toml",
section = "global.registration_terms",
optional = "true"
)]
pub struct RegistrationTerms {
/// The language code to provide to clients along with the policy documents.
///
/// default: "en"
pub language: String,
/// Policy documents, such as terms and conditions or a privacy policy,
/// which users must agree to when registering an account.
///
/// Example:
/// ```ignore
/// [global.registration_terms.documents]
/// privacy_policy = { name = "Privacy Policy", url = "https://homeserver.example/en/privacy_policy.html" }
/// ```
///
/// default: {}
pub documents: BTreeMap<String, TermsDocument>,
}
/// A policy document for use with a m.login.terms stage.
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct TermsDocument {
@@ -2358,6 +2386,43 @@ pub struct TermsDocument {
pub url: String,
}
#[derive(Clone, Debug, Default, Deserialize)]
#[config_example_generator(
filename = "conduwuit-example.toml",
section = "global.oauth",
optional = "true"
)]
pub struct OauthConfig {
/// The compatibility mode to use for OAuth.
///
/// - "disabled": OAuth will be unavailable. Users will only be able to log
/// in using legacy authentication.
/// - "hybrid": OAuth and legacy authentication will both be available. Some
/// clients may only use one or the other.
/// - "exclusive": Only OAuth will be available. Clients which require
/// legacy authentication will be unable to log in.
///
/// default: "hybrid"
pub compatibility_mode: OAuthMode,
}
#[derive(Clone, Debug, Default, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum OAuthMode {
Disabled,
#[default]
Hybrid,
Exclusive,
}
impl OAuthMode {
#[must_use]
pub fn uiaa_available(&self) -> bool { matches!(self, Self::Disabled | Self::Hybrid) }
#[must_use]
pub fn oauth_available(&self) -> bool { matches!(self, Self::Hybrid | Self::Exclusive) }
}
const DEPRECATED_KEYS: &[&str] = &[
"cache_capacity",
"conduit_cache_capacity_modifier",
@@ -2455,45 +2520,47 @@ fn default_database_backups_to_keep() -> i16 { 1 }
fn default_db_write_buffer_capacity_mb() -> f64 { 48.0 + parallelism_scaled_f64(4.0) }
fn default_db_cache_capacity_mb() -> f64 { 128.0 + parallelism_scaled_f64(64.0) }
fn default_db_cache_capacity_mb() -> f64 { 512.0 + parallelism_scaled_f64(512.0) }
fn default_pdu_cache_capacity() -> u32 { parallelism_scaled_u32(10_000).saturating_add(100_000) }
fn default_pdu_cache_capacity() -> u32 { parallelism_scaled_u32(50_000).saturating_add(100_000) }
fn default_cache_capacity_modifier() -> f64 { 1.0 }
fn default_auth_chain_cache_capacity() -> u32 {
parallelism_scaled_u32(10_000).saturating_add(100_000)
}
fn default_shorteventid_cache_capacity() -> u32 {
parallelism_scaled_u32(50_000).saturating_add(100_000)
}
fn default_shorteventid_cache_capacity() -> u32 {
parallelism_scaled_u32(100_000).saturating_add(100_000)
}
fn default_eventidshort_cache_capacity() -> u32 {
parallelism_scaled_u32(25_000).saturating_add(100_000)
parallelism_scaled_u32(50_000).saturating_add(100_000)
}
fn default_eventid_pdu_cache_capacity() -> u32 {
parallelism_scaled_u32(25_000).saturating_add(100_000)
parallelism_scaled_u32(50_000).saturating_add(100_000)
}
fn default_shortstatekey_cache_capacity() -> u32 {
parallelism_scaled_u32(10_000).saturating_add(100_000)
parallelism_scaled_u32(100_000).saturating_add(100_000)
}
fn default_statekeyshort_cache_capacity() -> u32 {
parallelism_scaled_u32(10_000).saturating_add(100_000)
parallelism_scaled_u32(50_000).saturating_add(100_000)
}
fn default_servernameevent_data_cache_capacity() -> u32 {
parallelism_scaled_u32(100_000).saturating_add(500_000)
parallelism_scaled_u32(100_000).saturating_add(100_000)
}
fn default_stateinfo_cache_capacity() -> u32 { parallelism_scaled_u32(100) }
fn default_stateinfo_cache_capacity() -> u32 { parallelism_scaled_u32(500).clamp(100, 12000) }
fn default_roomid_spacehierarchy_cache_capacity() -> u32 { parallelism_scaled_u32(1000) }
fn default_roomid_spacehierarchy_cache_capacity() -> u32 {
parallelism_scaled_u32(500).clamp(100, 12000)
}
fn default_dns_cache_entries() -> u32 { 32768 }
fn default_dns_cache_entries() -> u32 { 327_680 }
fn default_dns_min_ttl() -> u64 { 60 * 180 }
@@ -2701,15 +2768,26 @@ fn default_admin_log_capture() -> String {
fn default_admin_room_tag() -> String { "m.server_notice".to_owned() }
#[must_use]
#[allow(clippy::as_conversions, clippy::cast_precision_loss)]
fn parallelism_scaled_f64(val: f64) -> f64 { val * (sys::available_parallelism() as f64) }
pub fn parallelism_scaled_f64(val: f64) -> f64 { val * (sys::available_parallelism() as f64) }
fn parallelism_scaled_u32(val: u32) -> u32 {
let val = val.try_into().expect("failed to cast u32 to usize");
parallelism_scaled(val).try_into().unwrap_or(u32::MAX)
#[must_use]
#[allow(clippy::as_conversions, clippy::cast_possible_truncation)]
pub fn parallelism_scaled_u32(val: u32) -> u32 {
val.saturating_mul(sys::available_parallelism() as u32)
}
fn parallelism_scaled(val: usize) -> usize { val.saturating_mul(sys::available_parallelism()) }
#[must_use]
#[allow(clippy::as_conversions, clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
pub fn parallelism_scaled_i32(val: i32) -> i32 {
val.saturating_mul(sys::available_parallelism() as i32)
}
#[must_use]
pub fn parallelism_scaled(val: usize) -> usize {
val.saturating_mul(sys::available_parallelism())
}
fn default_trusted_server_batch_size() -> usize { 256 }
@@ -2729,6 +2807,8 @@ fn default_stream_width_scale() -> f32 { 1.0 }
fn default_stream_amplification() -> usize { 1024 }
fn default_sender_workers() -> usize { parallelism_scaled(1) }
fn default_client_receive_timeout() -> u64 { 75 }
fn default_client_request_timeout() -> u64 { 180 }
+1
View File
@@ -161,6 +161,7 @@ impl Error {
match self {
| Self::Federation(origin, error) => format!("Answer from {origin}: {error}"),
| Self::Ruma(error) => response::ruma_error_message(error),
| Self::Request(_, message, _) => message.clone().into_owned(),
| _ => format!("{self}"),
}
}
+1 -4
View File
@@ -73,11 +73,8 @@ pub(super) fn bad_request_code(kind: &ErrorKind) -> StatusCode {
// 413
| TooLarge => StatusCode::PAYLOAD_TOO_LARGE,
// 405
| Unrecognized => StatusCode::METHOD_NOT_ALLOWED,
// 404
| NotFound => StatusCode::NOT_FOUND,
| Unrecognized | NotFound => StatusCode::NOT_FOUND,
// 403
| GuestAccessForbidden
+6 -9
View File
@@ -7,19 +7,16 @@
use std::sync::OnceLock;
static BRANDING: &str = "continuwuity";
static WEBSITE: &str = "https://continuwuity.org";
static SEMANTIC: &str = env!("CARGO_PKG_VERSION");
pub const BRANDING: &str = "continuwuity";
pub const ROUTE_PREFIX: &str = "/_continuwuity";
pub const WEBSITE: &str = "https://continuwuity.org";
pub const SEMANTIC: &str = env!("CARGO_PKG_VERSION");
static VERSION: OnceLock<String> = OnceLock::new();
static VERSION_UA: OnceLock<String> = OnceLock::new();
static USER_AGENT: OnceLock<String> = OnceLock::new();
static USER_AGENT_MEDIA: OnceLock<String> = OnceLock::new();
#[inline]
#[must_use]
pub fn name() -> &'static str { BRANDING }
#[inline]
pub fn version() -> &'static str { VERSION.get_or_init(init_version) }
@@ -32,10 +29,10 @@ pub fn user_agent() -> &'static str { USER_AGENT.get_or_init(init_user_agent) }
#[inline]
pub fn user_agent_media() -> &'static str { USER_AGENT_MEDIA.get_or_init(init_user_agent_media) }
fn init_user_agent() -> String { format!("{}/{} (bot; +{WEBSITE})", name(), version_ua()) }
fn init_user_agent() -> String { format!("{BRANDING}/{} (bot; +{WEBSITE})", version_ua()) }
fn init_user_agent_media() -> String {
format!("{}/{} (embedbot; facebookexternalhit/1.1; +{WEBSITE})", name(), version_ua())
format!("{BRANDING}/{} (embedbot; facebookexternalhit/1.1; +{WEBSITE})", version_ua())
}
fn init_version_ua() -> String {
+1
View File
@@ -21,6 +21,7 @@ pub fn versions() -> Vec<String> {
"v1.12".to_owned(),
"v1.13".to_owned(),
"v1.14".to_owned(),
"v1.15".to_owned(),
]
}
+1 -4
View File
@@ -34,10 +34,7 @@ pub use ::tracing;
pub use conduwuit_build_metadata as build_metadata;
pub use config::Config;
pub use error::Error;
pub use info::{
version,
version::{name, version},
};
pub use info::version::*;
pub use matrix::{Event, EventTypeExt, Pdu, PduCount, PduEvent, PduId, pdu, state_res};
pub use parking_lot::{Mutex as SyncMutex, RwLock as SyncRwLock};
pub use server::Server;
+3
View File
@@ -5,6 +5,7 @@ pub type DigestOut = [u8; 256 / 8];
/// Sha256 hash (input gather joined by 0xFF bytes)
#[must_use]
#[tracing::instrument(skip(inputs), level = "trace")]
#[allow(clippy::unnecessary_fallible_conversions)]
pub fn delimited<'a, T, I>(mut inputs: I) -> DigestOut
where
I: Iterator<Item = T> + 'a,
@@ -25,6 +26,7 @@ where
/// Sha256 hash (input gather)
#[must_use]
#[tracing::instrument(skip(inputs), level = "trace")]
#[allow(clippy::unnecessary_fallible_conversions)]
pub fn concat<'a, T, I>(inputs: I) -> DigestOut
where
I: Iterator<Item = T> + 'a,
@@ -43,6 +45,7 @@ where
#[inline]
#[must_use]
#[tracing::instrument(skip(input), level = "trace")]
#[allow(clippy::unnecessary_fallible_conversions)]
pub fn hash<T>(input: T) -> DigestOut
where
T: AsRef<[u8]>,
+16 -10
View File
@@ -61,17 +61,23 @@ pub fn format(ts: SystemTime, str: &str) -> String {
pub fn pretty(d: Duration) -> String {
use Unit::*;
let fmt = |w, f, u| format!("{w}.{f} {u}");
let gen64 = |w, f, u| fmt(w, (f * 100.0) as u32, u);
let gen128 = |w, f, u| gen64(u64::try_from(w).expect("u128 to u64"), f, u);
let fmt = |w, u| {
if w == 1 {
format!("{w} {u}")
} else {
format!("{w} {u}s")
}
};
let gen64 = |w, u| fmt(w, u);
let gen128 = |w, u| gen64(u64::try_from(w).expect("u128 to u64"), u);
match whole_and_frac(d) {
| (Days(whole), frac) => gen64(whole, frac, "days"),
| (Hours(whole), frac) => gen64(whole, frac, "hours"),
| (Mins(whole), frac) => gen64(whole, frac, "minutes"),
| (Secs(whole), frac) => gen64(whole, frac, "seconds"),
| (Millis(whole), frac) => gen128(whole, frac, "milliseconds"),
| (Micros(whole), frac) => gen128(whole, frac, "microseconds"),
| (Nanos(whole), frac) => gen128(whole, frac, "nanoseconds"),
| (Days(whole), _) => gen64(whole, "day"),
| (Hours(whole), _) => gen64(whole, "hour"),
| (Mins(whole), _) => gen64(whole, "minute"),
| (Secs(whole), _) => gen64(whole, "second"),
| (Millis(whole), _) => gen128(whole, "millisecond"),
| (Micros(whole), _) => gen128(whole, "microsecond"),
| (Nanos(whole), _) => gen128(whole, "nanosecond"),
}
}
+1 -1
View File
@@ -29,7 +29,7 @@ fn descriptor_cf_options(
set_table_options(&mut opts, &desc, cache)?;
opts.set_min_write_buffer_number(1);
opts.set_max_write_buffer_number(2);
opts.set_max_write_buffer_number(3);
opts.set_write_buffer_size(desc.write_size);
opts.set_target_file_size_base(desc.file_size);
+21 -6
View File
@@ -49,6 +49,10 @@ pub(super) static MAPS: &[Descriptor] = &[
name: "bannedroomids",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "clientid_clientmetadata",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "disabledroomids",
..descriptor::RANDOM_SMALL
@@ -157,6 +161,10 @@ pub(super) static MAPS: &[Descriptor] = &[
name: "referencedevents",
..descriptor::RANDOM
},
Descriptor {
name: "refreshtoken_refreshtokeninfo",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "registrationtoken_info",
..descriptor::RANDOM_SMALL
@@ -193,12 +201,7 @@ pub(super) static MAPS: &[Descriptor] = &[
},
Descriptor {
name: "roomsynctoken_shortstatehash",
file_shape: 3,
val_size_hint: Some(8),
block_size: 512,
compression_level: 3,
bottommost_level: Some(6),
..descriptor::SEQUENTIAL
..descriptor::DROPPED
},
Descriptor {
name: "roomuserdataid_accountdata",
@@ -371,6 +374,14 @@ pub(super) static MAPS: &[Descriptor] = &[
name: "userdevicetxnid_response",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "userdeviceid_oauthsessioninfo",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "userdeviceid_tokenexpires",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "userfilterid_filter",
..descriptor::RANDOM_SMALL
@@ -475,4 +486,8 @@ pub(super) static MAPS: &[Descriptor] = &[
name: "userroomid_invitesender",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "websessionid_session",
..descriptor::RANDOM_SMALL
},
];
+1 -1
View File
@@ -15,7 +15,7 @@ use conduwuit_core::{
#[clap(
about,
long_about = None,
name = conduwuit_core::name(),
name = conduwuit_core::BRANDING,
version = conduwuit_core::version(),
)]
pub struct Args {
+1 -1
View File
@@ -110,7 +110,7 @@ pub(crate) fn init(
.with_batch_exporter(exporter)
.build();
let tracer = provider.tracer(conduwuit_core::name());
let tracer = provider.tracer(conduwuit_core::BRANDING);
let telemetry = tracing_opentelemetry::layer().with_tracer(tracer);
+1 -1
View File
@@ -47,7 +47,7 @@ fn options(config: &Config) -> ClientOptions {
traces_sample_rate: config.sentry_traces_sample_rate,
debug: cfg!(debug_assertions),
release: release_name(),
user_agent: conduwuit_core::version::user_agent().into(),
user_agent: conduwuit_core::user_agent().into(),
attach_stacktrace: config.sentry_attach_stacktrace,
before_send: Some(Arc::new(before_send)),
before_breadcrumb: Some(Arc::new(before_breadcrumb)),
+7 -5
View File
@@ -8,7 +8,7 @@ use axum::{
extract::State,
response::{IntoResponse, Response},
};
use conduwuit::{Result, debug, debug_error, debug_warn, err, error, trace};
use conduwuit::{Result, debug_warn, err, error, info, trace};
use conduwuit_service::Services;
use futures::FutureExt;
use http::{Method, StatusCode, Uri};
@@ -102,17 +102,19 @@ fn handle_result(method: &Method, uri: &Uri, result: Response) -> Result<Respons
let reason = status.canonical_reason().unwrap_or("Unknown Reason");
if status.is_server_error() {
error!(%method, %uri, "{code} {reason}");
info!(%method, %uri, "{code} {reason}");
} else if status.is_client_error() {
debug_error!(%method, %uri, "{code} {reason}");
info!(%method, %uri, "{code} {reason}");
} else if status.is_redirection() {
debug!(%method, %uri, "{code} {reason}");
trace!(%method, %uri, "{code} {reason}");
} else {
trace!(%method, %uri, "{code} {reason}");
}
if status == StatusCode::METHOD_NOT_ALLOWED {
return Ok(err!(Request(Unrecognized("Method Not Allowed"))).into_response());
return Ok(
err!(Request(Unrecognized("Method not allowed"), METHOD_NOT_ALLOWED)).into_response()
);
}
Ok(result)
+2 -2
View File
@@ -9,8 +9,8 @@ use ruma::api::error::ErrorKind;
pub(crate) fn build(services: &Arc<Services>) -> (Router, Guard) {
let router = Router::<state::State>::new();
let (state, guard) = state::create(services.clone());
let router = conduwuit_api::router::build(router, &services.server)
.merge(conduwuit_web::build())
let router = conduwuit_api::router::build(router, state)
.merge(conduwuit_web::build(services))
.fallback(not_found)
.with_state(state);
+1
View File
@@ -119,6 +119,7 @@ recaptcha-verify = { version = "0.2.0", default-features = false }
reqwest_recaptcha = { package = "reqwest", version = "0.12.28", default-features = false, features = ["rustls-tls-native-roots-no-provider"] } # As long as recaptcha-verify's reqwest is outdated
yansi.workspace = true
lettre.workspace = true
serde_urlencoded.workspace = true
[target.'cfg(all(unix, target_os = "linux"))'.dependencies]
sd-notify.workspace = true
+53 -1
View File
@@ -18,6 +18,8 @@ use futures::{Future, FutureExt, StreamExt, TryFutureExt};
use loole::{Receiver, Sender};
use ruma::{
OwnedEventId, OwnedMxcUri, OwnedRoomId, OwnedUserId, RoomId, UInt, UserId,
api::client::discovery::discover_support::{Contact, ContactRole},
assign,
events::{
Mentions,
room::message::{
@@ -28,7 +30,7 @@ use ruma::{
use tokio::sync::RwLock;
use crate::{
Dep, account_data, globals,
Dep, account_data, config, globals,
media::{MXC_LENGTH, mxc::Mxc},
rooms::{self, state::RoomMutexGuard},
};
@@ -44,6 +46,7 @@ pub struct Service {
struct Services {
server: Arc<Server>,
config: Dep<config::Service>,
globals: Dep<globals::Service>,
alias: Dep<rooms::alias::Service>,
timeline: Dep<rooms::timeline::Service>,
@@ -115,6 +118,7 @@ impl crate::Service for Service {
Ok(Arc::new(Self {
services: Services {
server: args.server.clone(),
config: args.depend::<config::Service>("config"),
globals: args.depend::<globals::Service>("globals"),
alias: args.depend::<rooms::alias::Service>("rooms::alias"),
timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"),
@@ -619,4 +623,52 @@ impl Service {
let weak = services.map(Arc::downgrade);
*receiver = weak;
}
/// Get the server's configured support contacts.
pub async fn get_support_contacts(&self) -> Vec<Contact> {
let email_address = self.services.config.well_known.support_email.clone();
let matrix_id = self.services.config.well_known.support_mxid.clone();
let pgp_key = self.services.config.well_known.support_pgp_key.clone();
// TODO: support defining multiple contacts in the config
let mut contacts: Vec<Contact> = vec![];
let role = self
.services
.config
.well_known
.support_role
.clone()
.unwrap_or(ContactRole::Admin);
// Add configured contact if at least one contact method is specified
let configured_contact = match (matrix_id, email_address) {
| (Some(matrix_id), email_address) =>
Some(assign!(Contact::with_matrix_id(role, matrix_id), { email_address })),
| (None, Some(email_address)) =>
Some(Contact::with_email_address(role, email_address)),
| (None, None) => None,
};
if let Some(mut configured_contact) = configured_contact {
configured_contact.pgp_key = pgp_key;
contacts.push(configured_contact);
}
// Try to add admin users as contacts if no contacts are configured
if contacts.is_empty() {
let admin_users = self.get_admins().await;
for user_id in &admin_users {
if *user_id == self.services.globals.server_user {
continue;
}
contacts.push(Contact::with_matrix_id(ContactRole::Admin, user_id.to_owned()));
}
}
contacts
}
}
+2 -2
View File
@@ -67,7 +67,7 @@ impl crate::Service for Service {
for (id, registration) in appservices {
// During startup, resolve any token collisions in favour of appservices
// by logging out conflicting user devices
if let Ok((user_id, device_id)) = self
if let Some((user_id, device_id, _)) = self
.services
.users
.find_from_token(&registration.as_token)
@@ -158,7 +158,7 @@ impl Service {
.users
.find_from_token(&registration.as_token)
.await
.is_ok()
.is_some()
{
return Err(err!(Request(InvalidParam(
"Cannot register appservice: The provided token is already in use by a user \
+2 -2
View File
@@ -39,7 +39,7 @@ impl crate::Service for Service {
let url_preview_user_agent = config
.url_preview_user_agent
.clone()
.unwrap_or_else(|| conduwuit::version::user_agent_media().to_owned());
.unwrap_or_else(|| conduwuit::user_agent_media().to_owned());
Ok(Arc::new(Self {
default: base(config)?
@@ -149,7 +149,7 @@ fn base(config: &Config) -> Result<reqwest::ClientBuilder> {
.timeout(Duration::from_secs(config.request_total_timeout))
.pool_idle_timeout(Duration::from_secs(config.request_idle_timeout))
.pool_max_idle_per_host(config.request_idle_per_host.into())
.user_agent(conduwuit::version::user_agent())
.user_agent(conduwuit::user_agent())
.redirect(redirect::Policy::limited(6))
.danger_accept_invalid_certs(config.allow_invalid_tls_certificates_yes_i_know_what_the_fuck_i_am_doing_with_this_and_i_know_this_is_insecure)
.connection_verbose(cfg!(debug_assertions));
+12 -7
View File
@@ -6,7 +6,7 @@ use std::{
use askama::Template;
use async_trait::async_trait;
use conduwuit::{Result, info, utils::ReadyExt};
use futures::{FutureExt, StreamExt};
use futures::StreamExt;
use ruma::{UserId, events::room::message::RoomMessageEventContent};
use crate::{
@@ -120,7 +120,7 @@ impl Service {
///
/// Returns Ok(true) if the specified user was the first user, and Ok(false)
/// if they were not.
pub async fn empower_first_user(&self, user: &UserId) -> Result<bool> {
pub async fn empower_first_user(&self, user: &UserId) -> bool {
#[derive(Template)]
#[template(path = "welcome.md")]
struct WelcomeMessage<'a> {
@@ -130,10 +130,14 @@ impl Service {
// If first run mode isn't active, do nothing.
if !self.disable_first_run() {
return Ok(false);
return false;
}
self.services.admin.make_user_admin(user).boxed().await?;
self.services
.admin
.make_user_admin(user)
.await
.expect("should have been able to empower the first user");
// Send the welcome message
let welcome_message = WelcomeMessage {
@@ -146,11 +150,12 @@ impl Service {
self.services
.admin
.send_loud_message(RoomMessageEventContent::text_markdown(welcome_message))
.await?;
.await
.expect("should have been able to send welcome message");
info!("{user} has been invited to the admin room as the first user.");
Ok(true)
true
}
/// Get the single-use registration token which may be used to create the
@@ -181,7 +186,7 @@ impl Service {
eprintln!(
"Welcome to {} {}!",
"Continuwuity".bold().bright_magenta(),
conduwuit::version::version().bold()
conduwuit::version().bold()
);
eprintln!();
eprintln!(
+4 -2
View File
@@ -92,8 +92,8 @@ impl Mailer<'_> {
let message = MessageBuilder::new()
.from(self.sender.clone())
.to(recipient)
.subject(subject)
.to(recipient.clone())
.subject(subject.clone())
.date_now()
.header(ContentType::TEXT_PLAIN)
.body(body)
@@ -104,6 +104,8 @@ impl Mailer<'_> {
.await
.map_err(|err: TransportError| err!("Failed to send message: {err}"))?;
info!(recipient = recipient.to_string(), ?subject, "Email sent");
Ok(())
}
}
+1 -1
View File
@@ -27,7 +27,7 @@ pub mod key_backups;
pub mod mailer;
pub mod media;
pub mod moderation;
pub mod password_reset;
pub mod oauth;
pub mod presence;
pub mod pusher;
pub mod registration_tokens;
+196
View File
@@ -0,0 +1,196 @@
use std::{collections::BTreeSet, hash::Hash};
use itertools::Itertools;
use serde::{Deserialize, Deserializer, Serialize};
use url::Url;
#[derive(Clone, Debug, PartialEq, Eq, Deserialize, Serialize)]
#[non_exhaustive]
pub struct ClientMetadata {
#[serde(default)]
pub application_type: ApplicationType,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub client_name: Option<String>,
pub client_uri: Url,
#[serde(default, deserialize_with = "btreeset_skip_err")]
pub grant_types: BTreeSet<GrantType>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub logo_uri: Option<Url>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub policy_uri: Option<Url>,
#[serde(default)]
pub redirect_uris: Vec<Url>,
#[serde(default, deserialize_with = "btreeset_skip_err")]
pub response_types: BTreeSet<ResponseType>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub token_endpoint_auth_method: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tos_uri: Option<Url>,
}
impl ClientMetadata {
pub(super) const ACCEPTABLE_LOCALHOSTS: [&str; 3] = ["localhost", "127.0.0.1", "[::1]"];
pub(super) fn validate(&self) -> Result<(), &'static str> {
let Some(client_domain) = self.client_uri.domain() else {
return Err("Client URI must have a domain.");
};
if self.client_uri.scheme() != "https" {
return Err("Client URI must be HTTPS.");
}
if !self.client_uri.username().is_empty() || self.client_uri.password().is_some() {
return Err("Client URI must not include credentials.");
}
for uri in [&self.logo_uri, &self.policy_uri, &self.tos_uri]
.iter()
.filter_map(|uri| uri.as_ref())
{
if uri.scheme() != "https" {
return Err("All metadata URIs must be HTTPS.");
}
if !uri.username().is_empty() || uri.password().is_some() {
return Err("All metadata URIs must not include credentials.");
}
if !uri
.domain()
.is_some_and(|domain| is_subdomain(domain, client_domain))
{
return Err("All metadata URIs must be subdomains of the client URI.");
}
}
for uri in &self.redirect_uris {
match uri.scheme() {
| "https" => {
// HTTPS URIs are okay for native and web clients
if !uri.username().is_empty() || uri.password().is_some() {
return Err("HTTPS redirect URIs must not contain credentials.");
}
},
| "http" if self.application_type == ApplicationType::Native => {
if uri
.host_str()
.is_none_or(|host| !Self::ACCEPTABLE_LOCALHOSTS.contains(&host))
{
return Err("HTTP redirect URIs for native applications must only \
refer to localhost.");
}
if uri.port().is_some() {
return Err("HTTP redirect URIs for native applications do not need to \
specify a port. All ports will be accepted during \
authorization.");
}
},
| private_scheme if self.application_type == ApplicationType::Native => {
let rdns_client_uri = client_domain.split('.').rev().join(".");
if !private_scheme.starts_with(&rdns_client_uri) {
return Err("Private-use scheme URIs for native applications must \
begin with the application's client URI domain in \
reverse-DNS notation.");
}
if uri.has_authority() {
return Err("Private-use scheme URIs for native applications must not \
have an authority.");
}
},
| _ =>
return Err("A redirect URI's scheme is not valid for this application type."),
}
}
Ok(())
}
}
#[derive(Clone, Debug, Default, PartialEq, Eq, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum ApplicationType {
#[default]
Web,
Native,
}
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum GrantType {
AuthorizationCode,
RefreshToken,
}
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum ResponseType {
Code,
}
/// Deserialize a BTreeSet from a sequence, skipping items which fail to
/// deserialize. This is used as a deserialize helper for ClientMetadata to
/// ignore unknown enum variants in a few fields.
fn btreeset_skip_err<'de, D, V>(de: D) -> Result<BTreeSet<V>, D::Error>
where
D: Deserializer<'de>,
V: Deserialize<'de> + Hash + Eq + Ord,
{
use std::marker::PhantomData;
use serde::de::{SeqAccess, Visitor};
struct BTreeSetVisitor<V> {
item: PhantomData<V>,
}
impl<'de, V> Visitor<'de> for BTreeSetVisitor<V>
where
V: Deserialize<'de> + Hash + Eq + Ord,
{
type Value = BTreeSet<V>;
fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(formatter, "a sequence")
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: SeqAccess<'de>,
{
let mut set = BTreeSet::new();
while let Some(element) = seq.next_element().transpose() {
if let Ok(element) = element {
set.insert(element);
}
}
Ok(set)
}
}
de.deserialize_seq(BTreeSetVisitor { item: PhantomData })
}
fn is_subdomain(subdomain: &str, domain: &str) -> bool {
if subdomain == domain {
return true;
}
subdomain.ends_with(&format!(".{domain}"))
}
+162
View File
@@ -0,0 +1,162 @@
use std::{collections::BTreeSet, fmt::Debug, hash::Hash, mem::discriminant};
use regex::Regex;
use ruma::OwnedDeviceId;
use serde::{Deserialize, Serialize};
use url::Url;
use super::client_metadata::ResponseType;
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct AuthorizationCodeQuery {
pub response_type: ResponseType,
pub client_id: String,
pub redirect_uri: Url,
pub scope: RawScopes,
pub state: String,
#[serde(default)]
pub response_mode: ResponseMode,
pub code_challenge: String,
pub code_challenge_method: CodeChallengeMethod,
#[serde(default)]
pub prompt: Option<Prompt>,
}
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum ResponseMode {
#[default]
// default for `code` response type, see https://openid.net/specs/oauth-v2-multiple-response-types-1_0.html#:~:text=Client%2E-,For,encoding%2E,-See
Query,
Fragment,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[non_exhaustive]
pub enum CodeChallengeMethod {
S256,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum Prompt {
Create,
#[serde(other)]
Unknown,
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialOrd, Ord)]
pub enum Scope {
Device(OwnedDeviceId),
ClientApi,
}
impl PartialEq for Scope {
fn eq(&self, other: &Self) -> bool { discriminant(self) == discriminant(other) }
}
impl Eq for Scope {}
impl Hash for Scope {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) { discriminant(self).hash(state); }
}
impl std::fmt::Display for Scope {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let urn = match self {
| Self::ClientApi => "urn:matrix:client:api:*".to_owned(),
| Self::Device(device_id) => format!("urn:matrix:client:device:{device_id}"),
};
f.write_str(&urn)
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct RawScopes(String);
impl RawScopes {
pub fn to_scopes(&self) -> Result<BTreeSet<Scope>, String> {
let client_api_token_regex =
Regex::new(r"urn:matrix:(client|org.matrix.msc2967.client):api:\*").unwrap();
let device_token_regex = Regex::new(
r"urn:matrix:(client|org.matrix.msc2967.client):device:([a-zA-Z0-9-._~]{5,})",
)
.unwrap();
let mut scopes = BTreeSet::new();
for token in self.0.split(' ') {
let scope_was_new = {
if client_api_token_regex.is_match(token) {
scopes.insert(Scope::ClientApi)
} else if let Some(captures) = device_token_regex.captures(token) {
scopes.insert(Scope::Device(captures.get(2).unwrap().as_str().into()))
} else if token == "openid" {
// TODO(unspecced): Element sets this scope but doesn't use it for anything
true
} else {
return Err(format!("Invalid scope: {token}"));
}
};
if !scope_was_new {
return Err("Scope was specified more than once".to_owned());
}
}
Ok(scopes)
}
}
#[derive(Serialize)]
pub struct AuthorizationCodeResponse {
pub state: String,
pub code: String,
}
#[derive(Deserialize)]
#[serde(tag = "grant_type", rename_all = "snake_case")]
pub enum TokenRequest {
AuthorizationCode {
code: String,
redirect_uri: Url,
client_id: String,
code_verifier: String,
},
RefreshToken {
client_id: String,
refresh_token: String,
},
}
impl TokenRequest {
#[must_use]
pub fn client_id(&self) -> &str {
match self {
| Self::AuthorizationCode { client_id, .. }
| Self::RefreshToken { client_id, .. } => client_id,
}
}
}
#[derive(Serialize)]
pub struct TokenResponse {
pub access_token: String,
pub token_type: TokenType,
pub expires_in: u64,
pub refresh_token: String,
pub scope: String,
}
#[derive(Serialize)]
pub enum TokenType {
Bearer,
}
#[derive(Deserialize)]
pub struct RevokeTokenRequest {
pub token: String,
}
+503
View File
@@ -0,0 +1,503 @@
use std::{
collections::{BTreeSet, HashMap},
sync::{Arc, Mutex},
time::{Duration, SystemTime},
};
use base64::Engine;
use conduwuit::{
Err, Result, err, info,
utils::{self, hash::sha256},
};
use database::{Deserialized, Json, Map};
use itertools::Itertools;
use ruma::{DeviceId, OwnedDeviceId, OwnedUserId, UserId};
use serde::{Deserialize, Serialize};
use url::Url;
use crate::{
Dep,
oauth::{
client_metadata::{ApplicationType, ClientMetadata, ResponseType},
grant::{
AuthorizationCodeQuery, AuthorizationCodeResponse, CodeChallengeMethod, ResponseMode,
Scope, TokenRequest, TokenResponse, TokenType,
},
},
users,
};
pub mod client_metadata;
pub mod grant;
pub struct Service {
services: Services,
db: Data,
tickets: Mutex<HashMap<String, HashMap<OAuthTicket, SystemTime>>>,
pending_code_grants: tokio::sync::Mutex<HashMap<String, PendingCodeGrant>>,
}
struct Data {
clientid_clientmetadata: Arc<Map>,
userdeviceid_oauthsessioninfo: Arc<Map>,
refreshtoken_refreshtokeninfo: Arc<Map>,
}
struct Services {
users: Dep<users::Service>,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct SessionInfo {
pub client_id: String,
pub scopes: BTreeSet<Scope>,
current_refresh_token: String,
}
#[derive(Debug, Deserialize, Serialize)]
struct RefreshTokenInfo {
client_id: String,
user_id: OwnedUserId,
device_id: OwnedDeviceId,
}
struct PendingCodeGrant {
authorizing_user: OwnedUserId,
requested_scopes: BTreeSet<Scope>,
client_name: Option<String>,
expected_client_id: String,
expected_redirect_uri: Url,
code_challenge: String,
requested_at: SystemTime,
}
impl PendingCodeGrant {
const MAX_AGE: Duration = Duration::from_mins(1);
const RANDOM_CODE_LENGTH: usize = 32;
#[must_use]
pub(crate) fn generate_code() -> String { utils::random_string(Self::RANDOM_CODE_LENGTH) }
#[must_use]
pub(crate) fn is_valid_for(&self, client_id: &str) -> bool {
let now = SystemTime::now();
self.expected_client_id == client_id
&& now
.duration_since(self.requested_at)
.is_ok_and(|age| age < Self::MAX_AGE)
}
}
/// A time-limited grant for a client to perform some sensitive action.
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum OAuthTicket {
CrossSigningReset,
}
impl OAuthTicket {
const MAX_AGE: Duration = Duration::from_mins(10);
#[must_use]
pub fn ticket_issue_path(&self) -> &'static str {
match self {
| Self::CrossSigningReset => "/account/cross_signing_reset",
}
}
}
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
services: Services {
users: args.depend::<users::Service>("users"),
},
db: Data {
clientid_clientmetadata: args.db["clientid_clientmetadata"].clone(),
userdeviceid_oauthsessioninfo: args.db["userdeviceid_oauthsessioninfo"].clone(),
refreshtoken_refreshtokeninfo: args.db["refreshtoken_refreshtokeninfo"].clone(),
},
tickets: Mutex::default(),
pending_code_grants: tokio::sync::Mutex::default(),
}))
}
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
}
impl Service {
const ACCESS_TOKEN_MAX_AGE: Duration = Duration::from_hours(1);
const RANDOM_TOKEN_LENGTH: usize = 32;
fn generate_token() -> String { utils::random_string(Self::RANDOM_TOKEN_LENGTH) }
pub async fn register_client(
&self,
metadata: &ClientMetadata,
) -> Result<String, &'static str> {
metadata.validate()?;
let client_id = base64::prelude::BASE64_STANDARD
.encode(sha256::hash(serde_json::to_string(metadata).unwrap().as_bytes()));
if self
.db
.clientid_clientmetadata
.exists(&client_id)
.await
.is_err()
{
self.db
.clientid_clientmetadata
.raw_put(&client_id, Json(metadata.clone()));
}
Ok(client_id)
}
pub async fn get_client_metadata(&self, client_id: &str) -> Option<ClientMetadata> {
self.db
.clientid_clientmetadata
.get(client_id)
.await
.deserialized()
.ok()
}
pub async fn get_session_info_for_device(
&self,
user_id: &UserId,
device_id: &DeviceId,
) -> Option<SessionInfo> {
self.db
.userdeviceid_oauthsessioninfo
.qry(&(user_id, device_id))
.await
.deserialized::<SessionInfo>()
.ok()
}
pub async fn request_authorization_code(
&self,
authorizing_user: OwnedUserId,
query: AuthorizationCodeQuery,
) -> Result<String, String> {
let Some(client_metadata) = self.get_client_metadata(&query.client_id).await else {
return Err("Invalid client ID".to_owned());
};
if !(client_metadata
.response_types
.contains(&query.response_type)
&& matches!(query.response_type, ResponseType::Code))
{
return Err("Invalid response type".to_owned());
}
if !matches!(query.code_challenge_method, CodeChallengeMethod::S256) {
return Err("Invalid code challenge type".to_owned());
}
{
let mut stripped_uri = query.redirect_uri.clone();
if client_metadata.application_type == ApplicationType::Native
&& query
.redirect_uri
.host_str()
.is_some_and(|host| ClientMetadata::ACCEPTABLE_LOCALHOSTS.contains(&host))
{
// Remove the port from localhost redirect URIs for native applications when
// checking if it's valid
stripped_uri.set_port(None).unwrap();
}
if !client_metadata.redirect_uris.contains(&stripped_uri) {
return Err("Invalid redirect URI".to_owned());
}
}
let requested_scopes = query.scope.to_scopes()?;
let redirect_uri_query_separator = match query.response_mode {
| ResponseMode::Fragment => '#',
| ResponseMode::Query => '?',
};
let code = PendingCodeGrant::generate_code();
info!(
client_id = &query.client_id,
client_name = &client_metadata.client_name,
?requested_scopes,
?authorizing_user,
"Issuing oauth authorization code"
);
let redirect_uri = format!(
"{}{}{}",
query.redirect_uri,
redirect_uri_query_separator,
serde_urlencoded::to_string(AuthorizationCodeResponse {
state: query.state,
code: code.clone(),
})
.unwrap(),
);
let pending_grant = PendingCodeGrant {
authorizing_user,
requested_scopes,
client_name: client_metadata.client_name,
expected_client_id: query.client_id,
expected_redirect_uri: query.redirect_uri,
code_challenge: query.code_challenge,
requested_at: SystemTime::now(),
};
self.pending_code_grants
.lock()
.await
.insert(code, pending_grant);
Ok(redirect_uri)
}
pub async fn issue_token(&self, request: TokenRequest) -> Result<TokenResponse> {
match request {
| TokenRequest::AuthorizationCode {
code,
redirect_uri,
client_id,
code_verifier,
} => {
let mut pending_grants = self.pending_code_grants.lock().await;
let Some(pending_grant) = pending_grants
.remove(&code)
.filter(|grant| grant.is_valid_for(&client_id))
else {
return Err!("Invalid code");
};
if redirect_uri != pending_grant.expected_redirect_uri {
return Err!("Unexpected redirect uri");
}
let expected_code_challenge =
base64::prelude::BASE64_URL_SAFE_NO_PAD.encode(sha256::hash(&code_verifier));
if expected_code_challenge != pending_grant.code_challenge {
return Err!("Invalid code challenge");
}
self.create_session(
pending_grant.authorizing_user,
pending_grant.requested_scopes,
pending_grant.client_name,
client_id,
)
.await
},
| TokenRequest::RefreshToken { client_id, refresh_token } =>
self.refresh_session(client_id, refresh_token).await,
}
}
pub async fn revoke_token(&self, token: String) -> Result<()> {
let (user_id, device_id) = if let Ok(refresh_token_info) = self
.db
.refreshtoken_refreshtokeninfo
.get(&token)
.await
.deserialized::<RefreshTokenInfo>()
{
(refresh_token_info.user_id, refresh_token_info.device_id)
} else if let Some((user_id, device_id, _)) =
self.services.users.find_from_token(&token).await
{
(user_id, device_id)
} else {
return Err!("Invalid token");
};
// This will also call [`Self::remove_session`]
self.services
.users
.remove_device(&user_id, &device_id)
.await;
Ok(())
}
async fn create_session(
&self,
authorizing_user: OwnedUserId,
requested_scopes: BTreeSet<Scope>,
client_name: Option<String>,
client_id: String,
) -> Result<TokenResponse> {
let access_token = Self::generate_token();
let refresh_token = Self::generate_token();
let device_id = requested_scopes
.iter()
.find_map(|scope| {
if let Scope::Device(device_id) = scope {
Some(device_id)
} else {
None
}
})
.ok_or_else(|| err!("No device ID scope supplied"))?;
self.services
.users
.create_device(
&authorizing_user,
device_id,
&access_token,
Some(Self::ACCESS_TOKEN_MAX_AGE),
client_name,
None,
)
.await?;
self.db.userdeviceid_oauthsessioninfo.put(
(&authorizing_user, device_id),
Json(SessionInfo {
client_id: client_id.clone(),
current_refresh_token: refresh_token.clone(),
scopes: requested_scopes.clone(),
}),
);
self.db.refreshtoken_refreshtokeninfo.raw_put(
&refresh_token,
Json(RefreshTokenInfo {
client_id: client_id.clone(),
user_id: authorizing_user.clone(),
device_id: device_id.to_owned(),
}),
);
info!(
?client_id,
?authorizing_user,
?device_id,
?requested_scopes,
"Created new oauth session"
);
Ok(TokenResponse {
access_token,
token_type: TokenType::Bearer,
expires_in: Self::ACCESS_TOKEN_MAX_AGE.as_secs(),
scope: requested_scopes.iter().join(" "),
refresh_token,
})
}
async fn refresh_session(
&self,
client_id: String,
refresh_token: String,
) -> Result<TokenResponse> {
let Some(refresh_token_info) = self
.db
.refreshtoken_refreshtokeninfo
.get(&refresh_token)
.await
.deserialized::<RefreshTokenInfo>()
.ok()
else {
return Err!("Invalid refresh token");
};
assert_eq!(&client_id, &refresh_token_info.client_id, "refresh token client id mismatch");
let mut session_info = self
.get_session_info_for_device(
&refresh_token_info.user_id,
&refresh_token_info.device_id,
)
.await
.expect("session info should exist");
assert_eq!(&client_id, &session_info.client_id, "session info client id mismatch");
let new_access_token = Self::generate_token();
let new_refresh_token = Self::generate_token();
let scope = session_info.scopes.iter().join(" ");
session_info
.current_refresh_token
.clone_from(&new_refresh_token);
self.services
.users
.set_token(
&refresh_token_info.user_id,
&refresh_token_info.device_id,
&new_access_token,
Some(Self::ACCESS_TOKEN_MAX_AGE),
)
.await?;
self.db.userdeviceid_oauthsessioninfo.put(
(&refresh_token_info.user_id, &refresh_token_info.device_id),
Json(session_info),
);
self.db.refreshtoken_refreshtokeninfo.remove(&refresh_token);
drop(refresh_token);
self.db
.refreshtoken_refreshtokeninfo
.raw_put(&new_refresh_token, Json(refresh_token_info));
Ok(TokenResponse {
access_token: new_access_token,
token_type: TokenType::Bearer,
expires_in: Self::ACCESS_TOKEN_MAX_AGE.as_secs(),
scope,
refresh_token: new_refresh_token,
})
}
pub async fn remove_session(&self, user_id: &UserId, device_id: &DeviceId) {
let session_info = self.get_session_info_for_device(user_id, device_id).await;
if let Some(session_info) = session_info {
self.db
.refreshtoken_refreshtokeninfo
.remove(&session_info.current_refresh_token);
self.db
.userdeviceid_oauthsessioninfo
.del((user_id, device_id));
info!(?user_id, ?device_id, "Removed OAuth session");
}
}
/// Issue a ticket for `localpart` to perform some action.
pub fn issue_ticket(&self, localpart: String, ticket: OAuthTicket) {
self.tickets
.lock()
.unwrap()
.entry(localpart)
.or_default()
.insert(ticket, SystemTime::now());
}
/// Try to consume an unexpired ticket for `localpart`.
pub fn try_consume_ticket(&self, localpart: &str, ticket: OAuthTicket) -> bool {
let now = SystemTime::now();
self.tickets
.lock()
.unwrap()
.get_mut(localpart)
.and_then(|tickets| tickets.remove(&ticket))
.is_some_and(|issued| {
now.duration_since(issued)
.is_ok_and(|duration| duration < OAuthTicket::MAX_AGE)
})
}
}
-68
View File
@@ -1,68 +0,0 @@
use std::{
sync::Arc,
time::{Duration, SystemTime},
};
use conduwuit::utils::{ReadyExt, stream::TryExpect};
use database::{Database, Deserialized, Json, Map};
use ruma::{OwnedUserId, UserId};
use serde::{Deserialize, Serialize};
pub(super) struct Data {
passwordresettoken_info: Arc<Map>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ResetTokenInfo {
pub user: OwnedUserId,
pub issued_at: SystemTime,
}
impl ResetTokenInfo {
// one hour
const MAX_TOKEN_AGE: Duration = Duration::from_hours(1);
pub fn is_valid(&self) -> bool {
let now = SystemTime::now();
now.duration_since(self.issued_at)
.is_ok_and(|duration| duration < Self::MAX_TOKEN_AGE)
}
}
impl Data {
pub(super) fn new(db: &Arc<Database>) -> Self {
Self {
passwordresettoken_info: db["passwordresettoken_info"].clone(),
}
}
/// Associate a reset token with its info in the database.
pub(super) fn save_token(&self, token: &str, info: &ResetTokenInfo) {
self.passwordresettoken_info.raw_put(token, Json(info));
}
/// Lookup the info for a reset token.
pub(super) async fn lookup_token_info(&self, token: &str) -> Option<ResetTokenInfo> {
self.passwordresettoken_info
.get(token)
.await
.deserialized()
.ok()
}
/// Find a user's existing reset token, if any.
pub(super) async fn find_token_for_user(
&self,
user: &UserId,
) -> Option<(String, ResetTokenInfo)> {
self.passwordresettoken_info
.stream::<'_, String, ResetTokenInfo>()
.expect_ok()
.ready_find(|(_, info)| info.user == user)
.await
}
/// Remove a reset token.
pub(super) fn remove_token(&self, token: &str) { self.passwordresettoken_info.remove(token); }
}
-111
View File
@@ -1,111 +0,0 @@
mod data;
use std::{sync::Arc, time::SystemTime};
use conduwuit::{Err, Result, utils};
use data::{Data, ResetTokenInfo};
use ruma::OwnedUserId;
use crate::{
Dep, globals,
users::{self, HashedPassword},
};
pub const PASSWORD_RESET_PATH: &str = "/_continuwuity/account/reset_password";
pub const RESET_TOKEN_QUERY_PARAM: &str = "token";
const RESET_TOKEN_LENGTH: usize = 32;
pub struct Service {
db: Data,
services: Services,
}
struct Services {
users: Dep<users::Service>,
globals: Dep<globals::Service>,
}
#[derive(Debug)]
pub struct ValidResetToken {
pub token: String,
pub info: ResetTokenInfo,
}
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
db: Data::new(args.db),
services: Services {
users: args.depend::<users::Service>("users"),
globals: args.depend::<globals::Service>("globals"),
},
}))
}
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
}
impl Service {
/// Generate a random string suitable to be used as a password reset token.
#[must_use]
pub fn generate_token_string() -> String { utils::random_string(RESET_TOKEN_LENGTH) }
/// Issue a password reset token for `user`, who must be a local user with
/// the `password` origin.
pub async fn issue_token(&self, user_id: OwnedUserId) -> Result<ValidResetToken> {
if !self.services.globals.user_is_local(&user_id) {
return Err!("Cannot issue a password reset token for remote user {user_id}");
}
if user_id == self.services.globals.server_user {
return Err!("Cannot issue a password reset token for the server user");
}
if self.services.users.is_deactivated(&user_id).await? {
return Err!("Cannot issue a password reset token for deactivated user {user_id}");
}
if let Some((existing_token, _)) = self.db.find_token_for_user(&user_id).await {
self.db.remove_token(&existing_token);
}
let token = Self::generate_token_string();
let info = ResetTokenInfo {
user: user_id,
issued_at: SystemTime::now(),
};
self.db.save_token(&token, &info);
Ok(ValidResetToken { token, info })
}
/// Check if `token` represents a valid, non-expired password reset token.
pub async fn check_token(&self, token: &str) -> Option<ValidResetToken> {
self.db.lookup_token_info(token).await.and_then(|info| {
if info.is_valid() {
Some(ValidResetToken { token: token.to_owned(), info })
} else {
self.db.remove_token(token);
None
}
})
}
/// Consume the supplied valid token, using it to change its user's password
/// to `new_password`.
pub async fn consume_token(
&self,
ValidResetToken { token, info }: ValidResetToken,
new_password: &str,
) -> Result<()> {
if info.is_valid() {
self.db.remove_token(&token);
self.services
.users
.set_password(&info.user, Some(HashedPassword::new(new_password)?));
}
Ok(())
}
}
+1 -1
View File
@@ -100,7 +100,7 @@ impl Service {
/// Pings the presence of the given user in the given room, setting the
/// specified state.
pub async fn ping_presence(&self, user_id: &UserId, new_state: &PresenceState) -> Result<()> {
const REFRESH_TIMEOUT: u64 = 60 * 1000;
const REFRESH_TIMEOUT: u64 = 60 * 1000 * 4;
let last_presence = self.db.get_presence(user_id).await;
let state_changed = match last_presence {
+3 -2
View File
@@ -10,6 +10,7 @@ use futures::{
stream::{iter, once},
};
use ruma::OwnedUserId;
use serde::{Deserialize, Serialize};
use crate::{Dep, config, firstrun};
@@ -27,7 +28,7 @@ struct Services {
}
/// A validated registration token which may be used to create an account.
#[derive(Debug)]
#[derive(Debug, Deserialize, Serialize)]
pub struct ValidToken {
pub token: String,
pub source: ValidTokenSource,
@@ -44,7 +45,7 @@ impl PartialEq<str> for ValidToken {
}
/// The source of a valid database token.
#[derive(Debug)]
#[derive(Debug, Deserialize, Serialize)]
pub enum ValidTokenSource {
/// The static token set in the homeserver's config file.
Config,
@@ -80,7 +80,7 @@ where
{
// Exponential backoff
const MIN_DURATION: u64 = 60 * 2;
const MAX_DURATION: u64 = 60 * 60 * 8;
const MAX_DURATION: u64 = 60 * 60;
if continue_exponential_backoff_secs(
MIN_DURATION,
MAX_DURATION,
@@ -46,7 +46,7 @@ where
{
// Exponential backoff
const MIN_DURATION: u64 = 5 * 60;
const MAX_DURATION: u64 = 60 * 60 * 24;
const MAX_DURATION: u64 = 60 * 60;
if continue_exponential_backoff_secs(MIN_DURATION, MAX_DURATION, time.elapsed(), *tries) {
debug!(
?tries,
@@ -204,17 +204,27 @@ pub async fn policy_server_allows_event(
%ps.via,
"Asking policy server to sign event"
);
self.fetch_policy_server_signature(pdu, pdu_json, &ps.via, outgoing, room_id, ps_key, 0)
.await?;
// Verify that the policy server signature was made with the same public key as
// is in the state event, not just that it was signed.
if let Err(e) = self
.fetch_policy_server_signature(pdu, pdu_json, &ps.via, outgoing, room_id, ps_key, 0)
.await
{
if e.is_not_found() {
return Ok(());
}
return Err(e);
}
trace!(
"Got successful response for fetching PS signature, ensuring it is signed with the \
expected key."
);
if verify_policy_signature(&ps.via, ps_key, pdu_json, &room_version_rules.redaction) {
Ok(())
} else if incoming {
Err!(Request(Forbidden("Policy server signature is invalid")))
} else {
Err(Error::Request(
ErrorKind::Unknown,
"Policy server signature was made with a different key to the one advertised".into(),
"Policy server signature is invalid".into(),
StatusCode::BAD_GATEWAY,
))
}
@@ -272,7 +282,7 @@ async fn handle_policy_server_error(
"Policy server is not actually a policy server or is not protecting this room: {}",
error.message()
);
Ok(())
Err(error)
},
| StatusCode::TOO_MANY_REQUESTS => {
if let Some(retry_after) = error.retry_after() {
+2 -46
View File
@@ -1,10 +1,10 @@
use std::sync::Arc;
use conduwuit::{Result, implement};
use database::{Database, Deserialized, Map};
use database::{Deserialized, Map};
use ruma::{RoomId, UserId};
use crate::{Dep, globals, rooms, rooms::short::ShortStateHash};
use crate::{Dep, globals};
pub struct Service {
db: Data,
@@ -12,32 +12,25 @@ pub struct Service {
}
struct Data {
db: Arc<Database>,
userroomid_notificationcount: Arc<Map>,
userroomid_highlightcount: Arc<Map>,
roomuserid_lastnotificationread: Arc<Map>,
roomsynctoken_shortstatehash: Arc<Map>,
}
struct Services {
globals: Dep<globals::Service>,
short: Dep<rooms::short::Service>,
}
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
db: Data {
db: args.db.clone(),
userroomid_notificationcount: args.db["userroomid_notificationcount"].clone(),
userroomid_highlightcount: args.db["userroomid_highlightcount"].clone(),
roomuserid_lastnotificationread: args.db["userroomid_highlightcount"].clone(),
roomsynctoken_shortstatehash: args.db["roomsynctoken_shortstatehash"].clone(),
},
services: Services {
globals: args.depend::<globals::Service>("globals"),
short: args.depend::<rooms::short::Service>("rooms::short"),
},
}))
}
@@ -90,40 +83,3 @@ pub async fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -
.deserialized()
.unwrap_or(0)
}
#[implement(Service)]
pub async fn associate_token_shortstatehash(
&self,
room_id: &RoomId,
token: u64,
shortstatehash: ShortStateHash,
) {
let shortroomid = self
.services
.short
.get_shortroomid(room_id)
.await
.expect("room exists");
let _cork = self.db.db.cork();
let key: &[u64] = &[shortroomid, token];
self.db
.roomsynctoken_shortstatehash
.put(key, shortstatehash);
}
#[implement(Service)]
pub async fn get_token_shortstatehash(
&self,
room_id: &RoomId,
token: u64,
) -> Result<ShortStateHash> {
let shortroomid = self.services.short.get_shortroomid(room_id).await?;
let key: &[u64] = &[shortroomid, token];
self.db
.roomsynctoken_shortstatehash
.qry(key)
.await
.deserialized()
}
+4 -4
View File
@@ -11,8 +11,8 @@ use crate::{
account_data, admin, announcements, antispam, appservice, client, config, emergency,
federation, firstrun, globals, key_backups, mailer,
manager::Manager,
media, moderation, password_reset, presence, pusher, registration_tokens, resolver, rooms,
sending, server_keys,
media, moderation, oauth, presence, pusher, registration_tokens, resolver, rooms, sending,
server_keys,
service::{self, Args, Map, Service},
sync, threepid, transactions, uiaa, users,
};
@@ -27,7 +27,7 @@ pub struct Services {
pub globals: Arc<globals::Service>,
pub key_backups: Arc<key_backups::Service>,
pub media: Arc<media::Service>,
pub password_reset: Arc<password_reset::Service>,
pub oauth: Arc<oauth::Service>,
pub mailer: Arc<mailer::Service>,
pub presence: Arc<presence::Service>,
pub pusher: Arc<pusher::Service>,
@@ -84,7 +84,7 @@ impl Services {
globals: build!(globals::Service),
key_backups: build!(key_backups::Service),
media: build!(media::Service),
password_reset: build!(password_reset::Service),
oauth: build!(oauth::Service),
mailer: build!(mailer::Service),
presence: build!(presence::Service),
pusher: build!(pusher::Service),
+29 -7
View File
@@ -9,8 +9,9 @@ use ruma::{
ClientSecret, OwnedClientSecret, OwnedSessionId, SessionId,
api::error::{ErrorKind, LimitExceededErrorData},
};
use tokio::sync::MutexGuard;
mod session;
pub mod session;
use crate::{
Args, Dep, config,
@@ -26,6 +27,7 @@ pub struct Service {
ratelimiter: DefaultKeyedRateLimiter<Address>,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum EmailRequirement {
/// Users may change their email, but cannot remove it entirely.
Required,
@@ -219,13 +221,12 @@ impl Service {
Ok(())
}
/// Consume a validated validation session, removing it from the database
/// and returning the newly validated email address.
pub async fn consume_valid_session(
/// Get a validated validation session.
pub async fn get_valid_session(
&self,
session_id: &SessionId,
client_secret: &ClientSecret,
) -> Result<Address, Cow<'static, str>> {
) -> Result<ValidSession<'_>, Cow<'static, str>> {
let mut sessions = self.sessions.lock().await;
let Some(session) = sessions.get_session(session_id) else {
@@ -235,9 +236,13 @@ impl Service {
if session.client_secret == client_secret
&& matches!(session.validation_state, ValidationState::Validated)
{
let session = sessions.remove_session(session_id);
let email = session.email.clone();
Ok(session.email)
Ok(ValidSession {
email,
session_id: session_id.to_owned(),
sessions,
})
} else {
Err("This email address has not been validated. Did you use the link that was sent \
to you?"
@@ -313,3 +318,20 @@ impl Service {
.ok()
}
}
pub struct ValidSession<'lock> {
pub email: Address,
session_id: OwnedSessionId,
sessions: MutexGuard<'lock, ValidationSessions>,
}
impl ValidSession<'_> {
/// Consume this session, removing it from the database and releasing the
/// lock it holds.
#[must_use]
pub fn consume(mut self) -> Address {
self.sessions.remove_session(&self.session_id);
self.email
}
}
+5 -5
View File
@@ -8,14 +8,14 @@ use lettre::Address;
use ruma::{ClientSecret, OwnedClientSecret, OwnedSessionId, SessionId};
#[derive(Default)]
pub(super) struct ValidationSessions {
pub struct ValidationSessions {
sessions: HashMap<OwnedSessionId, ValidationSession>,
client_secrets: HashMap<OwnedClientSecret, OwnedSessionId>,
}
/// A pending or completed email validation session.
#[derive(Debug)]
pub(crate) struct ValidationSession {
pub struct ValidationSession {
/// The session's ID
pub session_id: OwnedSessionId,
/// The client's supplied client secret
@@ -28,7 +28,7 @@ pub(crate) struct ValidationSession {
/// The state of an email validation session.
#[derive(Debug)]
pub(crate) enum ValidationState {
pub enum ValidationState {
/// The session is waiting for this validation token to be provided
Pending(ValidationToken),
/// The session has been validated
@@ -36,7 +36,7 @@ pub(crate) enum ValidationState {
}
#[derive(Clone, Debug)]
pub(crate) struct ValidationToken {
pub struct ValidationToken {
pub token: String,
pub issued_at: SystemTime,
}
@@ -69,7 +69,7 @@ impl ValidationSessions {
const RANDOM_SID_LENGTH: usize = 16;
#[must_use]
pub(super) fn generate_session_id() -> OwnedSessionId {
pub fn generate_session_id() -> OwnedSessionId {
SessionId::parse(utils::random_string(Self::RANDOM_SID_LENGTH)).unwrap()
}
+297 -151
View File
@@ -7,7 +7,7 @@ use std::{
use conduwuit::{Err, Error, Result, error, utils};
use lettre::Address;
use ruma::{
UserId,
DeviceId, UserId,
api::{
client::uiaa::{
AuthData, AuthFlow, AuthType, EmailIdentity, EmailUserIdentifier,
@@ -16,11 +16,19 @@ use ruma::{
},
error::{ErrorKind, StandardErrorBody},
},
assign,
};
use serde_json::{
json,
value::{RawValue, to_raw_value},
};
use serde_json::value::RawValue;
use tokio::sync::Mutex;
use crate::{Dep, config, globals, registration_tokens, threepid, users};
use crate::{
Dep, config, globals,
oauth::{self, OAuthTicket},
registration_tokens, threepid, users,
};
pub struct Service {
services: Services,
@@ -33,6 +41,7 @@ struct Services {
config: Dep<config::Service>,
registration_tokens: Dep<registration_tokens::Service>,
threepid: Dep<threepid::Service>,
oauth: Dep<oauth::Service>,
}
impl crate::Service for Service {
@@ -45,6 +54,7 @@ impl crate::Service for Service {
registration_tokens: args
.depend::<registration_tokens::Service>("registration_tokens"),
threepid: args.depend::<threepid::Service>("threepid"),
oauth: args.depend::<oauth::Service>("oauth"),
},
uiaa_sessions: Mutex::new(HashMap::new()),
}))
@@ -54,8 +64,56 @@ impl crate::Service for Service {
}
struct UiaaSession {
session_metadata: UiaaSessionMetadata,
info: UiaaInfo,
identity: Identity,
}
#[derive(Clone)]
enum UiaaSessionMetadata {
Legacy {
identity: Identity,
},
OAuth {
localpart: String,
ticket: OAuthTicket,
},
}
impl UiaaSessionMetadata {
fn into_identity(self) -> Identity {
match self {
| Self::Legacy { identity } => identity,
| Self::OAuth { localpart, .. } =>
assign!(Identity::default(), { localpart: Some(localpart) }),
}
}
}
/// Information about the user which is initiating this UIAA session.
pub struct UiaaInitiator<'a> {
user_id: &'a UserId,
device_id: &'a DeviceId,
oauth_ticket: Option<OAuthTicket>,
}
impl<'a> UiaaInitiator<'a> {
#[must_use]
pub fn new(user_id: &'a UserId, device_id: &'a DeviceId) -> Self {
Self { user_id, device_id, oauth_ticket: None }
}
#[must_use]
pub fn with_oauth_ticket(
user_id: &'a UserId,
device_id: &'a DeviceId,
oauth_ticket: OAuthTicket,
) -> Self {
Self {
user_id,
device_id,
oauth_ticket: Some(oauth_ticket),
}
}
}
/// Information about the authenticated user's identity.
@@ -106,7 +164,7 @@ impl Identity {
/// Create an Identity with the localpart of the provided user ID
/// and all other fields set to None.
#[must_use]
pub fn from_user_id(user_id: &UserId) -> Self {
fn from_user_id(user_id: &UserId) -> Self {
Self {
localpart: Some(user_id.localpart().to_owned()),
..Default::default()
@@ -124,11 +182,11 @@ impl Service {
auth: &Option<AuthData>,
flows: Vec<AuthFlow>,
params: Box<RawValue>,
identity: Option<Identity>,
initiator: Option<UiaaInitiator<'_>>,
) -> Result<Identity> {
match auth.as_ref() {
| None => {
let info = self.create_session(flows, params, identity).await;
let info = self.create_session(flows, params, initiator).await?;
Err(Error::Uiaa(info))
},
@@ -140,8 +198,8 @@ impl Service {
// session if they want to start the UIAA exchange with existing
// authentication data. If that happens, we create a new session
// here.
self.create_session(flows, params, identity)
.await
self.create_session(flows, params, initiator)
.await?
.session
.unwrap()
.into()
@@ -161,13 +219,15 @@ impl Service {
pub async fn authenticate_password(
&self,
auth: &Option<AuthData>,
identity: Option<Identity>,
user_id: &UserId,
device_id: &DeviceId,
oauth_ticket: Option<OAuthTicket>,
) -> Result<Identity> {
self.authenticate(
auth,
vec![AuthFlow::new(vec![AuthType::Password])],
Box::default(),
identity,
Some(UiaaInitiator { user_id, device_id, oauth_ticket }),
)
.await
}
@@ -183,20 +243,84 @@ impl Service {
&self,
flows: Vec<AuthFlow>,
params: Box<RawValue>,
identity: Option<Identity>,
) -> UiaaInfo {
initiator: Option<UiaaInitiator<'_>>,
) -> Result<UiaaInfo> {
let mut uiaa_sessions = self.uiaa_sessions.lock().await;
let session_id = utils::random_string(Self::SESSION_ID_LENGTH);
let mut info = assign::assign!(UiaaInfo::new(flows), {params: Some(params)});
info.session = Some(session_id.clone());
uiaa_sessions.insert(session_id, UiaaSession {
info: info.clone(),
identity: identity.unwrap_or_default(),
});
let mut info = assign!(UiaaInfo::new(flows), { params: Some(params), session: Some(session_id.clone()) });
info
let session_metadata = if let Some(initiator) = initiator {
let is_oauth = self
.services
.oauth
.get_session_info_for_device(initiator.user_id, initiator.device_id)
.await
.is_some();
if is_oauth {
if let Some(oauth_ticket) = initiator.oauth_ticket {
let ticket_url = self
.services
.config
.get_client_domain()
.join(&format!(
"{}{}",
conduwuit_core::ROUTE_PREFIX,
oauth_ticket.ticket_issue_path()
))
.unwrap();
info.flows = vec![AuthFlow::new(vec![AuthType::OAuth])];
info.params = Some(
to_raw_value(&json!({
AuthType::OAuth.as_str(): {
"url": ticket_url,
},
// TODO(compat): This is necessary for older versions of matrix-rust-sdk
"org.matrix.cross_signing_reset": {
"url": ticket_url,
}
}))
.unwrap(),
);
UiaaSessionMetadata::OAuth {
localpart: initiator.user_id.localpart().to_owned(),
ticket: oauth_ticket,
}
} else {
return Err!(Request(Forbidden(
"Clients authorized with OAuth cannot use this route."
)));
}
} else {
UiaaSessionMetadata::Legacy {
identity: Identity::from_user_id(initiator.user_id),
}
}
} else {
UiaaSessionMetadata::Legacy { identity: Identity::default() }
};
// Legacy sessions aren't available if OAuth is required
if matches!(&session_metadata, UiaaSessionMetadata::Legacy { .. })
&& !self
.services
.config
.oauth
.compatibility_mode
.uiaa_available()
{
return Err!(Request(Unrecognized(
"User-interactive authentication is unavailable on this server"
)));
}
uiaa_sessions.insert(session_id, UiaaSession { session_metadata, info: info.clone() });
Ok(info)
}
/// Proceed with UIAA authentication given a client's authorization data.
@@ -225,7 +349,7 @@ impl Service {
}
let completed = {
let UiaaSession { info, identity } = session.get_mut();
let UiaaSession { session_metadata, info } = session.get_mut();
let auth_type = auth.auth_type().expect("auth type should be set");
@@ -258,12 +382,12 @@ impl Service {
// If the provided stage hasn't already been completed, check it for completion
if !completed_stages.contains(auth_type.as_str()) {
match self.check_stage(auth, identity.clone()).await {
| Ok((completed_stage, updated_identity)) => {
match self.check_stage(auth, session_metadata.clone()).await {
| Ok((completed_stage, updated_metadata)) => {
info.auth_error = None;
completed_stages.insert(completed_stage.to_string());
info.completed.push(completed_stage);
*identity = updated_identity;
*session_metadata = updated_metadata;
},
| Err(error) => {
info.auth_error = Some(error);
@@ -279,9 +403,9 @@ impl Service {
if completed {
// This session is complete, remove it and return success
let (_, UiaaSession { identity, .. }) = session.remove_entry();
let (_, UiaaSession { session_metadata, .. }) = session.remove_entry();
Ok(Ok(identity))
Ok(Ok(session_metadata.into_identity()))
} else {
// The client needs to try again, return the updated session
Ok(Err(session.get().info.clone()))
@@ -295,152 +419,174 @@ impl Service {
async fn check_stage(
&self,
auth: &AuthData,
mut identity: Identity,
) -> Result<(AuthType, Identity), StandardErrorBody> {
// Note: This function takes ownership of `identity` because mutations to the
// identity must not be applied unless checking the stage succeeds. The
// updated identity is returned as part of the Ok value, and
// `continue_session` handles saving it to `uiaa_sessions`.
mut session_metadata: UiaaSessionMetadata,
) -> Result<(AuthType, UiaaSessionMetadata), StandardErrorBody> {
// Note: This function takes ownership of `session_metadata` because mutations
// to the identity (if it's a legacy session) must not be applied unless
// checking the stage succeeds. The updated identity is returned as part of
// the Ok value, and `continue_session` handles saving it to `uiaa_sessions`.
//
// This also means it's fine to mutate `identity` at any point in this function,
// because those mutations won't be saved unless the function returns Ok.
match auth {
| AuthData::Dummy(_) => Ok(AuthType::Dummy),
| AuthData::EmailIdentity(EmailIdentity {
thirdparty_id_creds: ThirdpartyIdCredentials { client_secret, sid, .. },
..
}) => {
match self
.services
.threepid
.consume_valid_session(sid, client_secret)
.await
{
| Ok(email) => {
if let Some(localpart) =
self.services.threepid.get_localpart_for_email(&email).await
{
identity.try_set_localpart(localpart)?;
}
let completed_auth_type = match &mut session_metadata {
| UiaaSessionMetadata::OAuth { localpart, ticket } => {
// m.oauth is the only valid stage for oauth sessions
assert!(
matches!(auth, AuthData::OAuth(_)),
"got non-oauth auth data for oauth session"
);
identity.try_set_email(email)?;
Ok(AuthType::EmailIdentity)
},
| Err(message) => Err(StandardErrorBody::new(
ErrorKind::ThreepidAuthFailed,
message.into_owned(),
)),
if self.services.oauth.try_consume_ticket(localpart, *ticket) {
Ok(AuthType::OAuth)
} else {
Err(StandardErrorBody::new(
ErrorKind::Forbidden,
"No OAuth ticket available".to_owned(),
))
}
},
#[allow(clippy::useless_let_if_seq)]
| AuthData::Password(Password { identifier, password, .. }) => {
let user_id_or_localpart = match identifier {
| UserIdentifier::Matrix(MatrixUserIdentifier { user, .. }) =>
user.to_owned(),
| UserIdentifier::Email(EmailUserIdentifier { address, .. }) => {
let Ok(email) = Address::try_from(address.to_owned()) else {
return Err(StandardErrorBody::new(
ErrorKind::InvalidParam,
"Email is malformed".to_owned(),
));
};
| UiaaSessionMetadata::Legacy { identity } => match auth {
| AuthData::Dummy(_) => Ok(AuthType::Dummy),
| AuthData::EmailIdentity(EmailIdentity {
thirdparty_id_creds: ThirdpartyIdCredentials { client_secret, sid, .. },
..
}) => {
match self
.services
.threepid
.get_valid_session(sid, client_secret)
.await
{
| Ok(session) => {
let email = session.consume();
if let Some(localpart) =
self.services.threepid.get_localpart_for_email(&email).await
{
identity.try_set_localpart(localpart)?;
}
if let Some(localpart) =
self.services.threepid.get_localpart_for_email(&email).await
{
identity.try_set_email(email)?;
localpart
} else {
return Err(StandardErrorBody::new(
ErrorKind::Forbidden,
"Invalid identifier or password".to_owned(),
));
}
},
| _ =>
return Err(StandardErrorBody::new(
ErrorKind::Unrecognized,
"Identifier type not recognized".to_owned(),
Ok(AuthType::EmailIdentity)
},
| Err(message) => Err(StandardErrorBody::new(
ErrorKind::ThreepidAuthFailed,
message.into_owned(),
)),
};
}
},
#[allow(clippy::useless_let_if_seq)]
| AuthData::Password(Password { identifier, password, .. }) => {
let user_id_or_localpart = match identifier {
| UserIdentifier::Matrix(MatrixUserIdentifier { user, .. }) =>
user.to_owned(),
| UserIdentifier::Email(EmailUserIdentifier { address, .. }) => {
let Ok(email) = Address::try_from(address.to_owned()) else {
return Err(StandardErrorBody::new(
ErrorKind::InvalidParam,
"Email is malformed".to_owned(),
));
};
let Ok(user_id) = UserId::parse_with_server_name(
user_id_or_localpart,
self.services.globals.server_name(),
) else {
return Err(StandardErrorBody::new(
ErrorKind::InvalidParam,
"User ID is malformed".to_owned(),
));
};
if let Some(localpart) =
self.services.threepid.get_localpart_for_email(&email).await
{
identity.try_set_email(email)?;
if self
.services
.users
.check_password(&user_id, password)
.await
.is_ok()
{
identity.try_set_localpart(user_id.localpart().to_owned())?;
localpart
} else {
return Err(StandardErrorBody::new(
ErrorKind::Forbidden,
"Invalid identifier or password".to_owned(),
));
}
},
| _ =>
return Err(StandardErrorBody::new(
ErrorKind::Unrecognized,
"Identifier type not recognized".to_owned(),
)),
};
Ok(AuthType::Password)
} else {
Err(StandardErrorBody::new(
ErrorKind::Forbidden,
"Invalid identifier or password".to_owned(),
))
}
},
| AuthData::ReCaptcha(ReCaptcha { response, .. }) => {
let Some(ref private_site_key) = self.services.config.recaptcha_private_site_key
else {
return Err(StandardErrorBody::new(
ErrorKind::Forbidden,
"ReCaptcha is not configured".to_owned(),
));
};
let Ok(user_id) = UserId::parse_with_server_name(
user_id_or_localpart,
self.services.globals.server_name(),
) else {
return Err(StandardErrorBody::new(
ErrorKind::InvalidParam,
"User ID is malformed".to_owned(),
));
};
match recaptcha_verify::verify_v3(private_site_key, response, None).await {
| Ok(()) => Ok(AuthType::ReCaptcha),
| Err(e) => {
error!("ReCaptcha verification failed: {e:?}");
if self
.services
.users
.check_password(&user_id, password)
.await
.is_ok()
{
identity.try_set_localpart(user_id.localpart().to_owned())?;
Ok(AuthType::Password)
} else {
Err(StandardErrorBody::new(
ErrorKind::Forbidden,
"ReCaptcha verification failed".to_owned(),
"Invalid identifier or password".to_owned(),
))
},
}
},
| AuthData::RegistrationToken(RegistrationToken { token, .. }) => {
let token = token.trim().to_owned();
}
},
| AuthData::ReCaptcha(ReCaptcha { response, .. }) => {
let Some(ref private_site_key) =
self.services.config.recaptcha_private_site_key
else {
return Err(StandardErrorBody::new(
ErrorKind::Forbidden,
"ReCaptcha is not configured".to_owned(),
));
};
if let Some(valid_token) = self
.services
.registration_tokens
.validate_token(token)
.await
{
self.services
match recaptcha_verify::verify_v3(private_site_key, response, None).await {
| Ok(()) => Ok(AuthType::ReCaptcha),
| Err(e) => {
error!("ReCaptcha verification failed: {e:?}");
Err(StandardErrorBody::new(
ErrorKind::Forbidden,
"ReCaptcha verification failed".to_owned(),
))
},
}
},
| AuthData::RegistrationToken(RegistrationToken { token, .. }) => {
let token = token.trim().to_owned();
if let Some(valid_token) = self
.services
.registration_tokens
.mark_token_as_used(valid_token);
.validate_token(token)
.await
{
self.services
.registration_tokens
.mark_token_as_used(valid_token);
Ok(AuthType::RegistrationToken)
} else {
Err(StandardErrorBody::new(
ErrorKind::Forbidden,
"Invalid registration token".to_owned(),
))
}
Ok(AuthType::RegistrationToken)
} else {
Err(StandardErrorBody::new(
ErrorKind::Forbidden,
"Invalid registration token".to_owned(),
))
}
},
| AuthData::Terms(_) => Ok(AuthType::Terms),
| _ => Err(StandardErrorBody::new(
ErrorKind::Unrecognized,
"Unsupported stage type".into(),
)),
},
| AuthData::Terms(_) => Ok(AuthType::Terms),
| _ => Err(StandardErrorBody::new(
ErrorKind::Unrecognized,
"Unsupported stage type".into(),
)),
}
.map(|auth_type| (auth_type, identity))
}?;
Ok((completed_auth_type, session_metadata))
}
}
+1 -1
View File
@@ -54,6 +54,7 @@ pub async fn set_dehydrated_device(&self, user_id: &UserId, request: Request) ->
user_id,
&request.device_id,
"",
None,
request.initial_device_display_name.clone(),
None,
)
@@ -138,7 +139,6 @@ pub async fn get_dehydrated_device_id(&self, user_id: &UserId) -> Result<OwnedDe
level = "debug",
skip_all,
fields(%user_id),
ret,
)]
pub async fn get_dehydrated_device(&self, user_id: &UserId) -> Result<DehydratedDevice> {
self.db
+334 -17
View File
@@ -1,13 +1,21 @@
pub(super) mod dehydrated_device;
use std::{collections::BTreeMap, mem, net::IpAddr, sync::Arc};
use std::{
collections::BTreeMap,
mem,
net::IpAddr,
sync::Arc,
time::{Duration, SystemTime},
};
use conduwuit::{
Err, Error, Result, Server, debug_error, debug_warn, err, trace,
Err, Error, Result, debug_error, debug_warn, err, info, trace,
utils::{self, ReadyExt, stream::TryIgnore, string::Unquoted},
warn,
};
use database::{Deserialized, Ignore, Interfix, Json, Map};
use futures::{Stream, StreamExt, TryFutureExt};
use futures::{FutureExt, Stream, StreamExt, TryFutureExt};
use lettre::Address;
use ruma::{
DeviceId, MilliSecondsSinceUnixEpoch, OneTimeKeyAlgorithm, OneTimeKeyId, OneTimeKeyName,
OwnedDeviceId, OwnedKeyId, OwnedMxcUri, OwnedOneTimeKeyId, OwnedUserId, RoomId, UInt, UserId,
@@ -18,15 +26,24 @@ use ruma::{
encryption::{CrossSigningKey, DeviceKeys, OneTimeKey},
events::{
AnyToDeviceEvent, GlobalAccountDataEventType, ignored_user_list::IgnoredUserListEvent,
push_rules::PushRulesEvent, room::message::RoomMessageEventContent,
},
push::Ruleset,
serde::Raw,
uint,
};
use ruminuwuity::invite_permission_config::{FilterLevel, InvitePermissionConfigEvent};
use serde::{Deserialize, Serialize};
use serde_json::json;
use tracing::error;
use crate::{Dep, account_data, admin, appservice, globals, rooms};
use crate::{
Dep, account_data, admin,
appservice::{self, RegistrationInfo},
config, firstrun, globals, oauth,
rooms::{self, alias, membership},
threepid,
};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UserSuspension {
@@ -41,6 +58,7 @@ pub struct UserSuspension {
/// A password hash. This is only for use when setting a user's password,
/// if the hash needs to be kept around for a while without keeping the password
/// in memory.
#[derive(Serialize, Deserialize)]
pub struct HashedPassword(String);
impl HashedPassword {
@@ -51,19 +69,30 @@ impl HashedPassword {
}
}
/// The status of an access token.
pub enum AccessTokenStatus {
Valid,
Expired,
}
pub struct Service {
services: Services,
db: Data,
}
struct Services {
server: Arc<Server>,
account_data: Dep<account_data::Service>,
admin: Dep<admin::Service>,
alias: Dep<alias::Service>,
appservice: Dep<appservice::Service>,
config: Dep<config::Service>,
firstrun: Dep<firstrun::Service>,
globals: Dep<globals::Service>,
membership: Dep<membership::Service>,
oauth: Dep<oauth::Service>,
state_accessor: Dep<rooms::state_accessor::Service>,
state_cache: Dep<rooms::state_cache::Service>,
threepid: Dep<threepid::Service>,
}
struct Data {
@@ -75,6 +104,7 @@ struct Data {
logintoken_expiresatuserid: Arc<Map>,
todeviceid_events: Arc<Map>,
token_userdeviceid: Arc<Map>,
userdeviceid_tokenexpires: Arc<Map>,
userdeviceid_metadata: Arc<Map>,
userdeviceid_token: Arc<Map>,
userfilterid_filter: Arc<Map>,
@@ -97,14 +127,19 @@ impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
services: Services {
server: args.server.clone(),
account_data: args.depend::<account_data::Service>("account_data"),
admin: args.depend::<admin::Service>("admin"),
alias: args.depend::<alias::Service>("alias"),
appservice: args.depend::<appservice::Service>("appservice"),
config: args.depend::<config::Service>("config"),
firstrun: args.depend::<firstrun::Service>("firstrun"),
globals: args.depend::<globals::Service>("globals"),
membership: args.depend::<membership::Service>("membership"),
oauth: args.depend::<oauth::Service>("oauth"),
state_accessor: args
.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"),
threepid: args.depend::<threepid::Service>("threepid"),
},
db: Data {
keychangeid_userid: args.db["keychangeid_userid"].clone(),
@@ -131,6 +166,7 @@ impl crate::Service for Service {
userid_selfsigningkeyid: args.db["userid_selfsigningkeyid"].clone(),
userid_usersigningkeyid: args.db["userid_usersigningkeyid"].clone(),
useridprofilekey_value: args.db["useridprofilekey_value"].clone(),
userdeviceid_tokenexpires: args.db["userdeviceid_tokenexpires"].clone(),
},
}))
}
@@ -192,12 +228,239 @@ impl Service {
Ok(())
}
// /// Create a new account for a local human or bot user.
// pub async fn create_local_account(
// &self,
// username: String,
// password:
// )
/// Create a new account for a local human or bot user.
pub async fn create_local_account(
&self,
user_id: &UserId,
password: HashedPassword,
email: Option<Address>,
) {
self.create(user_id, Some(password))
.await
.expect("should be able to save a new local user. what happened?");
// Set an initial display name
{
let mut displayname = user_id.localpart().to_owned();
let suffix = &self.services.config.new_user_displayname_suffix;
if !suffix.is_empty() {
displayname.push(' ');
displayname.push_str(suffix);
}
self.set_displayname(user_id, Some(displayname));
};
// Set default push rules
self.services
.account_data
.update(
None,
user_id,
GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(PushRulesEvent::new(
Ruleset::server_default(user_id).into(),
))
.expect("should be able to serialize push rules"),
)
.await
.expect("should be able to update account data");
// If the user registered with an email, associate it with their account.
if let Some(email) = email {
// This may fail if the email is already in use, but we should have already
// checked that when we sent the validation email, so ignoring the error is
// acceptable here in the rare case that an email is sniped by another user
// between the validation email being sent and the account being created.
let _ = self
.services
.threepid
.associate_localpart_email(user_id.localpart(), &email)
.await;
}
// Attempt to empower the first user and disable first-run mode.
let was_first_user = self.services.firstrun.empower_first_user(user_id).await;
// If the registering user was not the first and we're suspending users on
// register, suspend them.
if !was_first_user && self.services.config.suspend_on_register {
// Note that we can still do auto joins for suspended users
self.suspend_account(user_id, &self.services.globals.server_user)
.await;
// And send an @room notice to the admin room, to prompt admins to review the
// new user and ideally unsuspend them if deemed appropriate.
if self.services.config.admin_room_notices {
self.services
.admin
.send_loud_message(RoomMessageEventContent::text_plain(format!(
"User {user_id} has been suspended as they are not the first user on \
this server. Please review and unsuspend them if appropriate."
)))
.await
.ok();
}
}
// Autojoin the user to the configured autojoin rooms
for room in &self.services.config.auto_join_rooms {
let Ok(room_id) = self.services.alias.resolve(room).await else {
error!(
"Failed to resolve room alias to room ID when attempting to auto join \
{room}, skipping"
);
continue;
};
if !self
.services
.state_cache
.server_in_room(self.services.globals.server_name(), &room_id)
.await
{
warn!(
"Skipping room {room} to automatically join as we have never joined before."
);
continue;
}
if let Some(room_server_name) = room.server_name() {
match self
.services
.membership
.join_room(
user_id,
&room_id,
Some("Automatically joining this room upon registration".to_owned()),
&[
self.services.globals.server_name().to_owned(),
room_server_name.to_owned(),
],
)
.boxed()
.await
{
| Err(e) => {
// don't return this error so we don't fail registrations
error!(
"Failed to automatically join room {room} for user {user_id}: {e}"
);
},
| _ => {
info!("Automatically joined room {room} for user {user_id}");
},
}
}
}
info!("Created new user account for {user_id}");
}
pub async fn determine_registration_user_id(
&self,
supplied_username: Option<String>,
email: Option<&Address>,
appservice_info: Option<&RegistrationInfo>,
) -> Result<OwnedUserId> {
const RANDOM_USER_ID_LENGTH: usize = 10;
let emergency_mode_enabled = self.services.config.emergency_password.is_some();
let supplied_username = supplied_username.or_else(|| {
// If the user didn't supply a username but did supply an email, use
// the email's user part to avoid falling back to a random username
email.map(|address| address.user().to_owned())
});
if let Some(supplied_username) = supplied_username {
// The user gets to pick their username. Do some validation to make sure it's
// acceptable.
// Don't allow registration with forbidden usernames.
if self
.services
.globals
.forbidden_usernames()
.is_match(&supplied_username)
&& !emergency_mode_enabled
{
return Err!(Request(Forbidden("Username is forbidden")));
}
// Create and validate the user ID
let user_id = match UserId::parse_with_server_name(
&supplied_username,
self.services.globals.server_name(),
) {
| Ok(user_id) => {
if let Err(e) = user_id.validate_strict() {
// Unless we are in emergency mode, we should follow synapse's behaviour
// on not allowing things like spaces and UTF-8 characters in
// usernames
if !emergency_mode_enabled {
return Err!(Request(InvalidUsername(debug_warn!(
"Username {supplied_username} contains disallowed characters or \
spaces: {e}"
))));
}
}
// Don't allow registration with user IDs that aren't local
if !self.services.globals.user_is_local(&user_id) {
return Err!(Request(InvalidUsername(
"Username {supplied_username} is not local to this server"
)));
}
user_id
},
| Err(e) => {
return Err!(Request(InvalidUsername(debug_warn!(
"Username {supplied_username} is not valid: {e}"
))));
},
};
if self.exists(&user_id).await {
return Err!(Request(UserInUse("User ID is not available.")));
}
// Check that the user ID is/is not in an appservice's namespace
if let Some(appservice_info) = appservice_info {
if !appservice_info.is_user_match(&user_id) && !emergency_mode_enabled {
return Err!(Request(Exclusive(
"Username is not in this appservice's namespace."
)));
}
} else if self
.services
.appservice
.is_exclusive_user_id(&user_id)
.await && !emergency_mode_enabled
{
return Err!(Request(Exclusive("Username is reserved by an appservice.")));
}
Ok(user_id)
} else {
// The user didn't specify a username. Generate a username for
// them.
loop {
let user_id = UserId::parse_with_server_name(
utils::random_string(RANDOM_USER_ID_LENGTH).to_lowercase(),
self.services.globals.server_name(),
)
.unwrap();
if !self.exists(&user_id).await {
break Ok(user_id);
}
}
}
}
/// Deactivate account
pub async fn deactivate_account(&self, user_id: &UserId) -> Result<()> {
@@ -337,8 +600,42 @@ impl Service {
pub async fn count(&self) -> usize { self.db.userid_password.count().await }
/// Find out which user an access token belongs to.
pub async fn find_from_token(&self, token: &str) -> Result<(OwnedUserId, OwnedDeviceId)> {
self.db.token_userdeviceid.get(token).await.deserialized()
pub async fn find_from_token(
&self,
token: &str,
) -> Option<(OwnedUserId, OwnedDeviceId, AccessTokenStatus)> {
let user = self
.db
.token_userdeviceid
.get(token)
.await
.deserialized()
.ok();
// Check if the token has expired
if let Some((user_id, device_id)) = user {
if let Some(expires) = self
.db
.userdeviceid_tokenexpires
.qry(&(&user_id, &device_id))
.await
.deserialized::<u64>()
.ok()
.map(Duration::from_secs)
{
let expires_at = SystemTime::UNIX_EPOCH
.checked_add(expires)
.expect("expiry time should not overflow SystemTime");
if SystemTime::now() > expires_at {
return Some((user_id, device_id, AccessTokenStatus::Expired));
}
}
Some((user_id, device_id, AccessTokenStatus::Valid))
} else {
None
}
}
/// Returns an iterator over all users on this homeserver.
@@ -434,6 +731,7 @@ impl Service {
user_id: &UserId,
device_id: &DeviceId,
token: &str,
token_max_age: Option<Duration>,
initial_device_display_name: Option<String>,
client_ip: Option<String>,
) -> Result<()> {
@@ -451,7 +749,8 @@ impl Service {
increment(&self.db.userid_devicelistversion, user_id.as_bytes());
self.db.userdeviceid_metadata.put(key, Json(device));
self.set_token(user_id, device_id, token).await
self.set_token(user_id, device_id, token, token_max_age)
.await
}
/// Removes a device from a user.
@@ -467,6 +766,7 @@ impl Service {
if let Ok(old_token) = self.db.userdeviceid_token.qry(&userdeviceid).await {
self.db.userdeviceid_token.del(userdeviceid);
self.db.token_userdeviceid.remove(&old_token);
self.db.userdeviceid_tokenexpires.del(userdeviceid);
}
// Remove todevice events
@@ -480,6 +780,9 @@ impl Service {
// TODO: Remove onetimekeys
// Remove OAuth session information
self.services.oauth.remove_session(user_id, device_id).await;
increment(&self.db.userid_devicelistversion, user_id.as_bytes());
self.db.userdeviceid_metadata.del(userdeviceid);
@@ -535,6 +838,7 @@ impl Service {
user_id: &UserId,
device_id: &DeviceId,
token: &str,
token_max_age: Option<Duration>,
) -> Result<()> {
let key = (user_id, device_id);
if self.db.userdeviceid_metadata.qry(&key).await.is_err() {
@@ -561,6 +865,7 @@ impl Service {
// Remove old token
if let Ok(old_token) = self.db.userdeviceid_token.qry(&key).await {
self.db.token_userdeviceid.remove(&old_token);
self.db.userdeviceid_tokenexpires.remove(&old_token);
// It will be removed from userdeviceid_token by the insert later
}
@@ -568,6 +873,18 @@ impl Service {
self.db.userdeviceid_token.put_raw(key, token);
self.db.token_userdeviceid.raw_put(token, key);
if let Some(max_age) = token_max_age {
let expires = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.expect("system time should not be before the epoch")
.saturating_add(max_age)
.as_secs();
self.db.userdeviceid_tokenexpires.put(key, expires);
} else {
self.db.userdeviceid_tokenexpires.del(key);
}
Ok(())
}
@@ -1238,7 +1555,7 @@ impl Service {
pub fn create_openid_token(&self, user_id: &UserId, token: &str) -> Result<u64> {
use std::num::Saturating as Sat;
let expires_in = self.services.server.config.openid_token_ttl;
let expires_in = self.services.config.openid_token_ttl;
let expires_at = Sat(utils::millis_since_unix_epoch()) + Sat(expires_in) * Sat(1000);
let mut value = expires_at.0.to_be_bytes().to_vec();
@@ -1282,7 +1599,7 @@ impl Service {
pub fn create_login_token(&self, user_id: &UserId, token: &str) -> u64 {
use std::num::Saturating as Sat;
let expires_in = self.services.server.config.login_token_ttl;
let expires_in = self.services.config.login_token_ttl;
let expires_at = Sat(utils::millis_since_unix_epoch()) + Sat(expires_in);
let value = (expires_at.0, user_id);
+11
View File
@@ -22,6 +22,8 @@ crate-type = [
conduwuit-build-metadata.workspace = true
conduwuit-service.workspace = true
conduwuit-core.workspace = true
conduwuit-database.workspace = true
conduwuit-api.workspace = true
async-trait.workspace = true
askama.workspace = true
axum.workspace = true
@@ -35,9 +37,18 @@ ruma.workspace = true
thiserror.workspace = true
tower-http.workspace = true
serde.workspace = true
serde_json.workspace = true
lettre.workspace = true
memory-serve = "2.1.0"
validator = { version = "0.20.0", features = ["derive"] }
tower-sec-fetch = { version = "0.1.2", features = ["tracing"] }
tower-sessions = { version = "0.15.0", default-features = false, features = ["axum-core"] }
tower-sessions-core = { version = "0.15.0", features = ["deletion-task"] }
serde_urlencoded.workspace = true
url.workspace = true
recaptcha-verify = { version = "0.2.0", default-features = false }
reqwest_recaptcha = { package = "reqwest", version = "0.12.28", default-features = false, features = ["rustls-tls-native-roots-no-provider"] } # As long as recaptcha-verify's reqwest is outdated
form_urlencoded = "1.2.2"
[build-dependencies]
memory-serve = "2.1.0"
+48
View File
@@ -0,0 +1,48 @@
use axum::{
extract::{FromRequest, FromRequestParts, Request},
http::{Method, request::Parts},
};
use serde::de::DeserializeOwned;
use crate::WebError;
/// An extractor which deserializes a struct from a POST request's body.
/// For GET requests the struct will be None.
#[derive(Debug, Clone, Copy, Default)]
#[must_use]
pub(crate) struct PostForm<T>(pub Option<T>);
impl<T, S> FromRequest<S> for PostForm<T>
where
T: DeserializeOwned,
S: Send + Sync,
{
type Rejection = WebError;
async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
if req.method() == Method::POST {
let axum::Form(data) = axum::Form::from_request(req, state).await?;
Ok(Self(Some(data)))
} else {
Ok(Self(None))
}
}
}
/// An extractor which wraps another extractor and converts its errors into
/// `WebError`s.
pub(crate) struct Expect<E>(pub E);
impl<E, S, R> FromRequestParts<S> for Expect<E>
where
E: FromRequestParts<S, Rejection = R>,
WebError: From<R>,
S: Send + Sync,
{
type Rejection = WebError;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
Ok(Self(E::from_request_parts(parts, state).await?))
}
}
+61 -15
View File
@@ -1,25 +1,34 @@
use std::any::Any;
use std::{any::Any, sync::Once, time::Duration};
use askama::Template;
use axum::{
Router,
extract::rejection::{FormRejection, QueryRejection},
http::{HeaderValue, StatusCode, header},
response::{Html, IntoResponse, Response},
extract::rejection::{FormRejection, PathRejection, QueryRejection},
http::StatusCode,
middleware::from_fn_with_state,
response::{Html, IntoResponse, Redirect, Response},
};
use conduwuit_service::state;
use tower_http::{catch_panic::CatchPanicLayer, set_header::SetResponseHeaderLayer};
use conduwuit_service::{Services, state};
use tower_http::catch_panic::CatchPanicLayer;
use tower_sec_fetch::SecFetchLayer;
use tower_sessions::{ExpiredDeletion, SessionManagerLayer, cookie::SameSite};
use crate::pages::TemplateContext;
use crate::{
pages::TemplateContext,
session::{LoginQuery, store::RocksDbSessionStore},
};
mod extract;
mod pages;
mod session;
type State = state::State;
const CATASTROPHIC_FAILURE: &str = "cat-astrophic failure! we couldn't even render the error template. \
please contact the team @ https://continuwuity.org";
const ROUTE_PREFIX: &str = conduwuit_core::ROUTE_PREFIX;
#[derive(Debug, thiserror::Error)]
enum WebError {
#[error("Failed to validate form body: {0}")]
@@ -29,10 +38,16 @@ enum WebError {
#[error("{0}")]
FormRejection(#[from] FormRejection),
#[error("{0}")]
PathRejection(#[from] PathRejection),
#[error("{0}")]
BadRequest(String),
#[error("This page does not exist.")]
NotFound,
#[error("You are not allowed to request this page: {0}")]
Forbidden(String),
#[error("You must log in to access this page")]
LoginRequired(LoginQuery),
#[error("Failed to render template: {0}")]
Render(#[from] askama::Error),
@@ -52,12 +67,26 @@ impl IntoResponse for WebError {
context: TemplateContext,
}
if let Self::LoginRequired(query) = self {
return Redirect::to(&format!(
"{}/account/login?{}",
ROUTE_PREFIX,
serde_urlencoded::to_string(query).unwrap()
))
.into_response();
}
let status = match &self {
| Self::ValidationError(_)
| Self::BadRequest(_)
| Self::QueryRejection(_)
| Self::FormRejection(_) => StatusCode::BAD_REQUEST,
| Self::FormRejection(_)
| Self::InternalError(_) => StatusCode::BAD_REQUEST,
| Self::NotFound => StatusCode::NOT_FOUND,
| Self::Forbidden(_) => StatusCode::FORBIDDEN,
| Self::LoginRequired(_) => {
unreachable!("LoginRequired is handled earlier")
},
| _ => StatusCode::INTERNAL_SERVER_ERROR,
};
@@ -67,6 +96,7 @@ impl IntoResponse for WebError {
context: TemplateContext {
// Statically set false to prevent error pages from being indexed.
allow_indexing: false,
csp_nonce: String::new(),
},
};
@@ -78,21 +108,40 @@ impl IntoResponse for WebError {
}
}
pub fn build() -> Router<state::State> {
static STORE_CLEANUP_TASK: Once = Once::new();
pub fn build(services: &Services) -> Router<state::State> {
#[allow(clippy::wildcard_imports)]
use pages::*;
let store = RocksDbSessionStore::new(&services.db);
STORE_CLEANUP_TASK.call_once(|| {
services.server.runtime().spawn(
store
.clone()
.continuously_delete_expired(Duration::from_hours(1)),
);
});
Router::new()
.merge(index::build())
.nest(
"/_continuwuity/",
Router::new()
.merge(resources::build())
.merge(password_reset::build())
.nest("/about", about::build())
.nest("/account/", account::build())
.merge(debug::build())
.nest("/oauth2/", oauth::build())
.merge(resources::build())
.merge(threepid::build())
.fallback(async || WebError::NotFound),
)
.layer(
SessionManagerLayer::new(store)
.with_name("_c10y_session")
.with_same_site(SameSite::Lax),
)
.layer(CatchPanicLayer::custom(|panic: Box<dyn Any + Send + 'static>| {
let details = if let Some(s) = panic.downcast_ref::<String>() {
s.clone()
@@ -104,10 +153,7 @@ pub fn build() -> Router<state::State> {
WebError::Panic(details).into_response()
}))
.layer(SetResponseHeaderLayer::if_not_present(
header::CONTENT_SECURITY_POLICY,
HeaderValue::from_static("default-src 'self'; img-src 'self' data:;"),
))
.layer(from_fn_with_state(services.config.clone(), template_context_middleware))
.layer(SecFetchLayer::new(|policy| {
policy.allow_safe_methods().reject_missing_metadata();
}))
+38
View File
@@ -0,0 +1,38 @@
use std::collections::BTreeMap;
use axum::{Extension, Router, extract::State, routing::get};
use conduwuit_core::config::TermsDocument;
use ruma::{
OwnedServerName,
api::client::discovery::discover_support::{Contact, ContactRole},
};
use url::Url;
use crate::{
pages::{Result, TemplateContext},
response, template,
};
pub(crate) fn build() -> Router<crate::State> { Router::new().route("/", get(get_about)) }
template! {
struct About use "about.html.j2" {
server_name: OwnedServerName,
support_page: Option<Url>,
contacts: Vec<Contact>,
terms: BTreeMap<String, TermsDocument>
}
}
async fn get_about(
State(services): State<crate::State>,
Extension(context): Extension<TemplateContext>,
) -> Result {
response!(About::new(
context,
services.globals.server_name().to_owned(),
services.config.well_known.support_page.clone(),
services.admin.get_support_contacts().await,
services.config.registration_terms.documents.clone()
))
}
@@ -0,0 +1,47 @@
use axum::{Extension, Router, extract::State, routing::on};
use conduwuit_service::oauth::OAuthTicket;
use crate::{
extract::PostForm,
pages::{GET_POST, Result, TemplateContext, components::UserCard},
response,
session::{LoginTarget, User},
template,
};
pub(crate) fn build() -> Router<crate::State> {
Router::new().route("/", on(GET_POST, route_cross_signing_reset))
}
template! {
struct CrossSigningReset use "cross_signing_reset.html.j2" {
user_card: UserCard,
body: CrossSigningResetBody
}
}
#[derive(Debug)]
enum CrossSigningResetBody {
Form,
Success,
}
async fn route_cross_signing_reset(
State(services): State<crate::State>,
Extension(context): Extension<TemplateContext>,
user: User,
PostForm(form): PostForm<()>,
) -> Result {
let user_id = user.expect_recent(LoginTarget::CrossSigningReset)?;
let user_card = UserCard::for_local_user(&services, user_id.clone()).await;
if form.is_some() {
services
.oauth
.issue_ticket(user_id.localpart().to_owned(), OAuthTicket::CrossSigningReset);
response!(CrossSigningReset::new(context, user_card, CrossSigningResetBody::Success))
} else {
response!(CrossSigningReset::new(context, user_card, CrossSigningResetBody::Form))
}
}
+129
View File
@@ -0,0 +1,129 @@
use axum::{Extension, Router, extract::State, routing::on};
use conduwuit_api::client::full_user_deactivate;
use futures::StreamExt;
use ruma::{OwnedRoomId, OwnedUserId, UserId};
use tower_sessions::Session;
use validator::{Validate, ValidationError, ValidationErrors};
use crate::{
extract::PostForm,
form,
pages::{
GET_POST, Result, TemplateContext,
components::{UserCard, form::Form},
},
response,
session::{LoginTarget, User},
template,
};
pub(crate) fn build() -> Router<crate::State> {
Router::new().route("/", on(GET_POST, route_deactivate))
}
template! {
struct Deactivate use "deactivate.html.j2" {
body: DeactivateBody
}
}
#[derive(Debug)]
#[allow(clippy::large_enum_variant)]
enum DeactivateBody {
Unavailable,
Form {
user_id: OwnedUserId,
user_card: UserCard,
form: Form<'static>,
},
Success,
}
form! {
struct DeactivateForm {
password: String where {
input_type: "password",
label: "Enter your password to confirm",
autocomplete: "current-password"
},
#[validate(required(message = "This checkbox must be checked"))]
confirm: Option<String> where {
input_type: "checkbox",
label: "I understand that deactivating my account cannot be undone."
}
submit: "Deactivate my account",
slowdown: true
}
}
async fn route_deactivate(
State(services): State<crate::State>,
Extension(context): Extension<TemplateContext>,
user: User,
session: Session,
PostForm(form): PostForm<DeactivateForm>,
) -> Result {
let user_id = user.expect_recent(LoginTarget::Deactivate)?;
let user_card = UserCard::for_local_user(&services, user_id.clone()).await;
let body = {
if !services.config.allow_deactivation {
DeactivateBody::Unavailable
} else if let Some(form) = form {
if let Err(err) = validate_deactivate_form(&services, &user_id, form).await {
DeactivateBody::Form {
user_id,
user_card,
form: DeactivateForm::with_errors(context.clone(), err),
}
} else {
let all_joined_rooms: Vec<OwnedRoomId> = services
.rooms
.state_cache
.rooms_joined(&user_id)
.collect()
.await;
full_user_deactivate(&services, &user_id, &all_joined_rooms).await?;
session.clear().await;
DeactivateBody::Success
}
} else {
DeactivateBody::Form {
user_id,
user_card,
form: DeactivateForm::build(context.clone()),
}
}
};
response!(Deactivate::new(context, body))
}
async fn validate_deactivate_form(
services: &crate::State,
user_id: &UserId,
form: DeactivateForm,
) -> Result<(), ValidationErrors> {
form.validate()?;
if services
.users
.check_password(user_id, &form.password)
.await
.is_err()
{
let mut errors = ValidationErrors::new();
errors.add(
"password",
ValidationError::new("wrong").with_message("Incorrect password".into()),
);
return Err(errors);
}
Ok(())
}
+126
View File
@@ -0,0 +1,126 @@
use axum::{
Extension, Router,
extract::{Path, State},
routing::{get, on},
};
use conduwuit_service::oauth::{SessionInfo, client_metadata::ClientMetadata};
use futures::StreamExt;
use ruma::OwnedDeviceId;
use serde::{Deserialize, Serialize};
use crate::{
WebError,
extract::{Expect, PostForm},
pages::{
GET_POST, Result, TemplateContext,
components::{ClientScopes, DeviceCard, DeviceCardStyle},
},
response,
session::{LoginTarget, User},
template,
};
pub(crate) fn build() -> Router<crate::State> {
Router::new()
.route("/{device}/", get(get_device_info))
.route("/{device}/remove", on(GET_POST, route_remove_device))
}
template! {
struct DeviceInfo use "device_info.html.j2" {
device_card: DeviceCard,
client_metadata: Option<(ClientMetadata, SessionInfo)>
}
}
async fn get_device_info(
State(services): State<crate::State>,
Extension(context): Extension<TemplateContext>,
user: User,
Expect(Path(query)): Expect<Path<DevicePath>>,
) -> Result {
let user_id = user.expect(LoginTarget::RemoveDevice(query.clone()))?;
let Ok(device) = services
.users
.get_device_metadata(&user_id, &query.device)
.await
else {
return response!(WebError::BadRequest("Unknown device".to_owned()));
};
let client_metadata = async {
let session_info = services
.oauth
.get_session_info_for_device(&user_id, &device.device_id)
.await?;
let client_metadata = services
.oauth
.get_client_metadata(&session_info.client_id)
.await?;
Some((client_metadata, session_info))
}
.await;
let device_card =
DeviceCard::for_device(&services, &user_id, device, DeviceCardStyle::Detailed).await;
response!(DeviceInfo::new(context, device_card, client_metadata))
}
template! {
struct RemoveDevice use "remove_device.html.j2" {
body: RemoveDeviceBody
}
}
#[derive(Debug)]
enum RemoveDeviceBody {
Form {
device_card: Box<DeviceCard>,
last_device: bool,
},
Success,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub(crate) struct DevicePath {
pub device: OwnedDeviceId,
}
async fn route_remove_device(
State(services): State<crate::State>,
Extension(context): Extension<TemplateContext>,
user: User,
Expect(Path(query)): Expect<Path<DevicePath>>,
PostForm(form): PostForm<()>,
) -> Result {
let user_id = user.expect_recent(LoginTarget::RemoveDevice(query.clone()))?;
let Ok(device) = services
.users
.get_device_metadata(&user_id, &query.device)
.await
else {
return response!(WebError::BadRequest("Unknown device".to_owned()));
};
if form.is_some() {
services
.users
.remove_device(&user_id, &device.device_id)
.await;
response!(RemoveDevice::new(context, RemoveDeviceBody::Success))
} else {
let device_card =
DeviceCard::for_device(&services, &user_id, device, DeviceCardStyle::Minimal).await;
let last_device = services.users.all_devices_metadata(&user_id).count().await <= 1;
response!(RemoveDevice::new(context, RemoveDeviceBody::Form {
device_card: Box::new(device_card),
last_device
}))
}
}
+210
View File
@@ -0,0 +1,210 @@
use axum::{
Extension, Router,
extract::{Query, State},
routing::{get, on, post},
};
use conduwuit_core::warn;
use conduwuit_service::{mailer::messages, threepid::session::ValidationSessions};
use lettre::{Address, message::Mailbox};
use ruma::{ClientSecret, OwnedClientSecret, OwnedSessionId};
use serde::{Deserialize, Serialize};
use crate::{
WebError,
extract::{Expect, PostForm},
form,
pages::{
GET_POST, Result, TemplateContext,
account::ThreepidQuery,
components::{UserCard, form::Form},
},
response,
session::{LoginTarget, User},
template,
};
pub(crate) fn build() -> Router<crate::State> {
Router::new()
.route("/change/", on(GET_POST, route_change_email))
.route("/change/validate", get(get_change_email_validate))
.route("/change/delete", post(post_delete_email))
}
template! {
struct ChangeEmail use "change_email.html.j2" {
user_card: UserCard,
email: Option<String>,
form: Form<'static>,
may_remove: bool
}
}
form! {
struct ChangeEmailForm {
email: Address where {
input_type: "email",
label: "Email address"
}
submit: "Change email"
}
}
template! {
struct ChangeEmailValidate use "change_email_validate.html.j2" {
user_card: UserCard,
body: ChangeEmailValidateBody
}
}
template! {
struct DeleteEmail use "delete_email.html.j2" {
user_card: UserCard
}
}
#[derive(Debug)]
enum ChangeEmailValidateBody {
ValidationPending {
session_id: OwnedSessionId,
client_secret: OwnedClientSecret,
validation_error: bool,
},
Success,
}
async fn route_change_email(
State(services): State<crate::State>,
Extension(context): Extension<TemplateContext>,
user: User,
PostForm(form): PostForm<ChangeEmailForm>,
) -> Result {
let user_id = user.expect_recent(LoginTarget::ChangeEmail)?;
let Some(form) = form else {
return response!(ChangeEmail::new(
context.clone(),
UserCard::for_local_user(&services, user_id.clone()).await,
services
.threepid
.get_email_for_localpart(user_id.localpart())
.await
.map(|address| address.to_string()),
ChangeEmailForm::build(context),
services.threepid.email_requirement().may_remove(),
));
};
let client_secret = ClientSecret::new();
let session_id = {
let display_name = services.users.displayname(&user_id).await.ok();
match services
.threepid
.send_validation_email(
Mailbox::new(display_name, form.email.clone()),
|verification_link| messages::ChangeEmail {
server_name: services.globals.server_name().as_str(),
user_id: Some(&user_id),
verification_link,
},
&client_secret,
0,
)
.await
{
| Ok(session_id) => session_id,
| Err(err) => {
// If we couldn't send an email, generate a random session ID to not give that
// away
warn!(
"Failed to send email change message for {user_id} to {}: {err}",
form.email
);
ValidationSessions::generate_session_id()
},
}
};
response!(ChangeEmailValidate::new(
context,
UserCard::for_local_user(&services, user_id).await,
ChangeEmailValidateBody::ValidationPending {
session_id,
client_secret,
validation_error: false
}
))
}
#[derive(Deserialize, Serialize)]
struct ChangeEmailQuery {
#[serde(flatten)]
threepid: ThreepidQuery,
}
async fn get_change_email_validate(
State(services): State<crate::State>,
Extension(context): Extension<TemplateContext>,
Expect(Query(ChangeEmailQuery {
threepid: ThreepidQuery { client_secret, session_id },
})): Expect<Query<ChangeEmailQuery>>,
user: User,
) -> Result {
let user_id = user.expect(LoginTarget::ChangeEmail)?;
let user_card = UserCard::for_local_user(&services, user_id.clone()).await;
if !services.threepid.email_requirement().may_change() {
return Err(WebError::Forbidden("You may not change your email address.".to_owned()));
}
let Ok(session) = services
.threepid
.get_valid_session(&session_id, &client_secret)
.await
else {
return response!(ChangeEmailValidate::new(
context,
user_card,
ChangeEmailValidateBody::ValidationPending {
session_id,
client_secret,
validation_error: true
}
));
};
let new_email = session.consume();
if let Err(err) = services
.threepid
.associate_localpart_email(user_id.localpart(), &new_email)
.await
{
return response!(BadRequest(err.message()));
}
response!(ChangeEmailValidate::new(context, user_card, ChangeEmailValidateBody::Success))
}
async fn post_delete_email(
State(services): State<crate::State>,
Extension(context): Extension<TemplateContext>,
user: User,
) -> Result {
let user_id = user.expect(LoginTarget::ChangeEmail)?;
let user_card = UserCard::for_local_user(&services, user_id.clone()).await;
if !services.threepid.email_requirement().may_remove() {
return Err(WebError::Forbidden("You may not remove your email address.".to_owned()));
}
let _ = services
.threepid
.disassociate_localpart_email(user_id.localpart())
.await;
response!(DeleteEmail::new(context, user_card))
}
+155
View File
@@ -0,0 +1,155 @@
use std::time::SystemTime;
use axum::{
Extension, Router,
extract::{Query, RawQuery, State},
response::{IntoResponse, Redirect},
routing::{get, on},
};
use conduwuit_api::client::handle_login;
use ruma::{
OwnedUserId,
api::client::uiaa::{EmailUserIdentifier, MatrixUserIdentifier, UserIdentifier},
};
use serde::Deserialize;
use tower_sessions::Session;
use crate::{
ROUTE_PREFIX, WebError,
extract::{Expect, PostForm},
pages::{
GET_POST, Result, TemplateContext,
account::register::{TrustedFlowStatus, UntrustedFlowStatus, registration_flow_status},
components::UserCard,
},
response,
session::{LoginQuery, LoginTarget, User, UserSession},
template,
};
pub(crate) fn build() -> Router<crate::State> {
Router::new()
.route("/login", on(GET_POST, route_login))
.route("/logout", get(get_logout))
}
template! {
struct Login use "login.html.j2" {
body: LoginBody,
login_error: Option<String>
}
}
#[derive(Debug)]
enum LoginBody {
Unauthenticated {
server_name: String,
registration_available: bool,
next: Option<LoginTarget>,
},
Authenticated {
user_card: UserCard,
},
}
#[derive(Deserialize)]
struct LoginForm {
identifier: Option<String>,
password: String,
}
async fn route_login(
State(services): State<crate::State>,
Extension(context): Extension<TemplateContext>,
Expect(Query(LoginQuery { next, reauthenticate })): Expect<Query<LoginQuery>>,
session_store: Session,
user: User<true>,
PostForm(form): PostForm<LoginForm>,
) -> Result {
let user_id = user.into_session().map(|session| session.user_id);
let body = match &user_id {
| None => {
let (trusted_flow_status, untrusted_flow_status) =
registration_flow_status(&services).await;
let registration_available =
matches!(trusted_flow_status, TrustedFlowStatus::Available)
|| matches!(untrusted_flow_status, UntrustedFlowStatus::Available { .. });
LoginBody::Unauthenticated {
server_name: services.globals.server_name().to_string(),
registration_available,
next: next.clone(),
}
},
| Some(user_id) => {
if !reauthenticate {
return response!(Redirect::to(&next.unwrap_or_default().target_path()));
}
let user_card = UserCard::for_local_user(&services, user_id.to_owned()).await;
LoginBody::Authenticated { user_card }
},
};
let mut template = Login::new(context, body, None);
if let Some(form) = form {
let login_result = match (user_id, form.identifier) {
| (Some(user_id), _) => {
// The user is already authenticated, we need to check their password
services
.users
.check_password(&user_id, &form.password)
.await
},
| (None, Some(identifier)) => {
// The user isn't authenticated, we need to log them in
let identifier = if identifier.parse::<lettre::Address>().is_ok() {
UserIdentifier::Email(EmailUserIdentifier::new(identifier))
} else {
UserIdentifier::Matrix(MatrixUserIdentifier::new(identifier))
};
handle_login(&services, Some(&identifier), &form.password, None).await
},
| (None, None) => {
// The user isn't authenticated and didn't supply an identity
return response!(WebError::BadRequest("No identity provided".to_owned()));
},
};
let user_id = match login_result {
| Ok(user_id) => user_id,
| Err(err) => {
let error_message = if let conduwuit_core::Error::Request(_, message, _) = err {
message.into_owned()
} else {
"Internal login error".to_owned()
};
template.login_error = Some(error_message);
return response!(template);
},
};
let user_session = UserSession { user_id, last_login: SystemTime::now() };
session_store
.insert(User::KEY, user_session)
.await
.expect("should be able to serialize user session");
return response!(Redirect::to(&next.unwrap_or_default().target_path()));
}
response!(template)
}
async fn get_logout(session: Session, RawQuery(query): RawQuery) -> impl IntoResponse {
let _ = session.remove::<OwnedUserId>(User::KEY).await;
Redirect::to(&format!("{}/account/login?{}", ROUTE_PREFIX, query.unwrap_or_default()))
}
+173
View File
@@ -0,0 +1,173 @@
use axum::{
Extension, Router,
extract::{Query, State},
response::Redirect,
routing::get,
};
use conduwuit_core::utils::{IterStream, ReadyExt, stream::TryExpect};
use conduwuit_service::threepid::EmailRequirement;
use futures::StreamExt;
use ruma::{
OwnedClientSecret, OwnedDeviceId, OwnedSessionId,
api::client::discovery::get_authorization_server_metadata::v1::AccountManagementAction,
};
use serde::{Deserialize, Serialize};
use crate::{
WebError,
extract::Expect,
pages::{
Result, TemplateContext,
components::{DeviceCard, DeviceCardStyle, UserCard},
},
response,
session::{LoginTarget, User},
template,
};
pub(crate) mod cross_signing_reset;
pub(crate) mod deactivate;
pub(crate) mod device;
pub(crate) mod email;
pub(crate) mod login;
pub(crate) mod password;
pub(crate) mod register;
pub(crate) fn build() -> Router<crate::State> {
#[allow(clippy::wildcard_imports)]
use self::*;
Router::new()
.route("/", get(get_account))
.route("/deeplink", get(get_account_deeplink))
.merge(login::build())
.nest("/password/", password::build())
.nest("/email/", email::build())
.nest("/cross_signing_reset", cross_signing_reset::build())
.nest("/deactivate", deactivate::build())
.nest("/device/", device::build())
.nest("/register/", register::build())
}
#[derive(Deserialize, Serialize)]
struct ThreepidQuery {
client_secret: OwnedClientSecret,
session_id: OwnedSessionId,
}
template! {
struct Account use "account.html.j2" {
user_card: UserCard,
body: AccountBody
}
}
#[derive(Debug)]
enum AccountBody {
Unlocked {
suspended: bool,
email_requirement: EmailRequirement,
email: Option<String>,
devices: Vec<DeviceCard>,
},
Locked,
}
async fn get_account(
State(services): State<crate::State>,
Extension(context): Extension<TemplateContext>,
user: User<true>,
) -> Result {
let user_id = user.expect(LoginTarget::Account)?;
let user_card = UserCard::for_local_user(&services, user_id.clone()).await;
if services.users.is_locked(&user_id).await.unwrap() {
return response!(Account::new(context, user_card, AccountBody::Locked));
}
let email_requirement = services.threepid.email_requirement();
let email = services
.threepid
.get_email_for_localpart(user_id.localpart())
.await
.map(|address| address.to_string());
let dehydrated_device_id = services.users.get_dehydrated_device_id(&user_id).await.ok();
let mut devices: Vec<_> = services
.users
.all_device_ids(&user_id)
.then(async |device_id| {
services
.users
.get_device_metadata(&user_id, &device_id)
.await
})
.expect_ok()
.ready_filter(|device| {
dehydrated_device_id
.as_ref()
.is_none_or(|id| device.device_id != *id)
})
.collect()
.await;
devices.sort_unstable_by(|a, b| a.last_seen_ts.cmp(&b.last_seen_ts).reverse());
let device_cards = devices
.into_iter()
.stream()
.then(async |device| {
DeviceCard::for_device(&services, &user_id, device, DeviceCardStyle::Minimal).await
})
.collect()
.await;
let suspended = services.users.is_suspended(&user_id).await.unwrap();
response!(Account::new(context, user_card, AccountBody::Unlocked {
suspended,
email_requirement,
email,
devices: device_cards
}))
}
#[derive(Deserialize)]
struct AccountDeeplinkQuery {
action: Option<AccountManagementAction>,
device_id: Option<OwnedDeviceId>,
}
async fn get_account_deeplink(
Expect(Query(query)): Expect<Query<AccountDeeplinkQuery>>,
) -> Result {
let redirect_target = match query.action.unwrap_or(AccountManagementAction::Profile) {
| AccountManagementAction::AccountDeactivate => "deactivate".to_owned(),
| AccountManagementAction::CrossSigningReset => "cross_signing_reset".to_owned(),
| AccountManagementAction::DeviceDelete => {
let Some(device_id) = query.device_id else {
return response!(WebError::BadRequest(
"A device ID is required for this action".to_owned()
));
};
format!("device/{device_id}/delete")
},
| AccountManagementAction::DeviceView => {
let Some(device_id) = query.device_id else {
return response!(WebError::BadRequest(
"A device ID is required for this action".to_owned()
));
};
format!("device/{device_id}/")
},
| AccountManagementAction::DevicesList => "#devices".to_owned(),
| AccountManagementAction::Profile => String::new(),
| _ => return response!(WebError::BadRequest("Unknown action".to_owned())),
};
response!(Redirect::to(&format!("{}/account/{}", crate::ROUTE_PREFIX, redirect_target)))
}
+122
View File
@@ -0,0 +1,122 @@
use axum::{Extension, Router, extract::State, routing::on};
use conduwuit_service::users::HashedPassword;
use ruma::UserId;
use validator::{Validate, ValidationError, ValidationErrors};
use crate::{
extract::PostForm,
form,
pages::{
GET_POST, Result, TemplateContext,
components::{UserCard, form::Form},
},
response,
session::{LoginTarget, User},
template,
};
pub(crate) fn build() -> Router<crate::State> {
Router::new().route("/", on(GET_POST, route_change_password))
}
template! {
struct ChangePassword use "change_password.html.j2" {
user_card: UserCard,
body: ChangePasswordBody
}
}
#[derive(Debug)]
enum ChangePasswordBody {
Form(Form<'static>),
Success,
}
form! {
struct ChangePasswordForm {
#[validate(length(min = 1, message = "Current password cannot be empty"))]
current_password: String where {
input_type: "password",
label: "Current password",
autocomplete: "current-password"
},
#[validate(length(min = 1, message = "New password cannot be empty"))]
new_password: String where {
input_type: "password",
label: "New password",
autocomplete: "new-password"
},
#[validate(must_match(other = "new_password", message = "Passwords must match"))]
confirm_new_password: String where {
input_type: "password",
label: "Confirm new password",
autocomplete: "new-password"
}
submit: "Change password"
}
}
async fn route_change_password(
State(services): State<crate::State>,
Extension(context): Extension<TemplateContext>,
user: User,
PostForm(form): PostForm<ChangePasswordForm>,
) -> Result {
let user_id = user.expect(LoginTarget::ChangePassword)?;
let user_card = UserCard::for_local_user(&services, user_id.clone()).await;
let body = if let Some(form) = form {
match change_password(&services, &user_id, form).await {
| Ok(()) => ChangePasswordBody::Success,
| Err(errors) =>
ChangePasswordBody::Form(ChangePasswordForm::with_errors(context.clone(), errors)),
}
} else {
ChangePasswordBody::Form(ChangePasswordForm::build(context.clone()))
};
response!(ChangePassword::new(context, user_card, body))
}
async fn change_password(
services: &crate::State,
user_id: &UserId,
form: ChangePasswordForm,
) -> Result<(), ValidationErrors> {
form.validate()?;
if services
.users
.check_password(user_id, &form.current_password)
.await
.is_err()
{
let mut errors = ValidationErrors::new();
errors.add(
"current_password",
ValidationError::new("wrong").with_message("Incorrect password".into()),
);
return Err(errors);
}
match HashedPassword::new(&form.new_password) {
| Ok(hash) => {
services.users.set_password(user_id, Some(hash));
},
| Err(err) => {
let mut errors = ValidationErrors::new();
errors.add(
"new_password",
ValidationError::new("malformed").with_message(err.message().into()),
);
return Err(errors);
},
}
Ok(())
}
+13
View File
@@ -0,0 +1,13 @@
use axum::Router;
mod change;
mod reset;
pub(crate) fn build() -> Router<crate::State> {
#[allow(clippy::wildcard_imports)]
use self::*;
Router::new()
.nest("/change", change::build())
.nest("/reset/", reset::build())
}
+252
View File
@@ -0,0 +1,252 @@
use axum::{
Extension, Router,
extract::{Query, State},
routing::on,
};
use conduwuit_core::warn;
use conduwuit_service::{
mailer::messages, threepid::session::ValidationSessions, users::HashedPassword,
};
use lettre::{Address, message::Mailbox};
use ruma::{ClientSecret, OwnedClientSecret, OwnedSessionId, UserId};
use serde::{Deserialize, Serialize};
use validator::{Validate, ValidationError, ValidationErrors};
use crate::{
WebError,
extract::{Expect, PostForm},
form,
pages::{
GET_POST, Result, TemplateContext,
account::ThreepidQuery,
components::{UserCard, form::Form},
},
response,
session::require_active,
template,
};
pub(crate) fn build() -> Router<crate::State> {
Router::new()
.route("/", on(GET_POST, route_reset_password))
.route("/validate", on(GET_POST, route_reset_password_validate))
}
template! {
struct ResetPassword use "reset_password.html.j2" {
body: ResetPasswordBody
}
}
#[derive(Debug)]
enum ResetPasswordBody {
Form(Form<'static>),
Unavailable,
}
form! {
struct ResetPasswordRequestForm {
email: Address where {
input_type: "email",
label: "Email address"
}
submit: "Send email"
}
}
async fn route_reset_password(
State(services): State<crate::State>,
Extension(context): Extension<TemplateContext>,
PostForm(form): PostForm<ResetPasswordRequestForm>,
) -> Result {
// Check if SMTP is configured
if services.mailer.mailer().is_none() {
return response!(ResetPassword::new(context, ResetPasswordBody::Unavailable));
}
let Some(form) = form else {
// For GET requests return the reset request form
return response!(ResetPassword::new(
context.clone(),
ResetPasswordBody::Form(ResetPasswordRequestForm::build(context))
));
};
let client_secret = ClientSecret::new();
let session_id = async {
let Some(localpart) = services.threepid.get_localpart_for_email(&form.email).await else {
warn!("No user is associated with the email address {}", form.email);
return None;
};
let user_id =
UserId::parse(format!("@{localpart}:{}", services.globals.server_name())).unwrap();
let display_name = services.users.displayname(&user_id).await.ok();
match services
.threepid
.send_validation_email(
Mailbox::new(display_name.clone(), form.email.clone()),
|verification_link| messages::PasswordReset {
display_name: display_name.as_deref(),
user_id: &user_id,
verification_link,
},
&client_secret,
0,
)
.await
{
| Ok(session_id) => Some(session_id),
| Err(err) => {
warn!("Failed to send reset email for {localpart} to {}: {err}", form.email);
None
},
}
}
.await
.unwrap_or_else(|| {
// If we couldn't send an email, generate a random session ID to not give that
// away
ValidationSessions::generate_session_id()
});
response!(ResetPasswordValidate::new(
context,
ResetPasswordValidateBody::ValidationPending {
client_secret,
session_id,
validation_error: false
}
))
}
template! {
struct ResetPasswordValidate use "reset_password_validate.html.j2" {
body: ResetPasswordValidateBody
}
}
#[derive(Debug)]
enum ResetPasswordValidateBody {
ValidationPending {
session_id: OwnedSessionId,
client_secret: OwnedClientSecret,
validation_error: bool,
},
ValidationSuccess {
user_card: UserCard,
form: Form<'static>,
},
ResetSuccess {
user_card: UserCard,
},
}
form! {
struct ResetPasswordForm {
#[validate(length(min = 1, message = "Password cannot be empty"))]
new_password: String where {
input_type: "password",
label: "New password",
autocomplete: "new-password"
},
#[validate(must_match(other = "new_password", message = "Passwords must match"))]
confirm_new_password: String where {
input_type: "password",
label: "Confirm new password",
autocomplete: "new-password"
}
submit: "Reset password"
}
}
#[derive(Deserialize, Serialize)]
struct ResetPasswordQuery {
#[serde(flatten)]
threepid: ThreepidQuery,
}
async fn route_reset_password_validate(
State(services): State<crate::State>,
Extension(context): Extension<TemplateContext>,
Expect(Query(query)): Expect<Query<ResetPasswordQuery>>,
PostForm(form): PostForm<ResetPasswordForm>,
) -> Result {
let body = match services
.threepid
.get_valid_session(&query.threepid.session_id, &query.threepid.client_secret)
.await
{
| Ok(session) => {
let Some(localpart) = services
.threepid
.get_localpart_for_email(&session.email)
.await
else {
return Err(WebError::BadRequest("Inapplicable threepid session.".to_owned()));
};
let user_id =
UserId::parse(format!("@{localpart}:{}", services.globals.server_name()))
.unwrap();
if let Err(response) = require_active(&services, &user_id, true).await {
return Ok(response);
}
let user_card = UserCard::for_local_user(&services, user_id.clone()).await;
if let Some(form) = form {
if let Err(err) = form.validate() {
ResetPasswordValidateBody::ValidationSuccess {
user_card,
form: ResetPasswordForm::with_errors(context.clone(), err),
}
} else {
match HashedPassword::new(&form.new_password) {
| Ok(hash) => {
let _ = session.consume();
services.users.set_password(&user_id, Some(hash));
ResetPasswordValidateBody::ResetSuccess { user_card }
},
| Err(err) => {
let mut errors = ValidationErrors::new();
errors.add(
"new_password",
ValidationError::new("malformed")
.with_message(err.message().into()),
);
ResetPasswordValidateBody::ValidationSuccess {
user_card,
form: ResetPasswordForm::with_errors(context.clone(), errors),
}
},
}
}
} else {
ResetPasswordValidateBody::ValidationSuccess {
user_card,
form: ResetPasswordForm::build(context.clone()),
}
}
},
| Err(_) => ResetPasswordValidateBody::ValidationPending {
session_id: query.threepid.session_id,
client_secret: query.threepid.client_secret,
validation_error: true,
},
};
response!(ResetPasswordValidate::new(context, body))
}
+582
View File
@@ -0,0 +1,582 @@
use std::{collections::BTreeMap, time::SystemTime};
use axum::{
Extension, Router,
extract::{Query, State},
response::{Redirect, Response},
routing::{get, on},
};
use conduwuit_core::{config::TermsDocument, warn};
use conduwuit_service::{
mailer::messages, registration_tokens::ValidToken, users::HashedPassword,
};
use futures::{FutureExt, StreamExt};
use lettre::{Address, message::Mailbox};
use ruma::{ClientSecret, OwnedClientSecret, OwnedServerName, OwnedSessionId, OwnedUserId};
use serde::{Deserialize, Serialize, de::IgnoredAny};
use tower_sessions::Session;
use validator::{Validate, ValidationError, ValidationErrors};
use crate::{
WebError,
extract::{Expect, PostForm},
pages::{GET_POST, Result, TemplateContext, account::ThreepidQuery},
response,
session::{LoginTarget, User, UserSession},
template,
};
const COMPLETED_REGISTRATION_KEY: &str = "completed_registration";
pub(crate) fn build() -> Router<crate::State> {
Router::new()
.route("/", on(GET_POST, route_register))
.route("/validate", get(get_register_email_validate))
}
template! {
struct Register use "register.html.j2" {
server_name: OwnedServerName,
is_first_run: bool,
body: RegisterBody
}
}
#[derive(Debug)]
enum RegisterBody {
Unavailable,
UsernamePrompt {
allow_federation: bool,
trusted_flow_status: TrustedFlowStatus,
untrusted_flow_status: UntrustedFlowStatus,
username_error: Option<String>,
next: Option<LoginTarget>,
},
DetailsPrompt {
username: Option<String>,
require_email: bool,
flow: RegistrationFlowParameters,
terms: BTreeMap<String, TermsDocument>,
validation_errors: ValidationErrors,
},
}
#[derive(Debug)]
pub(super) enum TrustedFlowStatus {
Unavailable,
Available,
}
#[derive(Debug)]
pub(super) enum UntrustedFlowStatus {
Unavailable,
Available {
require_email: bool,
},
}
#[derive(Default, Deserialize, Serialize)]
pub(crate) struct RegisterQuery {
pub username: Option<String>,
pub token: Option<String>,
pub flow: Option<RequestedRegistrationFlow>,
#[serde(default)]
pub from_landing: bool,
#[serde(flatten)]
pub next: Option<LoginTarget>,
}
#[derive(Clone, Copy, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub(crate) enum RequestedRegistrationFlow {
Untrusted,
Trusted,
}
#[derive(Debug)]
enum RegistrationFlowParameters {
Untrusted {
recaptcha_sitekey: Option<String>,
},
Trusted {
registration_token: Option<String>,
},
}
#[derive(Deserialize, Validate)]
struct RegistrationForm {
flow: RequestedRegistrationFlow,
username: String,
email: Option<Address>,
#[validate(length(min = 1, message = "Password cannot be empty"))]
password: String,
#[validate(must_match(other = "password", message = "Passwords must match"))]
confirm_password: String,
registration_token: Option<String>,
#[serde(rename = "g-recaptcha-response")]
recaptcha_response: Option<String>,
}
#[derive(Deserialize, Serialize)]
struct CompletedRegistration {
user_id: OwnedUserId,
password_hash: HashedPassword,
registration_token: Option<ValidToken>,
next: Option<LoginTarget>,
}
async fn route_register(
State(services): State<crate::State>,
Extension(context): Extension<TemplateContext>,
session_store: Session,
Expect(Query(query)): Expect<Query<RegisterQuery>>,
PostForm(form): PostForm<RegistrationForm>,
) -> Result {
if session_store
.get::<IgnoredAny>(User::KEY)
.await
.unwrap()
.is_some()
{
// Redirect already logged-in users to the account panel
return response!(Redirect::to(&LoginTarget::Account.target_path()));
}
let validation_errors = if let Some(form) = form {
match form.validate() {
| Ok(()) => {
match begin_registration(
&services,
context.clone(),
session_store,
form,
query.next.clone(),
)
.boxed()
.await?
{
| Ok(response) => return Ok(response),
| Err(err) => err,
}
},
| Err(err) => err,
}
} else {
ValidationErrors::new()
};
let (trusted_flow_status, untrusted_flow_status) = registration_flow_status(&services).await;
if matches!(trusted_flow_status, TrustedFlowStatus::Unavailable)
&& matches!(untrusted_flow_status, UntrustedFlowStatus::Unavailable)
{
return response!(Register::new(
context,
services.globals.server_name().to_owned(),
services.firstrun.is_first_run(),
RegisterBody::Unavailable
));
}
if query.username.is_some() && query.flow.is_none() {
return response!(WebError::BadRequest(
"A flow must be provided if a username is provided".to_owned()
));
}
if let Some(username) = &query.username
&& query.from_landing
{
// Check if the username is valid and available before showing the details form
// to keep the user from wasting their time
if let Err(err) = services
.users
.determine_registration_user_id(Some(username.to_owned()), None, None)
.await
{
return response!(Register::new(
context,
services.globals.server_name().to_owned(),
services.firstrun.is_first_run(),
RegisterBody::UsernamePrompt {
allow_federation: services.config.allow_federation,
trusted_flow_status,
untrusted_flow_status,
username_error: Some(err.message()),
next: query.next,
}
));
}
}
let body = {
let terms = services.config.registration_terms.documents.clone();
match (query.flow, query.token) {
| (Some(RequestedRegistrationFlow::Trusted), token) | (_, token @ Some(_)) =>
RegisterBody::DetailsPrompt {
username: query.username,
require_email: services
.config
.smtp
.as_ref()
.is_some_and(|smtp| smtp.require_email_for_token_registration),
flow: RegistrationFlowParameters::Trusted { registration_token: token },
terms,
validation_errors,
},
| (Some(RequestedRegistrationFlow::Untrusted), _) => RegisterBody::DetailsPrompt {
username: query.username,
require_email: services
.config
.smtp
.as_ref()
.is_some_and(|smtp| smtp.require_email_for_registration),
flow: RegistrationFlowParameters::Untrusted {
recaptcha_sitekey: services.config.recaptcha_site_key.clone(),
},
terms,
validation_errors,
},
| (None, None) => RegisterBody::UsernamePrompt {
allow_federation: services.config.allow_federation,
trusted_flow_status,
untrusted_flow_status,
username_error: None,
next: query.next,
},
}
};
response!(Register::new(
context,
services.globals.server_name().to_owned(),
services.firstrun.is_first_run(),
body
))
}
template! {
struct RegisterEmailValidate use "register_email_validate.html.j2" {
session_id: OwnedSessionId,
client_secret: OwnedClientSecret,
validation_error: bool
}
}
#[derive(Deserialize, Serialize)]
struct RegisterEmailValidateQuery {
#[serde(flatten)]
threepid: ThreepidQuery,
}
async fn get_register_email_validate(
State(services): State<crate::State>,
Extension(context): Extension<TemplateContext>,
session_store: Session,
Expect(Query(RegisterEmailValidateQuery {
threepid: ThreepidQuery { client_secret, session_id },
})): Expect<Query<RegisterEmailValidateQuery>>,
) -> Result {
let Ok(session) = services
.threepid
.get_valid_session(&session_id, &client_secret)
.await
else {
return response!(RegisterEmailValidate::new(context, session_id, client_secret, true));
};
let Some(completed_registration) = session_store
.get::<CompletedRegistration>(COMPLETED_REGISTRATION_KEY)
.await
.expect("should be able to deserialize completed session")
else {
return response!(WebError::BadRequest(
"Inapplicable session. What are you doing here?".to_owned()
));
};
let email = session.consume();
response!(
complete_registration(&services, session_store, completed_registration, Some(email))
.await
)
}
async fn begin_registration(
services: &crate::State,
context: TemplateContext,
session_store: Session,
form: RegistrationForm,
next: Option<LoginTarget>,
) -> Result<Result<Response, ValidationErrors>> {
let open_registration = services
.config
.yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse;
let mut errors = ValidationErrors::new();
let user_id = match services
.users
.determine_registration_user_id(Some(form.username), form.email.as_ref(), None)
.await
{
| Ok(user_id) => user_id,
| Err(err) => {
errors.add(
"username",
ValidationError::new("invalid").with_message(err.message().into()),
);
return Ok(Err(errors));
},
};
let password_hash = match HashedPassword::new(&form.password) {
| Ok(password) => password,
| Err(err) => {
errors.add(
"password",
ValidationError::new("invalid").with_message(err.message().into()),
);
return Ok(Err(errors));
},
};
let mut registration_token = None;
// Check flow-specific form fields
match form.flow {
| RequestedRegistrationFlow::Trusted => {
// If the form claims to be using the trusted flow, it has to have a
// registration token
let Some(valid_token) = async {
services
.registration_tokens
.validate_token(form.registration_token?)
.await
}
.await
else {
errors.add(
"registration_token",
ValidationError::new("invalid")
.with_message("Invalid registration token".into()),
);
return Ok(Err(errors));
};
registration_token = Some(valid_token);
},
| RequestedRegistrationFlow::Untrusted => {
// Don't check auth for the untrusted flow at all if open reg is enabled
if !open_registration {
// If the form claims to be using the untrusted flow, it _may_ need to have a
// reCAPTCHA response if reCAPTCHA is configured
if let Some(recaptcha_private_site_key) =
&services.config.recaptcha_private_site_key
{
let Some(recaptcha_response) = form.recaptcha_response else {
return Err(WebError::BadRequest(
"reCAPTCHA response expected".to_owned(),
));
};
if recaptcha_verify::verify_v3(
recaptcha_private_site_key,
&recaptcha_response,
None,
)
.await
.is_err()
{
errors.add(
"recaptcha",
ValidationError::new("missing")
.with_message("Please complete the CAPTCHA".into()),
);
return Ok(Err(errors));
}
}
}
},
}
let completed_registration = CompletedRegistration {
user_id,
password_hash,
registration_token,
next,
};
// Check if we need to send an email
let require_email = services
.config
.smtp
.as_ref()
.is_some_and(|smtp| match form.flow {
| RequestedRegistrationFlow::Trusted => smtp.require_email_for_token_registration,
| RequestedRegistrationFlow::Untrusted =>
!open_registration && smtp.require_email_for_registration,
});
if require_email {
// If an email is required we have to validate it before we can complete
// registration
let Some(address) = form.email else {
errors.add(
"email",
ValidationError::new("missing")
.with_message("Please provide an email address".into()),
);
return Ok(Err(errors));
};
if services
.threepid
.get_localpart_for_email(&address)
.await
.is_some()
{
errors.add(
"email",
ValidationError::new("in_use")
.with_message("This email address is already in use.".into()),
);
return Ok(Err(errors));
}
let client_secret = ClientSecret::new();
let session_id = {
match services
.threepid
.send_validation_email(
Mailbox::new(None, address.clone()),
|verification_link| messages::NewAccount {
server_name: services.globals.server_name().as_str(),
verification_link,
},
&client_secret,
0,
)
.await
{
| Ok(session_id) => session_id,
| Err(err) => {
warn!(
"Failed to send new account message for {} to {}: {err}",
&completed_registration.user_id, address,
);
errors.add(
"email",
ValidationError::new("invalid").with_message(
"Failed to send validation email. Is this address correct?".into(),
),
);
return Ok(Err(errors));
},
}
};
session_store
.insert(COMPLETED_REGISTRATION_KEY, completed_registration)
.await
.expect("should have been able to serialize completed registration");
Ok(response!(
RegisterEmailValidate::new(context, session_id, client_secret, false,)
))
} else {
// If email isn't required we can immediately complete registration
Ok(response!(
complete_registration(services, session_store, completed_registration, None).await
))
}
}
async fn complete_registration(
services: &crate::State,
session_store: Session,
CompletedRegistration {
user_id,
password_hash,
registration_token,
next,
}: CompletedRegistration,
email: Option<Address>,
) -> Redirect {
services
.users
.create_local_account(&user_id, password_hash, email)
.await;
if let Some(registration_token) = registration_token {
services
.registration_tokens
.mark_token_as_used(registration_token);
}
let user_session = UserSession { user_id, last_login: SystemTime::now() };
session_store
.insert(User::KEY, user_session)
.await
.expect("should be able to serialize user session");
Redirect::to(&next.unwrap_or_default().target_path())
}
pub(super) async fn registration_flow_status(
services: &crate::State,
) -> (TrustedFlowStatus, UntrustedFlowStatus) {
// Allow registration if it's enabled in the config file or if this is the first
// run (so the first user account can be created)
let allow_registration =
services.config.allow_registration || services.firstrun.is_first_run();
// Trusted flow is only available if any registration tokens exist
let trusted_flow_status = {
if !allow_registration {
TrustedFlowStatus::Unavailable
} else if services
.registration_tokens
.iterate_tokens()
.next()
.await
.is_some()
{
TrustedFlowStatus::Available
} else {
TrustedFlowStatus::Unavailable
}
};
// Untrusted flow is available if email is required for registration,
// or reCAPTCHA is configured, or open registration is enabled
let untrusted_flow_status = {
let require_email = services
.config
.smtp
.as_ref()
.is_some_and(|smtp| smtp.require_email_for_registration);
if !allow_registration || services.firstrun.is_first_run() {
UntrustedFlowStatus::Unavailable
} else if services.config.recaptcha_private_site_key.is_some() || require_email {
UntrustedFlowStatus::Available { require_email }
} else if services
.config
.yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse
{
UntrustedFlowStatus::Available { require_email: false }
} else {
UntrustedFlowStatus::Unavailable
}
};
(trusted_flow_status, untrusted_flow_status)
}
+49 -8
View File
@@ -1,13 +1,34 @@
use askama::{Template, filters::HtmlSafe};
use validator::ValidationErrors;
use crate::pages::TemplateContext;
/// A reusable form component with field validation.
#[derive(Debug, Template)]
#[template(path = "_components/form.html.j2")]
pub(crate) struct Form<'a> {
pub inputs: Vec<FormInput<'a>>,
context: TemplateContext,
inputs: Vec<FormInput<'a>>,
submit_label: &'a str,
slowdown: bool,
pub validation_errors: Option<ValidationErrors>,
pub submit_label: &'a str,
}
impl<'a> Form<'a> {
pub(crate) fn new(
context: TemplateContext,
inputs: Vec<FormInput<'a>>,
submit_label: &'a str,
slowdown: bool,
) -> Self {
Self {
context,
inputs,
submit_label,
slowdown,
validation_errors: None,
}
}
}
impl HtmlSafe for Form<'_> {}
@@ -50,6 +71,16 @@ impl Default for FormInput<'_> {
}
}
#[macro_export]
macro_rules! default {
($value:expr) => {
$value
};
() => {
Default::default()
};
}
/// Generate a deserializable struct which may be turned into a [`Form`]
/// for inclusion in another template.
#[macro_export]
@@ -63,6 +94,7 @@ macro_rules! form {
),*
submit: $submit_label:expr
$(, slowdown: $slowdown:expr)?
}
) => {
#[derive(Debug, serde::Deserialize, validator::Validate)]
@@ -77,9 +109,10 @@ macro_rules! form {
impl $struct_name {
/// Generate a [`Form`] which matches the shape of this struct.
#[allow(clippy::needless_update)]
fn build(validation_errors: Option<validator::ValidationErrors>) -> $crate::pages::components::form::Form<'static> {
$crate::pages::components::form::Form {
inputs: vec![
fn build(context: TemplateContext) -> $crate::pages::components::form::Form<'static> {
$crate::pages::components::form::Form::new(
context,
vec![
$(
$crate::pages::components::form::FormInput {
id: stringify!($name),
@@ -89,9 +122,17 @@ macro_rules! form {
},
)*
],
validation_errors,
submit_label: $submit_label,
}
$submit_label,
$crate::default!($($slowdown)?)
)
}
/// Generate a [`Form`] with validation errors.
#[allow(unused)]
fn with_errors(context: TemplateContext, errors: validator::ValidationErrors) -> $crate::pages::components::form::Form<'static> {
let mut form = Self::build(context);
form.validation_errors = Some(errors);
form
}
}
};
+143 -29
View File
@@ -1,37 +1,33 @@
use std::{collections::BTreeSet, time::SystemTime};
use askama::{Template, filters::HtmlSafe};
use base64::Engine;
use conduwuit_core::result::FlatOk;
use conduwuit_service::{Services, media::mxc::Mxc};
use ruma::UserId;
use conduwuit_core::{result::FlatOk, utils};
use conduwuit_service::{
Services,
media::mxc::Mxc,
oauth::{client_metadata::ClientMetadata, grant::Scope},
};
use ruma::{OwnedDeviceId, OwnedUserId, UserId, api::client::device::Device};
pub(super) mod form;
#[derive(Debug)]
pub(super) enum AvatarType<'a> {
pub(super) enum AvatarType {
Initial(char),
Image(&'a str),
Image(String),
}
#[derive(Debug, Template)]
#[template(path = "_components/avatar.html.j2")]
pub(super) struct Avatar<'a> {
pub(super) avatar_type: AvatarType<'a>,
pub(super) struct Avatar {
pub(super) avatar_type: AvatarType,
}
impl HtmlSafe for Avatar<'_> {}
impl HtmlSafe for Avatar {}
#[derive(Debug, Template)]
#[template(path = "_components/user_card.html.j2")]
pub(super) struct UserCard<'a> {
pub user_id: &'a UserId,
pub display_name: Option<String>,
pub avatar_src: Option<String>,
}
impl HtmlSafe for UserCard<'_> {}
impl<'a> UserCard<'a> {
pub(super) async fn for_local_user(services: &Services, user_id: &'a UserId) -> Self {
impl Avatar {
pub(super) async fn for_local_user(services: &Services, user_id: &UserId) -> Self {
let display_name = services.users.displayname(user_id).await.ok();
let avatar_src = async {
@@ -54,22 +50,140 @@ impl<'a> UserCard<'a> {
}
.await;
Self { user_id, display_name, avatar_src }
}
fn avatar(&'a self) -> Avatar<'a> {
let avatar_type = if let Some(ref avatar_src) = self.avatar_src {
let avatar_type = if let Some(avatar_src) = avatar_src {
AvatarType::Image(avatar_src)
} else if let Some(initial) = self
.display_name
} else if let Some(initial) = display_name
.as_ref()
.and_then(|display_name| display_name.chars().next())
{
AvatarType::Initial(initial)
} else {
AvatarType::Initial(self.user_id.localpart().chars().next().unwrap())
AvatarType::Initial(user_id.localpart().chars().next().unwrap())
};
Avatar { avatar_type }
Self { avatar_type }
}
pub(super) fn for_device(
oauth_metadata: Option<&ClientMetadata>,
display_name: Option<&str>,
) -> Self {
let avatar_src = oauth_metadata
.and_then(|metadata| metadata.logo_uri.as_ref())
.map(|uri| uri.as_str().to_owned());
let avatar_type = if let Some(avatar_src) = avatar_src {
AvatarType::Image(avatar_src)
} else if let Some(initial) = display_name.and_then(|name| name.chars().next()) {
if oauth_metadata.is_some() {
AvatarType::Initial(initial)
} else {
AvatarType::Initial('❖')
}
} else {
AvatarType::Initial('?')
};
Self { avatar_type }
}
}
#[derive(Debug, Template)]
#[template(path = "_components/user_card.html.j2")]
pub(super) struct UserCard {
pub user_id: OwnedUserId,
pub display_name: Option<String>,
pub avatar: Avatar,
}
impl HtmlSafe for UserCard {}
impl UserCard {
pub(super) async fn for_local_user(services: &Services, user_id: OwnedUserId) -> Self {
let display_name = services.users.displayname(&user_id).await.ok();
let avatar = Avatar::for_local_user(services, &user_id).await;
Self { user_id, display_name, avatar }
}
}
#[derive(Debug, Template)]
#[template(path = "_components/device_card.html.j2")]
pub(super) struct DeviceCard {
pub device_id: OwnedDeviceId,
pub display_name: Option<String>,
pub avatar: Avatar,
pub last_active: String,
pub oauth_metadata: Option<ClientMetadata>,
pub style: DeviceCardStyle,
}
impl HtmlSafe for DeviceCard {}
#[derive(Debug, PartialEq, Eq)]
pub(super) enum DeviceCardStyle {
Minimal,
Detailed,
}
impl DeviceCard {
pub(super) async fn for_device(
services: &Services,
user_id: &UserId,
device: Device,
style: DeviceCardStyle,
) -> Self {
let oauth_metadata = async {
let session_info = services
.oauth
.get_session_info_for_device(user_id, &device.device_id)
.await?;
Some(
services
.oauth
.get_client_metadata(&session_info.client_id)
.await
.expect("client should exist"),
)
}
.await;
let display_name = oauth_metadata
.as_ref()
.and_then(|metadata| metadata.client_name.clone())
.or_else(|| device.display_name.clone());
let avatar = Avatar::for_device(oauth_metadata.as_ref(), display_name.as_deref());
let last_active = device.last_seen_ts.map_or_else(
|| "unknown".to_owned(),
|last_seen_ts| {
last_seen_ts
.to_system_time()
.and_then(|t| SystemTime::now().duration_since(t).ok())
.map_or_else(
|| "now".to_owned(),
|duration| format!("{} ago", utils::time::pretty(duration)),
)
},
);
Self {
device_id: device.device_id,
display_name,
avatar,
last_active,
oauth_metadata,
style,
}
}
}
#[derive(Debug, Template)]
#[template(path = "_components/client_scopes.html.j2")]
pub(super) struct ClientScopes {
pub scopes: BTreeSet<Scope>,
}
impl HtmlSafe for ClientScopes {}
+28 -14
View File
@@ -1,25 +1,39 @@
use axum::{Router, extract::State, response::IntoResponse, routing::get};
use axum::{Extension, Router, extract::State, routing::get};
use crate::{WebError, template};
use crate::{
pages::{Result, TemplateContext},
response, template,
};
pub(crate) fn build() -> Router<crate::State> {
Router::new()
.route("/", get(index))
.route("/_continuwuity/", get(index))
.route("/", get(get_index))
.route(&format!("{}/", crate::ROUTE_PREFIX), get(get_index))
.route(&format!("{}/_book", crate::ROUTE_PREFIX), get(get_book))
}
async fn index(State(services): State<crate::State>) -> Result<impl IntoResponse, WebError> {
template! {
struct Index<'a> use "index.html.j2" {
server_name: &'a str,
first_run: bool
}
template! {
struct Index<'a> use "index.html.j2" {
server_name: &'a str,
first_run: bool
}
}
Ok(Index::new(
&services,
async fn get_index(
State(services): State<crate::State>,
Extension(context): Extension<TemplateContext>,
) -> Result {
response!(Index::new(
context,
services.globals.server_name().as_str(),
services.firstrun.is_first_run(),
)
.into_response())
))
}
template! {
struct Book use "book.html.j2" {}
}
async fn get_book(Extension(context): Extension<TemplateContext>) -> Result {
response!(Book::new(context))
}
+76 -11
View File
@@ -1,21 +1,71 @@
use std::sync::Arc;
use axum::{
extract::{Request, State},
http::{HeaderValue, header},
middleware::Next,
response::Response,
routing::MethodFilter,
};
use conduwuit_core::utils;
use crate::WebError;
pub(super) mod about;
pub(super) mod account;
mod components;
pub(super) mod debug;
pub(super) mod index;
pub(super) mod password_reset;
pub(super) mod oauth;
pub(super) mod resources;
pub(super) mod threepid;
#[derive(Debug)]
type Result<T = Response, E = WebError> = std::result::Result<T, E>;
const GET_POST: MethodFilter = MethodFilter::GET.or(MethodFilter::POST);
#[derive(Debug, Clone)]
pub(crate) struct TemplateContext {
pub allow_indexing: bool,
pub csp_nonce: String,
}
impl From<&crate::State> for TemplateContext {
fn from(state: &crate::State) -> Self {
Self {
allow_indexing: state.config.allow_web_indexing,
}
}
const CSP_NONCE_LENGTH: usize = 32;
pub(super) async fn template_context_middleware(
State(config): State<Arc<conduwuit_service::config::Service>>,
mut request: Request,
next: Next,
) -> Response {
let csp_nonce = utils::random_string(CSP_NONCE_LENGTH);
let context = TemplateContext {
allow_indexing: config.allow_web_indexing,
csp_nonce: csp_nonce.clone(),
};
assert!(
request.extensions_mut().insert(context).is_none(),
"template context should only be inserted once"
);
let mut response = next.run(request).await;
let child_src = if config.recaptcha_site_key.is_some() {
"www.google.com"
} else {
"'none'"
};
response.headers_mut().insert(
header::CONTENT_SECURITY_POLICY,
HeaderValue::from_str(&format!(
"default-src 'none'; style-src 'self'; img-src 'self' https: data:; script-src \
'nonce-{csp_nonce}'; child-src {child_src};"
))
.expect("should be able to build CSP header"),
);
response
}
#[macro_export]
@@ -27,15 +77,17 @@ macro_rules! template {
) => {
#[derive(Debug, askama::Template)]
#[template(path = $path)]
#[allow(unused)]
struct $name$(<$lifetime>)? {
context: $crate::pages::TemplateContext,
context: $crate::pages::TemplateContext,
$($field_name: $field_type,)*
}
impl$(<$lifetime>)? $name$(<$lifetime>)? {
fn new(state: &$crate::State, $($field_name: $field_type,)*) -> Self {
#[allow(clippy::too_many_arguments)]
fn new(context: $crate::pages::TemplateContext, $($field_name: $field_type,)*) -> Self {
Self {
context: state.into(),
context,
$($field_name,)*
}
}
@@ -54,3 +106,16 @@ macro_rules! template {
}
};
}
#[macro_export]
macro_rules! response {
(BadRequest($body:expr)) => {
response!((axum::http::StatusCode::BAD_REQUEST, $body))
};
($body:expr) => {{
use axum::response::IntoResponse;
Ok($body.into_response())
}};
}
+144
View File
@@ -0,0 +1,144 @@
use axum::{
Extension, Router,
extract::{Query, State},
response::Redirect,
routing::on,
};
use conduwuit_service::oauth::grant::{AuthorizationCodeQuery, Prompt};
use ruma::OwnedUserId;
use url::Url;
use crate::{
ROUTE_PREFIX, WebError,
extract::{Expect, PostForm},
pages::{
GET_POST, Result, TemplateContext,
account::register::{RegisterQuery, RequestedRegistrationFlow},
components::{Avatar, AvatarType, ClientScopes},
},
response,
session::{LoginQuery, LoginTarget, User},
template,
};
pub(crate) fn build() -> Router<crate::State> {
Router::new().route("/authorization_code", on(GET_POST, route_authorization_code))
}
template! {
struct Grant use "grant.html.j2" {
logout_query: String,
user_id: OwnedUserId,
user_avatar: Avatar,
client_uri: Url,
client_name: String,
client_avatar: Avatar,
policy_uri: Option<Url>,
tos_uri: Option<Url>,
scopes: ClientScopes
}
}
async fn route_authorization_code(
State(services): State<crate::State>,
Extension(context): Extension<TemplateContext>,
user: User<true>,
Expect(Query(query)): Expect<Query<AuthorizationCodeQuery>>,
PostForm(form): PostForm<()>,
) -> Result {
let user_id = if let Some(user) = user.into_session() {
user.user_id
} else {
let is_first_run = services.firstrun.is_first_run();
let next = LoginTarget::AuthorizationCode(query.clone());
let uri = if query
.prompt
.is_some_and(|prompt| matches!(prompt, Prompt::Create))
|| is_first_run
{
format!(
"{}/account/register/?{}",
ROUTE_PREFIX,
serde_urlencoded::to_string(RegisterQuery {
next: Some(next),
flow: if is_first_run {
Some(RequestedRegistrationFlow::Trusted)
} else {
None
},
..Default::default()
})
.unwrap()
)
} else {
format!(
"{}/account/login?{}",
ROUTE_PREFIX,
serde_urlencoded::to_string(LoginQuery {
next: Some(next),
..Default::default()
})
.unwrap()
)
};
return response!(Redirect::to(&uri));
};
if form.is_some() {
let redirect_uri = services
.oauth
.request_authorization_code(user_id, query)
.await
.map_err(WebError::BadRequest)?;
return response!(Redirect::to(&redirect_uri));
}
let Some(client) = services.oauth.get_client_metadata(&query.client_id).await else {
return Err(WebError::BadRequest("Invalid client ID".to_owned()));
};
let scopes = query.scope.to_scopes().map_err(WebError::BadRequest)?;
let client_name = if let Some(name) = &client.client_name {
name
} else {
"Unknown application"
}
.to_owned();
let client_avatar = {
let avatar_type = if let Some(logo) = &client.logo_uri {
AvatarType::Image(logo.to_string())
} else if let Some(name) = &client.client_name
&& let Some(char) = name.chars().next()
{
AvatarType::Initial(char)
} else {
AvatarType::Initial('?')
};
Avatar { avatar_type }
};
let user_avatar = Avatar::for_local_user(&services, &user_id).await;
response!(Grant::new(
context,
serde_urlencoded::to_string(LoginQuery {
next: Some(LoginTarget::AuthorizationCode(query)),
reauthenticate: false,
})
.unwrap(),
user_id,
user_avatar,
client.client_uri.clone(),
client_name,
client_avatar,
client.policy_uri.clone(),
client.tos_uri.clone(),
ClientScopes { scopes },
))
}
+10
View File
@@ -0,0 +1,10 @@
use axum::Router;
mod grant;
pub(crate) fn build() -> Router<crate::State> {
#[allow(clippy::wildcard_imports)]
use self::*;
Router::new().nest("/grant/", grant::build())
}
-119
View File
@@ -1,119 +0,0 @@
use axum::{
Router,
extract::{
Query, State,
rejection::{FormRejection, QueryRejection},
},
http::StatusCode,
response::{IntoResponse, Response},
routing::get,
};
use serde::Deserialize;
use validator::Validate;
use crate::{
WebError, form,
pages::components::{UserCard, form::Form},
template,
};
const INVALID_TOKEN_ERROR: &str = "Invalid reset token. Your reset link may have expired.";
template! {
struct PasswordReset<'a> use "password_reset.html.j2" {
user_card: UserCard<'a>,
body: PasswordResetBody
}
}
#[derive(Debug)]
enum PasswordResetBody {
Form(Form<'static>),
Success,
}
form! {
struct PasswordResetForm {
#[validate(length(min = 1, message = "Password cannot be empty"))]
new_password: String where {
input_type: "password",
label: "New password",
autocomplete: "new-password"
},
#[validate(must_match(other = "new_password", message = "Passwords must match"))]
confirm_new_password: String where {
input_type: "password",
label: "Confirm new password",
autocomplete: "new-password"
}
submit: "Reset Password"
}
}
pub(crate) fn build() -> Router<crate::State> {
Router::new()
.route("/account/reset_password", get(get_password_reset).post(post_password_reset))
}
#[derive(Deserialize)]
struct PasswordResetQuery {
token: String,
}
async fn password_reset_form(
services: crate::State,
query: PasswordResetQuery,
reset_form: Form<'static>,
) -> Result<impl IntoResponse, WebError> {
let Some(token) = services.password_reset.check_token(&query.token).await else {
return Err(WebError::BadRequest(INVALID_TOKEN_ERROR.to_owned()));
};
let user_card = UserCard::for_local_user(&services, &token.info.user).await;
Ok(PasswordReset::new(&services, user_card, PasswordResetBody::Form(reset_form))
.into_response())
}
async fn get_password_reset(
State(services): State<crate::State>,
query: Result<Query<PasswordResetQuery>, QueryRejection>,
) -> Result<impl IntoResponse, WebError> {
let Query(query) = query?;
password_reset_form(services, query, PasswordResetForm::build(None)).await
}
async fn post_password_reset(
State(services): State<crate::State>,
query: Result<Query<PasswordResetQuery>, QueryRejection>,
form: Result<axum::Form<PasswordResetForm>, FormRejection>,
) -> Result<Response, WebError> {
let Query(query) = query?;
let axum::Form(form) = form?;
match form.validate() {
| Ok(()) => {
let Some(token) = services.password_reset.check_token(&query.token).await else {
return Err(WebError::BadRequest(INVALID_TOKEN_ERROR.to_owned()));
};
let user_id = token.info.user.clone();
services
.password_reset
.consume_token(token, &form.new_password)
.await?;
let user_card = UserCard::for_local_user(&services, &user_id).await;
Ok(PasswordReset::new(&services, user_card, PasswordResetBody::Success)
.into_response())
},
| Err(err) => Ok((
StatusCode::BAD_REQUEST,
password_reset_form(services, query, PasswordResetForm::build(Some(err))).await,
)
.into_response()),
}
}
+22
View File
@@ -0,0 +1,22 @@
html {
height: 100svh;
--bg: oklch(0.44 0.177 353.06);
background-color: var(--bg);
background-image: linear-gradient(180deg, var(--bg) 55%, black 100%);;
}
main {
margin-top: 20vh;
margin-left: 5%;
margin-right: 5%;
color: white;
font-family: monospace;
font-size: 1.3em;
line-height: 1.2;
}
em {
font-size: 1.35em;
font-weight: bold;
white-space: nowrap;
}
+165 -25
View File
@@ -9,6 +9,7 @@
--panel-bg: oklch(0.91 0.042 317.27);
--c1: oklch(0.44 0.177 353.06);
--c2: oklch(0.59 0.158 150.88);
--avatar-color: var(--c2);
--name-lightness: 0.45;
--background-lightness: 0.9;
@@ -26,8 +27,8 @@
@media (prefers-color-scheme: dark) {
color-scheme: dark;
--text-color: #fff;
--secondary: #888;
--text-color: #f5ebeb;
--secondary: #999;
--bg: oklch(0.15 0.042 317.27);
--panel-bg: oklch(0.24 0.03 317.27);
@@ -54,10 +55,13 @@
}
body {
display: grid;
display: flex;
flex-direction: column;
align-items: center;
justify-content: space-around;
margin: 0;
padding: 0;
place-items: center;
min-height: 100vh;
color: var(--text-color);
@@ -73,6 +77,7 @@ html {
footer {
padding-inline: 0.25rem;
margin-top: 1rem;
height: max(fit-content, 2rem);
.logo {
@@ -83,12 +88,33 @@ footer {
p {
margin: 1rem 0;
a {
white-space: nowrap;
}
}
ul {
margin: 1rem 0;
}
section {
margin: 1rem 0;
}
em {
color: oklch(from var(--c2) var(--name-lightness) c h);
font-weight: bold;
font-style: normal;
&.negative {
color: red;
}
}
hr {
border-width: 1px;
border-color: var(--secondary);
}
small {
@@ -103,39 +129,59 @@ small.error {
margin-bottom: 0.5rem;
}
.panel {
--preferred-width: 12rem + 40dvw;
--maximum-width: 48rem;
width: min(clamp(24rem, var(--preferred-width), var(--maximum-width)), calc(100dvw - 3rem));
border-radius: var(--border-radius-lg);
background-color: var(--panel-bg);
padding-inline: 1.5rem;
padding-block: 1rem;
box-shadow: 0 0.25em 0.375em hsla(0, 0%, 0%, 0.1);
img.matrix-icon {
@media (prefers-color-scheme: dark) {
filter: invert();
}
}
&.narrow {
--preferred-width: 12rem + 20dvw;
--maximum-width: 36rem;
h1.with-matrix-icon {
display: flex;
align-items: center;
input, button {
width: 100%;
a:last-of-type {
margin-left: auto;
img {
height: 1em;
}
}
}
label {
display: block;
h1 a.back {
font-size: initial;
font-weight: initial;
}
input, button {
label {
display: block;
margin-bottom: 0.2em;
}
a, a:visited {
color: oklch(from var(--c1) var(--name-lightness) c h);
}
code {
color: oklch(from var(--secondary) var(--name-lightness) c h);
}
pre {
background-color: oklch(from var(--panel-bg) calc(l - 0.05) c h);
border-radius: var(--border-radius-sm);
padding: 8px;
}
input, button, a.button {
display: inline-block;
padding: 0.5em;
margin-bottom: 0.5em;
font-size: inherit;
font-family: inherit;
color: white;
line-height: normal;
color: var(--text-color);
text-decoration: none;
background-color: transparent;
border: none;
@@ -151,14 +197,36 @@ input {
}
}
button {
input[type="checkbox"] {
display: inline;
margin: 0;
width: auto !important;
}
button, a.button {
color: white;
background-color: var(--c1);
transition: opacity .2s;
text-align: center;
margin: 0.5rem 0;
&:enabled:hover {
opacity: 0.8;
cursor: pointer;
}
&:disabled {
color: lightgray;
background-color: gray;
}
&:not(:disabled) {
transition: linear color, background-color 0.1s;
}
&:visited {
color: white;
}
}
h1 {
@@ -166,6 +234,77 @@ h1 {
margin-bottom: 0.67em;
}
ul.bullet-separated {
display: inline-block;
margin: 0;
padding: 0;
li {
display: inline;
flex: 1;
list-style-type: none;
&:not(:first-child)::before {
content: "• ";
}
}
}
.fullwidth {
width: 100%;
margin-bottom: 0 !important;
}
.select-all {
user-select: all;
}
.panel {
--preferred-width: 12rem + 40dvw;
--maximum-width: 48rem;
--minimum-width: 32rem;
width: min(clamp(var(--minimum-width), var(--preferred-width), var(--maximum-width)), calc(100dvw - 3rem));
border-radius: var(--border-radius-lg);
background-color: var(--panel-bg);
padding-inline: 1.5rem;
padding-block: 1rem;
margin-top: 1em;
margin-bottom: auto;
box-shadow: 0 0.25em 0.375em hsla(0, 0%, 0%, 0.1);
&.middle {
margin-top: 0;
margin-bottom: 0;
}
&.narrow {
--preferred-width: 12rem + 20dvw;
--maximum-width: 36rem;
input, button, a.button {
width: 100%;
}
}
&:not(.narrow) form p {
margin-bottom: 0;
}
}
.project-name {
font-weight: bold;
text-decoration: none !important;
background: linear-gradient(
130deg,
oklch(from var(--c1) var(--name-lightness) c h),
oklch(from var(--c2) var(--name-lightness) c h)
);
background-clip: text;
color: transparent;
filter: brightness(1.2);
}
@media (max-width: 425px) {
main {
padding-block-start: 2rem;
@@ -175,11 +314,12 @@ h1 {
.panel {
border-radius: 0;
width: 100%;
margin-top: 0;
}
}
@media (max-width: 799px) {
input, button {
input, button, a.button {
width: 100%;
}
}
+27 -12
View File
@@ -11,12 +11,17 @@
font-size: calc(var(--avatar-size) * 0.5);
font-weight: 700;
line-height: calc(var(--avatar-size) - 2px);
user-select: none;
color: oklch(from var(--c1) calc(l + 0.2) c h);
background-color: var(--c1);
color: oklch(from var(--avatar-color) calc(l + 0.2) c h);
background-color: var(--avatar-color);
}
.user-card {
.red-avatar {
--avatar-color: var(--c1);
}
.card {
display: flex;
flex-direction: row;
align-items: center;
@@ -29,16 +34,26 @@
.info {
flex: 1 1;
p {
margin: 0;
.name {
font-weight: bold;
}
&.display-name {
font-weight: 700;
}
&:nth-of-type(2) {
color: var(--secondary);
}
.id {
color: var(--secondary);
font-weight: normal;
}
}
&.danger {
display: block;
background-color: oklch(from red 0.2 c h);
border: 1px dashed red;
}
}
.card-list {
display: flex;
flex-direction: column;
gap: 10px;
margin-top: 10px;
}
+8 -1
View File
@@ -2,12 +2,19 @@
font-family: monospace;
font-size: x-small;
font-weight: 700;
transform: translate(1rem, 1.6rem);
transform: translate(0rem, 2rem);
color: var(--secondary);
user-select: none;
margin: 0;
padding: 0;
background: none;
}
h1 {
display: flex;
align-items: center;
}
code {
white-space: pre-wrap;
}
+22
View File
@@ -0,0 +1,22 @@
.avatars {
justify-content: center;
display: flex;
flex-direction: row;
.separator {
align-self: center;
margin-inline: 1em;
color: var(--secondary);
font-size: x-large;
font-weight: bold;
user-select: none;
}
}
.identity {
margin-block: 1em;
color: var(--secondary);
font-size: small;
font-style: italic;
text-align: center;
}
+2 -10
View File
@@ -1,11 +1,3 @@
.project-name {
text-decoration: none;
background: linear-gradient(
130deg,
oklch(from var(--c1) var(--name-lightness) c h),
oklch(from var(--c2) var(--name-lightness) c h)
);
background-clip: text;
color: transparent;
filter: brightness(1.2);
.button {
margin: 0 !important;
}

Some files were not shown because too many files have changed in this diff Show More