feat: Implement oauth auth code and refresh token flows

This commit is contained in:
Ginger
2026-04-30 08:54:55 -04:00
parent f269fb5cfc
commit 13917bb5c3
37 changed files with 1057 additions and 157 deletions
Generated
+2
View File
@@ -1088,6 +1088,7 @@ dependencies = [
"serde",
"serde-saphyr",
"serde_json",
"serde_urlencoded",
"sha2 0.11.0",
"termimad",
"tokio",
@@ -1125,6 +1126,7 @@ dependencies = [
"tower-sessions",
"tower-sessions-core",
"tracing",
"url",
"validator",
]
+3
View File
@@ -559,6 +559,9 @@ features = ["std"]
[workspace.dependencies.nonzero_ext]
version = "0.3.0"
[workspace.dependencies.serde_urlencoded]
version = "0.7.1"
#
# Patches
#
+1
View File
@@ -179,6 +179,7 @@ pub(crate) async fn register_route(
&user_id,
&device_id,
&new_token,
None,
body.initial_device_display_name.clone(),
Some(client.to_string()),
)
+1
View File
@@ -94,6 +94,7 @@ pub(crate) async fn update_device_route(
&device_id,
&appservice.registration.as_token,
None,
None,
Some(client.to_string()),
)
.await?;
+15 -1
View File
@@ -1,17 +1,31 @@
mod register_client;
mod server_metadata;
mod token;
use axum::{
Json, Router,
extract::State,
routing::method_routing::{get, post},
};
use serde_json::json;
pub(crate) use server_metadata::*;
pub(crate) const BASE_PATH: &str = "/_continuwuity/oauth2/";
const BASE_PATH: &str = "/_continuwuity/oauth2/";
pub(crate) fn router() -> Router<crate::State> {
Router::new().nest(BASE_PATH, oauth_router())
// TODO(unspecced): used by old versions of the matrix-js-sdk
// .route("/.well-known/openid-configuration", get(
// async |State(services): State<crate::State>| {
// Json(authorization_server_metadata(&services).await)
// }
// ))
}
fn oauth_router() -> Router<crate::State> {
Router::new()
.route("/client/register", post(register_client::register_client_route))
// TODO(unspecced): used by old versions of the matrix-js-sdk
.route("/client/keys.json", get(async || Json(json!({"keys": []}))))
.route("/grant/token", post(token::token_route))
}
+10 -6
View File
@@ -1,7 +1,8 @@
use axum::extract::State;
use conduwuit::Result;
use ruma::{api::client::discovery::get_authorization_server_metadata, serde::Raw};
use serde_json::json;
use serde_json::{Value, json};
use service::Services;
use crate::Ruma;
@@ -9,13 +10,19 @@ pub(crate) async fn get_authorization_server_metadata_route(
State(services): State<crate::State>,
_body: Ruma<get_authorization_server_metadata::v1::Request>,
) -> Result<get_authorization_server_metadata::v1::Response> {
let metadata = Raw::new(&authorization_server_metadata(&services).await).unwrap();
Ok(get_authorization_server_metadata::v1::Response::new(metadata.cast_unchecked()))
}
pub(crate) async fn authorization_server_metadata(services: &Services) -> Value {
let endpoint_base = services
.config
.get_client_domain()
.join(super::BASE_PATH)
.unwrap();
let metadata = Raw::new(&json!({
json!({
"authorization_endpoint": endpoint_base.join("grant/authorization_code").unwrap(),
"code_challenge_methods_supported": ["S256"],
"grant_types_supported": ["authorization_code", "refresh_token"],
@@ -27,8 +34,5 @@ pub(crate) async fn get_authorization_server_metadata_route(
"response_types_supported": ["code"],
"revocation_endpoint": endpoint_base.join("client/revoke").unwrap(),
"token_endpoint": endpoint_base.join("grant/token").unwrap(),
}))
.unwrap();
Ok(get_authorization_server_metadata::v1::Response::new(metadata.cast_unchecked()))
})
}
+13
View File
@@ -0,0 +1,13 @@
use axum::{Form, Json, extract::State, response::IntoResponse};
use http::StatusCode;
use service::oauth::grant::TokenRequest;
pub(crate) async fn token_route(
State(services): State<crate::State>,
Form(request): Form<TokenRequest>,
) -> impl IntoResponse {
match services.oauth.issue_token(request).await {
| Ok(response) => Ok(Json(response).into_response()),
| Err(err) => Err((StatusCode::BAD_REQUEST, err.message())),
}
}
+2 -1
View File
@@ -202,7 +202,7 @@ pub(crate) async fn login_route(
if device_exists {
services
.users
.set_token(&user_id, &device_id, &token)
.set_token(&user_id, &device_id, &token, None)
.await?;
} else {
services
@@ -211,6 +211,7 @@ pub(crate) async fn login_route(
&user_id,
&device_id,
&token,
None,
body.initial_device_display_name.clone(),
Some(client.to_string()),
)
-1
View File
@@ -69,7 +69,6 @@ pub(crate) async fn sync_events_v5_route(
ClientIp(client_ip): ClientIp,
body: Ruma<sync_events::v5::Request>,
) -> Result<sync_events::v5::Response> {
debug_assert!(DEFAULT_BUMP_TYPES.is_sorted(), "DEFAULT_BUMP_TYPES is not sorted");
let sender_user = body.identity.sender_user();
let sender_device = body.identity.expect_sender_device()?;
+1 -1
View File
@@ -187,7 +187,7 @@ pub fn build(router: Router<State>, server: &Server) -> Router<State> {
.ruma_route(&client::get_rtc_transports)
.ruma_route(&client::room_initial_sync_route)
.ruma_route(&client::get_authorization_server_metadata_route)
.nest(client::oauth::BASE_PATH, client::oauth::router())
.merge(client::oauth::router())
.route("/_conduwuit/server_version", get(client::conduwuit_server_version))
.route("/_continuwuity/server_version", get(client::conduwuit_server_version))
.ruma_route(&admin::rooms::ban::ban_room)
+1 -1
View File
@@ -153,7 +153,7 @@ impl CheckAuth for AccessToken {
query: AuthQueryParams,
route: TypeId,
) -> Result<Self::Identity> {
if let Ok((sender_user, sender_device)) = services.users.find_from_token(&output).await {
if let Some((sender_user, sender_device)) = services.users.find_from_token(&output).await {
// Locked users can only use /logout and /logout/all
if services
.users
+16 -10
View File
@@ -61,17 +61,23 @@ pub fn format(ts: SystemTime, str: &str) -> String {
pub fn pretty(d: Duration) -> String {
use Unit::*;
let fmt = |w, f, u| format!("{w}.{f} {u}");
let gen64 = |w, f, u| fmt(w, (f * 100.0) as u32, u);
let gen128 = |w, f, u| gen64(u64::try_from(w).expect("u128 to u64"), f, u);
let fmt = |w, u| {
if w == 1 {
format!("{w} {u}")
} else {
format!("{w} {u}s")
}
};
let gen64 = |w, u| fmt(w, u);
let gen128 = |w, u| gen64(u64::try_from(w).expect("u128 to u64"), u);
match whole_and_frac(d) {
| (Days(whole), frac) => gen64(whole, frac, "days"),
| (Hours(whole), frac) => gen64(whole, frac, "hours"),
| (Mins(whole), frac) => gen64(whole, frac, "minutes"),
| (Secs(whole), frac) => gen64(whole, frac, "seconds"),
| (Millis(whole), frac) => gen128(whole, frac, "milliseconds"),
| (Micros(whole), frac) => gen128(whole, frac, "microseconds"),
| (Nanos(whole), frac) => gen128(whole, frac, "nanoseconds"),
| (Days(whole), _) => gen64(whole, "day"),
| (Hours(whole), _) => gen64(whole, "hour"),
| (Mins(whole), _) => gen64(whole, "minute"),
| (Secs(whole), _) => gen64(whole, "second"),
| (Millis(whole), _) => gen128(whole, "millisecond"),
| (Micros(whole), _) => gen128(whole, "microsecond"),
| (Nanos(whole), _) => gen128(whole, "nanosecond"),
}
}
+12
View File
@@ -161,6 +161,10 @@ pub(super) static MAPS: &[Descriptor] = &[
name: "referencedevents",
..descriptor::RANDOM
},
Descriptor {
name: "refreshtoken_refreshtokeninfo",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "registrationtoken_info",
..descriptor::RANDOM_SMALL
@@ -375,6 +379,14 @@ pub(super) static MAPS: &[Descriptor] = &[
name: "userdevicetxnid_response",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "userdeviceid_oauthsessioninfo",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "userdeviceid_tokenexpires",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "userfilterid_filter",
..descriptor::RANDOM_SMALL
+1
View File
@@ -119,6 +119,7 @@ recaptcha-verify = { version = "0.2.0", default-features = false }
reqwest_recaptcha = { package = "reqwest", version = "0.12.28", default-features = false, features = ["rustls-tls-native-roots-no-provider"] } # As long as recaptcha-verify's reqwest is outdated
yansi.workspace = true
lettre.workspace = true
serde_urlencoded.workspace = true
[target.'cfg(all(unix, target_os = "linux"))'.dependencies]
sd-notify.workspace = true
+2 -2
View File
@@ -67,7 +67,7 @@ impl crate::Service for Service {
for (id, registration) in appservices {
// During startup, resolve any token collisions in favour of appservices
// by logging out conflicting user devices
if let Ok((user_id, device_id)) = self
if let Some((user_id, device_id)) = self
.services
.users
.find_from_token(&registration.as_token)
@@ -158,7 +158,7 @@ impl Service {
.users
.find_from_token(&registration.as_token)
.await
.is_ok()
.is_some()
{
return Err(err!(Request(InvalidParam(
"Cannot register appservice: The provided token is already in use by a user \
+2 -1
View File
@@ -38,7 +38,7 @@ pub struct ClientMetadata {
}
impl ClientMetadata {
const ACCEPTABLE_LOCALHOSTS: [&str; 3] = ["localhost", "127.0.0.1", "[::1]"];
pub(super) const ACCEPTABLE_LOCALHOSTS: [&str; 3] = ["localhost", "127.0.0.1", "[::1]"];
pub(super) fn validate(&self) -> Result<(), &'static str> {
let Some(client_domain) = self.client_uri.domain() else {
@@ -137,6 +137,7 @@ pub enum GrantType {
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum ResponseType {
Code,
}
+150
View File
@@ -0,0 +1,150 @@
use std::{
collections::{BTreeSet, HashSet},
fmt::Debug,
hash::Hash,
mem::discriminant,
};
use regex::Regex;
use ruma::OwnedDeviceId;
use serde::{Deserialize, Serialize};
use url::Url;
use super::client_metadata::ResponseType;
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct AuthorizationCodeQuery {
pub response_type: ResponseType,
pub client_id: String,
pub redirect_uri: Url,
pub scope: RawScopes,
pub state: String,
#[serde(default)]
pub response_mode: ResponseMode,
pub code_challenge: String,
pub code_challenge_method: CodeChallengeMethod,
}
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum ResponseMode {
#[default]
// default for `code` response type, see https://openid.net/specs/oauth-v2-multiple-response-types-1_0.html#:~:text=Client%2E-,For,encoding%2E,-See
Query,
Fragment,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[non_exhaustive]
pub enum CodeChallengeMethod {
S256,
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialOrd, Ord)]
pub enum Scope {
Device(OwnedDeviceId),
ClientApi,
}
impl PartialEq for Scope {
fn eq(&self, other: &Self) -> bool { discriminant(self) == discriminant(other) }
}
impl Eq for Scope {}
impl Hash for Scope {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) { discriminant(self).hash(state); }
}
impl std::fmt::Display for Scope {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let urn = match self {
| Self::ClientApi => "urn:matrix:client:api:*".to_owned(),
| Self::Device(device_id) => format!("urn:matrix:client:device:{device_id}"),
};
f.write_str(&urn)
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct RawScopes(String);
impl RawScopes {
pub fn to_scopes(&self) -> Result<BTreeSet<Scope>, String> {
let client_api_token_regex =
Regex::new(r"urn:matrix:(client|org.matrix.msc2967.client):api:\*").unwrap();
let device_token_regex = Regex::new(
r"urn:matrix:(client|org.matrix.msc2967.client):device:([a-zA-Z0-9-._~]{5,})",
)
.unwrap();
let mut scopes = BTreeSet::new();
for token in self.0.split(' ') {
let scope_was_new = {
if client_api_token_regex.is_match(token) {
scopes.insert(Scope::ClientApi)
} else if let Some(captures) = device_token_regex.captures(token) {
scopes.insert(Scope::Device(captures.get(2).unwrap().as_str().into()))
} else if token == "openid" {
// TODO(unspecced): Element sets this scope but doesn't use it for anything
true
} else {
return Err(format!("Invalid scope: {token}"));
}
};
if !scope_was_new {
return Err("Scope was specified more than once".to_owned());
}
}
Ok(scopes)
}
}
#[derive(Serialize)]
pub struct AuthorizationCodeResponse {
pub state: String,
pub code: String,
}
#[derive(Deserialize)]
#[serde(tag = "grant_type", rename_all = "snake_case")]
pub enum TokenRequest {
AuthorizationCode {
code: String,
redirect_uri: Url,
client_id: String,
code_verifier: String,
},
RefreshToken {
client_id: String,
refresh_token: String,
},
}
impl TokenRequest {
pub fn client_id(&self) -> &str {
match self {
| Self::AuthorizationCode { client_id, .. }
| Self::RefreshToken { client_id, .. } => client_id,
}
}
}
#[derive(Serialize)]
pub struct TokenResponse {
pub access_token: String,
pub token_type: TokenType,
pub expires_in: u64,
pub refresh_token: String,
pub scope: String,
}
#[derive(Serialize)]
pub enum TokenType {
Bearer,
}
+370 -9
View File
@@ -1,30 +1,93 @@
use std::{
collections::HashMap,
collections::{BTreeMap, BTreeSet, HashMap},
sync::{Arc, Mutex},
time::{Duration, SystemTime},
};
use base64::Engine;
use conduwuit::{Result, utils::hash::sha256};
use conduwuit::{
Err, Result, err, info,
utils::{self, hash::sha256},
};
use database::{Deserialized, Json, Map};
use ruma::DeviceId;
use itertools::Itertools;
use ruma::{DeviceId, OwnedDeviceId, OwnedUserId, UserId};
use serde::{Deserialize, Serialize};
use url::Url;
use crate::{Dep, config, oauth::client_metadata::ClientMetadata};
use crate::{
Dep, config,
oauth::{
client_metadata::{ApplicationType, ClientMetadata, ResponseType},
grant::{
AuthorizationCodeQuery, AuthorizationCodeResponse, CodeChallengeMethod, ResponseMode,
Scope, TokenRequest, TokenResponse, TokenType,
},
},
users,
};
pub mod client_metadata;
pub mod grant;
pub struct Service {
services: Services,
db: Data,
tickets: Mutex<HashMap<String, HashMap<OAuthTicket, SystemTime>>>,
pending_code_grants: tokio::sync::Mutex<HashMap<String, PendingCodeGrant>>,
}
struct Data {
clientid_clientmetadata: Arc<Map>,
userdeviceid_oauthsessioninfo: Arc<Map>,
refreshtoken_refreshtokeninfo: Arc<Map>,
}
struct Services {
config: Dep<config::Service>,
users: Dep<users::Service>,
}
#[derive(Deserialize, Serialize)]
struct SessionInfo {
client_id: String,
current_refresh_token: String,
scopes: BTreeSet<Scope>,
}
#[derive(Deserialize, Serialize)]
struct RefreshTokenInfo {
client_id: String,
user_id: OwnedUserId,
device_id: OwnedDeviceId,
}
struct PendingCodeGrant {
authorizing_user: OwnedUserId,
requested_scopes: BTreeSet<Scope>,
client_name: Option<String>,
expected_client_id: String,
expected_redirect_uri: Url,
code_challenge: String,
requested_at: SystemTime,
}
impl PendingCodeGrant {
const MAX_AGE: Duration = Duration::from_mins(1);
const RANDOM_CODE_LENGTH: usize = 32;
#[must_use]
pub fn generate_code() -> String { utils::random_string(Self::RANDOM_CODE_LENGTH) }
#[must_use]
pub fn is_valid_for(&self, client_id: &str) -> bool {
let now = SystemTime::now();
self.expected_client_id == client_id
&& now
.duration_since(self.requested_at)
.is_ok_and(|age| age < Self::MAX_AGE)
}
}
/// A time-limited grant for a client to perform some sensitive action.
@@ -49,11 +112,15 @@ impl crate::Service for Service {
Ok(Arc::new(Self {
services: Services {
config: args.depend::<config::Service>("config"),
users: args.depend::<users::Service>("users"),
},
db: Data {
clientid_clientmetadata: args.db["clientid_clientmetadata"].clone(),
userdeviceid_oauthsessioninfo: args.db["userdeviceid_oauthsessioninfo"].clone(),
refreshtoken_refreshtokeninfo: args.db["refreshtoken_refreshtokeninfo"].clone(),
},
tickets: Mutex::default(),
pending_code_grants: tokio::sync::Mutex::default(),
}))
}
@@ -61,6 +128,11 @@ impl crate::Service for Service {
}
impl Service {
const ACCESS_TOKEN_MAX_AGE: Duration = Duration::from_hours(1);
const RANDOM_TOKEN_LENGTH: usize = 32;
fn generate_token() -> String { utils::random_string(Self::RANDOM_TOKEN_LENGTH) }
pub async fn register_client(
&self,
metadata: &ClientMetadata,
@@ -85,7 +157,7 @@ impl Service {
Ok(client_id)
}
pub async fn get_client_registration(&self, client_id: &str) -> Option<ClientMetadata> {
pub async fn get_client_metadata(&self, client_id: &str) -> Option<ClientMetadata> {
self.db
.clientid_clientmetadata
.get(client_id)
@@ -94,15 +166,304 @@ impl Service {
.ok()
}
pub async fn get_client_id_for_device(&self, _device_id: &DeviceId) -> Option<String> {
None // TODO
pub async fn get_client_id_for_device(
&self,
user_id: &UserId,
device_id: &DeviceId,
) -> Option<String> {
self.db
.userdeviceid_oauthsessioninfo
.qry(&(user_id, device_id))
.await
.deserialized::<SessionInfo>()
.ok()
.map(|session| session.client_id)
}
pub async fn request_authorization_code(
&self,
authorizing_user: OwnedUserId,
query: AuthorizationCodeQuery,
) -> Result<String, String> {
let Some(client_metadata) = self.get_client_metadata(&query.client_id).await else {
return Err("Invalid client ID".to_owned());
};
if !(client_metadata
.response_types
.contains(&query.response_type)
&& matches!(query.response_type, ResponseType::Code))
{
return Err("Invalid response type".to_owned());
}
if !matches!(query.code_challenge_method, CodeChallengeMethod::S256) {
return Err("Invalid code challenge type".to_owned());
}
{
let mut stripped_uri = query.redirect_uri.clone();
if client_metadata.application_type == ApplicationType::Native
&& query
.redirect_uri
.host_str()
.is_some_and(|host| ClientMetadata::ACCEPTABLE_LOCALHOSTS.contains(&host))
{
// Remove the port from localhost redirect URIs for native applications when
// checking if it's valid
stripped_uri.set_port(None).unwrap();
}
if !client_metadata.redirect_uris.contains(&stripped_uri) {
return Err("Invalid redirect URI".to_owned());
}
}
let requested_scopes = query.scope.to_scopes()?;
let redirect_uri_query_separator = match query.response_mode {
| ResponseMode::Fragment => '#',
| ResponseMode::Query => '?',
};
let code = PendingCodeGrant::generate_code();
info!(
client_id = &query.client_id,
client_name = &client_metadata.client_name,
?requested_scopes,
?authorizing_user,
"Issuing oauth authorization code"
);
let redirect_uri = format!(
"{}{}{}",
query.redirect_uri,
redirect_uri_query_separator,
serde_urlencoded::to_string(AuthorizationCodeResponse {
state: query.state,
code: code.clone(),
})
.unwrap(),
);
let pending_grant = PendingCodeGrant {
authorizing_user,
requested_scopes,
client_name: client_metadata.client_name,
expected_client_id: query.client_id,
expected_redirect_uri: query.redirect_uri,
code_challenge: query.code_challenge,
requested_at: SystemTime::now(),
};
self.pending_code_grants
.lock()
.await
.insert(code, pending_grant);
Ok(redirect_uri)
}
pub async fn issue_token(&self, request: TokenRequest) -> Result<TokenResponse> {
match request {
| TokenRequest::AuthorizationCode {
code,
redirect_uri,
client_id,
code_verifier,
} => {
let mut pending_grants = self.pending_code_grants.lock().await;
let Some(pending_grant) = pending_grants
.remove(&code)
.filter(|grant| grant.is_valid_for(&client_id))
else {
return Err!("Invalid code");
};
if redirect_uri != pending_grant.expected_redirect_uri {
return Err!("Unexpected redirect uri");
}
let expected_code_challenge =
base64::prelude::BASE64_URL_SAFE_NO_PAD.encode(sha256::hash(&code_verifier));
if expected_code_challenge != pending_grant.code_challenge {
return Err!("Invalid code challenge");
}
self.create_session(
pending_grant.authorizing_user,
pending_grant.requested_scopes,
pending_grant.client_name,
client_id,
)
.await
},
| TokenRequest::RefreshToken { client_id, refresh_token } =>
self.refresh_session(client_id, refresh_token).await,
}
}
async fn create_session(
&self,
authorizing_user: OwnedUserId,
requested_scopes: BTreeSet<Scope>,
client_name: Option<String>,
client_id: String,
) -> Result<TokenResponse> {
let access_token = Self::generate_token();
let refresh_token = Self::generate_token();
let device_id = requested_scopes
.iter()
.find_map(|scope| {
if let Scope::Device(device_id) = scope {
Some(device_id)
} else {
None
}
})
.ok_or_else(|| err!("No device ID scope supplied"))?;
self.services
.users
.create_device(
&authorizing_user,
device_id,
&access_token,
Some(Self::ACCESS_TOKEN_MAX_AGE),
client_name,
None,
)
.await?;
self.db.userdeviceid_oauthsessioninfo.put(
(&authorizing_user, device_id),
Json(SessionInfo {
client_id: client_id.clone(),
current_refresh_token: refresh_token.clone(),
scopes: requested_scopes.clone(),
}),
);
self.db.refreshtoken_refreshtokeninfo.raw_put(
&refresh_token,
Json(RefreshTokenInfo {
client_id: client_id.clone(),
user_id: authorizing_user.clone(),
device_id: device_id.to_owned(),
}),
);
info!(
?client_id,
?authorizing_user,
?device_id,
?requested_scopes,
"Created new oauth session"
);
Ok(TokenResponse {
access_token,
token_type: TokenType::Bearer,
expires_in: Self::ACCESS_TOKEN_MAX_AGE.as_secs(),
scope: requested_scopes.iter().join(" "),
refresh_token,
})
}
async fn refresh_session(
&self,
client_id: String,
refresh_token: String,
) -> Result<TokenResponse> {
let Some(refresh_token_info) = self
.db
.refreshtoken_refreshtokeninfo
.get(&refresh_token)
.await
.deserialized::<RefreshTokenInfo>()
.ok()
else {
return Err!("Invalid refresh token");
};
assert_eq!(&client_id, &refresh_token_info.client_id, "refresh token client id mismatch");
let mut session_info = self
.db
.userdeviceid_oauthsessioninfo
.qry(&(&refresh_token_info.user_id, &refresh_token_info.device_id))
.await
.deserialized::<SessionInfo>()
.expect("session info should exist");
assert_eq!(&client_id, &session_info.client_id, "session info client id mismatch");
let new_access_token = Self::generate_token();
let new_refresh_token = Self::generate_token();
let scope = session_info.scopes.iter().join(" ");
session_info
.current_refresh_token
.clone_from(&new_refresh_token);
self.services
.users
.set_token(
&refresh_token_info.user_id,
&refresh_token_info.device_id,
&new_access_token,
Some(Self::ACCESS_TOKEN_MAX_AGE),
)
.await?;
self.db.userdeviceid_oauthsessioninfo.put(
(&refresh_token_info.user_id, &refresh_token_info.device_id),
Json(session_info),
);
self.db.refreshtoken_refreshtokeninfo.remove(&refresh_token);
drop(refresh_token);
self.db
.refreshtoken_refreshtokeninfo
.raw_put(&new_refresh_token, Json(refresh_token_info));
Ok(TokenResponse {
access_token: new_access_token,
token_type: TokenType::Bearer,
expires_in: Self::ACCESS_TOKEN_MAX_AGE.as_secs(),
scope,
refresh_token: new_refresh_token,
})
}
pub async fn remove_session(&self, user_id: &UserId, device_id: &DeviceId) {
let session_info = self
.db
.userdeviceid_oauthsessioninfo
.qry(&(user_id, device_id))
.await
.deserialized::<SessionInfo>()
.ok();
if let Some(session_info) = session_info {
self.db
.refreshtoken_refreshtokeninfo
.remove(&session_info.current_refresh_token);
self.db
.userdeviceid_oauthsessioninfo
.del(&(user_id, device_id));
info!(?user_id, ?device_id, "Removed OAuth session");
}
}
/// Issue a ticket for `localpart` to perform some action.
pub fn issue_ticket(&self, localpart: String, ticket: OAuthTicket) {
self.tickets
.lock()
.expect("should be able to lock tickets")
.unwrap()
.entry(localpart)
.or_default()
.insert(ticket, SystemTime::now());
@@ -114,7 +475,7 @@ impl Service {
self.tickets
.lock()
.expect("should be able to lock tickets")
.unwrap()
.get_mut(localpart)
.and_then(|tickets| tickets.remove(&ticket))
.is_some_and(|issued| {
+23 -12
View File
@@ -253,17 +253,17 @@ impl Service {
let mut info = assign!(UiaaInfo::new(flows), { params: Some(params), session: Some(session_id.clone()) });
let session_metadata = if let Some(initiator) = initiator {
let is_oauth = OptionFuture::from(
initiator.device_id.map(async |device_id| {
self
.services
.oauth
.get_client_id_for_device(device_id)
.await
})
)
.await
.is_some();
let is_oauth = if let Some(device_id) = initiator.device_id {
self
.services
.oauth
.get_client_id_for_device(initiator.user_id, device_id)
.await
.is_some()
} else {
// Appservices never have oauth sessions
false
};
if is_oauth {
if let Some(oauth_ticket) = initiator.oauth_ticket {
@@ -279,7 +279,18 @@ impl Service {
.unwrap();
info.flows = vec![AuthFlow::new(vec![AuthType::OAuth])];
info.params = Some(to_raw_value(&json!({"url": ticket_url})).unwrap());
info.params = Some(
to_raw_value(&json!({
AuthType::OAuth.as_str(): {
"url": ticket_url,
},
// TODO(compat): This is necessary for older versions of matrix-rust-sdk
"org.matrix.cross_signing_reset": {
"url": ticket_url,
}
}))
.unwrap(),
);
UiaaSessionMetadata::OAuth {
localpart: initiator.user_id.localpart().to_owned(),
+1
View File
@@ -54,6 +54,7 @@ pub async fn set_dehydrated_device(&self, user_id: &UserId, request: Request) ->
user_id,
&request.device_id,
"",
None,
request.initial_device_display_name.clone(),
None,
)
+64 -5
View File
@@ -1,6 +1,12 @@
pub(super) mod dehydrated_device;
use std::{collections::BTreeMap, mem, net::IpAddr, sync::Arc};
use std::{
collections::BTreeMap,
mem,
net::IpAddr,
sync::Arc,
time::{Duration, SystemTime},
};
use conduwuit::{
Err, Error, Result, Server, debug_error, debug_warn, err, trace,
@@ -26,7 +32,7 @@ use ruminuwuity::invite_permission_config::{FilterLevel, InvitePermissionConfigE
use serde::{Deserialize, Serialize};
use serde_json::json;
use crate::{Dep, account_data, admin, appservice, globals, rooms};
use crate::{Dep, account_data, admin, appservice, globals, oauth, rooms};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UserSuspension {
@@ -62,6 +68,7 @@ struct Services {
admin: Dep<admin::Service>,
appservice: Dep<appservice::Service>,
globals: Dep<globals::Service>,
oauth: Dep<oauth::Service>,
state_accessor: Dep<rooms::state_accessor::Service>,
state_cache: Dep<rooms::state_cache::Service>,
}
@@ -75,6 +82,7 @@ struct Data {
logintoken_expiresatuserid: Arc<Map>,
todeviceid_events: Arc<Map>,
token_userdeviceid: Arc<Map>,
userdeviceid_tokenexpires: Arc<Map>,
userdeviceid_metadata: Arc<Map>,
userdeviceid_token: Arc<Map>,
userfilterid_filter: Arc<Map>,
@@ -102,6 +110,7 @@ impl crate::Service for Service {
admin: args.depend::<admin::Service>("admin"),
appservice: args.depend::<appservice::Service>("appservice"),
globals: args.depend::<globals::Service>("globals"),
oauth: args.depend::<oauth::Service>("oauth"),
state_accessor: args
.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"),
@@ -131,6 +140,7 @@ impl crate::Service for Service {
userid_selfsigningkeyid: args.db["userid_selfsigningkeyid"].clone(),
userid_usersigningkeyid: args.db["userid_usersigningkeyid"].clone(),
useridprofilekey_value: args.db["useridprofilekey_value"].clone(),
userdeviceid_tokenexpires: args.db["userdeviceid_tokenexpires"].clone(),
},
}))
}
@@ -337,8 +347,37 @@ impl Service {
pub async fn count(&self) -> usize { self.db.userid_password.count().await }
/// Find out which user an access token belongs to.
pub async fn find_from_token(&self, token: &str) -> Result<(OwnedUserId, OwnedDeviceId)> {
self.db.token_userdeviceid.get(token).await.deserialized()
pub async fn find_from_token(&self, token: &str) -> Option<(OwnedUserId, OwnedDeviceId)> {
let user = self
.db
.token_userdeviceid
.get(token)
.await
.deserialized()
.ok();
// Check if the token has expired
if let Some(user) = &user {
if let Some(expires) = self
.db
.userdeviceid_tokenexpires
.qry(user)
.await
.deserialized::<u64>()
.ok()
.map(Duration::from_secs)
{
let expires_at = SystemTime::UNIX_EPOCH
.checked_add(expires)
.expect("expiry time should not overflow SystemTime");
if SystemTime::now() > expires_at {
return None;
}
}
}
user
}
/// Returns an iterator over all users on this homeserver.
@@ -434,6 +473,7 @@ impl Service {
user_id: &UserId,
device_id: &DeviceId,
token: &str,
token_max_age: Option<Duration>,
initial_device_display_name: Option<String>,
client_ip: Option<String>,
) -> Result<()> {
@@ -451,7 +491,8 @@ impl Service {
increment(&self.db.userid_devicelistversion, user_id.as_bytes());
self.db.userdeviceid_metadata.put(key, Json(device));
self.set_token(user_id, device_id, token).await
self.set_token(user_id, device_id, token, token_max_age)
.await
}
/// Removes a device from a user.
@@ -467,6 +508,7 @@ impl Service {
if let Ok(old_token) = self.db.userdeviceid_token.qry(&userdeviceid).await {
self.db.userdeviceid_token.del(userdeviceid);
self.db.token_userdeviceid.remove(&old_token);
self.db.userdeviceid_tokenexpires.del(userdeviceid);
}
// Remove todevice events
@@ -480,6 +522,9 @@ impl Service {
// TODO: Remove onetimekeys
// Remove OAuth session information
self.services.oauth.remove_session(user_id, device_id).await;
increment(&self.db.userid_devicelistversion, user_id.as_bytes());
self.db.userdeviceid_metadata.del(userdeviceid);
@@ -535,6 +580,7 @@ impl Service {
user_id: &UserId,
device_id: &DeviceId,
token: &str,
token_max_age: Option<Duration>,
) -> Result<()> {
let key = (user_id, device_id);
if self.db.userdeviceid_metadata.qry(&key).await.is_err() {
@@ -561,6 +607,7 @@ impl Service {
// Remove old token
if let Ok(old_token) = self.db.userdeviceid_token.qry(&key).await {
self.db.token_userdeviceid.remove(&old_token);
self.db.userdeviceid_tokenexpires.remove(&old_token);
// It will be removed from userdeviceid_token by the insert later
}
@@ -568,6 +615,18 @@ impl Service {
self.db.userdeviceid_token.put_raw(key, token);
self.db.token_userdeviceid.raw_put(token, key);
if let Some(max_age) = token_max_age {
let expires = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.expect("system time should not be before the epoch")
.saturating_add(max_age)
.as_secs();
self.db.userdeviceid_tokenexpires.put(key, expires);
} else {
self.db.userdeviceid_tokenexpires.del(key);
}
Ok(())
}
+2 -1
View File
@@ -43,7 +43,8 @@ validator = { version = "0.20.0", features = ["derive"] }
tower-sec-fetch = { version = "0.1.2", features = ["tracing"] }
tower-sessions = { version = "0.15.0", default-features = false, features = ["axum-core"] }
tower-sessions-core = { version = "0.15.0", features = ["deletion-task"] }
serde_urlencoded = "0.7.1"
serde_urlencoded.workspace = true
url.workspace = true
[build-dependencies]
memory-serve = "2.1.0"
+2 -1
View File
@@ -127,6 +127,7 @@ pub fn build(services: &Services) -> Router<state::State> {
Router::new()
.nest("/account/", account::build())
.merge(debug::build())
.nest("/oauth2/", oauth::build())
.merge(resources::build())
.merge(threepid::build())
.fallback(async || WebError::NotFound),
@@ -145,7 +146,7 @@ pub fn build(services: &Services) -> Router<state::State> {
}))
.layer(SetResponseHeaderLayer::if_not_present(
header::CONTENT_SECURITY_POLICY,
HeaderValue::from_static("default-src 'self'; img-src 'self' data:;"),
HeaderValue::from_static("default-src 'self'; img-src 'self' https: data:;"),
))
.layer(SecFetchLayer::new(|policy| {
policy.allow_safe_methods().reject_missing_metadata();
+12 -12
View File
@@ -2,7 +2,7 @@ use std::time::SystemTime;
use axum::{
Router,
extract::{Query, State},
extract::{Query, RawQuery, State},
response::{IntoResponse, Redirect},
routing::{get, on},
};
@@ -15,11 +15,11 @@ use serde::Deserialize;
use tower_sessions::Session;
use crate::{
WebError,
ROUTE_PREFIX, WebError,
extract::{Expect, PostForm},
pages::{GET_POST, Result, components::UserCard},
response,
session::{LoginQuery, User, UserSession},
session::{LoginQuery, LoginTarget, User, UserSession},
template,
};
@@ -32,6 +32,7 @@ pub(crate) fn build() -> Router<crate::State> {
template! {
struct Login use "login.html.j2" {
body: LoginBody,
has_next: bool,
login_error: Option<String>
}
}
@@ -54,11 +55,12 @@ struct LoginForm {
async fn route_login(
State(services): State<crate::State>,
Expect(Query(query)): Expect<Query<LoginQuery>>,
Expect(Query(LoginQuery { next, reauthenticate })): Expect<Query<LoginQuery>>,
session_store: Session,
user: User,
PostForm(form): PostForm<LoginForm>,
) -> Result {
let next = next.unwrap_or_default();
let user_id = user.into_session().map(|session| session.user_id);
let body = match &user_id {
@@ -66,8 +68,8 @@ async fn route_login(
server_name: services.globals.server_name().to_string(),
},
| Some(user_id) => {
if !query.reauthenticate {
return response!(Redirect::to(&query.next.target_path()));
if !reauthenticate {
return response!(Redirect::to(&next.target_path()));
}
let user_card = UserCard::for_local_user(&services, user_id.to_owned()).await;
@@ -76,7 +78,7 @@ async fn route_login(
},
};
let mut template = Login::new(&services, body, None);
let mut template = Login::new(&services, body, next != LoginTarget::Account, None);
if let Some(form) = form {
let login_result = match (user_id, form.identifier) {
@@ -86,8 +88,6 @@ async fn route_login(
},
| (None, Some(identifier)) => {
// The user isn't authenticated, we need to log them in
// Yes, this does parse the email twice (handle_login does it again). I don't
// think this really needs to be optimized.
let identifier = if identifier.parse::<lettre::Address>().is_ok() {
UserIdentifier::Email(EmailUserIdentifier::new(identifier))
} else {
@@ -123,14 +123,14 @@ async fn route_login(
.await
.expect("should be able to serialize user session");
return response!(Redirect::to(&query.next.target_path()));
return response!(Redirect::to(&next.target_path()));
}
response!(template)
}
async fn get_logout(session: Session) -> impl IntoResponse {
async fn get_logout(session: Session, RawQuery(query): RawQuery) -> impl IntoResponse {
let _ = session.remove::<OwnedUserId>(User::KEY).await;
Redirect::to("/_continuwuity/account/")
Redirect::to(&format!("{}/account/login?{}", ROUTE_PREFIX, query.unwrap_or_default()))
}
+3 -1
View File
@@ -61,12 +61,14 @@ async fn get_account(
let user_card = UserCard::for_local_user(&services, user_id.clone()).await;
let devices = services
let mut devices: Vec<_> = services
.users
.all_device_ids(&user_id)
.then(async |device_id| DeviceCard::for_device(&services, &user_id, device_id).await)
.collect()
.await;
devices.sort_unstable_by(|a, b| a.last_seen_ts.cmp(&b.last_seen_ts).reverse());
response!(Account::new(&services, user_card, email_requirement, email, devices))
}
+75 -69
View File
@@ -4,36 +4,26 @@ use askama::{Template, filters::HtmlSafe};
use base64::Engine;
use conduwuit_core::{result::FlatOk, utils};
use conduwuit_service::{Services, media::mxc::Mxc, oauth::client_metadata::ClientMetadata};
use ruma::{OwnedDeviceId, OwnedUserId, UserId};
use ruma::{MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedUserId, UserId};
pub(super) mod form;
#[derive(Debug)]
pub(super) enum AvatarType<'a> {
pub(super) enum AvatarType {
Initial(char),
Image(&'a str),
Image(String),
}
#[derive(Debug, Template)]
#[template(path = "_components/avatar.html.j2")]
pub(super) struct Avatar<'a> {
pub(super) avatar_type: AvatarType<'a>,
pub(super) struct Avatar {
pub(super) avatar_type: AvatarType,
}
impl HtmlSafe for Avatar<'_> {}
impl HtmlSafe for Avatar {}
#[derive(Debug, Template)]
#[template(path = "_components/user_card.html.j2")]
pub(super) struct UserCard {
pub user_id: OwnedUserId,
pub display_name: Option<String>,
pub avatar_src: Option<String>,
}
impl HtmlSafe for UserCard {}
impl UserCard {
pub(super) async fn for_local_user(services: &Services, user_id: OwnedUserId) -> Self {
impl Avatar {
pub(super) async fn for_local_user(services: &Services, user_id: &UserId) -> Self {
let display_name = services.users.displayname(&user_id).await.ok();
let avatar_src = async {
@@ -56,33 +46,48 @@ impl UserCard {
}
.await;
Self { user_id, display_name, avatar_src }
}
fn avatar(&self) -> Avatar<'_> {
let avatar_type = if let Some(ref avatar_src) = self.avatar_src {
let avatar_type = if let Some(avatar_src) = avatar_src {
AvatarType::Image(avatar_src)
} else if let Some(initial) = self
.display_name
} else if let Some(initial) = display_name
.as_ref()
.and_then(|display_name| display_name.chars().next())
{
AvatarType::Initial(initial)
} else {
AvatarType::Initial(self.user_id.localpart().chars().next().unwrap())
AvatarType::Initial(user_id.localpart().chars().next().unwrap())
};
Avatar { avatar_type }
}
}
#[derive(Debug, Template)]
#[template(path = "_components/user_card.html.j2")]
pub(super) struct UserCard {
pub user_id: OwnedUserId,
pub display_name: Option<String>,
pub avatar: Avatar,
}
impl HtmlSafe for UserCard {}
impl UserCard {
pub(super) async fn for_local_user(services: &Services, user_id: OwnedUserId) -> Self {
let display_name = services.users.displayname(&user_id).await.ok();
let avatar = Avatar::for_local_user(services, &user_id).await;
Self { user_id, display_name, avatar }
}
}
#[derive(Debug, Template)]
#[template(path = "_components/device_card.html.j2")]
pub(super) struct DeviceCard {
pub device_id: OwnedDeviceId,
pub display_name: Option<String>,
pub avatar_src: Option<String>,
pub avatar: Avatar,
pub last_active: String,
pub last_seen_ts: Option<u64>,
pub oauth_metadata: Option<ClientMetadata>,
}
@@ -101,12 +106,15 @@ impl DeviceCard {
.ok();
let oauth_metadata = async {
let client_id = services.oauth.get_client_id_for_device(&device_id).await?;
let client_id = services
.oauth
.get_client_id_for_device(user_id, &device_id)
.await?;
Some(
services
.oauth
.get_client_registration(&client_id)
.get_client_metadata(&client_id)
.await
.expect("client should exist"),
)
@@ -122,53 +130,51 @@ impl DeviceCard {
.and_then(|device| device.display_name.clone())
});
let avatar_src = oauth_metadata
.as_ref()
.and_then(|metadata| metadata.logo_uri.as_ref())
.map(|uri| uri.as_str().to_owned());
let avatar = {
let avatar_src = oauth_metadata
.as_ref()
.and_then(|metadata| metadata.logo_uri.as_ref())
.map(|uri| uri.as_str().to_owned());
let last_active = device
.as_ref()
.and_then(|device| device.last_seen_ts)
.map_or_else(
|| "unknown".to_owned(),
|active| {
active
.to_system_time()
.and_then(|t| SystemTime::now().duration_since(t).ok())
.map_or_else(
|| "now".to_owned(),
|duration| format!("{} ago", utils::time::pretty(duration)),
)
},
);
let avatar_type = if let Some(avatar_src) = avatar_src {
AvatarType::Image(avatar_src)
} else if let Some(initial) =
display_name.as_ref().and_then(|name| name.chars().next())
{
if oauth_metadata.is_some() {
AvatarType::Initial(initial)
} else {
AvatarType::Initial('❖')
}
} else {
AvatarType::Initial('?')
};
Avatar { avatar_type }
};
let last_seen_ts = device.as_ref().and_then(|device| device.last_seen_ts);
let last_active = last_seen_ts.map_or_else(
|| "unknown".to_owned(),
|last_seen_ts| {
last_seen_ts
.to_system_time()
.and_then(|t| SystemTime::now().duration_since(t).ok())
.map_or_else(
|| "now".to_owned(),
|duration| format!("{} ago", utils::time::pretty(duration)),
)
},
);
Self {
device_id,
display_name,
avatar_src,
avatar,
last_active,
last_seen_ts: last_seen_ts.map(|last_seen_ts| last_seen_ts.as_secs().into()),
oauth_metadata,
}
}
fn avatar(&self) -> Avatar<'_> {
let avatar_type = if let Some(avatar_src) = &self.avatar_src {
AvatarType::Image(avatar_src.as_str())
} else if let Some(initial) = self
.display_name
.as_ref()
.and_then(|name| name.chars().next())
{
if self.oauth_metadata.is_some() {
AvatarType::Initial(initial)
} else {
AvatarType::Initial('❖')
}
} else {
AvatarType::Initial('?')
};
Avatar { avatar_type }
}
}
+1
View File
@@ -6,6 +6,7 @@ pub(super) mod account;
mod components;
pub(super) mod debug;
pub(super) mod index;
pub(super) mod oauth;
pub(super) mod resources;
pub(super) mod threepid;
+113
View File
@@ -0,0 +1,113 @@
use std::collections::BTreeSet;
use axum::{
Router,
extract::{Query, State},
response::{IntoResponse, Redirect},
routing::on,
};
use conduwuit_service::{
oauth::{
client_metadata::{self, ClientMetadata},
grant::{AuthorizationCodeQuery, Scope},
},
rooms::user,
};
use ruma::{OwnedDeviceId, OwnedUserId};
use serde::Deserialize;
use url::Url;
use crate::{
WebError,
extract::{Expect, PostForm},
pages::{
GET_POST, Result,
components::{Avatar, AvatarType},
},
response,
session::{LoginQuery, LoginTarget, User},
template,
};
pub(crate) fn build() -> Router<crate::State> {
Router::new().route("/authorization_code", on(GET_POST, route_authorization_code))
}
template! {
struct Grant use "grant.html.j2" {
logout_query: String,
user_id: OwnedUserId,
user_avatar: Avatar,
client_uri: Url,
client_name: String,
client_avatar: Avatar,
policy_uri: Option<Url>,
tos_uri: Option<Url>,
scopes: BTreeSet<Scope>
}
}
async fn route_authorization_code(
State(services): State<crate::State>,
user: User,
Expect(Query(query)): Expect<Query<AuthorizationCodeQuery>>,
PostForm(form): PostForm<()>,
) -> Result {
let user_id = user.expect(LoginTarget::AuthorizationCode(query.clone()))?;
if form.is_some() {
let redirect_uri = services
.oauth
.request_authorization_code(user_id, query)
.await
.map_err(WebError::BadRequest)?;
return response!(Redirect::to(&redirect_uri));
}
let Some(client) = services.oauth.get_client_metadata(&query.client_id).await else {
return Err(WebError::BadRequest("Invalid client ID".to_owned()));
};
let scopes = query.scope.to_scopes().map_err(WebError::BadRequest)?;
let client_name = if let Some(name) = &client.client_name {
name
} else {
"Unknown application"
}
.to_owned();
let client_avatar = {
let avatar_type = if let Some(logo) = &client.logo_uri {
AvatarType::Image(logo.to_string())
} else if let Some(name) = &client.client_name
&& let Some(char) = name.chars().next()
{
AvatarType::Initial(char)
} else {
AvatarType::Initial('?')
};
Avatar { avatar_type }
};
let user_avatar = Avatar::for_local_user(&services, &user_id).await;
response!(Grant::new(
&services,
serde_urlencoded::to_string(LoginQuery {
next: Some(LoginTarget::AuthorizationCode(query)),
reauthenticate: false,
})
.unwrap(),
user_id,
user_avatar,
client.client_uri.clone(),
client_name,
client_avatar,
client.policy_uri.clone(),
client.tos_uri.clone(),
scopes,
))
}
+10
View File
@@ -0,0 +1,10 @@
use axum::Router;
mod grant;
pub(crate) fn build() -> Router<crate::State> {
#[allow(clippy::wildcard_imports)]
use self::*;
Router::new().nest("/grant/", grant::build())
}
+6 -1
View File
@@ -123,8 +123,9 @@ small.error {
.panel {
--preferred-width: 12rem + 40dvw;
--maximum-width: 48rem;
--minimum-width: 32rem;
width: min(clamp(24rem, var(--preferred-width), var(--maximum-width)), calc(100dvw - 3rem));
width: min(clamp(var(--minimum-width), var(--preferred-width), var(--maximum-width)), calc(100dvw - 3rem));
border-radius: var(--border-radius-lg);
background-color: var(--panel-bg);
padding-inline: 1.5rem;
@@ -184,6 +185,10 @@ a, a:visited {
color: oklch(from var(--c1) var(--name-lightness) c h);
}
code {
color: oklch(from var(--secondary) var(--name-lightness) c h);
}
input, button, a.button {
display: inline-block;
padding: 0.5em;
+1 -1
View File
@@ -17,7 +17,7 @@
background-color: var(--avatar-color);
}
.green-avatar {
.red-avatar {
--avatar-color: var(--c1);
}
+22
View File
@@ -0,0 +1,22 @@
.avatars {
justify-content: center;
display: flex;
flex-direction: row;
.separator {
align-self: center;
margin-inline: 1em;
color: var(--secondary);
font-size: x-large;
font-weight: bold;
user-select: none;
}
}
.identity {
margin-block: 1em;
color: var(--secondary);
font-size: small;
font-style: italic;
text-align: center;
}
@@ -1,19 +1,25 @@
<div class="card">
{{ avatar() }}
{{ avatar }}
<div class="info">
<p class="name">
{% if let Some(display_name) = display_name %}
{{ display_name }}
{% if let Some(metadata) = oauth_metadata %}
<a href="{{ metadata.client_uri }}">{{ display_name }}</a>
{% else %}
{{ display_name }}
{% endif %}
{% else %}
Unknown device
{% endif %}
&nbsp;<span class="id">{{ device_id }}</span>
<span class="id">
&bullet;&nbsp;{{ device_id }}
{% if oauth_metadata.is_none() %}
(legacy)
{% endif %}
</span>
</p>
<p>
Last active: {{ last_active }}
{% if let Some(metadata) = oauth_metadata %}
&nbsp;&bullet;&nbsp;<a href="{{ metadata.client_uri }}">Client information</a>
{% endif %}
</p>
</div>
</div>
@@ -1,5 +1,5 @@
<div class="card green-avatar">
{{ avatar() }}
<div class="card red-avatar">
{{ avatar }}
<div class="info">
{% if let Some(display_name) = display_name %}
<p class="name">{{ display_name }}</p>
+64
View File
@@ -0,0 +1,64 @@
{% extends "_layout.html.j2" %}
{%- block head -%}
<link rel="stylesheet" href="{{ crate::ROUTE_PREFIX }}/resources/grant.css">
{%- endblock -%}
{%- block title -%}
Authorize client
{%- endblock -%}
{%- block content -%}
<div class="panel narrow">
<h1>Authorize {{ client_name }}</h1>
<div class="avatars">
<div class="red-avatar">
{{ user_avatar }}
</div>
<div class="separator" aria-hidden>
</div>
{{ client_avatar }}
</div>
<div class="identity">
Signed in as <code>{{ user_id }}</code>. <a href="{{ crate::ROUTE_PREFIX }}/account/logout?{{ logout_query }}">Switch accounts</a>
</div>
<p>
<b>{{ client_name }}</b> (<a href="{{ client_uri }}">{{ client_uri.domain().unwrap() }}</a>) would like
your permission to:
<ul>
{% for scope in scopes %}
{% match scope %}
{% when Scope::ClientApi %}
<li>Interact with Matrix on your behalf</li>
{% when Scope::Device(_) %}
<li>Connect to your Matrix account</li>
{% endmatch %}
{% endfor %}
</ul>
</p>
{% match (&policy_uri, &tos_uri) %}
{% when (Some(policy_uri), Some(tos_uri)) %}
<p>
{{ client_name }}'s <a href="{{ policy_uri }}">policies</a>
and <a href="{{ tos_uri }}">terms of service</a> apply.
</p>
{% when (Some(policy_uri), None) %}
<p>
{{ client_name }}'s <a href="{{ policy_uri }}">policies</a> apply.
</p>
{% when (None, Some(tos_uri)) %}
<p>
{{ client_name }}'s <a href="{{ tos_uri }}">terms of service</a> apply.
</p>
{% when (None, None) %}
<p>
Make sure you trust {{ client_name }} with access to your data.
</p>
{% endmatch %}
<form method="post">
<button type="submit">Continue</button>
</form>
</div>
{%- endblock -%}
+5 -1
View File
@@ -13,7 +13,11 @@ Log in
{% match body %}
{% when LoginBody::Unauthenticated { server_name } %}
<h1 class="with-matrix-icon">
Log in to Matrix
{% if has_next %}
Log in to continue
{% else %}
Log in to Matrix
{% endif %}
<a href="https://matrix.org" target="_blank" noreferer>
<img class="matrix-icon" alt="Matrix logo" aria-ignore src="{{ crate::ROUTE_PREFIX }}/resources/matrix-icon.svg">
</a>
+37 -12
View File
@@ -1,8 +1,14 @@
use std::time::{Duration, SystemTime};
use std::{
borrow::Cow,
collections::HashMap,
mem::discriminant,
time::{Duration, SystemTime},
};
use axum::{extract::FromRequestParts, http::request::Parts};
use conduwuit_service::oauth::grant::AuthorizationCodeQuery;
use ruma::{OwnedUserId, UserId};
use serde::{Deserialize, Serialize};
use serde::{Deserialize, Deserializer, Serialize};
use tower_sessions::Session;
use crate::{ROUTE_PREFIX, WebError};
@@ -12,7 +18,7 @@ pub(crate) mod store;
#[derive(Debug, Deserialize, Serialize)]
pub(crate) struct LoginQuery {
#[serde(flatten)]
pub next: LoginTarget,
pub next: Option<LoginTarget>,
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
pub reauthenticate: bool,
}
@@ -20,6 +26,7 @@ pub(crate) struct LoginQuery {
#[derive(Debug, Default, Deserialize, Serialize)]
#[serde(tag = "next", rename_all = "snake_case")]
pub(crate) enum LoginTarget {
AuthorizationCode(AuthorizationCodeQuery),
#[default]
Account,
ChangePassword,
@@ -28,14 +35,23 @@ pub(crate) enum LoginTarget {
Deactivate,
}
impl PartialEq for LoginTarget {
fn eq(&self, other: &Self) -> bool { discriminant(self) == discriminant(other) }
}
impl LoginTarget {
pub(crate) fn target_path(&self) -> String {
let path = match self {
| Self::Account => "account/",
| Self::ChangePassword => "account/password/change",
| Self::ChangeEmail => "account/email/change/",
| Self::CrossSigningReset => "account/cross_signing_reset",
| Self::Deactivate => "account/deactivate",
let path: Cow<'_, str> = match self {
| Self::AuthorizationCode(code) => format!(
"oauth2/grant/authorization_code?{}",
serde_urlencoded::to_string(code).unwrap()
)
.into(),
| Self::Account => "account/".into(),
| Self::ChangePassword => "account/password/change".into(),
| Self::ChangeEmail => "account/email/change/".into(),
| Self::CrossSigningReset => "account/cross_signing_reset".into(),
| Self::Deactivate => "account/deactivate".into(),
};
format!("{ROUTE_PREFIX}/{path}")
@@ -80,7 +96,10 @@ impl User {
if let Some(session) = self.0 {
Ok(session.user_id)
} else {
Err(WebError::LoginRequired(LoginQuery { next: or_else, reauthenticate: false }))
Err(WebError::LoginRequired(LoginQuery {
next: Some(or_else),
reauthenticate: false,
}))
}
}
@@ -91,10 +110,16 @@ impl User {
if session.is_recent() {
Ok(session.user_id)
} else {
Err(WebError::LoginRequired(LoginQuery { next: or_else, reauthenticate: true }))
Err(WebError::LoginRequired(LoginQuery {
next: Some(or_else),
reauthenticate: true,
}))
}
} else {
Err(WebError::LoginRequired(LoginQuery { next: or_else, reauthenticate: false }))
Err(WebError::LoginRequired(LoginQuery {
next: Some(or_else),
reauthenticate: false,
}))
}
}
}