mirror of
https://forgejo.ellis.link/continuwuation/continuwuity.git
synced 2026-05-26 20:49:55 +00:00
feat: Implement support for prompt=create in the authorization code flow
This commit is contained in:
Generated
+1
@@ -1113,6 +1113,7 @@ dependencies = [
|
||||
"conduwuit_core",
|
||||
"conduwuit_database",
|
||||
"conduwuit_service",
|
||||
"form_urlencoded",
|
||||
"futures",
|
||||
"lettre",
|
||||
"memory-serve",
|
||||
|
||||
@@ -18,6 +18,8 @@ pub struct AuthorizationCodeQuery {
|
||||
pub response_mode: ResponseMode,
|
||||
pub code_challenge: String,
|
||||
pub code_challenge_method: CodeChallengeMethod,
|
||||
#[serde(default)]
|
||||
pub prompt: Option<Prompt>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
|
||||
@@ -36,6 +38,15 @@ pub enum CodeChallengeMethod {
|
||||
S256,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
#[non_exhaustive]
|
||||
pub enum Prompt {
|
||||
Create,
|
||||
#[serde(other)]
|
||||
Unknown,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, PartialOrd, Ord)]
|
||||
pub enum Scope {
|
||||
Device(OwnedDeviceId),
|
||||
|
||||
@@ -245,6 +245,7 @@ impl Service {
|
||||
|
||||
let suffix = &self.services.config.new_user_displayname_suffix;
|
||||
if !suffix.is_empty() {
|
||||
displayname.push(' ');
|
||||
displayname.push_str(suffix);
|
||||
}
|
||||
|
||||
|
||||
@@ -48,6 +48,7 @@ serde_urlencoded.workspace = true
|
||||
url.workspace = true
|
||||
recaptcha-verify = { version = "0.2.0", default-features = false }
|
||||
reqwest_recaptcha = { package = "reqwest", version = "0.12.28", default-features = false, features = ["rustls-tls-native-roots-no-provider"] } # As long as recaptcha-verify's reqwest is outdated
|
||||
form_urlencoded = "1.2.2"
|
||||
|
||||
[build-dependencies]
|
||||
memory-serve = "2.1.0"
|
||||
|
||||
@@ -36,7 +36,6 @@ pub(crate) fn build() -> Router<crate::State> {
|
||||
template! {
|
||||
struct Login use "login.html.j2" {
|
||||
body: LoginBody,
|
||||
has_next: bool,
|
||||
login_error: Option<String>
|
||||
}
|
||||
}
|
||||
@@ -46,6 +45,7 @@ enum LoginBody {
|
||||
Unauthenticated {
|
||||
server_name: String,
|
||||
registration_available: bool,
|
||||
next: Option<LoginTarget>,
|
||||
},
|
||||
Authenticated {
|
||||
user_card: UserCard,
|
||||
@@ -66,7 +66,6 @@ async fn route_login(
|
||||
user: User,
|
||||
PostForm(form): PostForm<LoginForm>,
|
||||
) -> Result {
|
||||
let next = next.unwrap_or_default();
|
||||
let user_id = user.into_session().map(|session| session.user_id);
|
||||
|
||||
let body = match &user_id {
|
||||
@@ -74,20 +73,19 @@ async fn route_login(
|
||||
let (trusted_flow_status, untrusted_flow_status) =
|
||||
registration_flow_status(&services).await;
|
||||
|
||||
let registration_available =
|
||||
matches!(trusted_flow_status, TrustedFlowStatus::Available)
|
||||
|| matches!(untrusted_flow_status, UntrustedFlowStatus::Available { .. });
|
||||
|
||||
LoginBody::Unauthenticated {
|
||||
server_name: services.globals.server_name().to_string(),
|
||||
registration_available: matches!(
|
||||
trusted_flow_status,
|
||||
TrustedFlowStatus::Available
|
||||
) || matches!(
|
||||
untrusted_flow_status,
|
||||
UntrustedFlowStatus::Available { .. }
|
||||
),
|
||||
registration_available,
|
||||
next: next.clone(),
|
||||
}
|
||||
},
|
||||
| Some(user_id) => {
|
||||
if !reauthenticate {
|
||||
return response!(Redirect::to(&next.target_path()));
|
||||
return response!(Redirect::to(&next.unwrap_or_default().target_path()));
|
||||
}
|
||||
|
||||
let user_card = UserCard::for_local_user(&services, user_id.to_owned()).await;
|
||||
@@ -96,7 +94,7 @@ async fn route_login(
|
||||
},
|
||||
};
|
||||
|
||||
let mut template = Login::new(context, body, next != LoginTarget::Account, None);
|
||||
let mut template = Login::new(context, body, None);
|
||||
|
||||
if let Some(form) = form {
|
||||
let login_result = match (user_id, form.identifier) {
|
||||
@@ -144,7 +142,7 @@ async fn route_login(
|
||||
.await
|
||||
.expect("should be able to serialize user session");
|
||||
|
||||
return response!(Redirect::to(&next.target_path()));
|
||||
return response!(Redirect::to(&next.unwrap_or_default().target_path()));
|
||||
}
|
||||
|
||||
response!(template)
|
||||
|
||||
@@ -49,6 +49,7 @@ enum RegisterBody {
|
||||
trusted_flow_status: TrustedFlowStatus,
|
||||
untrusted_flow_status: UntrustedFlowStatus,
|
||||
username_error: Option<String>,
|
||||
next: Option<LoginTarget>,
|
||||
},
|
||||
DetailsPrompt {
|
||||
username: Option<String>,
|
||||
@@ -73,18 +74,20 @@ pub(super) enum UntrustedFlowStatus {
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct RegistrationQuery {
|
||||
username: Option<String>,
|
||||
token: Option<String>,
|
||||
flow: Option<RequestedRegistrationFlow>,
|
||||
#[derive(Default, Deserialize, Serialize)]
|
||||
pub(crate) struct RegisterQuery {
|
||||
pub username: Option<String>,
|
||||
pub token: Option<String>,
|
||||
pub flow: Option<RequestedRegistrationFlow>,
|
||||
#[serde(default)]
|
||||
from_landing: bool,
|
||||
pub from_landing: bool,
|
||||
#[serde(flatten)]
|
||||
pub next: Option<LoginTarget>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Deserialize)]
|
||||
#[derive(Clone, Copy, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
enum RequestedRegistrationFlow {
|
||||
pub(crate) enum RequestedRegistrationFlow {
|
||||
Untrusted,
|
||||
Trusted,
|
||||
}
|
||||
@@ -118,13 +121,14 @@ struct CompletedRegistration {
|
||||
user_id: OwnedUserId,
|
||||
password_hash: HashedPassword,
|
||||
registration_token: Option<ValidToken>,
|
||||
next: Option<LoginTarget>,
|
||||
}
|
||||
|
||||
async fn route_register(
|
||||
State(services): State<crate::State>,
|
||||
Extension(context): Extension<TemplateContext>,
|
||||
session_store: Session,
|
||||
Expect(Query(query)): Expect<Query<RegistrationQuery>>,
|
||||
Expect(Query(query)): Expect<Query<RegisterQuery>>,
|
||||
PostForm(form): PostForm<RegistrationForm>,
|
||||
) -> Result {
|
||||
if session_store
|
||||
@@ -140,7 +144,15 @@ async fn route_register(
|
||||
let validation_errors = if let Some(form) = form {
|
||||
match form.validate() {
|
||||
| Ok(()) => {
|
||||
match begin_registration(&services, context.clone(), session_store, form).await? {
|
||||
match begin_registration(
|
||||
&services,
|
||||
context.clone(),
|
||||
session_store,
|
||||
form,
|
||||
query.next.clone(),
|
||||
)
|
||||
.await?
|
||||
{
|
||||
| Ok(response) => return Ok(response),
|
||||
| Err(err) => err,
|
||||
}
|
||||
@@ -188,6 +200,7 @@ async fn route_register(
|
||||
trusted_flow_status,
|
||||
untrusted_flow_status,
|
||||
username_error: Some(err.message()),
|
||||
next: query.next,
|
||||
}
|
||||
));
|
||||
}
|
||||
@@ -227,6 +240,7 @@ async fn route_register(
|
||||
trusted_flow_status,
|
||||
untrusted_flow_status,
|
||||
username_error: None,
|
||||
next: query.next,
|
||||
},
|
||||
}
|
||||
};
|
||||
@@ -276,9 +290,10 @@ async fn get_register_email_validate(
|
||||
|
||||
let email = session.consume();
|
||||
|
||||
complete_registration(&services, session_store, completed_registration, Some(email)).await;
|
||||
|
||||
response!(Redirect::to(&LoginTarget::Account.target_path()))
|
||||
response!(
|
||||
complete_registration(&services, session_store, completed_registration, Some(email))
|
||||
.await
|
||||
)
|
||||
}
|
||||
|
||||
async fn begin_registration(
|
||||
@@ -286,6 +301,7 @@ async fn begin_registration(
|
||||
context: TemplateContext,
|
||||
session_store: Session,
|
||||
form: RegistrationForm,
|
||||
next: Option<LoginTarget>,
|
||||
) -> Result<Result<Response, ValidationErrors>> {
|
||||
let open_registration = services
|
||||
.config
|
||||
@@ -383,6 +399,7 @@ async fn begin_registration(
|
||||
user_id,
|
||||
password_hash,
|
||||
registration_token,
|
||||
next,
|
||||
};
|
||||
|
||||
// Check if we need to send an email
|
||||
@@ -466,9 +483,9 @@ async fn begin_registration(
|
||||
))
|
||||
} else {
|
||||
// If email isn't required we can immediately complete registration
|
||||
complete_registration(services, session_store, completed_registration, None).await;
|
||||
|
||||
Ok(response!(Redirect::to(&LoginTarget::Account.target_path())))
|
||||
Ok(response!(
|
||||
complete_registration(services, session_store, completed_registration, None).await
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -479,9 +496,10 @@ async fn complete_registration(
|
||||
user_id,
|
||||
password_hash,
|
||||
registration_token,
|
||||
next,
|
||||
}: CompletedRegistration,
|
||||
email: Option<Address>,
|
||||
) {
|
||||
) -> Redirect {
|
||||
services
|
||||
.users
|
||||
.create_local_account(&user_id, password_hash, email)
|
||||
@@ -499,6 +517,8 @@ async fn complete_registration(
|
||||
.insert(User::KEY, user_session)
|
||||
.await
|
||||
.expect("should be able to serialize user session");
|
||||
|
||||
Redirect::to(&next.unwrap_or_default().target_path())
|
||||
}
|
||||
|
||||
pub(super) async fn registration_flow_status(
|
||||
|
||||
@@ -58,7 +58,7 @@ pub(super) async fn template_context_middleware(
|
||||
response.headers_mut().insert(
|
||||
header::CONTENT_SECURITY_POLICY,
|
||||
HeaderValue::from_str(&format!(
|
||||
"default-src 'none'; style-src 'self'; img-src 'self' 'https' data:; script-src \
|
||||
"default-src 'none'; style-src 'self'; img-src 'self' https: data:; script-src \
|
||||
'nonce-{csp_nonce}'; child-src {child_src};"
|
||||
))
|
||||
.expect("should be able to build CSP header"),
|
||||
|
||||
@@ -4,15 +4,16 @@ use axum::{
|
||||
response::Redirect,
|
||||
routing::on,
|
||||
};
|
||||
use conduwuit_service::oauth::grant::AuthorizationCodeQuery;
|
||||
use conduwuit_service::oauth::grant::{AuthorizationCodeQuery, Prompt};
|
||||
use ruma::OwnedUserId;
|
||||
use url::Url;
|
||||
|
||||
use crate::{
|
||||
WebError,
|
||||
ROUTE_PREFIX, WebError,
|
||||
extract::{Expect, PostForm},
|
||||
pages::{
|
||||
GET_POST, Result, TemplateContext,
|
||||
account::register::RegisterQuery,
|
||||
components::{Avatar, AvatarType, ClientScopes},
|
||||
},
|
||||
response,
|
||||
@@ -45,7 +46,38 @@ async fn route_authorization_code(
|
||||
Expect(Query(query)): Expect<Query<AuthorizationCodeQuery>>,
|
||||
PostForm(form): PostForm<()>,
|
||||
) -> Result {
|
||||
let user_id = user.expect(LoginTarget::AuthorizationCode(query.clone()))?;
|
||||
let user_id = if let Some(user) = user.into_session() {
|
||||
user.user_id
|
||||
} else {
|
||||
let next = LoginTarget::AuthorizationCode(query.clone());
|
||||
|
||||
let uri = if query
|
||||
.prompt
|
||||
.is_some_and(|prompt| matches!(prompt, Prompt::Create))
|
||||
{
|
||||
format!(
|
||||
"{}/account/register/?{}",
|
||||
ROUTE_PREFIX,
|
||||
serde_urlencoded::to_string(RegisterQuery {
|
||||
next: Some(next),
|
||||
..Default::default()
|
||||
})
|
||||
.unwrap()
|
||||
)
|
||||
} else {
|
||||
format!(
|
||||
"{}/account/login?{}",
|
||||
ROUTE_PREFIX,
|
||||
serde_urlencoded::to_string(LoginQuery {
|
||||
next: Some(next),
|
||||
..Default::default()
|
||||
})
|
||||
.unwrap()
|
||||
)
|
||||
};
|
||||
|
||||
return response!(Redirect::to(&uri));
|
||||
};
|
||||
|
||||
if form.is_some() {
|
||||
let redirect_uri = services
|
||||
|
||||
@@ -11,9 +11,9 @@ Log in
|
||||
{%- block content -%}
|
||||
<div class="panel narrow">
|
||||
{% match body %}
|
||||
{% when LoginBody::Unauthenticated { server_name, registration_available } %}
|
||||
{% when LoginBody::Unauthenticated { server_name, registration_available, next } %}
|
||||
<h1 class="with-matrix-icon">
|
||||
{% if has_next %}
|
||||
{% if next.is_some() %}
|
||||
Log in to continue
|
||||
{% else %}
|
||||
Log in to Matrix
|
||||
@@ -39,7 +39,8 @@ Log in
|
||||
</form>
|
||||
<div class="centered-links">
|
||||
{% if registration_available %}
|
||||
<a href="{{ crate::ROUTE_PREFIX }}/account/register/">Sign up</a>
|
||||
{% let query = next.as_ref().map(serde_urlencoded::to_string).transpose().unwrap().unwrap_or_default() %}
|
||||
<a href="{{ crate::ROUTE_PREFIX }}/account/register/?{{ query }}">Sign up</a>
|
||||
{% endif %}
|
||||
<a href="{{ crate::ROUTE_PREFIX }}/account/password/reset/">Forgot your password?</a>
|
||||
</div>
|
||||
|
||||
@@ -12,7 +12,7 @@ Sign up
|
||||
{%- block content -%}
|
||||
<div class="panel narrow">
|
||||
<h1 class="with-matrix-icon">
|
||||
{% if false %}
|
||||
{% if let RegisterBody::UsernamePrompt { next, .. } = body && next.is_some() %}
|
||||
Sign up to continue
|
||||
{% else %}
|
||||
Sign up
|
||||
@@ -26,7 +26,7 @@ Sign up
|
||||
<p>
|
||||
This server is not currently accepting new accounts.
|
||||
</p>
|
||||
{% when RegisterBody::UsernamePrompt { allow_federation, untrusted_flow_status, trusted_flow_status, username_error } %}
|
||||
{% when RegisterBody::UsernamePrompt { allow_federation, untrusted_flow_status, trusted_flow_status, username_error, next } %}
|
||||
<p>
|
||||
You're about to register a new Matrix account on <em>{{ server_name }}</em>.
|
||||
</p>
|
||||
@@ -42,6 +42,13 @@ Sign up
|
||||
</p>
|
||||
<form method="get">
|
||||
<input type="hidden" name="from_landing" value="true">
|
||||
{% if let Some(next) = next %}
|
||||
{# urlencoded roundtrip moment #}
|
||||
{% let next = serde_urlencoded::to_string(&next).unwrap() %}
|
||||
{% for (key, value) in form_urlencoded::parse(next.as_bytes()) %}
|
||||
<input type="hidden" name="{{ key }}" value="{{ value }}">
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
<p>
|
||||
<label for="username">Username</label>
|
||||
<span class="username-input">
|
||||
|
||||
@@ -14,7 +14,7 @@ use crate::{ROUTE_PREFIX, WebError, pages::account::device::DevicePath};
|
||||
|
||||
pub(crate) mod store;
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
#[derive(Default, Debug, Deserialize, Serialize)]
|
||||
pub(crate) struct LoginQuery {
|
||||
#[serde(flatten)]
|
||||
pub next: Option<LoginTarget>,
|
||||
@@ -22,7 +22,7 @@ pub(crate) struct LoginQuery {
|
||||
pub reauthenticate: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Deserialize, Serialize)]
|
||||
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
|
||||
#[serde(tag = "next", rename_all = "snake_case")]
|
||||
pub(crate) enum LoginTarget {
|
||||
AuthorizationCode(AuthorizationCodeQuery),
|
||||
|
||||
Reference in New Issue
Block a user