From 3c07857e1f41712cb1efd738cb82535e8ec06cb3 Mon Sep 17 00:00:00 2001 From: Ginger Date: Thu, 7 May 2026 10:04:46 -0400 Subject: [PATCH] feat: Implement support for `prompt=create` in the authorization code flow --- Cargo.lock | 1 + src/service/oauth/grant.rs | 11 +++++ src/service/users/mod.rs | 1 + src/web/Cargo.toml | 1 + src/web/pages/account/login.rs | 22 +++++----- src/web/pages/account/register.rs | 54 ++++++++++++++++-------- src/web/pages/mod.rs | 2 +- src/web/pages/oauth/grant.rs | 38 +++++++++++++++-- src/web/pages/templates/login.html.j2 | 7 +-- src/web/pages/templates/register.html.j2 | 11 ++++- src/web/session/mod.rs | 4 +- 11 files changed, 112 insertions(+), 40 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index dfbe1f03f..e65a474e9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1113,6 +1113,7 @@ dependencies = [ "conduwuit_core", "conduwuit_database", "conduwuit_service", + "form_urlencoded", "futures", "lettre", "memory-serve", diff --git a/src/service/oauth/grant.rs b/src/service/oauth/grant.rs index 06ad5ad47..156b0f851 100644 --- a/src/service/oauth/grant.rs +++ b/src/service/oauth/grant.rs @@ -18,6 +18,8 @@ pub struct AuthorizationCodeQuery { pub response_mode: ResponseMode, pub code_challenge: String, pub code_challenge_method: CodeChallengeMethod, + #[serde(default)] + pub prompt: Option, } #[derive(Debug, Clone, Default, Deserialize, Serialize)] @@ -36,6 +38,15 @@ pub enum CodeChallengeMethod { S256, } +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(rename_all = "snake_case")] +#[non_exhaustive] +pub enum Prompt { + Create, + #[serde(other)] + Unknown, +} + #[derive(Debug, Clone, Deserialize, Serialize, PartialOrd, Ord)] pub enum Scope { Device(OwnedDeviceId), diff --git a/src/service/users/mod.rs b/src/service/users/mod.rs index 6a224be35..c93b9e82a 100644 --- a/src/service/users/mod.rs +++ b/src/service/users/mod.rs @@ -245,6 +245,7 @@ impl Service { let suffix = &self.services.config.new_user_displayname_suffix; if !suffix.is_empty() { + displayname.push(' '); displayname.push_str(suffix); } diff --git a/src/web/Cargo.toml b/src/web/Cargo.toml index 010b50bad..d70335ebf 100644 --- a/src/web/Cargo.toml +++ b/src/web/Cargo.toml @@ -48,6 +48,7 @@ serde_urlencoded.workspace = true url.workspace = true recaptcha-verify = { version = "0.2.0", default-features = false } reqwest_recaptcha = { package = "reqwest", version = "0.12.28", default-features = false, features = ["rustls-tls-native-roots-no-provider"] } # As long as recaptcha-verify's reqwest is outdated +form_urlencoded = "1.2.2" [build-dependencies] memory-serve = "2.1.0" diff --git a/src/web/pages/account/login.rs b/src/web/pages/account/login.rs index 29d2afeed..2c555a8e7 100644 --- a/src/web/pages/account/login.rs +++ b/src/web/pages/account/login.rs @@ -36,7 +36,6 @@ pub(crate) fn build() -> Router { template! { struct Login use "login.html.j2" { body: LoginBody, - has_next: bool, login_error: Option } } @@ -46,6 +45,7 @@ enum LoginBody { Unauthenticated { server_name: String, registration_available: bool, + next: Option, }, Authenticated { user_card: UserCard, @@ -66,7 +66,6 @@ async fn route_login( user: User, PostForm(form): PostForm, ) -> Result { - let next = next.unwrap_or_default(); let user_id = user.into_session().map(|session| session.user_id); let body = match &user_id { @@ -74,20 +73,19 @@ async fn route_login( let (trusted_flow_status, untrusted_flow_status) = registration_flow_status(&services).await; + let registration_available = + matches!(trusted_flow_status, TrustedFlowStatus::Available) + || matches!(untrusted_flow_status, UntrustedFlowStatus::Available { .. }); + LoginBody::Unauthenticated { server_name: services.globals.server_name().to_string(), - registration_available: matches!( - trusted_flow_status, - TrustedFlowStatus::Available - ) || matches!( - untrusted_flow_status, - UntrustedFlowStatus::Available { .. } - ), + registration_available, + next: next.clone(), } }, | Some(user_id) => { if !reauthenticate { - return response!(Redirect::to(&next.target_path())); + return response!(Redirect::to(&next.unwrap_or_default().target_path())); } let user_card = UserCard::for_local_user(&services, user_id.to_owned()).await; @@ -96,7 +94,7 @@ async fn route_login( }, }; - let mut template = Login::new(context, body, next != LoginTarget::Account, None); + let mut template = Login::new(context, body, None); if let Some(form) = form { let login_result = match (user_id, form.identifier) { @@ -144,7 +142,7 @@ async fn route_login( .await .expect("should be able to serialize user session"); - return response!(Redirect::to(&next.target_path())); + return response!(Redirect::to(&next.unwrap_or_default().target_path())); } response!(template) diff --git a/src/web/pages/account/register.rs b/src/web/pages/account/register.rs index 2128398fc..e53f8be0c 100644 --- a/src/web/pages/account/register.rs +++ b/src/web/pages/account/register.rs @@ -49,6 +49,7 @@ enum RegisterBody { trusted_flow_status: TrustedFlowStatus, untrusted_flow_status: UntrustedFlowStatus, username_error: Option, + next: Option, }, DetailsPrompt { username: Option, @@ -73,18 +74,20 @@ pub(super) enum UntrustedFlowStatus { }, } -#[derive(Deserialize)] -struct RegistrationQuery { - username: Option, - token: Option, - flow: Option, +#[derive(Default, Deserialize, Serialize)] +pub(crate) struct RegisterQuery { + pub username: Option, + pub token: Option, + pub flow: Option, #[serde(default)] - from_landing: bool, + pub from_landing: bool, + #[serde(flatten)] + pub next: Option, } -#[derive(Clone, Copy, Deserialize)] +#[derive(Clone, Copy, Deserialize, Serialize)] #[serde(rename_all = "snake_case")] -enum RequestedRegistrationFlow { +pub(crate) enum RequestedRegistrationFlow { Untrusted, Trusted, } @@ -118,13 +121,14 @@ struct CompletedRegistration { user_id: OwnedUserId, password_hash: HashedPassword, registration_token: Option, + next: Option, } async fn route_register( State(services): State, Extension(context): Extension, session_store: Session, - Expect(Query(query)): Expect>, + Expect(Query(query)): Expect>, PostForm(form): PostForm, ) -> Result { if session_store @@ -140,7 +144,15 @@ async fn route_register( let validation_errors = if let Some(form) = form { match form.validate() { | Ok(()) => { - match begin_registration(&services, context.clone(), session_store, form).await? { + match begin_registration( + &services, + context.clone(), + session_store, + form, + query.next.clone(), + ) + .await? + { | Ok(response) => return Ok(response), | Err(err) => err, } @@ -188,6 +200,7 @@ async fn route_register( trusted_flow_status, untrusted_flow_status, username_error: Some(err.message()), + next: query.next, } )); } @@ -227,6 +240,7 @@ async fn route_register( trusted_flow_status, untrusted_flow_status, username_error: None, + next: query.next, }, } }; @@ -276,9 +290,10 @@ async fn get_register_email_validate( let email = session.consume(); - complete_registration(&services, session_store, completed_registration, Some(email)).await; - - response!(Redirect::to(&LoginTarget::Account.target_path())) + response!( + complete_registration(&services, session_store, completed_registration, Some(email)) + .await + ) } async fn begin_registration( @@ -286,6 +301,7 @@ async fn begin_registration( context: TemplateContext, session_store: Session, form: RegistrationForm, + next: Option, ) -> Result> { let open_registration = services .config @@ -383,6 +399,7 @@ async fn begin_registration( user_id, password_hash, registration_token, + next, }; // Check if we need to send an email @@ -466,9 +483,9 @@ async fn begin_registration( )) } else { // If email isn't required we can immediately complete registration - complete_registration(services, session_store, completed_registration, None).await; - - Ok(response!(Redirect::to(&LoginTarget::Account.target_path()))) + Ok(response!( + complete_registration(services, session_store, completed_registration, None).await + )) } } @@ -479,9 +496,10 @@ async fn complete_registration( user_id, password_hash, registration_token, + next, }: CompletedRegistration, email: Option
, -) { +) -> Redirect { services .users .create_local_account(&user_id, password_hash, email) @@ -499,6 +517,8 @@ async fn complete_registration( .insert(User::KEY, user_session) .await .expect("should be able to serialize user session"); + + Redirect::to(&next.unwrap_or_default().target_path()) } pub(super) async fn registration_flow_status( diff --git a/src/web/pages/mod.rs b/src/web/pages/mod.rs index 7a75e5bb1..b7d5eed86 100644 --- a/src/web/pages/mod.rs +++ b/src/web/pages/mod.rs @@ -58,7 +58,7 @@ pub(super) async fn template_context_middleware( response.headers_mut().insert( header::CONTENT_SECURITY_POLICY, HeaderValue::from_str(&format!( - "default-src 'none'; style-src 'self'; img-src 'self' 'https' data:; script-src \ + "default-src 'none'; style-src 'self'; img-src 'self' https: data:; script-src \ 'nonce-{csp_nonce}'; child-src {child_src};" )) .expect("should be able to build CSP header"), diff --git a/src/web/pages/oauth/grant.rs b/src/web/pages/oauth/grant.rs index 15f0e6b91..2154e32d4 100644 --- a/src/web/pages/oauth/grant.rs +++ b/src/web/pages/oauth/grant.rs @@ -4,15 +4,16 @@ use axum::{ response::Redirect, routing::on, }; -use conduwuit_service::oauth::grant::AuthorizationCodeQuery; +use conduwuit_service::oauth::grant::{AuthorizationCodeQuery, Prompt}; use ruma::OwnedUserId; use url::Url; use crate::{ - WebError, + ROUTE_PREFIX, WebError, extract::{Expect, PostForm}, pages::{ GET_POST, Result, TemplateContext, + account::register::RegisterQuery, components::{Avatar, AvatarType, ClientScopes}, }, response, @@ -45,7 +46,38 @@ async fn route_authorization_code( Expect(Query(query)): Expect>, PostForm(form): PostForm<()>, ) -> Result { - let user_id = user.expect(LoginTarget::AuthorizationCode(query.clone()))?; + let user_id = if let Some(user) = user.into_session() { + user.user_id + } else { + let next = LoginTarget::AuthorizationCode(query.clone()); + + let uri = if query + .prompt + .is_some_and(|prompt| matches!(prompt, Prompt::Create)) + { + format!( + "{}/account/register/?{}", + ROUTE_PREFIX, + serde_urlencoded::to_string(RegisterQuery { + next: Some(next), + ..Default::default() + }) + .unwrap() + ) + } else { + format!( + "{}/account/login?{}", + ROUTE_PREFIX, + serde_urlencoded::to_string(LoginQuery { + next: Some(next), + ..Default::default() + }) + .unwrap() + ) + }; + + return response!(Redirect::to(&uri)); + }; if form.is_some() { let redirect_uri = services diff --git a/src/web/pages/templates/login.html.j2 b/src/web/pages/templates/login.html.j2 index f590907ba..dc2674926 100644 --- a/src/web/pages/templates/login.html.j2 +++ b/src/web/pages/templates/login.html.j2 @@ -11,9 +11,9 @@ Log in {%- block content -%}
{% match body %} - {% when LoginBody::Unauthenticated { server_name, registration_available } %} + {% when LoginBody::Unauthenticated { server_name, registration_available, next } %}

- {% if has_next %} + {% if next.is_some() %} Log in to continue {% else %} Log in to Matrix @@ -39,7 +39,8 @@ Log in diff --git a/src/web/pages/templates/register.html.j2 b/src/web/pages/templates/register.html.j2 index bdd393125..eabadd7b1 100644 --- a/src/web/pages/templates/register.html.j2 +++ b/src/web/pages/templates/register.html.j2 @@ -12,7 +12,7 @@ Sign up {%- block content -%}

- {% if false %} + {% if let RegisterBody::UsernamePrompt { next, .. } = body && next.is_some() %} Sign up to continue {% else %} Sign up @@ -26,7 +26,7 @@ Sign up

This server is not currently accepting new accounts.

- {% when RegisterBody::UsernamePrompt { allow_federation, untrusted_flow_status, trusted_flow_status, username_error } %} + {% when RegisterBody::UsernamePrompt { allow_federation, untrusted_flow_status, trusted_flow_status, username_error, next } %}

You're about to register a new Matrix account on {{ server_name }}.

@@ -42,6 +42,13 @@ Sign up

+ {% if let Some(next) = next %} + {# urlencoded roundtrip moment #} + {% let next = serde_urlencoded::to_string(&next).unwrap() %} + {% for (key, value) in form_urlencoded::parse(next.as_bytes()) %} + + {% endfor %} + {% endif %}

diff --git a/src/web/session/mod.rs b/src/web/session/mod.rs index b2d07e2e9..357f683fc 100644 --- a/src/web/session/mod.rs +++ b/src/web/session/mod.rs @@ -14,7 +14,7 @@ use crate::{ROUTE_PREFIX, WebError, pages::account::device::DevicePath}; pub(crate) mod store; -#[derive(Debug, Deserialize, Serialize)] +#[derive(Default, Debug, Deserialize, Serialize)] pub(crate) struct LoginQuery { #[serde(flatten)] pub next: Option, @@ -22,7 +22,7 @@ pub(crate) struct LoginQuery { pub reauthenticate: bool, } -#[derive(Debug, Default, Deserialize, Serialize)] +#[derive(Clone, Debug, Default, Deserialize, Serialize)] #[serde(tag = "next", rename_all = "snake_case")] pub(crate) enum LoginTarget { AuthorizationCode(AuthorizationCodeQuery),