feat: Implement oauth auth code and refresh token flows

This commit is contained in:
Ginger
2026-04-30 08:54:55 -04:00
parent f269fb5cfc
commit 13917bb5c3
37 changed files with 1057 additions and 157 deletions
+2 -1
View File
@@ -43,7 +43,8 @@ validator = { version = "0.20.0", features = ["derive"] }
tower-sec-fetch = { version = "0.1.2", features = ["tracing"] }
tower-sessions = { version = "0.15.0", default-features = false, features = ["axum-core"] }
tower-sessions-core = { version = "0.15.0", features = ["deletion-task"] }
serde_urlencoded = "0.7.1"
serde_urlencoded.workspace = true
url.workspace = true
[build-dependencies]
memory-serve = "2.1.0"
+2 -1
View File
@@ -127,6 +127,7 @@ pub fn build(services: &Services) -> Router<state::State> {
Router::new()
.nest("/account/", account::build())
.merge(debug::build())
.nest("/oauth2/", oauth::build())
.merge(resources::build())
.merge(threepid::build())
.fallback(async || WebError::NotFound),
@@ -145,7 +146,7 @@ pub fn build(services: &Services) -> Router<state::State> {
}))
.layer(SetResponseHeaderLayer::if_not_present(
header::CONTENT_SECURITY_POLICY,
HeaderValue::from_static("default-src 'self'; img-src 'self' data:;"),
HeaderValue::from_static("default-src 'self'; img-src 'self' https: data:;"),
))
.layer(SecFetchLayer::new(|policy| {
policy.allow_safe_methods().reject_missing_metadata();
+12 -12
View File
@@ -2,7 +2,7 @@ use std::time::SystemTime;
use axum::{
Router,
extract::{Query, State},
extract::{Query, RawQuery, State},
response::{IntoResponse, Redirect},
routing::{get, on},
};
@@ -15,11 +15,11 @@ use serde::Deserialize;
use tower_sessions::Session;
use crate::{
WebError,
ROUTE_PREFIX, WebError,
extract::{Expect, PostForm},
pages::{GET_POST, Result, components::UserCard},
response,
session::{LoginQuery, User, UserSession},
session::{LoginQuery, LoginTarget, User, UserSession},
template,
};
@@ -32,6 +32,7 @@ pub(crate) fn build() -> Router<crate::State> {
template! {
struct Login use "login.html.j2" {
body: LoginBody,
has_next: bool,
login_error: Option<String>
}
}
@@ -54,11 +55,12 @@ struct LoginForm {
async fn route_login(
State(services): State<crate::State>,
Expect(Query(query)): Expect<Query<LoginQuery>>,
Expect(Query(LoginQuery { next, reauthenticate })): Expect<Query<LoginQuery>>,
session_store: Session,
user: User,
PostForm(form): PostForm<LoginForm>,
) -> Result {
let next = next.unwrap_or_default();
let user_id = user.into_session().map(|session| session.user_id);
let body = match &user_id {
@@ -66,8 +68,8 @@ async fn route_login(
server_name: services.globals.server_name().to_string(),
},
| Some(user_id) => {
if !query.reauthenticate {
return response!(Redirect::to(&query.next.target_path()));
if !reauthenticate {
return response!(Redirect::to(&next.target_path()));
}
let user_card = UserCard::for_local_user(&services, user_id.to_owned()).await;
@@ -76,7 +78,7 @@ async fn route_login(
},
};
let mut template = Login::new(&services, body, None);
let mut template = Login::new(&services, body, next != LoginTarget::Account, None);
if let Some(form) = form {
let login_result = match (user_id, form.identifier) {
@@ -86,8 +88,6 @@ async fn route_login(
},
| (None, Some(identifier)) => {
// The user isn't authenticated, we need to log them in
// Yes, this does parse the email twice (handle_login does it again). I don't
// think this really needs to be optimized.
let identifier = if identifier.parse::<lettre::Address>().is_ok() {
UserIdentifier::Email(EmailUserIdentifier::new(identifier))
} else {
@@ -123,14 +123,14 @@ async fn route_login(
.await
.expect("should be able to serialize user session");
return response!(Redirect::to(&query.next.target_path()));
return response!(Redirect::to(&next.target_path()));
}
response!(template)
}
async fn get_logout(session: Session) -> impl IntoResponse {
async fn get_logout(session: Session, RawQuery(query): RawQuery) -> impl IntoResponse {
let _ = session.remove::<OwnedUserId>(User::KEY).await;
Redirect::to("/_continuwuity/account/")
Redirect::to(&format!("{}/account/login?{}", ROUTE_PREFIX, query.unwrap_or_default()))
}
+3 -1
View File
@@ -61,12 +61,14 @@ async fn get_account(
let user_card = UserCard::for_local_user(&services, user_id.clone()).await;
let devices = services
let mut devices: Vec<_> = services
.users
.all_device_ids(&user_id)
.then(async |device_id| DeviceCard::for_device(&services, &user_id, device_id).await)
.collect()
.await;
devices.sort_unstable_by(|a, b| a.last_seen_ts.cmp(&b.last_seen_ts).reverse());
response!(Account::new(&services, user_card, email_requirement, email, devices))
}
+75 -69
View File
@@ -4,36 +4,26 @@ use askama::{Template, filters::HtmlSafe};
use base64::Engine;
use conduwuit_core::{result::FlatOk, utils};
use conduwuit_service::{Services, media::mxc::Mxc, oauth::client_metadata::ClientMetadata};
use ruma::{OwnedDeviceId, OwnedUserId, UserId};
use ruma::{MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedUserId, UserId};
pub(super) mod form;
#[derive(Debug)]
pub(super) enum AvatarType<'a> {
pub(super) enum AvatarType {
Initial(char),
Image(&'a str),
Image(String),
}
#[derive(Debug, Template)]
#[template(path = "_components/avatar.html.j2")]
pub(super) struct Avatar<'a> {
pub(super) avatar_type: AvatarType<'a>,
pub(super) struct Avatar {
pub(super) avatar_type: AvatarType,
}
impl HtmlSafe for Avatar<'_> {}
impl HtmlSafe for Avatar {}
#[derive(Debug, Template)]
#[template(path = "_components/user_card.html.j2")]
pub(super) struct UserCard {
pub user_id: OwnedUserId,
pub display_name: Option<String>,
pub avatar_src: Option<String>,
}
impl HtmlSafe for UserCard {}
impl UserCard {
pub(super) async fn for_local_user(services: &Services, user_id: OwnedUserId) -> Self {
impl Avatar {
pub(super) async fn for_local_user(services: &Services, user_id: &UserId) -> Self {
let display_name = services.users.displayname(&user_id).await.ok();
let avatar_src = async {
@@ -56,33 +46,48 @@ impl UserCard {
}
.await;
Self { user_id, display_name, avatar_src }
}
fn avatar(&self) -> Avatar<'_> {
let avatar_type = if let Some(ref avatar_src) = self.avatar_src {
let avatar_type = if let Some(avatar_src) = avatar_src {
AvatarType::Image(avatar_src)
} else if let Some(initial) = self
.display_name
} else if let Some(initial) = display_name
.as_ref()
.and_then(|display_name| display_name.chars().next())
{
AvatarType::Initial(initial)
} else {
AvatarType::Initial(self.user_id.localpart().chars().next().unwrap())
AvatarType::Initial(user_id.localpart().chars().next().unwrap())
};
Avatar { avatar_type }
}
}
#[derive(Debug, Template)]
#[template(path = "_components/user_card.html.j2")]
pub(super) struct UserCard {
pub user_id: OwnedUserId,
pub display_name: Option<String>,
pub avatar: Avatar,
}
impl HtmlSafe for UserCard {}
impl UserCard {
pub(super) async fn for_local_user(services: &Services, user_id: OwnedUserId) -> Self {
let display_name = services.users.displayname(&user_id).await.ok();
let avatar = Avatar::for_local_user(services, &user_id).await;
Self { user_id, display_name, avatar }
}
}
#[derive(Debug, Template)]
#[template(path = "_components/device_card.html.j2")]
pub(super) struct DeviceCard {
pub device_id: OwnedDeviceId,
pub display_name: Option<String>,
pub avatar_src: Option<String>,
pub avatar: Avatar,
pub last_active: String,
pub last_seen_ts: Option<u64>,
pub oauth_metadata: Option<ClientMetadata>,
}
@@ -101,12 +106,15 @@ impl DeviceCard {
.ok();
let oauth_metadata = async {
let client_id = services.oauth.get_client_id_for_device(&device_id).await?;
let client_id = services
.oauth
.get_client_id_for_device(user_id, &device_id)
.await?;
Some(
services
.oauth
.get_client_registration(&client_id)
.get_client_metadata(&client_id)
.await
.expect("client should exist"),
)
@@ -122,53 +130,51 @@ impl DeviceCard {
.and_then(|device| device.display_name.clone())
});
let avatar_src = oauth_metadata
.as_ref()
.and_then(|metadata| metadata.logo_uri.as_ref())
.map(|uri| uri.as_str().to_owned());
let avatar = {
let avatar_src = oauth_metadata
.as_ref()
.and_then(|metadata| metadata.logo_uri.as_ref())
.map(|uri| uri.as_str().to_owned());
let last_active = device
.as_ref()
.and_then(|device| device.last_seen_ts)
.map_or_else(
|| "unknown".to_owned(),
|active| {
active
.to_system_time()
.and_then(|t| SystemTime::now().duration_since(t).ok())
.map_or_else(
|| "now".to_owned(),
|duration| format!("{} ago", utils::time::pretty(duration)),
)
},
);
let avatar_type = if let Some(avatar_src) = avatar_src {
AvatarType::Image(avatar_src)
} else if let Some(initial) =
display_name.as_ref().and_then(|name| name.chars().next())
{
if oauth_metadata.is_some() {
AvatarType::Initial(initial)
} else {
AvatarType::Initial('❖')
}
} else {
AvatarType::Initial('?')
};
Avatar { avatar_type }
};
let last_seen_ts = device.as_ref().and_then(|device| device.last_seen_ts);
let last_active = last_seen_ts.map_or_else(
|| "unknown".to_owned(),
|last_seen_ts| {
last_seen_ts
.to_system_time()
.and_then(|t| SystemTime::now().duration_since(t).ok())
.map_or_else(
|| "now".to_owned(),
|duration| format!("{} ago", utils::time::pretty(duration)),
)
},
);
Self {
device_id,
display_name,
avatar_src,
avatar,
last_active,
last_seen_ts: last_seen_ts.map(|last_seen_ts| last_seen_ts.as_secs().into()),
oauth_metadata,
}
}
fn avatar(&self) -> Avatar<'_> {
let avatar_type = if let Some(avatar_src) = &self.avatar_src {
AvatarType::Image(avatar_src.as_str())
} else if let Some(initial) = self
.display_name
.as_ref()
.and_then(|name| name.chars().next())
{
if self.oauth_metadata.is_some() {
AvatarType::Initial(initial)
} else {
AvatarType::Initial('❖')
}
} else {
AvatarType::Initial('?')
};
Avatar { avatar_type }
}
}
+1
View File
@@ -6,6 +6,7 @@ pub(super) mod account;
mod components;
pub(super) mod debug;
pub(super) mod index;
pub(super) mod oauth;
pub(super) mod resources;
pub(super) mod threepid;
+113
View File
@@ -0,0 +1,113 @@
use std::collections::BTreeSet;
use axum::{
Router,
extract::{Query, State},
response::{IntoResponse, Redirect},
routing::on,
};
use conduwuit_service::{
oauth::{
client_metadata::{self, ClientMetadata},
grant::{AuthorizationCodeQuery, Scope},
},
rooms::user,
};
use ruma::{OwnedDeviceId, OwnedUserId};
use serde::Deserialize;
use url::Url;
use crate::{
WebError,
extract::{Expect, PostForm},
pages::{
GET_POST, Result,
components::{Avatar, AvatarType},
},
response,
session::{LoginQuery, LoginTarget, User},
template,
};
pub(crate) fn build() -> Router<crate::State> {
Router::new().route("/authorization_code", on(GET_POST, route_authorization_code))
}
template! {
struct Grant use "grant.html.j2" {
logout_query: String,
user_id: OwnedUserId,
user_avatar: Avatar,
client_uri: Url,
client_name: String,
client_avatar: Avatar,
policy_uri: Option<Url>,
tos_uri: Option<Url>,
scopes: BTreeSet<Scope>
}
}
async fn route_authorization_code(
State(services): State<crate::State>,
user: User,
Expect(Query(query)): Expect<Query<AuthorizationCodeQuery>>,
PostForm(form): PostForm<()>,
) -> Result {
let user_id = user.expect(LoginTarget::AuthorizationCode(query.clone()))?;
if form.is_some() {
let redirect_uri = services
.oauth
.request_authorization_code(user_id, query)
.await
.map_err(WebError::BadRequest)?;
return response!(Redirect::to(&redirect_uri));
}
let Some(client) = services.oauth.get_client_metadata(&query.client_id).await else {
return Err(WebError::BadRequest("Invalid client ID".to_owned()));
};
let scopes = query.scope.to_scopes().map_err(WebError::BadRequest)?;
let client_name = if let Some(name) = &client.client_name {
name
} else {
"Unknown application"
}
.to_owned();
let client_avatar = {
let avatar_type = if let Some(logo) = &client.logo_uri {
AvatarType::Image(logo.to_string())
} else if let Some(name) = &client.client_name
&& let Some(char) = name.chars().next()
{
AvatarType::Initial(char)
} else {
AvatarType::Initial('?')
};
Avatar { avatar_type }
};
let user_avatar = Avatar::for_local_user(&services, &user_id).await;
response!(Grant::new(
&services,
serde_urlencoded::to_string(LoginQuery {
next: Some(LoginTarget::AuthorizationCode(query)),
reauthenticate: false,
})
.unwrap(),
user_id,
user_avatar,
client.client_uri.clone(),
client_name,
client_avatar,
client.policy_uri.clone(),
client.tos_uri.clone(),
scopes,
))
}
+10
View File
@@ -0,0 +1,10 @@
use axum::Router;
mod grant;
pub(crate) fn build() -> Router<crate::State> {
#[allow(clippy::wildcard_imports)]
use self::*;
Router::new().nest("/grant/", grant::build())
}
+6 -1
View File
@@ -123,8 +123,9 @@ small.error {
.panel {
--preferred-width: 12rem + 40dvw;
--maximum-width: 48rem;
--minimum-width: 32rem;
width: min(clamp(24rem, var(--preferred-width), var(--maximum-width)), calc(100dvw - 3rem));
width: min(clamp(var(--minimum-width), var(--preferred-width), var(--maximum-width)), calc(100dvw - 3rem));
border-radius: var(--border-radius-lg);
background-color: var(--panel-bg);
padding-inline: 1.5rem;
@@ -184,6 +185,10 @@ a, a:visited {
color: oklch(from var(--c1) var(--name-lightness) c h);
}
code {
color: oklch(from var(--secondary) var(--name-lightness) c h);
}
input, button, a.button {
display: inline-block;
padding: 0.5em;
+1 -1
View File
@@ -17,7 +17,7 @@
background-color: var(--avatar-color);
}
.green-avatar {
.red-avatar {
--avatar-color: var(--c1);
}
+22
View File
@@ -0,0 +1,22 @@
.avatars {
justify-content: center;
display: flex;
flex-direction: row;
.separator {
align-self: center;
margin-inline: 1em;
color: var(--secondary);
font-size: x-large;
font-weight: bold;
user-select: none;
}
}
.identity {
margin-block: 1em;
color: var(--secondary);
font-size: small;
font-style: italic;
text-align: center;
}
@@ -1,19 +1,25 @@
<div class="card">
{{ avatar() }}
{{ avatar }}
<div class="info">
<p class="name">
{% if let Some(display_name) = display_name %}
{{ display_name }}
{% if let Some(metadata) = oauth_metadata %}
<a href="{{ metadata.client_uri }}">{{ display_name }}</a>
{% else %}
{{ display_name }}
{% endif %}
{% else %}
Unknown device
{% endif %}
&nbsp;<span class="id">{{ device_id }}</span>
<span class="id">
&bullet;&nbsp;{{ device_id }}
{% if oauth_metadata.is_none() %}
(legacy)
{% endif %}
</span>
</p>
<p>
Last active: {{ last_active }}
{% if let Some(metadata) = oauth_metadata %}
&nbsp;&bullet;&nbsp;<a href="{{ metadata.client_uri }}">Client information</a>
{% endif %}
</p>
</div>
</div>
@@ -1,5 +1,5 @@
<div class="card green-avatar">
{{ avatar() }}
<div class="card red-avatar">
{{ avatar }}
<div class="info">
{% if let Some(display_name) = display_name %}
<p class="name">{{ display_name }}</p>
+64
View File
@@ -0,0 +1,64 @@
{% extends "_layout.html.j2" %}
{%- block head -%}
<link rel="stylesheet" href="{{ crate::ROUTE_PREFIX }}/resources/grant.css">
{%- endblock -%}
{%- block title -%}
Authorize client
{%- endblock -%}
{%- block content -%}
<div class="panel narrow">
<h1>Authorize {{ client_name }}</h1>
<div class="avatars">
<div class="red-avatar">
{{ user_avatar }}
</div>
<div class="separator" aria-hidden>
</div>
{{ client_avatar }}
</div>
<div class="identity">
Signed in as <code>{{ user_id }}</code>. <a href="{{ crate::ROUTE_PREFIX }}/account/logout?{{ logout_query }}">Switch accounts</a>
</div>
<p>
<b>{{ client_name }}</b> (<a href="{{ client_uri }}">{{ client_uri.domain().unwrap() }}</a>) would like
your permission to:
<ul>
{% for scope in scopes %}
{% match scope %}
{% when Scope::ClientApi %}
<li>Interact with Matrix on your behalf</li>
{% when Scope::Device(_) %}
<li>Connect to your Matrix account</li>
{% endmatch %}
{% endfor %}
</ul>
</p>
{% match (&policy_uri, &tos_uri) %}
{% when (Some(policy_uri), Some(tos_uri)) %}
<p>
{{ client_name }}'s <a href="{{ policy_uri }}">policies</a>
and <a href="{{ tos_uri }}">terms of service</a> apply.
</p>
{% when (Some(policy_uri), None) %}
<p>
{{ client_name }}'s <a href="{{ policy_uri }}">policies</a> apply.
</p>
{% when (None, Some(tos_uri)) %}
<p>
{{ client_name }}'s <a href="{{ tos_uri }}">terms of service</a> apply.
</p>
{% when (None, None) %}
<p>
Make sure you trust {{ client_name }} with access to your data.
</p>
{% endmatch %}
<form method="post">
<button type="submit">Continue</button>
</form>
</div>
{%- endblock -%}
+5 -1
View File
@@ -13,7 +13,11 @@ Log in
{% match body %}
{% when LoginBody::Unauthenticated { server_name } %}
<h1 class="with-matrix-icon">
Log in to Matrix
{% if has_next %}
Log in to continue
{% else %}
Log in to Matrix
{% endif %}
<a href="https://matrix.org" target="_blank" noreferer>
<img class="matrix-icon" alt="Matrix logo" aria-ignore src="{{ crate::ROUTE_PREFIX }}/resources/matrix-icon.svg">
</a>
+37 -12
View File
@@ -1,8 +1,14 @@
use std::time::{Duration, SystemTime};
use std::{
borrow::Cow,
collections::HashMap,
mem::discriminant,
time::{Duration, SystemTime},
};
use axum::{extract::FromRequestParts, http::request::Parts};
use conduwuit_service::oauth::grant::AuthorizationCodeQuery;
use ruma::{OwnedUserId, UserId};
use serde::{Deserialize, Serialize};
use serde::{Deserialize, Deserializer, Serialize};
use tower_sessions::Session;
use crate::{ROUTE_PREFIX, WebError};
@@ -12,7 +18,7 @@ pub(crate) mod store;
#[derive(Debug, Deserialize, Serialize)]
pub(crate) struct LoginQuery {
#[serde(flatten)]
pub next: LoginTarget,
pub next: Option<LoginTarget>,
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
pub reauthenticate: bool,
}
@@ -20,6 +26,7 @@ pub(crate) struct LoginQuery {
#[derive(Debug, Default, Deserialize, Serialize)]
#[serde(tag = "next", rename_all = "snake_case")]
pub(crate) enum LoginTarget {
AuthorizationCode(AuthorizationCodeQuery),
#[default]
Account,
ChangePassword,
@@ -28,14 +35,23 @@ pub(crate) enum LoginTarget {
Deactivate,
}
impl PartialEq for LoginTarget {
fn eq(&self, other: &Self) -> bool { discriminant(self) == discriminant(other) }
}
impl LoginTarget {
pub(crate) fn target_path(&self) -> String {
let path = match self {
| Self::Account => "account/",
| Self::ChangePassword => "account/password/change",
| Self::ChangeEmail => "account/email/change/",
| Self::CrossSigningReset => "account/cross_signing_reset",
| Self::Deactivate => "account/deactivate",
let path: Cow<'_, str> = match self {
| Self::AuthorizationCode(code) => format!(
"oauth2/grant/authorization_code?{}",
serde_urlencoded::to_string(code).unwrap()
)
.into(),
| Self::Account => "account/".into(),
| Self::ChangePassword => "account/password/change".into(),
| Self::ChangeEmail => "account/email/change/".into(),
| Self::CrossSigningReset => "account/cross_signing_reset".into(),
| Self::Deactivate => "account/deactivate".into(),
};
format!("{ROUTE_PREFIX}/{path}")
@@ -80,7 +96,10 @@ impl User {
if let Some(session) = self.0 {
Ok(session.user_id)
} else {
Err(WebError::LoginRequired(LoginQuery { next: or_else, reauthenticate: false }))
Err(WebError::LoginRequired(LoginQuery {
next: Some(or_else),
reauthenticate: false,
}))
}
}
@@ -91,10 +110,16 @@ impl User {
if session.is_recent() {
Ok(session.user_id)
} else {
Err(WebError::LoginRequired(LoginQuery { next: or_else, reauthenticate: true }))
Err(WebError::LoginRequired(LoginQuery {
next: Some(or_else),
reauthenticate: true,
}))
}
} else {
Err(WebError::LoginRequired(LoginQuery { next: or_else, reauthenticate: false }))
Err(WebError::LoginRequired(LoginQuery {
next: Some(or_else),
reauthenticate: false,
}))
}
}
}