feat: Implement support for prompt=create in the authorization code flow

This commit is contained in:
Ginger
2026-05-07 10:04:46 -04:00
parent 851d6e219f
commit 3c07857e1f
11 changed files with 112 additions and 40 deletions
Generated
+1
View File
@@ -1113,6 +1113,7 @@ dependencies = [
"conduwuit_core",
"conduwuit_database",
"conduwuit_service",
"form_urlencoded",
"futures",
"lettre",
"memory-serve",
+11
View File
@@ -18,6 +18,8 @@ pub struct AuthorizationCodeQuery {
pub response_mode: ResponseMode,
pub code_challenge: String,
pub code_challenge_method: CodeChallengeMethod,
#[serde(default)]
pub prompt: Option<Prompt>,
}
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
@@ -36,6 +38,15 @@ pub enum CodeChallengeMethod {
S256,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum Prompt {
Create,
#[serde(other)]
Unknown,
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialOrd, Ord)]
pub enum Scope {
Device(OwnedDeviceId),
+1
View File
@@ -245,6 +245,7 @@ impl Service {
let suffix = &self.services.config.new_user_displayname_suffix;
if !suffix.is_empty() {
displayname.push(' ');
displayname.push_str(suffix);
}
+1
View File
@@ -48,6 +48,7 @@ serde_urlencoded.workspace = true
url.workspace = true
recaptcha-verify = { version = "0.2.0", default-features = false }
reqwest_recaptcha = { package = "reqwest", version = "0.12.28", default-features = false, features = ["rustls-tls-native-roots-no-provider"] } # As long as recaptcha-verify's reqwest is outdated
form_urlencoded = "1.2.2"
[build-dependencies]
memory-serve = "2.1.0"
+10 -12
View File
@@ -36,7 +36,6 @@ pub(crate) fn build() -> Router<crate::State> {
template! {
struct Login use "login.html.j2" {
body: LoginBody,
has_next: bool,
login_error: Option<String>
}
}
@@ -46,6 +45,7 @@ enum LoginBody {
Unauthenticated {
server_name: String,
registration_available: bool,
next: Option<LoginTarget>,
},
Authenticated {
user_card: UserCard,
@@ -66,7 +66,6 @@ async fn route_login(
user: User,
PostForm(form): PostForm<LoginForm>,
) -> Result {
let next = next.unwrap_or_default();
let user_id = user.into_session().map(|session| session.user_id);
let body = match &user_id {
@@ -74,20 +73,19 @@ async fn route_login(
let (trusted_flow_status, untrusted_flow_status) =
registration_flow_status(&services).await;
let registration_available =
matches!(trusted_flow_status, TrustedFlowStatus::Available)
|| matches!(untrusted_flow_status, UntrustedFlowStatus::Available { .. });
LoginBody::Unauthenticated {
server_name: services.globals.server_name().to_string(),
registration_available: matches!(
trusted_flow_status,
TrustedFlowStatus::Available
) || matches!(
untrusted_flow_status,
UntrustedFlowStatus::Available { .. }
),
registration_available,
next: next.clone(),
}
},
| Some(user_id) => {
if !reauthenticate {
return response!(Redirect::to(&next.target_path()));
return response!(Redirect::to(&next.unwrap_or_default().target_path()));
}
let user_card = UserCard::for_local_user(&services, user_id.to_owned()).await;
@@ -96,7 +94,7 @@ async fn route_login(
},
};
let mut template = Login::new(context, body, next != LoginTarget::Account, None);
let mut template = Login::new(context, body, None);
if let Some(form) = form {
let login_result = match (user_id, form.identifier) {
@@ -144,7 +142,7 @@ async fn route_login(
.await
.expect("should be able to serialize user session");
return response!(Redirect::to(&next.target_path()));
return response!(Redirect::to(&next.unwrap_or_default().target_path()));
}
response!(template)
+37 -17
View File
@@ -49,6 +49,7 @@ enum RegisterBody {
trusted_flow_status: TrustedFlowStatus,
untrusted_flow_status: UntrustedFlowStatus,
username_error: Option<String>,
next: Option<LoginTarget>,
},
DetailsPrompt {
username: Option<String>,
@@ -73,18 +74,20 @@ pub(super) enum UntrustedFlowStatus {
},
}
#[derive(Deserialize)]
struct RegistrationQuery {
username: Option<String>,
token: Option<String>,
flow: Option<RequestedRegistrationFlow>,
#[derive(Default, Deserialize, Serialize)]
pub(crate) struct RegisterQuery {
pub username: Option<String>,
pub token: Option<String>,
pub flow: Option<RequestedRegistrationFlow>,
#[serde(default)]
from_landing: bool,
pub from_landing: bool,
#[serde(flatten)]
pub next: Option<LoginTarget>,
}
#[derive(Clone, Copy, Deserialize)]
#[derive(Clone, Copy, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
enum RequestedRegistrationFlow {
pub(crate) enum RequestedRegistrationFlow {
Untrusted,
Trusted,
}
@@ -118,13 +121,14 @@ struct CompletedRegistration {
user_id: OwnedUserId,
password_hash: HashedPassword,
registration_token: Option<ValidToken>,
next: Option<LoginTarget>,
}
async fn route_register(
State(services): State<crate::State>,
Extension(context): Extension<TemplateContext>,
session_store: Session,
Expect(Query(query)): Expect<Query<RegistrationQuery>>,
Expect(Query(query)): Expect<Query<RegisterQuery>>,
PostForm(form): PostForm<RegistrationForm>,
) -> Result {
if session_store
@@ -140,7 +144,15 @@ async fn route_register(
let validation_errors = if let Some(form) = form {
match form.validate() {
| Ok(()) => {
match begin_registration(&services, context.clone(), session_store, form).await? {
match begin_registration(
&services,
context.clone(),
session_store,
form,
query.next.clone(),
)
.await?
{
| Ok(response) => return Ok(response),
| Err(err) => err,
}
@@ -188,6 +200,7 @@ async fn route_register(
trusted_flow_status,
untrusted_flow_status,
username_error: Some(err.message()),
next: query.next,
}
));
}
@@ -227,6 +240,7 @@ async fn route_register(
trusted_flow_status,
untrusted_flow_status,
username_error: None,
next: query.next,
},
}
};
@@ -276,9 +290,10 @@ async fn get_register_email_validate(
let email = session.consume();
complete_registration(&services, session_store, completed_registration, Some(email)).await;
response!(Redirect::to(&LoginTarget::Account.target_path()))
response!(
complete_registration(&services, session_store, completed_registration, Some(email))
.await
)
}
async fn begin_registration(
@@ -286,6 +301,7 @@ async fn begin_registration(
context: TemplateContext,
session_store: Session,
form: RegistrationForm,
next: Option<LoginTarget>,
) -> Result<Result<Response, ValidationErrors>> {
let open_registration = services
.config
@@ -383,6 +399,7 @@ async fn begin_registration(
user_id,
password_hash,
registration_token,
next,
};
// Check if we need to send an email
@@ -466,9 +483,9 @@ async fn begin_registration(
))
} else {
// If email isn't required we can immediately complete registration
complete_registration(services, session_store, completed_registration, None).await;
Ok(response!(Redirect::to(&LoginTarget::Account.target_path())))
Ok(response!(
complete_registration(services, session_store, completed_registration, None).await
))
}
}
@@ -479,9 +496,10 @@ async fn complete_registration(
user_id,
password_hash,
registration_token,
next,
}: CompletedRegistration,
email: Option<Address>,
) {
) -> Redirect {
services
.users
.create_local_account(&user_id, password_hash, email)
@@ -499,6 +517,8 @@ async fn complete_registration(
.insert(User::KEY, user_session)
.await
.expect("should be able to serialize user session");
Redirect::to(&next.unwrap_or_default().target_path())
}
pub(super) async fn registration_flow_status(
+1 -1
View File
@@ -58,7 +58,7 @@ pub(super) async fn template_context_middleware(
response.headers_mut().insert(
header::CONTENT_SECURITY_POLICY,
HeaderValue::from_str(&format!(
"default-src 'none'; style-src 'self'; img-src 'self' 'https' data:; script-src \
"default-src 'none'; style-src 'self'; img-src 'self' https: data:; script-src \
'nonce-{csp_nonce}'; child-src {child_src};"
))
.expect("should be able to build CSP header"),
+35 -3
View File
@@ -4,15 +4,16 @@ use axum::{
response::Redirect,
routing::on,
};
use conduwuit_service::oauth::grant::AuthorizationCodeQuery;
use conduwuit_service::oauth::grant::{AuthorizationCodeQuery, Prompt};
use ruma::OwnedUserId;
use url::Url;
use crate::{
WebError,
ROUTE_PREFIX, WebError,
extract::{Expect, PostForm},
pages::{
GET_POST, Result, TemplateContext,
account::register::RegisterQuery,
components::{Avatar, AvatarType, ClientScopes},
},
response,
@@ -45,7 +46,38 @@ async fn route_authorization_code(
Expect(Query(query)): Expect<Query<AuthorizationCodeQuery>>,
PostForm(form): PostForm<()>,
) -> Result {
let user_id = user.expect(LoginTarget::AuthorizationCode(query.clone()))?;
let user_id = if let Some(user) = user.into_session() {
user.user_id
} else {
let next = LoginTarget::AuthorizationCode(query.clone());
let uri = if query
.prompt
.is_some_and(|prompt| matches!(prompt, Prompt::Create))
{
format!(
"{}/account/register/?{}",
ROUTE_PREFIX,
serde_urlencoded::to_string(RegisterQuery {
next: Some(next),
..Default::default()
})
.unwrap()
)
} else {
format!(
"{}/account/login?{}",
ROUTE_PREFIX,
serde_urlencoded::to_string(LoginQuery {
next: Some(next),
..Default::default()
})
.unwrap()
)
};
return response!(Redirect::to(&uri));
};
if form.is_some() {
let redirect_uri = services
+4 -3
View File
@@ -11,9 +11,9 @@ Log in
{%- block content -%}
<div class="panel narrow">
{% match body %}
{% when LoginBody::Unauthenticated { server_name, registration_available } %}
{% when LoginBody::Unauthenticated { server_name, registration_available, next } %}
<h1 class="with-matrix-icon">
{% if has_next %}
{% if next.is_some() %}
Log in to continue
{% else %}
Log in to Matrix
@@ -39,7 +39,8 @@ Log in
</form>
<div class="centered-links">
{% if registration_available %}
<a href="{{ crate::ROUTE_PREFIX }}/account/register/">Sign up</a>
{% let query = next.as_ref().map(serde_urlencoded::to_string).transpose().unwrap().unwrap_or_default() %}
<a href="{{ crate::ROUTE_PREFIX }}/account/register/?{{ query }}">Sign up</a>
{% endif %}
<a href="{{ crate::ROUTE_PREFIX }}/account/password/reset/">Forgot your password?</a>
</div>
+9 -2
View File
@@ -12,7 +12,7 @@ Sign up
{%- block content -%}
<div class="panel narrow">
<h1 class="with-matrix-icon">
{% if false %}
{% if let RegisterBody::UsernamePrompt { next, .. } = body && next.is_some() %}
Sign up to continue
{% else %}
Sign up
@@ -26,7 +26,7 @@ Sign up
<p>
This server is not currently accepting new accounts.
</p>
{% when RegisterBody::UsernamePrompt { allow_federation, untrusted_flow_status, trusted_flow_status, username_error } %}
{% when RegisterBody::UsernamePrompt { allow_federation, untrusted_flow_status, trusted_flow_status, username_error, next } %}
<p>
You're about to register a new Matrix account on <em>{{ server_name }}</em>.
</p>
@@ -42,6 +42,13 @@ Sign up
</p>
<form method="get">
<input type="hidden" name="from_landing" value="true">
{% if let Some(next) = next %}
{# urlencoded roundtrip moment #}
{% let next = serde_urlencoded::to_string(&next).unwrap() %}
{% for (key, value) in form_urlencoded::parse(next.as_bytes()) %}
<input type="hidden" name="{{ key }}" value="{{ value }}">
{% endfor %}
{% endif %}
<p>
<label for="username">Username</label>
<span class="username-input">
+2 -2
View File
@@ -14,7 +14,7 @@ use crate::{ROUTE_PREFIX, WebError, pages::account::device::DevicePath};
pub(crate) mod store;
#[derive(Debug, Deserialize, Serialize)]
#[derive(Default, Debug, Deserialize, Serialize)]
pub(crate) struct LoginQuery {
#[serde(flatten)]
pub next: Option<LoginTarget>,
@@ -22,7 +22,7 @@ pub(crate) struct LoginQuery {
pub reauthenticate: bool,
}
#[derive(Debug, Default, Deserialize, Serialize)]
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
#[serde(tag = "next", rename_all = "snake_case")]
pub(crate) enum LoginTarget {
AuthorizationCode(AuthorizationCodeQuery),