mirror of
https://forgejo.ellis.link/continuwuation/continuwuity.git
synced 2026-05-26 20:49:55 +00:00
feat: Implement oauth auth code and refresh token flows
This commit is contained in:
Generated
+2
@@ -1088,6 +1088,7 @@ dependencies = [
|
||||
"serde",
|
||||
"serde-saphyr",
|
||||
"serde_json",
|
||||
"serde_urlencoded",
|
||||
"sha2 0.11.0",
|
||||
"termimad",
|
||||
"tokio",
|
||||
@@ -1125,6 +1126,7 @@ dependencies = [
|
||||
"tower-sessions",
|
||||
"tower-sessions-core",
|
||||
"tracing",
|
||||
"url",
|
||||
"validator",
|
||||
]
|
||||
|
||||
|
||||
@@ -559,6 +559,9 @@ features = ["std"]
|
||||
[workspace.dependencies.nonzero_ext]
|
||||
version = "0.3.0"
|
||||
|
||||
[workspace.dependencies.serde_urlencoded]
|
||||
version = "0.7.1"
|
||||
|
||||
#
|
||||
# Patches
|
||||
#
|
||||
|
||||
@@ -179,6 +179,7 @@ pub(crate) async fn register_route(
|
||||
&user_id,
|
||||
&device_id,
|
||||
&new_token,
|
||||
None,
|
||||
body.initial_device_display_name.clone(),
|
||||
Some(client.to_string()),
|
||||
)
|
||||
|
||||
@@ -94,6 +94,7 @@ pub(crate) async fn update_device_route(
|
||||
&device_id,
|
||||
&appservice.registration.as_token,
|
||||
None,
|
||||
None,
|
||||
Some(client.to_string()),
|
||||
)
|
||||
.await?;
|
||||
|
||||
@@ -1,17 +1,31 @@
|
||||
mod register_client;
|
||||
mod server_metadata;
|
||||
mod token;
|
||||
|
||||
use axum::{
|
||||
Json, Router,
|
||||
extract::State,
|
||||
routing::method_routing::{get, post},
|
||||
};
|
||||
use serde_json::json;
|
||||
pub(crate) use server_metadata::*;
|
||||
|
||||
pub(crate) const BASE_PATH: &str = "/_continuwuity/oauth2/";
|
||||
const BASE_PATH: &str = "/_continuwuity/oauth2/";
|
||||
|
||||
pub(crate) fn router() -> Router<crate::State> {
|
||||
Router::new().nest(BASE_PATH, oauth_router())
|
||||
// TODO(unspecced): used by old versions of the matrix-js-sdk
|
||||
// .route("/.well-known/openid-configuration", get(
|
||||
// async |State(services): State<crate::State>| {
|
||||
// Json(authorization_server_metadata(&services).await)
|
||||
// }
|
||||
// ))
|
||||
}
|
||||
|
||||
fn oauth_router() -> Router<crate::State> {
|
||||
Router::new()
|
||||
.route("/client/register", post(register_client::register_client_route))
|
||||
// TODO(unspecced): used by old versions of the matrix-js-sdk
|
||||
.route("/client/keys.json", get(async || Json(json!({"keys": []}))))
|
||||
.route("/grant/token", post(token::token_route))
|
||||
}
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
use axum::extract::State;
|
||||
use conduwuit::Result;
|
||||
use ruma::{api::client::discovery::get_authorization_server_metadata, serde::Raw};
|
||||
use serde_json::json;
|
||||
use serde_json::{Value, json};
|
||||
use service::Services;
|
||||
|
||||
use crate::Ruma;
|
||||
|
||||
@@ -9,13 +10,19 @@ pub(crate) async fn get_authorization_server_metadata_route(
|
||||
State(services): State<crate::State>,
|
||||
_body: Ruma<get_authorization_server_metadata::v1::Request>,
|
||||
) -> Result<get_authorization_server_metadata::v1::Response> {
|
||||
let metadata = Raw::new(&authorization_server_metadata(&services).await).unwrap();
|
||||
|
||||
Ok(get_authorization_server_metadata::v1::Response::new(metadata.cast_unchecked()))
|
||||
}
|
||||
|
||||
pub(crate) async fn authorization_server_metadata(services: &Services) -> Value {
|
||||
let endpoint_base = services
|
||||
.config
|
||||
.get_client_domain()
|
||||
.join(super::BASE_PATH)
|
||||
.unwrap();
|
||||
|
||||
let metadata = Raw::new(&json!({
|
||||
json!({
|
||||
"authorization_endpoint": endpoint_base.join("grant/authorization_code").unwrap(),
|
||||
"code_challenge_methods_supported": ["S256"],
|
||||
"grant_types_supported": ["authorization_code", "refresh_token"],
|
||||
@@ -27,8 +34,5 @@ pub(crate) async fn get_authorization_server_metadata_route(
|
||||
"response_types_supported": ["code"],
|
||||
"revocation_endpoint": endpoint_base.join("client/revoke").unwrap(),
|
||||
"token_endpoint": endpoint_base.join("grant/token").unwrap(),
|
||||
}))
|
||||
.unwrap();
|
||||
|
||||
Ok(get_authorization_server_metadata::v1::Response::new(metadata.cast_unchecked()))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
use axum::{Form, Json, extract::State, response::IntoResponse};
|
||||
use http::StatusCode;
|
||||
use service::oauth::grant::TokenRequest;
|
||||
|
||||
pub(crate) async fn token_route(
|
||||
State(services): State<crate::State>,
|
||||
Form(request): Form<TokenRequest>,
|
||||
) -> impl IntoResponse {
|
||||
match services.oauth.issue_token(request).await {
|
||||
| Ok(response) => Ok(Json(response).into_response()),
|
||||
| Err(err) => Err((StatusCode::BAD_REQUEST, err.message())),
|
||||
}
|
||||
}
|
||||
@@ -202,7 +202,7 @@ pub(crate) async fn login_route(
|
||||
if device_exists {
|
||||
services
|
||||
.users
|
||||
.set_token(&user_id, &device_id, &token)
|
||||
.set_token(&user_id, &device_id, &token, None)
|
||||
.await?;
|
||||
} else {
|
||||
services
|
||||
@@ -211,6 +211,7 @@ pub(crate) async fn login_route(
|
||||
&user_id,
|
||||
&device_id,
|
||||
&token,
|
||||
None,
|
||||
body.initial_device_display_name.clone(),
|
||||
Some(client.to_string()),
|
||||
)
|
||||
|
||||
@@ -69,7 +69,6 @@ pub(crate) async fn sync_events_v5_route(
|
||||
ClientIp(client_ip): ClientIp,
|
||||
body: Ruma<sync_events::v5::Request>,
|
||||
) -> Result<sync_events::v5::Response> {
|
||||
debug_assert!(DEFAULT_BUMP_TYPES.is_sorted(), "DEFAULT_BUMP_TYPES is not sorted");
|
||||
let sender_user = body.identity.sender_user();
|
||||
let sender_device = body.identity.expect_sender_device()?;
|
||||
|
||||
|
||||
+1
-1
@@ -187,7 +187,7 @@ pub fn build(router: Router<State>, server: &Server) -> Router<State> {
|
||||
.ruma_route(&client::get_rtc_transports)
|
||||
.ruma_route(&client::room_initial_sync_route)
|
||||
.ruma_route(&client::get_authorization_server_metadata_route)
|
||||
.nest(client::oauth::BASE_PATH, client::oauth::router())
|
||||
.merge(client::oauth::router())
|
||||
.route("/_conduwuit/server_version", get(client::conduwuit_server_version))
|
||||
.route("/_continuwuity/server_version", get(client::conduwuit_server_version))
|
||||
.ruma_route(&admin::rooms::ban::ban_room)
|
||||
|
||||
@@ -153,7 +153,7 @@ impl CheckAuth for AccessToken {
|
||||
query: AuthQueryParams,
|
||||
route: TypeId,
|
||||
) -> Result<Self::Identity> {
|
||||
if let Ok((sender_user, sender_device)) = services.users.find_from_token(&output).await {
|
||||
if let Some((sender_user, sender_device)) = services.users.find_from_token(&output).await {
|
||||
// Locked users can only use /logout and /logout/all
|
||||
if services
|
||||
.users
|
||||
|
||||
+16
-10
@@ -61,17 +61,23 @@ pub fn format(ts: SystemTime, str: &str) -> String {
|
||||
pub fn pretty(d: Duration) -> String {
|
||||
use Unit::*;
|
||||
|
||||
let fmt = |w, f, u| format!("{w}.{f} {u}");
|
||||
let gen64 = |w, f, u| fmt(w, (f * 100.0) as u32, u);
|
||||
let gen128 = |w, f, u| gen64(u64::try_from(w).expect("u128 to u64"), f, u);
|
||||
let fmt = |w, u| {
|
||||
if w == 1 {
|
||||
format!("{w} {u}")
|
||||
} else {
|
||||
format!("{w} {u}s")
|
||||
}
|
||||
};
|
||||
let gen64 = |w, u| fmt(w, u);
|
||||
let gen128 = |w, u| gen64(u64::try_from(w).expect("u128 to u64"), u);
|
||||
match whole_and_frac(d) {
|
||||
| (Days(whole), frac) => gen64(whole, frac, "days"),
|
||||
| (Hours(whole), frac) => gen64(whole, frac, "hours"),
|
||||
| (Mins(whole), frac) => gen64(whole, frac, "minutes"),
|
||||
| (Secs(whole), frac) => gen64(whole, frac, "seconds"),
|
||||
| (Millis(whole), frac) => gen128(whole, frac, "milliseconds"),
|
||||
| (Micros(whole), frac) => gen128(whole, frac, "microseconds"),
|
||||
| (Nanos(whole), frac) => gen128(whole, frac, "nanoseconds"),
|
||||
| (Days(whole), _) => gen64(whole, "day"),
|
||||
| (Hours(whole), _) => gen64(whole, "hour"),
|
||||
| (Mins(whole), _) => gen64(whole, "minute"),
|
||||
| (Secs(whole), _) => gen64(whole, "second"),
|
||||
| (Millis(whole), _) => gen128(whole, "millisecond"),
|
||||
| (Micros(whole), _) => gen128(whole, "microsecond"),
|
||||
| (Nanos(whole), _) => gen128(whole, "nanosecond"),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -161,6 +161,10 @@ pub(super) static MAPS: &[Descriptor] = &[
|
||||
name: "referencedevents",
|
||||
..descriptor::RANDOM
|
||||
},
|
||||
Descriptor {
|
||||
name: "refreshtoken_refreshtokeninfo",
|
||||
..descriptor::RANDOM_SMALL
|
||||
},
|
||||
Descriptor {
|
||||
name: "registrationtoken_info",
|
||||
..descriptor::RANDOM_SMALL
|
||||
@@ -375,6 +379,14 @@ pub(super) static MAPS: &[Descriptor] = &[
|
||||
name: "userdevicetxnid_response",
|
||||
..descriptor::RANDOM_SMALL
|
||||
},
|
||||
Descriptor {
|
||||
name: "userdeviceid_oauthsessioninfo",
|
||||
..descriptor::RANDOM_SMALL
|
||||
},
|
||||
Descriptor {
|
||||
name: "userdeviceid_tokenexpires",
|
||||
..descriptor::RANDOM_SMALL
|
||||
},
|
||||
Descriptor {
|
||||
name: "userfilterid_filter",
|
||||
..descriptor::RANDOM_SMALL
|
||||
|
||||
@@ -119,6 +119,7 @@ recaptcha-verify = { version = "0.2.0", default-features = false }
|
||||
reqwest_recaptcha = { package = "reqwest", version = "0.12.28", default-features = false, features = ["rustls-tls-native-roots-no-provider"] } # As long as recaptcha-verify's reqwest is outdated
|
||||
yansi.workspace = true
|
||||
lettre.workspace = true
|
||||
serde_urlencoded.workspace = true
|
||||
|
||||
[target.'cfg(all(unix, target_os = "linux"))'.dependencies]
|
||||
sd-notify.workspace = true
|
||||
|
||||
@@ -67,7 +67,7 @@ impl crate::Service for Service {
|
||||
for (id, registration) in appservices {
|
||||
// During startup, resolve any token collisions in favour of appservices
|
||||
// by logging out conflicting user devices
|
||||
if let Ok((user_id, device_id)) = self
|
||||
if let Some((user_id, device_id)) = self
|
||||
.services
|
||||
.users
|
||||
.find_from_token(®istration.as_token)
|
||||
@@ -158,7 +158,7 @@ impl Service {
|
||||
.users
|
||||
.find_from_token(®istration.as_token)
|
||||
.await
|
||||
.is_ok()
|
||||
.is_some()
|
||||
{
|
||||
return Err(err!(Request(InvalidParam(
|
||||
"Cannot register appservice: The provided token is already in use by a user \
|
||||
|
||||
@@ -38,7 +38,7 @@ pub struct ClientMetadata {
|
||||
}
|
||||
|
||||
impl ClientMetadata {
|
||||
const ACCEPTABLE_LOCALHOSTS: [&str; 3] = ["localhost", "127.0.0.1", "[::1]"];
|
||||
pub(super) const ACCEPTABLE_LOCALHOSTS: [&str; 3] = ["localhost", "127.0.0.1", "[::1]"];
|
||||
|
||||
pub(super) fn validate(&self) -> Result<(), &'static str> {
|
||||
let Some(client_domain) = self.client_uri.domain() else {
|
||||
@@ -137,6 +137,7 @@ pub enum GrantType {
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
#[non_exhaustive]
|
||||
pub enum ResponseType {
|
||||
Code,
|
||||
}
|
||||
|
||||
@@ -0,0 +1,150 @@
|
||||
use std::{
|
||||
collections::{BTreeSet, HashSet},
|
||||
fmt::Debug,
|
||||
hash::Hash,
|
||||
mem::discriminant,
|
||||
};
|
||||
|
||||
use regex::Regex;
|
||||
use ruma::OwnedDeviceId;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use url::Url;
|
||||
|
||||
use super::client_metadata::ResponseType;
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct AuthorizationCodeQuery {
|
||||
pub response_type: ResponseType,
|
||||
pub client_id: String,
|
||||
pub redirect_uri: Url,
|
||||
pub scope: RawScopes,
|
||||
pub state: String,
|
||||
#[serde(default)]
|
||||
pub response_mode: ResponseMode,
|
||||
pub code_challenge: String,
|
||||
pub code_challenge_method: CodeChallengeMethod,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
#[non_exhaustive]
|
||||
pub enum ResponseMode {
|
||||
#[default]
|
||||
// default for `code` response type, see https://openid.net/specs/oauth-v2-multiple-response-types-1_0.html#:~:text=Client%2E-,For,encoding%2E,-See
|
||||
Query,
|
||||
Fragment,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[non_exhaustive]
|
||||
pub enum CodeChallengeMethod {
|
||||
S256,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, PartialOrd, Ord)]
|
||||
pub enum Scope {
|
||||
Device(OwnedDeviceId),
|
||||
ClientApi,
|
||||
}
|
||||
|
||||
impl PartialEq for Scope {
|
||||
fn eq(&self, other: &Self) -> bool { discriminant(self) == discriminant(other) }
|
||||
}
|
||||
|
||||
impl Eq for Scope {}
|
||||
|
||||
impl Hash for Scope {
|
||||
fn hash<H: std::hash::Hasher>(&self, state: &mut H) { discriminant(self).hash(state); }
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Scope {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let urn = match self {
|
||||
| Self::ClientApi => "urn:matrix:client:api:*".to_owned(),
|
||||
| Self::Device(device_id) => format!("urn:matrix:client:device:{device_id}"),
|
||||
};
|
||||
|
||||
f.write_str(&urn)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct RawScopes(String);
|
||||
|
||||
impl RawScopes {
|
||||
pub fn to_scopes(&self) -> Result<BTreeSet<Scope>, String> {
|
||||
let client_api_token_regex =
|
||||
Regex::new(r"urn:matrix:(client|org.matrix.msc2967.client):api:\*").unwrap();
|
||||
let device_token_regex = Regex::new(
|
||||
r"urn:matrix:(client|org.matrix.msc2967.client):device:([a-zA-Z0-9-._~]{5,})",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let mut scopes = BTreeSet::new();
|
||||
|
||||
for token in self.0.split(' ') {
|
||||
let scope_was_new = {
|
||||
if client_api_token_regex.is_match(token) {
|
||||
scopes.insert(Scope::ClientApi)
|
||||
} else if let Some(captures) = device_token_regex.captures(token) {
|
||||
scopes.insert(Scope::Device(captures.get(2).unwrap().as_str().into()))
|
||||
} else if token == "openid" {
|
||||
// TODO(unspecced): Element sets this scope but doesn't use it for anything
|
||||
true
|
||||
} else {
|
||||
return Err(format!("Invalid scope: {token}"));
|
||||
}
|
||||
};
|
||||
|
||||
if !scope_was_new {
|
||||
return Err("Scope was specified more than once".to_owned());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(scopes)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct AuthorizationCodeResponse {
|
||||
pub state: String,
|
||||
pub code: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(tag = "grant_type", rename_all = "snake_case")]
|
||||
pub enum TokenRequest {
|
||||
AuthorizationCode {
|
||||
code: String,
|
||||
redirect_uri: Url,
|
||||
client_id: String,
|
||||
code_verifier: String,
|
||||
},
|
||||
RefreshToken {
|
||||
client_id: String,
|
||||
refresh_token: String,
|
||||
},
|
||||
}
|
||||
|
||||
impl TokenRequest {
|
||||
pub fn client_id(&self) -> &str {
|
||||
match self {
|
||||
| Self::AuthorizationCode { client_id, .. }
|
||||
| Self::RefreshToken { client_id, .. } => client_id,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct TokenResponse {
|
||||
pub access_token: String,
|
||||
pub token_type: TokenType,
|
||||
pub expires_in: u64,
|
||||
pub refresh_token: String,
|
||||
pub scope: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub enum TokenType {
|
||||
Bearer,
|
||||
}
|
||||
+370
-9
@@ -1,30 +1,93 @@
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
collections::{BTreeMap, BTreeSet, HashMap},
|
||||
sync::{Arc, Mutex},
|
||||
time::{Duration, SystemTime},
|
||||
};
|
||||
|
||||
use base64::Engine;
|
||||
use conduwuit::{Result, utils::hash::sha256};
|
||||
use conduwuit::{
|
||||
Err, Result, err, info,
|
||||
utils::{self, hash::sha256},
|
||||
};
|
||||
use database::{Deserialized, Json, Map};
|
||||
use ruma::DeviceId;
|
||||
use itertools::Itertools;
|
||||
use ruma::{DeviceId, OwnedDeviceId, OwnedUserId, UserId};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use url::Url;
|
||||
|
||||
use crate::{Dep, config, oauth::client_metadata::ClientMetadata};
|
||||
use crate::{
|
||||
Dep, config,
|
||||
oauth::{
|
||||
client_metadata::{ApplicationType, ClientMetadata, ResponseType},
|
||||
grant::{
|
||||
AuthorizationCodeQuery, AuthorizationCodeResponse, CodeChallengeMethod, ResponseMode,
|
||||
Scope, TokenRequest, TokenResponse, TokenType,
|
||||
},
|
||||
},
|
||||
users,
|
||||
};
|
||||
|
||||
pub mod client_metadata;
|
||||
pub mod grant;
|
||||
|
||||
pub struct Service {
|
||||
services: Services,
|
||||
db: Data,
|
||||
tickets: Mutex<HashMap<String, HashMap<OAuthTicket, SystemTime>>>,
|
||||
pending_code_grants: tokio::sync::Mutex<HashMap<String, PendingCodeGrant>>,
|
||||
}
|
||||
|
||||
struct Data {
|
||||
clientid_clientmetadata: Arc<Map>,
|
||||
userdeviceid_oauthsessioninfo: Arc<Map>,
|
||||
refreshtoken_refreshtokeninfo: Arc<Map>,
|
||||
}
|
||||
|
||||
struct Services {
|
||||
config: Dep<config::Service>,
|
||||
users: Dep<users::Service>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize)]
|
||||
struct SessionInfo {
|
||||
client_id: String,
|
||||
current_refresh_token: String,
|
||||
scopes: BTreeSet<Scope>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize)]
|
||||
struct RefreshTokenInfo {
|
||||
client_id: String,
|
||||
user_id: OwnedUserId,
|
||||
device_id: OwnedDeviceId,
|
||||
}
|
||||
|
||||
struct PendingCodeGrant {
|
||||
authorizing_user: OwnedUserId,
|
||||
requested_scopes: BTreeSet<Scope>,
|
||||
client_name: Option<String>,
|
||||
expected_client_id: String,
|
||||
expected_redirect_uri: Url,
|
||||
code_challenge: String,
|
||||
requested_at: SystemTime,
|
||||
}
|
||||
|
||||
impl PendingCodeGrant {
|
||||
const MAX_AGE: Duration = Duration::from_mins(1);
|
||||
const RANDOM_CODE_LENGTH: usize = 32;
|
||||
|
||||
#[must_use]
|
||||
pub fn generate_code() -> String { utils::random_string(Self::RANDOM_CODE_LENGTH) }
|
||||
|
||||
#[must_use]
|
||||
pub fn is_valid_for(&self, client_id: &str) -> bool {
|
||||
let now = SystemTime::now();
|
||||
|
||||
self.expected_client_id == client_id
|
||||
&& now
|
||||
.duration_since(self.requested_at)
|
||||
.is_ok_and(|age| age < Self::MAX_AGE)
|
||||
}
|
||||
}
|
||||
|
||||
/// A time-limited grant for a client to perform some sensitive action.
|
||||
@@ -49,11 +112,15 @@ impl crate::Service for Service {
|
||||
Ok(Arc::new(Self {
|
||||
services: Services {
|
||||
config: args.depend::<config::Service>("config"),
|
||||
users: args.depend::<users::Service>("users"),
|
||||
},
|
||||
db: Data {
|
||||
clientid_clientmetadata: args.db["clientid_clientmetadata"].clone(),
|
||||
userdeviceid_oauthsessioninfo: args.db["userdeviceid_oauthsessioninfo"].clone(),
|
||||
refreshtoken_refreshtokeninfo: args.db["refreshtoken_refreshtokeninfo"].clone(),
|
||||
},
|
||||
tickets: Mutex::default(),
|
||||
pending_code_grants: tokio::sync::Mutex::default(),
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -61,6 +128,11 @@ impl crate::Service for Service {
|
||||
}
|
||||
|
||||
impl Service {
|
||||
const ACCESS_TOKEN_MAX_AGE: Duration = Duration::from_hours(1);
|
||||
const RANDOM_TOKEN_LENGTH: usize = 32;
|
||||
|
||||
fn generate_token() -> String { utils::random_string(Self::RANDOM_TOKEN_LENGTH) }
|
||||
|
||||
pub async fn register_client(
|
||||
&self,
|
||||
metadata: &ClientMetadata,
|
||||
@@ -85,7 +157,7 @@ impl Service {
|
||||
Ok(client_id)
|
||||
}
|
||||
|
||||
pub async fn get_client_registration(&self, client_id: &str) -> Option<ClientMetadata> {
|
||||
pub async fn get_client_metadata(&self, client_id: &str) -> Option<ClientMetadata> {
|
||||
self.db
|
||||
.clientid_clientmetadata
|
||||
.get(client_id)
|
||||
@@ -94,15 +166,304 @@ impl Service {
|
||||
.ok()
|
||||
}
|
||||
|
||||
pub async fn get_client_id_for_device(&self, _device_id: &DeviceId) -> Option<String> {
|
||||
None // TODO
|
||||
pub async fn get_client_id_for_device(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
) -> Option<String> {
|
||||
self.db
|
||||
.userdeviceid_oauthsessioninfo
|
||||
.qry(&(user_id, device_id))
|
||||
.await
|
||||
.deserialized::<SessionInfo>()
|
||||
.ok()
|
||||
.map(|session| session.client_id)
|
||||
}
|
||||
|
||||
pub async fn request_authorization_code(
|
||||
&self,
|
||||
authorizing_user: OwnedUserId,
|
||||
query: AuthorizationCodeQuery,
|
||||
) -> Result<String, String> {
|
||||
let Some(client_metadata) = self.get_client_metadata(&query.client_id).await else {
|
||||
return Err("Invalid client ID".to_owned());
|
||||
};
|
||||
|
||||
if !(client_metadata
|
||||
.response_types
|
||||
.contains(&query.response_type)
|
||||
&& matches!(query.response_type, ResponseType::Code))
|
||||
{
|
||||
return Err("Invalid response type".to_owned());
|
||||
}
|
||||
|
||||
if !matches!(query.code_challenge_method, CodeChallengeMethod::S256) {
|
||||
return Err("Invalid code challenge type".to_owned());
|
||||
}
|
||||
|
||||
{
|
||||
let mut stripped_uri = query.redirect_uri.clone();
|
||||
|
||||
if client_metadata.application_type == ApplicationType::Native
|
||||
&& query
|
||||
.redirect_uri
|
||||
.host_str()
|
||||
.is_some_and(|host| ClientMetadata::ACCEPTABLE_LOCALHOSTS.contains(&host))
|
||||
{
|
||||
// Remove the port from localhost redirect URIs for native applications when
|
||||
// checking if it's valid
|
||||
stripped_uri.set_port(None).unwrap();
|
||||
}
|
||||
|
||||
if !client_metadata.redirect_uris.contains(&stripped_uri) {
|
||||
return Err("Invalid redirect URI".to_owned());
|
||||
}
|
||||
}
|
||||
|
||||
let requested_scopes = query.scope.to_scopes()?;
|
||||
|
||||
let redirect_uri_query_separator = match query.response_mode {
|
||||
| ResponseMode::Fragment => '#',
|
||||
| ResponseMode::Query => '?',
|
||||
};
|
||||
|
||||
let code = PendingCodeGrant::generate_code();
|
||||
|
||||
info!(
|
||||
client_id = &query.client_id,
|
||||
client_name = &client_metadata.client_name,
|
||||
?requested_scopes,
|
||||
?authorizing_user,
|
||||
"Issuing oauth authorization code"
|
||||
);
|
||||
|
||||
let redirect_uri = format!(
|
||||
"{}{}{}",
|
||||
query.redirect_uri,
|
||||
redirect_uri_query_separator,
|
||||
serde_urlencoded::to_string(AuthorizationCodeResponse {
|
||||
state: query.state,
|
||||
code: code.clone(),
|
||||
})
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
let pending_grant = PendingCodeGrant {
|
||||
authorizing_user,
|
||||
requested_scopes,
|
||||
client_name: client_metadata.client_name,
|
||||
expected_client_id: query.client_id,
|
||||
expected_redirect_uri: query.redirect_uri,
|
||||
code_challenge: query.code_challenge,
|
||||
requested_at: SystemTime::now(),
|
||||
};
|
||||
|
||||
self.pending_code_grants
|
||||
.lock()
|
||||
.await
|
||||
.insert(code, pending_grant);
|
||||
|
||||
Ok(redirect_uri)
|
||||
}
|
||||
|
||||
pub async fn issue_token(&self, request: TokenRequest) -> Result<TokenResponse> {
|
||||
match request {
|
||||
| TokenRequest::AuthorizationCode {
|
||||
code,
|
||||
redirect_uri,
|
||||
client_id,
|
||||
code_verifier,
|
||||
} => {
|
||||
let mut pending_grants = self.pending_code_grants.lock().await;
|
||||
|
||||
let Some(pending_grant) = pending_grants
|
||||
.remove(&code)
|
||||
.filter(|grant| grant.is_valid_for(&client_id))
|
||||
else {
|
||||
return Err!("Invalid code");
|
||||
};
|
||||
|
||||
if redirect_uri != pending_grant.expected_redirect_uri {
|
||||
return Err!("Unexpected redirect uri");
|
||||
}
|
||||
|
||||
let expected_code_challenge =
|
||||
base64::prelude::BASE64_URL_SAFE_NO_PAD.encode(sha256::hash(&code_verifier));
|
||||
if expected_code_challenge != pending_grant.code_challenge {
|
||||
return Err!("Invalid code challenge");
|
||||
}
|
||||
|
||||
self.create_session(
|
||||
pending_grant.authorizing_user,
|
||||
pending_grant.requested_scopes,
|
||||
pending_grant.client_name,
|
||||
client_id,
|
||||
)
|
||||
.await
|
||||
},
|
||||
| TokenRequest::RefreshToken { client_id, refresh_token } =>
|
||||
self.refresh_session(client_id, refresh_token).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn create_session(
|
||||
&self,
|
||||
authorizing_user: OwnedUserId,
|
||||
requested_scopes: BTreeSet<Scope>,
|
||||
client_name: Option<String>,
|
||||
client_id: String,
|
||||
) -> Result<TokenResponse> {
|
||||
let access_token = Self::generate_token();
|
||||
let refresh_token = Self::generate_token();
|
||||
|
||||
let device_id = requested_scopes
|
||||
.iter()
|
||||
.find_map(|scope| {
|
||||
if let Scope::Device(device_id) = scope {
|
||||
Some(device_id)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.ok_or_else(|| err!("No device ID scope supplied"))?;
|
||||
|
||||
self.services
|
||||
.users
|
||||
.create_device(
|
||||
&authorizing_user,
|
||||
device_id,
|
||||
&access_token,
|
||||
Some(Self::ACCESS_TOKEN_MAX_AGE),
|
||||
client_name,
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
|
||||
self.db.userdeviceid_oauthsessioninfo.put(
|
||||
(&authorizing_user, device_id),
|
||||
Json(SessionInfo {
|
||||
client_id: client_id.clone(),
|
||||
current_refresh_token: refresh_token.clone(),
|
||||
scopes: requested_scopes.clone(),
|
||||
}),
|
||||
);
|
||||
|
||||
self.db.refreshtoken_refreshtokeninfo.raw_put(
|
||||
&refresh_token,
|
||||
Json(RefreshTokenInfo {
|
||||
client_id: client_id.clone(),
|
||||
user_id: authorizing_user.clone(),
|
||||
device_id: device_id.to_owned(),
|
||||
}),
|
||||
);
|
||||
|
||||
info!(
|
||||
?client_id,
|
||||
?authorizing_user,
|
||||
?device_id,
|
||||
?requested_scopes,
|
||||
"Created new oauth session"
|
||||
);
|
||||
|
||||
Ok(TokenResponse {
|
||||
access_token,
|
||||
token_type: TokenType::Bearer,
|
||||
expires_in: Self::ACCESS_TOKEN_MAX_AGE.as_secs(),
|
||||
scope: requested_scopes.iter().join(" "),
|
||||
refresh_token,
|
||||
})
|
||||
}
|
||||
|
||||
async fn refresh_session(
|
||||
&self,
|
||||
client_id: String,
|
||||
refresh_token: String,
|
||||
) -> Result<TokenResponse> {
|
||||
let Some(refresh_token_info) = self
|
||||
.db
|
||||
.refreshtoken_refreshtokeninfo
|
||||
.get(&refresh_token)
|
||||
.await
|
||||
.deserialized::<RefreshTokenInfo>()
|
||||
.ok()
|
||||
else {
|
||||
return Err!("Invalid refresh token");
|
||||
};
|
||||
|
||||
assert_eq!(&client_id, &refresh_token_info.client_id, "refresh token client id mismatch");
|
||||
|
||||
let mut session_info = self
|
||||
.db
|
||||
.userdeviceid_oauthsessioninfo
|
||||
.qry(&(&refresh_token_info.user_id, &refresh_token_info.device_id))
|
||||
.await
|
||||
.deserialized::<SessionInfo>()
|
||||
.expect("session info should exist");
|
||||
|
||||
assert_eq!(&client_id, &session_info.client_id, "session info client id mismatch");
|
||||
|
||||
let new_access_token = Self::generate_token();
|
||||
let new_refresh_token = Self::generate_token();
|
||||
let scope = session_info.scopes.iter().join(" ");
|
||||
session_info
|
||||
.current_refresh_token
|
||||
.clone_from(&new_refresh_token);
|
||||
|
||||
self.services
|
||||
.users
|
||||
.set_token(
|
||||
&refresh_token_info.user_id,
|
||||
&refresh_token_info.device_id,
|
||||
&new_access_token,
|
||||
Some(Self::ACCESS_TOKEN_MAX_AGE),
|
||||
)
|
||||
.await?;
|
||||
|
||||
self.db.userdeviceid_oauthsessioninfo.put(
|
||||
(&refresh_token_info.user_id, &refresh_token_info.device_id),
|
||||
Json(session_info),
|
||||
);
|
||||
|
||||
self.db.refreshtoken_refreshtokeninfo.remove(&refresh_token);
|
||||
drop(refresh_token);
|
||||
self.db
|
||||
.refreshtoken_refreshtokeninfo
|
||||
.raw_put(&new_refresh_token, Json(refresh_token_info));
|
||||
|
||||
Ok(TokenResponse {
|
||||
access_token: new_access_token,
|
||||
token_type: TokenType::Bearer,
|
||||
expires_in: Self::ACCESS_TOKEN_MAX_AGE.as_secs(),
|
||||
scope,
|
||||
refresh_token: new_refresh_token,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn remove_session(&self, user_id: &UserId, device_id: &DeviceId) {
|
||||
let session_info = self
|
||||
.db
|
||||
.userdeviceid_oauthsessioninfo
|
||||
.qry(&(user_id, device_id))
|
||||
.await
|
||||
.deserialized::<SessionInfo>()
|
||||
.ok();
|
||||
|
||||
if let Some(session_info) = session_info {
|
||||
self.db
|
||||
.refreshtoken_refreshtokeninfo
|
||||
.remove(&session_info.current_refresh_token);
|
||||
self.db
|
||||
.userdeviceid_oauthsessioninfo
|
||||
.del(&(user_id, device_id));
|
||||
info!(?user_id, ?device_id, "Removed OAuth session");
|
||||
}
|
||||
}
|
||||
|
||||
/// Issue a ticket for `localpart` to perform some action.
|
||||
pub fn issue_ticket(&self, localpart: String, ticket: OAuthTicket) {
|
||||
self.tickets
|
||||
.lock()
|
||||
.expect("should be able to lock tickets")
|
||||
.unwrap()
|
||||
.entry(localpart)
|
||||
.or_default()
|
||||
.insert(ticket, SystemTime::now());
|
||||
@@ -114,7 +475,7 @@ impl Service {
|
||||
|
||||
self.tickets
|
||||
.lock()
|
||||
.expect("should be able to lock tickets")
|
||||
.unwrap()
|
||||
.get_mut(localpart)
|
||||
.and_then(|tickets| tickets.remove(&ticket))
|
||||
.is_some_and(|issued| {
|
||||
|
||||
+23
-12
@@ -253,17 +253,17 @@ impl Service {
|
||||
let mut info = assign!(UiaaInfo::new(flows), { params: Some(params), session: Some(session_id.clone()) });
|
||||
|
||||
let session_metadata = if let Some(initiator) = initiator {
|
||||
let is_oauth = OptionFuture::from(
|
||||
initiator.device_id.map(async |device_id| {
|
||||
self
|
||||
.services
|
||||
.oauth
|
||||
.get_client_id_for_device(device_id)
|
||||
.await
|
||||
})
|
||||
)
|
||||
.await
|
||||
.is_some();
|
||||
let is_oauth = if let Some(device_id) = initiator.device_id {
|
||||
self
|
||||
.services
|
||||
.oauth
|
||||
.get_client_id_for_device(initiator.user_id, device_id)
|
||||
.await
|
||||
.is_some()
|
||||
} else {
|
||||
// Appservices never have oauth sessions
|
||||
false
|
||||
};
|
||||
|
||||
if is_oauth {
|
||||
if let Some(oauth_ticket) = initiator.oauth_ticket {
|
||||
@@ -279,7 +279,18 @@ impl Service {
|
||||
.unwrap();
|
||||
|
||||
info.flows = vec![AuthFlow::new(vec![AuthType::OAuth])];
|
||||
info.params = Some(to_raw_value(&json!({"url": ticket_url})).unwrap());
|
||||
info.params = Some(
|
||||
to_raw_value(&json!({
|
||||
AuthType::OAuth.as_str(): {
|
||||
"url": ticket_url,
|
||||
},
|
||||
// TODO(compat): This is necessary for older versions of matrix-rust-sdk
|
||||
"org.matrix.cross_signing_reset": {
|
||||
"url": ticket_url,
|
||||
}
|
||||
}))
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
UiaaSessionMetadata::OAuth {
|
||||
localpart: initiator.user_id.localpart().to_owned(),
|
||||
|
||||
@@ -54,6 +54,7 @@ pub async fn set_dehydrated_device(&self, user_id: &UserId, request: Request) ->
|
||||
user_id,
|
||||
&request.device_id,
|
||||
"",
|
||||
None,
|
||||
request.initial_device_display_name.clone(),
|
||||
None,
|
||||
)
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
pub(super) mod dehydrated_device;
|
||||
|
||||
use std::{collections::BTreeMap, mem, net::IpAddr, sync::Arc};
|
||||
use std::{
|
||||
collections::BTreeMap,
|
||||
mem,
|
||||
net::IpAddr,
|
||||
sync::Arc,
|
||||
time::{Duration, SystemTime},
|
||||
};
|
||||
|
||||
use conduwuit::{
|
||||
Err, Error, Result, Server, debug_error, debug_warn, err, trace,
|
||||
@@ -26,7 +32,7 @@ use ruminuwuity::invite_permission_config::{FilterLevel, InvitePermissionConfigE
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
|
||||
use crate::{Dep, account_data, admin, appservice, globals, rooms};
|
||||
use crate::{Dep, account_data, admin, appservice, globals, oauth, rooms};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct UserSuspension {
|
||||
@@ -62,6 +68,7 @@ struct Services {
|
||||
admin: Dep<admin::Service>,
|
||||
appservice: Dep<appservice::Service>,
|
||||
globals: Dep<globals::Service>,
|
||||
oauth: Dep<oauth::Service>,
|
||||
state_accessor: Dep<rooms::state_accessor::Service>,
|
||||
state_cache: Dep<rooms::state_cache::Service>,
|
||||
}
|
||||
@@ -75,6 +82,7 @@ struct Data {
|
||||
logintoken_expiresatuserid: Arc<Map>,
|
||||
todeviceid_events: Arc<Map>,
|
||||
token_userdeviceid: Arc<Map>,
|
||||
userdeviceid_tokenexpires: Arc<Map>,
|
||||
userdeviceid_metadata: Arc<Map>,
|
||||
userdeviceid_token: Arc<Map>,
|
||||
userfilterid_filter: Arc<Map>,
|
||||
@@ -102,6 +110,7 @@ impl crate::Service for Service {
|
||||
admin: args.depend::<admin::Service>("admin"),
|
||||
appservice: args.depend::<appservice::Service>("appservice"),
|
||||
globals: args.depend::<globals::Service>("globals"),
|
||||
oauth: args.depend::<oauth::Service>("oauth"),
|
||||
state_accessor: args
|
||||
.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
|
||||
state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"),
|
||||
@@ -131,6 +140,7 @@ impl crate::Service for Service {
|
||||
userid_selfsigningkeyid: args.db["userid_selfsigningkeyid"].clone(),
|
||||
userid_usersigningkeyid: args.db["userid_usersigningkeyid"].clone(),
|
||||
useridprofilekey_value: args.db["useridprofilekey_value"].clone(),
|
||||
userdeviceid_tokenexpires: args.db["userdeviceid_tokenexpires"].clone(),
|
||||
},
|
||||
}))
|
||||
}
|
||||
@@ -337,8 +347,37 @@ impl Service {
|
||||
pub async fn count(&self) -> usize { self.db.userid_password.count().await }
|
||||
|
||||
/// Find out which user an access token belongs to.
|
||||
pub async fn find_from_token(&self, token: &str) -> Result<(OwnedUserId, OwnedDeviceId)> {
|
||||
self.db.token_userdeviceid.get(token).await.deserialized()
|
||||
pub async fn find_from_token(&self, token: &str) -> Option<(OwnedUserId, OwnedDeviceId)> {
|
||||
let user = self
|
||||
.db
|
||||
.token_userdeviceid
|
||||
.get(token)
|
||||
.await
|
||||
.deserialized()
|
||||
.ok();
|
||||
|
||||
// Check if the token has expired
|
||||
if let Some(user) = &user {
|
||||
if let Some(expires) = self
|
||||
.db
|
||||
.userdeviceid_tokenexpires
|
||||
.qry(user)
|
||||
.await
|
||||
.deserialized::<u64>()
|
||||
.ok()
|
||||
.map(Duration::from_secs)
|
||||
{
|
||||
let expires_at = SystemTime::UNIX_EPOCH
|
||||
.checked_add(expires)
|
||||
.expect("expiry time should not overflow SystemTime");
|
||||
|
||||
if SystemTime::now() > expires_at {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
user
|
||||
}
|
||||
|
||||
/// Returns an iterator over all users on this homeserver.
|
||||
@@ -434,6 +473,7 @@ impl Service {
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
token: &str,
|
||||
token_max_age: Option<Duration>,
|
||||
initial_device_display_name: Option<String>,
|
||||
client_ip: Option<String>,
|
||||
) -> Result<()> {
|
||||
@@ -451,7 +491,8 @@ impl Service {
|
||||
|
||||
increment(&self.db.userid_devicelistversion, user_id.as_bytes());
|
||||
self.db.userdeviceid_metadata.put(key, Json(device));
|
||||
self.set_token(user_id, device_id, token).await
|
||||
self.set_token(user_id, device_id, token, token_max_age)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Removes a device from a user.
|
||||
@@ -467,6 +508,7 @@ impl Service {
|
||||
if let Ok(old_token) = self.db.userdeviceid_token.qry(&userdeviceid).await {
|
||||
self.db.userdeviceid_token.del(userdeviceid);
|
||||
self.db.token_userdeviceid.remove(&old_token);
|
||||
self.db.userdeviceid_tokenexpires.del(userdeviceid);
|
||||
}
|
||||
|
||||
// Remove todevice events
|
||||
@@ -480,6 +522,9 @@ impl Service {
|
||||
|
||||
// TODO: Remove onetimekeys
|
||||
|
||||
// Remove OAuth session information
|
||||
self.services.oauth.remove_session(user_id, device_id).await;
|
||||
|
||||
increment(&self.db.userid_devicelistversion, user_id.as_bytes());
|
||||
|
||||
self.db.userdeviceid_metadata.del(userdeviceid);
|
||||
@@ -535,6 +580,7 @@ impl Service {
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
token: &str,
|
||||
token_max_age: Option<Duration>,
|
||||
) -> Result<()> {
|
||||
let key = (user_id, device_id);
|
||||
if self.db.userdeviceid_metadata.qry(&key).await.is_err() {
|
||||
@@ -561,6 +607,7 @@ impl Service {
|
||||
// Remove old token
|
||||
if let Ok(old_token) = self.db.userdeviceid_token.qry(&key).await {
|
||||
self.db.token_userdeviceid.remove(&old_token);
|
||||
self.db.userdeviceid_tokenexpires.remove(&old_token);
|
||||
// It will be removed from userdeviceid_token by the insert later
|
||||
}
|
||||
|
||||
@@ -568,6 +615,18 @@ impl Service {
|
||||
self.db.userdeviceid_token.put_raw(key, token);
|
||||
self.db.token_userdeviceid.raw_put(token, key);
|
||||
|
||||
if let Some(max_age) = token_max_age {
|
||||
let expires = SystemTime::now()
|
||||
.duration_since(SystemTime::UNIX_EPOCH)
|
||||
.expect("system time should not be before the epoch")
|
||||
.saturating_add(max_age)
|
||||
.as_secs();
|
||||
|
||||
self.db.userdeviceid_tokenexpires.put(key, expires);
|
||||
} else {
|
||||
self.db.userdeviceid_tokenexpires.del(key);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
+2
-1
@@ -43,7 +43,8 @@ validator = { version = "0.20.0", features = ["derive"] }
|
||||
tower-sec-fetch = { version = "0.1.2", features = ["tracing"] }
|
||||
tower-sessions = { version = "0.15.0", default-features = false, features = ["axum-core"] }
|
||||
tower-sessions-core = { version = "0.15.0", features = ["deletion-task"] }
|
||||
serde_urlencoded = "0.7.1"
|
||||
serde_urlencoded.workspace = true
|
||||
url.workspace = true
|
||||
|
||||
[build-dependencies]
|
||||
memory-serve = "2.1.0"
|
||||
|
||||
+2
-1
@@ -127,6 +127,7 @@ pub fn build(services: &Services) -> Router<state::State> {
|
||||
Router::new()
|
||||
.nest("/account/", account::build())
|
||||
.merge(debug::build())
|
||||
.nest("/oauth2/", oauth::build())
|
||||
.merge(resources::build())
|
||||
.merge(threepid::build())
|
||||
.fallback(async || WebError::NotFound),
|
||||
@@ -145,7 +146,7 @@ pub fn build(services: &Services) -> Router<state::State> {
|
||||
}))
|
||||
.layer(SetResponseHeaderLayer::if_not_present(
|
||||
header::CONTENT_SECURITY_POLICY,
|
||||
HeaderValue::from_static("default-src 'self'; img-src 'self' data:;"),
|
||||
HeaderValue::from_static("default-src 'self'; img-src 'self' https: data:;"),
|
||||
))
|
||||
.layer(SecFetchLayer::new(|policy| {
|
||||
policy.allow_safe_methods().reject_missing_metadata();
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::time::SystemTime;
|
||||
|
||||
use axum::{
|
||||
Router,
|
||||
extract::{Query, State},
|
||||
extract::{Query, RawQuery, State},
|
||||
response::{IntoResponse, Redirect},
|
||||
routing::{get, on},
|
||||
};
|
||||
@@ -15,11 +15,11 @@ use serde::Deserialize;
|
||||
use tower_sessions::Session;
|
||||
|
||||
use crate::{
|
||||
WebError,
|
||||
ROUTE_PREFIX, WebError,
|
||||
extract::{Expect, PostForm},
|
||||
pages::{GET_POST, Result, components::UserCard},
|
||||
response,
|
||||
session::{LoginQuery, User, UserSession},
|
||||
session::{LoginQuery, LoginTarget, User, UserSession},
|
||||
template,
|
||||
};
|
||||
|
||||
@@ -32,6 +32,7 @@ pub(crate) fn build() -> Router<crate::State> {
|
||||
template! {
|
||||
struct Login use "login.html.j2" {
|
||||
body: LoginBody,
|
||||
has_next: bool,
|
||||
login_error: Option<String>
|
||||
}
|
||||
}
|
||||
@@ -54,11 +55,12 @@ struct LoginForm {
|
||||
|
||||
async fn route_login(
|
||||
State(services): State<crate::State>,
|
||||
Expect(Query(query)): Expect<Query<LoginQuery>>,
|
||||
Expect(Query(LoginQuery { next, reauthenticate })): Expect<Query<LoginQuery>>,
|
||||
session_store: Session,
|
||||
user: User,
|
||||
PostForm(form): PostForm<LoginForm>,
|
||||
) -> Result {
|
||||
let next = next.unwrap_or_default();
|
||||
let user_id = user.into_session().map(|session| session.user_id);
|
||||
|
||||
let body = match &user_id {
|
||||
@@ -66,8 +68,8 @@ async fn route_login(
|
||||
server_name: services.globals.server_name().to_string(),
|
||||
},
|
||||
| Some(user_id) => {
|
||||
if !query.reauthenticate {
|
||||
return response!(Redirect::to(&query.next.target_path()));
|
||||
if !reauthenticate {
|
||||
return response!(Redirect::to(&next.target_path()));
|
||||
}
|
||||
|
||||
let user_card = UserCard::for_local_user(&services, user_id.to_owned()).await;
|
||||
@@ -76,7 +78,7 @@ async fn route_login(
|
||||
},
|
||||
};
|
||||
|
||||
let mut template = Login::new(&services, body, None);
|
||||
let mut template = Login::new(&services, body, next != LoginTarget::Account, None);
|
||||
|
||||
if let Some(form) = form {
|
||||
let login_result = match (user_id, form.identifier) {
|
||||
@@ -86,8 +88,6 @@ async fn route_login(
|
||||
},
|
||||
| (None, Some(identifier)) => {
|
||||
// The user isn't authenticated, we need to log them in
|
||||
// Yes, this does parse the email twice (handle_login does it again). I don't
|
||||
// think this really needs to be optimized.
|
||||
let identifier = if identifier.parse::<lettre::Address>().is_ok() {
|
||||
UserIdentifier::Email(EmailUserIdentifier::new(identifier))
|
||||
} else {
|
||||
@@ -123,14 +123,14 @@ async fn route_login(
|
||||
.await
|
||||
.expect("should be able to serialize user session");
|
||||
|
||||
return response!(Redirect::to(&query.next.target_path()));
|
||||
return response!(Redirect::to(&next.target_path()));
|
||||
}
|
||||
|
||||
response!(template)
|
||||
}
|
||||
|
||||
async fn get_logout(session: Session) -> impl IntoResponse {
|
||||
async fn get_logout(session: Session, RawQuery(query): RawQuery) -> impl IntoResponse {
|
||||
let _ = session.remove::<OwnedUserId>(User::KEY).await;
|
||||
|
||||
Redirect::to("/_continuwuity/account/")
|
||||
Redirect::to(&format!("{}/account/login?{}", ROUTE_PREFIX, query.unwrap_or_default()))
|
||||
}
|
||||
|
||||
@@ -61,12 +61,14 @@ async fn get_account(
|
||||
|
||||
let user_card = UserCard::for_local_user(&services, user_id.clone()).await;
|
||||
|
||||
let devices = services
|
||||
let mut devices: Vec<_> = services
|
||||
.users
|
||||
.all_device_ids(&user_id)
|
||||
.then(async |device_id| DeviceCard::for_device(&services, &user_id, device_id).await)
|
||||
.collect()
|
||||
.await;
|
||||
|
||||
devices.sort_unstable_by(|a, b| a.last_seen_ts.cmp(&b.last_seen_ts).reverse());
|
||||
|
||||
response!(Account::new(&services, user_card, email_requirement, email, devices))
|
||||
}
|
||||
|
||||
@@ -4,36 +4,26 @@ use askama::{Template, filters::HtmlSafe};
|
||||
use base64::Engine;
|
||||
use conduwuit_core::{result::FlatOk, utils};
|
||||
use conduwuit_service::{Services, media::mxc::Mxc, oauth::client_metadata::ClientMetadata};
|
||||
use ruma::{OwnedDeviceId, OwnedUserId, UserId};
|
||||
use ruma::{MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedUserId, UserId};
|
||||
|
||||
pub(super) mod form;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(super) enum AvatarType<'a> {
|
||||
pub(super) enum AvatarType {
|
||||
Initial(char),
|
||||
Image(&'a str),
|
||||
Image(String),
|
||||
}
|
||||
|
||||
#[derive(Debug, Template)]
|
||||
#[template(path = "_components/avatar.html.j2")]
|
||||
pub(super) struct Avatar<'a> {
|
||||
pub(super) avatar_type: AvatarType<'a>,
|
||||
pub(super) struct Avatar {
|
||||
pub(super) avatar_type: AvatarType,
|
||||
}
|
||||
|
||||
impl HtmlSafe for Avatar<'_> {}
|
||||
impl HtmlSafe for Avatar {}
|
||||
|
||||
#[derive(Debug, Template)]
|
||||
#[template(path = "_components/user_card.html.j2")]
|
||||
pub(super) struct UserCard {
|
||||
pub user_id: OwnedUserId,
|
||||
pub display_name: Option<String>,
|
||||
pub avatar_src: Option<String>,
|
||||
}
|
||||
|
||||
impl HtmlSafe for UserCard {}
|
||||
|
||||
impl UserCard {
|
||||
pub(super) async fn for_local_user(services: &Services, user_id: OwnedUserId) -> Self {
|
||||
impl Avatar {
|
||||
pub(super) async fn for_local_user(services: &Services, user_id: &UserId) -> Self {
|
||||
let display_name = services.users.displayname(&user_id).await.ok();
|
||||
|
||||
let avatar_src = async {
|
||||
@@ -56,33 +46,48 @@ impl UserCard {
|
||||
}
|
||||
.await;
|
||||
|
||||
Self { user_id, display_name, avatar_src }
|
||||
}
|
||||
|
||||
fn avatar(&self) -> Avatar<'_> {
|
||||
let avatar_type = if let Some(ref avatar_src) = self.avatar_src {
|
||||
let avatar_type = if let Some(avatar_src) = avatar_src {
|
||||
AvatarType::Image(avatar_src)
|
||||
} else if let Some(initial) = self
|
||||
.display_name
|
||||
} else if let Some(initial) = display_name
|
||||
.as_ref()
|
||||
.and_then(|display_name| display_name.chars().next())
|
||||
{
|
||||
AvatarType::Initial(initial)
|
||||
} else {
|
||||
AvatarType::Initial(self.user_id.localpart().chars().next().unwrap())
|
||||
AvatarType::Initial(user_id.localpart().chars().next().unwrap())
|
||||
};
|
||||
|
||||
Avatar { avatar_type }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Template)]
|
||||
#[template(path = "_components/user_card.html.j2")]
|
||||
pub(super) struct UserCard {
|
||||
pub user_id: OwnedUserId,
|
||||
pub display_name: Option<String>,
|
||||
pub avatar: Avatar,
|
||||
}
|
||||
|
||||
impl HtmlSafe for UserCard {}
|
||||
|
||||
impl UserCard {
|
||||
pub(super) async fn for_local_user(services: &Services, user_id: OwnedUserId) -> Self {
|
||||
let display_name = services.users.displayname(&user_id).await.ok();
|
||||
let avatar = Avatar::for_local_user(services, &user_id).await;
|
||||
|
||||
Self { user_id, display_name, avatar }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Template)]
|
||||
#[template(path = "_components/device_card.html.j2")]
|
||||
pub(super) struct DeviceCard {
|
||||
pub device_id: OwnedDeviceId,
|
||||
pub display_name: Option<String>,
|
||||
pub avatar_src: Option<String>,
|
||||
pub avatar: Avatar,
|
||||
pub last_active: String,
|
||||
pub last_seen_ts: Option<u64>,
|
||||
pub oauth_metadata: Option<ClientMetadata>,
|
||||
}
|
||||
|
||||
@@ -101,12 +106,15 @@ impl DeviceCard {
|
||||
.ok();
|
||||
|
||||
let oauth_metadata = async {
|
||||
let client_id = services.oauth.get_client_id_for_device(&device_id).await?;
|
||||
let client_id = services
|
||||
.oauth
|
||||
.get_client_id_for_device(user_id, &device_id)
|
||||
.await?;
|
||||
|
||||
Some(
|
||||
services
|
||||
.oauth
|
||||
.get_client_registration(&client_id)
|
||||
.get_client_metadata(&client_id)
|
||||
.await
|
||||
.expect("client should exist"),
|
||||
)
|
||||
@@ -122,53 +130,51 @@ impl DeviceCard {
|
||||
.and_then(|device| device.display_name.clone())
|
||||
});
|
||||
|
||||
let avatar_src = oauth_metadata
|
||||
.as_ref()
|
||||
.and_then(|metadata| metadata.logo_uri.as_ref())
|
||||
.map(|uri| uri.as_str().to_owned());
|
||||
let avatar = {
|
||||
let avatar_src = oauth_metadata
|
||||
.as_ref()
|
||||
.and_then(|metadata| metadata.logo_uri.as_ref())
|
||||
.map(|uri| uri.as_str().to_owned());
|
||||
|
||||
let last_active = device
|
||||
.as_ref()
|
||||
.and_then(|device| device.last_seen_ts)
|
||||
.map_or_else(
|
||||
|| "unknown".to_owned(),
|
||||
|active| {
|
||||
active
|
||||
.to_system_time()
|
||||
.and_then(|t| SystemTime::now().duration_since(t).ok())
|
||||
.map_or_else(
|
||||
|| "now".to_owned(),
|
||||
|duration| format!("{} ago", utils::time::pretty(duration)),
|
||||
)
|
||||
},
|
||||
);
|
||||
let avatar_type = if let Some(avatar_src) = avatar_src {
|
||||
AvatarType::Image(avatar_src)
|
||||
} else if let Some(initial) =
|
||||
display_name.as_ref().and_then(|name| name.chars().next())
|
||||
{
|
||||
if oauth_metadata.is_some() {
|
||||
AvatarType::Initial(initial)
|
||||
} else {
|
||||
AvatarType::Initial('❖')
|
||||
}
|
||||
} else {
|
||||
AvatarType::Initial('?')
|
||||
};
|
||||
|
||||
Avatar { avatar_type }
|
||||
};
|
||||
|
||||
let last_seen_ts = device.as_ref().and_then(|device| device.last_seen_ts);
|
||||
|
||||
let last_active = last_seen_ts.map_or_else(
|
||||
|| "unknown".to_owned(),
|
||||
|last_seen_ts| {
|
||||
last_seen_ts
|
||||
.to_system_time()
|
||||
.and_then(|t| SystemTime::now().duration_since(t).ok())
|
||||
.map_or_else(
|
||||
|| "now".to_owned(),
|
||||
|duration| format!("{} ago", utils::time::pretty(duration)),
|
||||
)
|
||||
},
|
||||
);
|
||||
|
||||
Self {
|
||||
device_id,
|
||||
display_name,
|
||||
avatar_src,
|
||||
avatar,
|
||||
last_active,
|
||||
last_seen_ts: last_seen_ts.map(|last_seen_ts| last_seen_ts.as_secs().into()),
|
||||
oauth_metadata,
|
||||
}
|
||||
}
|
||||
|
||||
fn avatar(&self) -> Avatar<'_> {
|
||||
let avatar_type = if let Some(avatar_src) = &self.avatar_src {
|
||||
AvatarType::Image(avatar_src.as_str())
|
||||
} else if let Some(initial) = self
|
||||
.display_name
|
||||
.as_ref()
|
||||
.and_then(|name| name.chars().next())
|
||||
{
|
||||
if self.oauth_metadata.is_some() {
|
||||
AvatarType::Initial(initial)
|
||||
} else {
|
||||
AvatarType::Initial('❖')
|
||||
}
|
||||
} else {
|
||||
AvatarType::Initial('?')
|
||||
};
|
||||
|
||||
Avatar { avatar_type }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ pub(super) mod account;
|
||||
mod components;
|
||||
pub(super) mod debug;
|
||||
pub(super) mod index;
|
||||
pub(super) mod oauth;
|
||||
pub(super) mod resources;
|
||||
pub(super) mod threepid;
|
||||
|
||||
|
||||
@@ -0,0 +1,113 @@
|
||||
use std::collections::BTreeSet;
|
||||
|
||||
use axum::{
|
||||
Router,
|
||||
extract::{Query, State},
|
||||
response::{IntoResponse, Redirect},
|
||||
routing::on,
|
||||
};
|
||||
use conduwuit_service::{
|
||||
oauth::{
|
||||
client_metadata::{self, ClientMetadata},
|
||||
grant::{AuthorizationCodeQuery, Scope},
|
||||
},
|
||||
rooms::user,
|
||||
};
|
||||
use ruma::{OwnedDeviceId, OwnedUserId};
|
||||
use serde::Deserialize;
|
||||
use url::Url;
|
||||
|
||||
use crate::{
|
||||
WebError,
|
||||
extract::{Expect, PostForm},
|
||||
pages::{
|
||||
GET_POST, Result,
|
||||
components::{Avatar, AvatarType},
|
||||
},
|
||||
response,
|
||||
session::{LoginQuery, LoginTarget, User},
|
||||
template,
|
||||
};
|
||||
|
||||
pub(crate) fn build() -> Router<crate::State> {
|
||||
Router::new().route("/authorization_code", on(GET_POST, route_authorization_code))
|
||||
}
|
||||
|
||||
template! {
|
||||
struct Grant use "grant.html.j2" {
|
||||
logout_query: String,
|
||||
user_id: OwnedUserId,
|
||||
user_avatar: Avatar,
|
||||
client_uri: Url,
|
||||
client_name: String,
|
||||
client_avatar: Avatar,
|
||||
policy_uri: Option<Url>,
|
||||
tos_uri: Option<Url>,
|
||||
scopes: BTreeSet<Scope>
|
||||
}
|
||||
}
|
||||
|
||||
async fn route_authorization_code(
|
||||
State(services): State<crate::State>,
|
||||
user: User,
|
||||
Expect(Query(query)): Expect<Query<AuthorizationCodeQuery>>,
|
||||
PostForm(form): PostForm<()>,
|
||||
) -> Result {
|
||||
let user_id = user.expect(LoginTarget::AuthorizationCode(query.clone()))?;
|
||||
|
||||
if form.is_some() {
|
||||
let redirect_uri = services
|
||||
.oauth
|
||||
.request_authorization_code(user_id, query)
|
||||
.await
|
||||
.map_err(WebError::BadRequest)?;
|
||||
|
||||
return response!(Redirect::to(&redirect_uri));
|
||||
}
|
||||
|
||||
let Some(client) = services.oauth.get_client_metadata(&query.client_id).await else {
|
||||
return Err(WebError::BadRequest("Invalid client ID".to_owned()));
|
||||
};
|
||||
|
||||
let scopes = query.scope.to_scopes().map_err(WebError::BadRequest)?;
|
||||
|
||||
let client_name = if let Some(name) = &client.client_name {
|
||||
name
|
||||
} else {
|
||||
"Unknown application"
|
||||
}
|
||||
.to_owned();
|
||||
|
||||
let client_avatar = {
|
||||
let avatar_type = if let Some(logo) = &client.logo_uri {
|
||||
AvatarType::Image(logo.to_string())
|
||||
} else if let Some(name) = &client.client_name
|
||||
&& let Some(char) = name.chars().next()
|
||||
{
|
||||
AvatarType::Initial(char)
|
||||
} else {
|
||||
AvatarType::Initial('?')
|
||||
};
|
||||
|
||||
Avatar { avatar_type }
|
||||
};
|
||||
|
||||
let user_avatar = Avatar::for_local_user(&services, &user_id).await;
|
||||
|
||||
response!(Grant::new(
|
||||
&services,
|
||||
serde_urlencoded::to_string(LoginQuery {
|
||||
next: Some(LoginTarget::AuthorizationCode(query)),
|
||||
reauthenticate: false,
|
||||
})
|
||||
.unwrap(),
|
||||
user_id,
|
||||
user_avatar,
|
||||
client.client_uri.clone(),
|
||||
client_name,
|
||||
client_avatar,
|
||||
client.policy_uri.clone(),
|
||||
client.tos_uri.clone(),
|
||||
scopes,
|
||||
))
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
use axum::Router;
|
||||
|
||||
mod grant;
|
||||
|
||||
pub(crate) fn build() -> Router<crate::State> {
|
||||
#[allow(clippy::wildcard_imports)]
|
||||
use self::*;
|
||||
|
||||
Router::new().nest("/grant/", grant::build())
|
||||
}
|
||||
@@ -123,8 +123,9 @@ small.error {
|
||||
.panel {
|
||||
--preferred-width: 12rem + 40dvw;
|
||||
--maximum-width: 48rem;
|
||||
--minimum-width: 32rem;
|
||||
|
||||
width: min(clamp(24rem, var(--preferred-width), var(--maximum-width)), calc(100dvw - 3rem));
|
||||
width: min(clamp(var(--minimum-width), var(--preferred-width), var(--maximum-width)), calc(100dvw - 3rem));
|
||||
border-radius: var(--border-radius-lg);
|
||||
background-color: var(--panel-bg);
|
||||
padding-inline: 1.5rem;
|
||||
@@ -184,6 +185,10 @@ a, a:visited {
|
||||
color: oklch(from var(--c1) var(--name-lightness) c h);
|
||||
}
|
||||
|
||||
code {
|
||||
color: oklch(from var(--secondary) var(--name-lightness) c h);
|
||||
}
|
||||
|
||||
input, button, a.button {
|
||||
display: inline-block;
|
||||
padding: 0.5em;
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
background-color: var(--avatar-color);
|
||||
}
|
||||
|
||||
.green-avatar {
|
||||
.red-avatar {
|
||||
--avatar-color: var(--c1);
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
.avatars {
|
||||
justify-content: center;
|
||||
display: flex;
|
||||
flex-direction: row;
|
||||
|
||||
.separator {
|
||||
align-self: center;
|
||||
margin-inline: 1em;
|
||||
color: var(--secondary);
|
||||
font-size: x-large;
|
||||
font-weight: bold;
|
||||
user-select: none;
|
||||
}
|
||||
}
|
||||
|
||||
.identity {
|
||||
margin-block: 1em;
|
||||
color: var(--secondary);
|
||||
font-size: small;
|
||||
font-style: italic;
|
||||
text-align: center;
|
||||
}
|
||||
@@ -1,19 +1,25 @@
|
||||
<div class="card">
|
||||
{{ avatar() }}
|
||||
{{ avatar }}
|
||||
<div class="info">
|
||||
<p class="name">
|
||||
{% if let Some(display_name) = display_name %}
|
||||
{{ display_name }}
|
||||
{% if let Some(metadata) = oauth_metadata %}
|
||||
<a href="{{ metadata.client_uri }}">{{ display_name }}</a>
|
||||
{% else %}
|
||||
{{ display_name }}
|
||||
{% endif %}
|
||||
{% else %}
|
||||
Unknown device
|
||||
{% endif %}
|
||||
<span class="id">{{ device_id }}</span>
|
||||
<span class="id">
|
||||
• {{ device_id }}
|
||||
{% if oauth_metadata.is_none() %}
|
||||
(legacy)
|
||||
{% endif %}
|
||||
</span>
|
||||
</p>
|
||||
<p>
|
||||
Last active: {{ last_active }}
|
||||
{% if let Some(metadata) = oauth_metadata %}
|
||||
• <a href="{{ metadata.client_uri }}">Client information</a>
|
||||
{% endif %}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
<div class="card green-avatar">
|
||||
{{ avatar() }}
|
||||
<div class="card red-avatar">
|
||||
{{ avatar }}
|
||||
<div class="info">
|
||||
{% if let Some(display_name) = display_name %}
|
||||
<p class="name">{{ display_name }}</p>
|
||||
|
||||
@@ -0,0 +1,64 @@
|
||||
{% extends "_layout.html.j2" %}
|
||||
|
||||
{%- block head -%}
|
||||
<link rel="stylesheet" href="{{ crate::ROUTE_PREFIX }}/resources/grant.css">
|
||||
{%- endblock -%}
|
||||
|
||||
{%- block title -%}
|
||||
Authorize client
|
||||
{%- endblock -%}
|
||||
|
||||
{%- block content -%}
|
||||
<div class="panel narrow">
|
||||
<h1>Authorize {{ client_name }}</h1>
|
||||
<div class="avatars">
|
||||
<div class="red-avatar">
|
||||
{{ user_avatar }}
|
||||
</div>
|
||||
<div class="separator" aria-hidden>
|
||||
⇄
|
||||
</div>
|
||||
{{ client_avatar }}
|
||||
</div>
|
||||
<div class="identity">
|
||||
Signed in as <code>{{ user_id }}</code>. <a href="{{ crate::ROUTE_PREFIX }}/account/logout?{{ logout_query }}">Switch accounts</a>
|
||||
</div>
|
||||
<p>
|
||||
<b>{{ client_name }}</b> (<a href="{{ client_uri }}">{{ client_uri.domain().unwrap() }}</a>) would like
|
||||
your permission to:
|
||||
<ul>
|
||||
{% for scope in scopes %}
|
||||
{% match scope %}
|
||||
{% when Scope::ClientApi %}
|
||||
<li>Interact with Matrix on your behalf</li>
|
||||
{% when Scope::Device(_) %}
|
||||
<li>Connect to your Matrix account</li>
|
||||
{% endmatch %}
|
||||
{% endfor %}
|
||||
</ul>
|
||||
</p>
|
||||
{% match (&policy_uri, &tos_uri) %}
|
||||
{% when (Some(policy_uri), Some(tos_uri)) %}
|
||||
<p>
|
||||
{{ client_name }}'s <a href="{{ policy_uri }}">policies</a>
|
||||
and <a href="{{ tos_uri }}">terms of service</a> apply.
|
||||
</p>
|
||||
{% when (Some(policy_uri), None) %}
|
||||
<p>
|
||||
{{ client_name }}'s <a href="{{ policy_uri }}">policies</a> apply.
|
||||
</p>
|
||||
{% when (None, Some(tos_uri)) %}
|
||||
<p>
|
||||
{{ client_name }}'s <a href="{{ tos_uri }}">terms of service</a> apply.
|
||||
</p>
|
||||
{% when (None, None) %}
|
||||
<p>
|
||||
Make sure you trust {{ client_name }} with access to your data.
|
||||
</p>
|
||||
{% endmatch %}
|
||||
|
||||
<form method="post">
|
||||
<button type="submit">Continue</button>
|
||||
</form>
|
||||
</div>
|
||||
{%- endblock -%}
|
||||
@@ -13,7 +13,11 @@ Log in
|
||||
{% match body %}
|
||||
{% when LoginBody::Unauthenticated { server_name } %}
|
||||
<h1 class="with-matrix-icon">
|
||||
Log in to Matrix
|
||||
{% if has_next %}
|
||||
Log in to continue
|
||||
{% else %}
|
||||
Log in to Matrix
|
||||
{% endif %}
|
||||
<a href="https://matrix.org" target="_blank" noreferer>
|
||||
<img class="matrix-icon" alt="Matrix logo" aria-ignore src="{{ crate::ROUTE_PREFIX }}/resources/matrix-icon.svg">
|
||||
</a>
|
||||
|
||||
+37
-12
@@ -1,8 +1,14 @@
|
||||
use std::time::{Duration, SystemTime};
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
collections::HashMap,
|
||||
mem::discriminant,
|
||||
time::{Duration, SystemTime},
|
||||
};
|
||||
|
||||
use axum::{extract::FromRequestParts, http::request::Parts};
|
||||
use conduwuit_service::oauth::grant::AuthorizationCodeQuery;
|
||||
use ruma::{OwnedUserId, UserId};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde::{Deserialize, Deserializer, Serialize};
|
||||
use tower_sessions::Session;
|
||||
|
||||
use crate::{ROUTE_PREFIX, WebError};
|
||||
@@ -12,7 +18,7 @@ pub(crate) mod store;
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub(crate) struct LoginQuery {
|
||||
#[serde(flatten)]
|
||||
pub next: LoginTarget,
|
||||
pub next: Option<LoginTarget>,
|
||||
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
|
||||
pub reauthenticate: bool,
|
||||
}
|
||||
@@ -20,6 +26,7 @@ pub(crate) struct LoginQuery {
|
||||
#[derive(Debug, Default, Deserialize, Serialize)]
|
||||
#[serde(tag = "next", rename_all = "snake_case")]
|
||||
pub(crate) enum LoginTarget {
|
||||
AuthorizationCode(AuthorizationCodeQuery),
|
||||
#[default]
|
||||
Account,
|
||||
ChangePassword,
|
||||
@@ -28,14 +35,23 @@ pub(crate) enum LoginTarget {
|
||||
Deactivate,
|
||||
}
|
||||
|
||||
impl PartialEq for LoginTarget {
|
||||
fn eq(&self, other: &Self) -> bool { discriminant(self) == discriminant(other) }
|
||||
}
|
||||
|
||||
impl LoginTarget {
|
||||
pub(crate) fn target_path(&self) -> String {
|
||||
let path = match self {
|
||||
| Self::Account => "account/",
|
||||
| Self::ChangePassword => "account/password/change",
|
||||
| Self::ChangeEmail => "account/email/change/",
|
||||
| Self::CrossSigningReset => "account/cross_signing_reset",
|
||||
| Self::Deactivate => "account/deactivate",
|
||||
let path: Cow<'_, str> = match self {
|
||||
| Self::AuthorizationCode(code) => format!(
|
||||
"oauth2/grant/authorization_code?{}",
|
||||
serde_urlencoded::to_string(code).unwrap()
|
||||
)
|
||||
.into(),
|
||||
| Self::Account => "account/".into(),
|
||||
| Self::ChangePassword => "account/password/change".into(),
|
||||
| Self::ChangeEmail => "account/email/change/".into(),
|
||||
| Self::CrossSigningReset => "account/cross_signing_reset".into(),
|
||||
| Self::Deactivate => "account/deactivate".into(),
|
||||
};
|
||||
|
||||
format!("{ROUTE_PREFIX}/{path}")
|
||||
@@ -80,7 +96,10 @@ impl User {
|
||||
if let Some(session) = self.0 {
|
||||
Ok(session.user_id)
|
||||
} else {
|
||||
Err(WebError::LoginRequired(LoginQuery { next: or_else, reauthenticate: false }))
|
||||
Err(WebError::LoginRequired(LoginQuery {
|
||||
next: Some(or_else),
|
||||
reauthenticate: false,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -91,10 +110,16 @@ impl User {
|
||||
if session.is_recent() {
|
||||
Ok(session.user_id)
|
||||
} else {
|
||||
Err(WebError::LoginRequired(LoginQuery { next: or_else, reauthenticate: true }))
|
||||
Err(WebError::LoginRequired(LoginQuery {
|
||||
next: Some(or_else),
|
||||
reauthenticate: true,
|
||||
}))
|
||||
}
|
||||
} else {
|
||||
Err(WebError::LoginRequired(LoginQuery { next: or_else, reauthenticate: false }))
|
||||
Err(WebError::LoginRequired(LoginQuery {
|
||||
next: Some(or_else),
|
||||
reauthenticate: false,
|
||||
}))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user