diff --git a/Cargo.lock b/Cargo.lock index 8cc54f925..ab9d2a626 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1088,6 +1088,7 @@ dependencies = [ "serde", "serde-saphyr", "serde_json", + "serde_urlencoded", "sha2 0.11.0", "termimad", "tokio", @@ -1125,6 +1126,7 @@ dependencies = [ "tower-sessions", "tower-sessions-core", "tracing", + "url", "validator", ] diff --git a/Cargo.toml b/Cargo.toml index dfce2a2be..96864f928 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -559,6 +559,9 @@ features = ["std"] [workspace.dependencies.nonzero_ext] version = "0.3.0" +[workspace.dependencies.serde_urlencoded] +version = "0.7.1" + # # Patches # diff --git a/src/api/client/account/register.rs b/src/api/client/account/register.rs index 1f3d17311..601dc3615 100644 --- a/src/api/client/account/register.rs +++ b/src/api/client/account/register.rs @@ -179,6 +179,7 @@ pub(crate) async fn register_route( &user_id, &device_id, &new_token, + None, body.initial_device_display_name.clone(), Some(client.to_string()), ) diff --git a/src/api/client/device.rs b/src/api/client/device.rs index 1bd798bb5..ce5ab469e 100644 --- a/src/api/client/device.rs +++ b/src/api/client/device.rs @@ -94,6 +94,7 @@ pub(crate) async fn update_device_route( &device_id, &appservice.registration.as_token, None, + None, Some(client.to_string()), ) .await?; diff --git a/src/api/client/oauth/mod.rs b/src/api/client/oauth/mod.rs index 9c564e6d3..4ba91cd75 100644 --- a/src/api/client/oauth/mod.rs +++ b/src/api/client/oauth/mod.rs @@ -1,17 +1,31 @@ mod register_client; mod server_metadata; +mod token; use axum::{ Json, Router, + extract::State, routing::method_routing::{get, post}, }; use serde_json::json; pub(crate) use server_metadata::*; -pub(crate) const BASE_PATH: &str = "/_continuwuity/oauth2/"; +const BASE_PATH: &str = "/_continuwuity/oauth2/"; pub(crate) fn router() -> Router { + Router::new().nest(BASE_PATH, oauth_router()) + // TODO(unspecced): used by old versions of the matrix-js-sdk + // .route("/.well-known/openid-configuration", get( + // async |State(services): State| { + // Json(authorization_server_metadata(&services).await) + // } + // )) +} + +fn oauth_router() -> Router { Router::new() .route("/client/register", post(register_client::register_client_route)) + // TODO(unspecced): used by old versions of the matrix-js-sdk .route("/client/keys.json", get(async || Json(json!({"keys": []})))) + .route("/grant/token", post(token::token_route)) } diff --git a/src/api/client/oauth/server_metadata.rs b/src/api/client/oauth/server_metadata.rs index e292a4377..a2f646f1b 100644 --- a/src/api/client/oauth/server_metadata.rs +++ b/src/api/client/oauth/server_metadata.rs @@ -1,7 +1,8 @@ use axum::extract::State; use conduwuit::Result; use ruma::{api::client::discovery::get_authorization_server_metadata, serde::Raw}; -use serde_json::json; +use serde_json::{Value, json}; +use service::Services; use crate::Ruma; @@ -9,13 +10,19 @@ pub(crate) async fn get_authorization_server_metadata_route( State(services): State, _body: Ruma, ) -> Result { + let metadata = Raw::new(&authorization_server_metadata(&services).await).unwrap(); + + Ok(get_authorization_server_metadata::v1::Response::new(metadata.cast_unchecked())) +} + +pub(crate) async fn authorization_server_metadata(services: &Services) -> Value { let endpoint_base = services .config .get_client_domain() .join(super::BASE_PATH) .unwrap(); - let metadata = Raw::new(&json!({ + json!({ "authorization_endpoint": endpoint_base.join("grant/authorization_code").unwrap(), "code_challenge_methods_supported": ["S256"], "grant_types_supported": ["authorization_code", "refresh_token"], @@ -27,8 +34,5 @@ pub(crate) async fn get_authorization_server_metadata_route( "response_types_supported": ["code"], "revocation_endpoint": endpoint_base.join("client/revoke").unwrap(), "token_endpoint": endpoint_base.join("grant/token").unwrap(), - })) - .unwrap(); - - Ok(get_authorization_server_metadata::v1::Response::new(metadata.cast_unchecked())) + }) } diff --git a/src/api/client/oauth/token.rs b/src/api/client/oauth/token.rs new file mode 100644 index 000000000..7b2d0a2d5 --- /dev/null +++ b/src/api/client/oauth/token.rs @@ -0,0 +1,13 @@ +use axum::{Form, Json, extract::State, response::IntoResponse}; +use http::StatusCode; +use service::oauth::grant::TokenRequest; + +pub(crate) async fn token_route( + State(services): State, + Form(request): Form, +) -> impl IntoResponse { + match services.oauth.issue_token(request).await { + | Ok(response) => Ok(Json(response).into_response()), + | Err(err) => Err((StatusCode::BAD_REQUEST, err.message())), + } +} diff --git a/src/api/client/session.rs b/src/api/client/session.rs index 3132613cf..dbd653033 100644 --- a/src/api/client/session.rs +++ b/src/api/client/session.rs @@ -202,7 +202,7 @@ pub(crate) async fn login_route( if device_exists { services .users - .set_token(&user_id, &device_id, &token) + .set_token(&user_id, &device_id, &token, None) .await?; } else { services @@ -211,6 +211,7 @@ pub(crate) async fn login_route( &user_id, &device_id, &token, + None, body.initial_device_display_name.clone(), Some(client.to_string()), ) diff --git a/src/api/client/sync/v5.rs b/src/api/client/sync/v5.rs index 900619c5f..8b52cbe09 100644 --- a/src/api/client/sync/v5.rs +++ b/src/api/client/sync/v5.rs @@ -69,7 +69,6 @@ pub(crate) async fn sync_events_v5_route( ClientIp(client_ip): ClientIp, body: Ruma, ) -> Result { - debug_assert!(DEFAULT_BUMP_TYPES.is_sorted(), "DEFAULT_BUMP_TYPES is not sorted"); let sender_user = body.identity.sender_user(); let sender_device = body.identity.expect_sender_device()?; diff --git a/src/api/router.rs b/src/api/router.rs index 2b6cb7a01..463ca98ff 100644 --- a/src/api/router.rs +++ b/src/api/router.rs @@ -187,7 +187,7 @@ pub fn build(router: Router, server: &Server) -> Router { .ruma_route(&client::get_rtc_transports) .ruma_route(&client::room_initial_sync_route) .ruma_route(&client::get_authorization_server_metadata_route) - .nest(client::oauth::BASE_PATH, client::oauth::router()) + .merge(client::oauth::router()) .route("/_conduwuit/server_version", get(client::conduwuit_server_version)) .route("/_continuwuity/server_version", get(client::conduwuit_server_version)) .ruma_route(&admin::rooms::ban::ban_room) diff --git a/src/api/router/auth.rs b/src/api/router/auth.rs index 3cbeb6ccb..f60c46382 100644 --- a/src/api/router/auth.rs +++ b/src/api/router/auth.rs @@ -153,7 +153,7 @@ impl CheckAuth for AccessToken { query: AuthQueryParams, route: TypeId, ) -> Result { - if let Ok((sender_user, sender_device)) = services.users.find_from_token(&output).await { + if let Some((sender_user, sender_device)) = services.users.find_from_token(&output).await { // Locked users can only use /logout and /logout/all if services .users diff --git a/src/core/utils/time.rs b/src/core/utils/time.rs index b8fad41d8..e606ca188 100644 --- a/src/core/utils/time.rs +++ b/src/core/utils/time.rs @@ -61,17 +61,23 @@ pub fn format(ts: SystemTime, str: &str) -> String { pub fn pretty(d: Duration) -> String { use Unit::*; - let fmt = |w, f, u| format!("{w}.{f} {u}"); - let gen64 = |w, f, u| fmt(w, (f * 100.0) as u32, u); - let gen128 = |w, f, u| gen64(u64::try_from(w).expect("u128 to u64"), f, u); + let fmt = |w, u| { + if w == 1 { + format!("{w} {u}") + } else { + format!("{w} {u}s") + } + }; + let gen64 = |w, u| fmt(w, u); + let gen128 = |w, u| gen64(u64::try_from(w).expect("u128 to u64"), u); match whole_and_frac(d) { - | (Days(whole), frac) => gen64(whole, frac, "days"), - | (Hours(whole), frac) => gen64(whole, frac, "hours"), - | (Mins(whole), frac) => gen64(whole, frac, "minutes"), - | (Secs(whole), frac) => gen64(whole, frac, "seconds"), - | (Millis(whole), frac) => gen128(whole, frac, "milliseconds"), - | (Micros(whole), frac) => gen128(whole, frac, "microseconds"), - | (Nanos(whole), frac) => gen128(whole, frac, "nanoseconds"), + | (Days(whole), _) => gen64(whole, "day"), + | (Hours(whole), _) => gen64(whole, "hour"), + | (Mins(whole), _) => gen64(whole, "minute"), + | (Secs(whole), _) => gen64(whole, "second"), + | (Millis(whole), _) => gen128(whole, "millisecond"), + | (Micros(whole), _) => gen128(whole, "microsecond"), + | (Nanos(whole), _) => gen128(whole, "nanosecond"), } } diff --git a/src/database/maps.rs b/src/database/maps.rs index ecdf4a8f7..0a9ef8b75 100644 --- a/src/database/maps.rs +++ b/src/database/maps.rs @@ -161,6 +161,10 @@ pub(super) static MAPS: &[Descriptor] = &[ name: "referencedevents", ..descriptor::RANDOM }, + Descriptor { + name: "refreshtoken_refreshtokeninfo", + ..descriptor::RANDOM_SMALL + }, Descriptor { name: "registrationtoken_info", ..descriptor::RANDOM_SMALL @@ -375,6 +379,14 @@ pub(super) static MAPS: &[Descriptor] = &[ name: "userdevicetxnid_response", ..descriptor::RANDOM_SMALL }, + Descriptor { + name: "userdeviceid_oauthsessioninfo", + ..descriptor::RANDOM_SMALL + }, + Descriptor { + name: "userdeviceid_tokenexpires", + ..descriptor::RANDOM_SMALL + }, Descriptor { name: "userfilterid_filter", ..descriptor::RANDOM_SMALL diff --git a/src/service/Cargo.toml b/src/service/Cargo.toml index a0568db0b..e302ff6c9 100644 --- a/src/service/Cargo.toml +++ b/src/service/Cargo.toml @@ -119,6 +119,7 @@ recaptcha-verify = { version = "0.2.0", default-features = false } reqwest_recaptcha = { package = "reqwest", version = "0.12.28", default-features = false, features = ["rustls-tls-native-roots-no-provider"] } # As long as recaptcha-verify's reqwest is outdated yansi.workspace = true lettre.workspace = true +serde_urlencoded.workspace = true [target.'cfg(all(unix, target_os = "linux"))'.dependencies] sd-notify.workspace = true diff --git a/src/service/appservice/mod.rs b/src/service/appservice/mod.rs index d492f3188..0d79d4505 100644 --- a/src/service/appservice/mod.rs +++ b/src/service/appservice/mod.rs @@ -67,7 +67,7 @@ impl crate::Service for Service { for (id, registration) in appservices { // During startup, resolve any token collisions in favour of appservices // by logging out conflicting user devices - if let Ok((user_id, device_id)) = self + if let Some((user_id, device_id)) = self .services .users .find_from_token(®istration.as_token) @@ -158,7 +158,7 @@ impl Service { .users .find_from_token(®istration.as_token) .await - .is_ok() + .is_some() { return Err(err!(Request(InvalidParam( "Cannot register appservice: The provided token is already in use by a user \ diff --git a/src/service/oauth/client_metadata.rs b/src/service/oauth/client_metadata.rs index 879ec1fbc..99d6435ce 100644 --- a/src/service/oauth/client_metadata.rs +++ b/src/service/oauth/client_metadata.rs @@ -38,7 +38,7 @@ pub struct ClientMetadata { } impl ClientMetadata { - const ACCEPTABLE_LOCALHOSTS: [&str; 3] = ["localhost", "127.0.0.1", "[::1]"]; + pub(super) const ACCEPTABLE_LOCALHOSTS: [&str; 3] = ["localhost", "127.0.0.1", "[::1]"]; pub(super) fn validate(&self) -> Result<(), &'static str> { let Some(client_domain) = self.client_uri.domain() else { @@ -137,6 +137,7 @@ pub enum GrantType { #[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Deserialize, Serialize)] #[serde(rename_all = "snake_case")] +#[non_exhaustive] pub enum ResponseType { Code, } diff --git a/src/service/oauth/grant.rs b/src/service/oauth/grant.rs new file mode 100644 index 000000000..261910501 --- /dev/null +++ b/src/service/oauth/grant.rs @@ -0,0 +1,150 @@ +use std::{ + collections::{BTreeSet, HashSet}, + fmt::Debug, + hash::Hash, + mem::discriminant, +}; + +use regex::Regex; +use ruma::OwnedDeviceId; +use serde::{Deserialize, Serialize}; +use url::Url; + +use super::client_metadata::ResponseType; + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct AuthorizationCodeQuery { + pub response_type: ResponseType, + pub client_id: String, + pub redirect_uri: Url, + pub scope: RawScopes, + pub state: String, + #[serde(default)] + pub response_mode: ResponseMode, + pub code_challenge: String, + pub code_challenge_method: CodeChallengeMethod, +} + +#[derive(Debug, Clone, Default, Deserialize, Serialize)] +#[serde(rename_all = "snake_case")] +#[non_exhaustive] +pub enum ResponseMode { + #[default] + // default for `code` response type, see https://openid.net/specs/oauth-v2-multiple-response-types-1_0.html#:~:text=Client%2E-,For,encoding%2E,-See + Query, + Fragment, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[non_exhaustive] +pub enum CodeChallengeMethod { + S256, +} + +#[derive(Debug, Clone, Deserialize, Serialize, PartialOrd, Ord)] +pub enum Scope { + Device(OwnedDeviceId), + ClientApi, +} + +impl PartialEq for Scope { + fn eq(&self, other: &Self) -> bool { discriminant(self) == discriminant(other) } +} + +impl Eq for Scope {} + +impl Hash for Scope { + fn hash(&self, state: &mut H) { discriminant(self).hash(state); } +} + +impl std::fmt::Display for Scope { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let urn = match self { + | Self::ClientApi => "urn:matrix:client:api:*".to_owned(), + | Self::Device(device_id) => format!("urn:matrix:client:device:{device_id}"), + }; + + f.write_str(&urn) + } +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct RawScopes(String); + +impl RawScopes { + pub fn to_scopes(&self) -> Result, String> { + let client_api_token_regex = + Regex::new(r"urn:matrix:(client|org.matrix.msc2967.client):api:\*").unwrap(); + let device_token_regex = Regex::new( + r"urn:matrix:(client|org.matrix.msc2967.client):device:([a-zA-Z0-9-._~]{5,})", + ) + .unwrap(); + + let mut scopes = BTreeSet::new(); + + for token in self.0.split(' ') { + let scope_was_new = { + if client_api_token_regex.is_match(token) { + scopes.insert(Scope::ClientApi) + } else if let Some(captures) = device_token_regex.captures(token) { + scopes.insert(Scope::Device(captures.get(2).unwrap().as_str().into())) + } else if token == "openid" { + // TODO(unspecced): Element sets this scope but doesn't use it for anything + true + } else { + return Err(format!("Invalid scope: {token}")); + } + }; + + if !scope_was_new { + return Err("Scope was specified more than once".to_owned()); + } + } + + Ok(scopes) + } +} + +#[derive(Serialize)] +pub struct AuthorizationCodeResponse { + pub state: String, + pub code: String, +} + +#[derive(Deserialize)] +#[serde(tag = "grant_type", rename_all = "snake_case")] +pub enum TokenRequest { + AuthorizationCode { + code: String, + redirect_uri: Url, + client_id: String, + code_verifier: String, + }, + RefreshToken { + client_id: String, + refresh_token: String, + }, +} + +impl TokenRequest { + pub fn client_id(&self) -> &str { + match self { + | Self::AuthorizationCode { client_id, .. } + | Self::RefreshToken { client_id, .. } => client_id, + } + } +} + +#[derive(Serialize)] +pub struct TokenResponse { + pub access_token: String, + pub token_type: TokenType, + pub expires_in: u64, + pub refresh_token: String, + pub scope: String, +} + +#[derive(Serialize)] +pub enum TokenType { + Bearer, +} diff --git a/src/service/oauth/mod.rs b/src/service/oauth/mod.rs index 9d85623ec..33c9dd40c 100644 --- a/src/service/oauth/mod.rs +++ b/src/service/oauth/mod.rs @@ -1,30 +1,93 @@ use std::{ - collections::HashMap, + collections::{BTreeMap, BTreeSet, HashMap}, sync::{Arc, Mutex}, time::{Duration, SystemTime}, }; use base64::Engine; -use conduwuit::{Result, utils::hash::sha256}; +use conduwuit::{ + Err, Result, err, info, + utils::{self, hash::sha256}, +}; use database::{Deserialized, Json, Map}; -use ruma::DeviceId; +use itertools::Itertools; +use ruma::{DeviceId, OwnedDeviceId, OwnedUserId, UserId}; +use serde::{Deserialize, Serialize}; +use url::Url; -use crate::{Dep, config, oauth::client_metadata::ClientMetadata}; +use crate::{ + Dep, config, + oauth::{ + client_metadata::{ApplicationType, ClientMetadata, ResponseType}, + grant::{ + AuthorizationCodeQuery, AuthorizationCodeResponse, CodeChallengeMethod, ResponseMode, + Scope, TokenRequest, TokenResponse, TokenType, + }, + }, + users, +}; pub mod client_metadata; +pub mod grant; pub struct Service { services: Services, db: Data, tickets: Mutex>>, + pending_code_grants: tokio::sync::Mutex>, } struct Data { clientid_clientmetadata: Arc, + userdeviceid_oauthsessioninfo: Arc, + refreshtoken_refreshtokeninfo: Arc, } struct Services { config: Dep, + users: Dep, +} + +#[derive(Deserialize, Serialize)] +struct SessionInfo { + client_id: String, + current_refresh_token: String, + scopes: BTreeSet, +} + +#[derive(Deserialize, Serialize)] +struct RefreshTokenInfo { + client_id: String, + user_id: OwnedUserId, + device_id: OwnedDeviceId, +} + +struct PendingCodeGrant { + authorizing_user: OwnedUserId, + requested_scopes: BTreeSet, + client_name: Option, + expected_client_id: String, + expected_redirect_uri: Url, + code_challenge: String, + requested_at: SystemTime, +} + +impl PendingCodeGrant { + const MAX_AGE: Duration = Duration::from_mins(1); + const RANDOM_CODE_LENGTH: usize = 32; + + #[must_use] + pub fn generate_code() -> String { utils::random_string(Self::RANDOM_CODE_LENGTH) } + + #[must_use] + pub fn is_valid_for(&self, client_id: &str) -> bool { + let now = SystemTime::now(); + + self.expected_client_id == client_id + && now + .duration_since(self.requested_at) + .is_ok_and(|age| age < Self::MAX_AGE) + } } /// A time-limited grant for a client to perform some sensitive action. @@ -49,11 +112,15 @@ impl crate::Service for Service { Ok(Arc::new(Self { services: Services { config: args.depend::("config"), + users: args.depend::("users"), }, db: Data { clientid_clientmetadata: args.db["clientid_clientmetadata"].clone(), + userdeviceid_oauthsessioninfo: args.db["userdeviceid_oauthsessioninfo"].clone(), + refreshtoken_refreshtokeninfo: args.db["refreshtoken_refreshtokeninfo"].clone(), }, tickets: Mutex::default(), + pending_code_grants: tokio::sync::Mutex::default(), })) } @@ -61,6 +128,11 @@ impl crate::Service for Service { } impl Service { + const ACCESS_TOKEN_MAX_AGE: Duration = Duration::from_hours(1); + const RANDOM_TOKEN_LENGTH: usize = 32; + + fn generate_token() -> String { utils::random_string(Self::RANDOM_TOKEN_LENGTH) } + pub async fn register_client( &self, metadata: &ClientMetadata, @@ -85,7 +157,7 @@ impl Service { Ok(client_id) } - pub async fn get_client_registration(&self, client_id: &str) -> Option { + pub async fn get_client_metadata(&self, client_id: &str) -> Option { self.db .clientid_clientmetadata .get(client_id) @@ -94,15 +166,304 @@ impl Service { .ok() } - pub async fn get_client_id_for_device(&self, _device_id: &DeviceId) -> Option { - None // TODO + pub async fn get_client_id_for_device( + &self, + user_id: &UserId, + device_id: &DeviceId, + ) -> Option { + self.db + .userdeviceid_oauthsessioninfo + .qry(&(user_id, device_id)) + .await + .deserialized::() + .ok() + .map(|session| session.client_id) + } + + pub async fn request_authorization_code( + &self, + authorizing_user: OwnedUserId, + query: AuthorizationCodeQuery, + ) -> Result { + let Some(client_metadata) = self.get_client_metadata(&query.client_id).await else { + return Err("Invalid client ID".to_owned()); + }; + + if !(client_metadata + .response_types + .contains(&query.response_type) + && matches!(query.response_type, ResponseType::Code)) + { + return Err("Invalid response type".to_owned()); + } + + if !matches!(query.code_challenge_method, CodeChallengeMethod::S256) { + return Err("Invalid code challenge type".to_owned()); + } + + { + let mut stripped_uri = query.redirect_uri.clone(); + + if client_metadata.application_type == ApplicationType::Native + && query + .redirect_uri + .host_str() + .is_some_and(|host| ClientMetadata::ACCEPTABLE_LOCALHOSTS.contains(&host)) + { + // Remove the port from localhost redirect URIs for native applications when + // checking if it's valid + stripped_uri.set_port(None).unwrap(); + } + + if !client_metadata.redirect_uris.contains(&stripped_uri) { + return Err("Invalid redirect URI".to_owned()); + } + } + + let requested_scopes = query.scope.to_scopes()?; + + let redirect_uri_query_separator = match query.response_mode { + | ResponseMode::Fragment => '#', + | ResponseMode::Query => '?', + }; + + let code = PendingCodeGrant::generate_code(); + + info!( + client_id = &query.client_id, + client_name = &client_metadata.client_name, + ?requested_scopes, + ?authorizing_user, + "Issuing oauth authorization code" + ); + + let redirect_uri = format!( + "{}{}{}", + query.redirect_uri, + redirect_uri_query_separator, + serde_urlencoded::to_string(AuthorizationCodeResponse { + state: query.state, + code: code.clone(), + }) + .unwrap(), + ); + + let pending_grant = PendingCodeGrant { + authorizing_user, + requested_scopes, + client_name: client_metadata.client_name, + expected_client_id: query.client_id, + expected_redirect_uri: query.redirect_uri, + code_challenge: query.code_challenge, + requested_at: SystemTime::now(), + }; + + self.pending_code_grants + .lock() + .await + .insert(code, pending_grant); + + Ok(redirect_uri) + } + + pub async fn issue_token(&self, request: TokenRequest) -> Result { + match request { + | TokenRequest::AuthorizationCode { + code, + redirect_uri, + client_id, + code_verifier, + } => { + let mut pending_grants = self.pending_code_grants.lock().await; + + let Some(pending_grant) = pending_grants + .remove(&code) + .filter(|grant| grant.is_valid_for(&client_id)) + else { + return Err!("Invalid code"); + }; + + if redirect_uri != pending_grant.expected_redirect_uri { + return Err!("Unexpected redirect uri"); + } + + let expected_code_challenge = + base64::prelude::BASE64_URL_SAFE_NO_PAD.encode(sha256::hash(&code_verifier)); + if expected_code_challenge != pending_grant.code_challenge { + return Err!("Invalid code challenge"); + } + + self.create_session( + pending_grant.authorizing_user, + pending_grant.requested_scopes, + pending_grant.client_name, + client_id, + ) + .await + }, + | TokenRequest::RefreshToken { client_id, refresh_token } => + self.refresh_session(client_id, refresh_token).await, + } + } + + async fn create_session( + &self, + authorizing_user: OwnedUserId, + requested_scopes: BTreeSet, + client_name: Option, + client_id: String, + ) -> Result { + let access_token = Self::generate_token(); + let refresh_token = Self::generate_token(); + + let device_id = requested_scopes + .iter() + .find_map(|scope| { + if let Scope::Device(device_id) = scope { + Some(device_id) + } else { + None + } + }) + .ok_or_else(|| err!("No device ID scope supplied"))?; + + self.services + .users + .create_device( + &authorizing_user, + device_id, + &access_token, + Some(Self::ACCESS_TOKEN_MAX_AGE), + client_name, + None, + ) + .await?; + + self.db.userdeviceid_oauthsessioninfo.put( + (&authorizing_user, device_id), + Json(SessionInfo { + client_id: client_id.clone(), + current_refresh_token: refresh_token.clone(), + scopes: requested_scopes.clone(), + }), + ); + + self.db.refreshtoken_refreshtokeninfo.raw_put( + &refresh_token, + Json(RefreshTokenInfo { + client_id: client_id.clone(), + user_id: authorizing_user.clone(), + device_id: device_id.to_owned(), + }), + ); + + info!( + ?client_id, + ?authorizing_user, + ?device_id, + ?requested_scopes, + "Created new oauth session" + ); + + Ok(TokenResponse { + access_token, + token_type: TokenType::Bearer, + expires_in: Self::ACCESS_TOKEN_MAX_AGE.as_secs(), + scope: requested_scopes.iter().join(" "), + refresh_token, + }) + } + + async fn refresh_session( + &self, + client_id: String, + refresh_token: String, + ) -> Result { + let Some(refresh_token_info) = self + .db + .refreshtoken_refreshtokeninfo + .get(&refresh_token) + .await + .deserialized::() + .ok() + else { + return Err!("Invalid refresh token"); + }; + + assert_eq!(&client_id, &refresh_token_info.client_id, "refresh token client id mismatch"); + + let mut session_info = self + .db + .userdeviceid_oauthsessioninfo + .qry(&(&refresh_token_info.user_id, &refresh_token_info.device_id)) + .await + .deserialized::() + .expect("session info should exist"); + + assert_eq!(&client_id, &session_info.client_id, "session info client id mismatch"); + + let new_access_token = Self::generate_token(); + let new_refresh_token = Self::generate_token(); + let scope = session_info.scopes.iter().join(" "); + session_info + .current_refresh_token + .clone_from(&new_refresh_token); + + self.services + .users + .set_token( + &refresh_token_info.user_id, + &refresh_token_info.device_id, + &new_access_token, + Some(Self::ACCESS_TOKEN_MAX_AGE), + ) + .await?; + + self.db.userdeviceid_oauthsessioninfo.put( + (&refresh_token_info.user_id, &refresh_token_info.device_id), + Json(session_info), + ); + + self.db.refreshtoken_refreshtokeninfo.remove(&refresh_token); + drop(refresh_token); + self.db + .refreshtoken_refreshtokeninfo + .raw_put(&new_refresh_token, Json(refresh_token_info)); + + Ok(TokenResponse { + access_token: new_access_token, + token_type: TokenType::Bearer, + expires_in: Self::ACCESS_TOKEN_MAX_AGE.as_secs(), + scope, + refresh_token: new_refresh_token, + }) + } + + pub async fn remove_session(&self, user_id: &UserId, device_id: &DeviceId) { + let session_info = self + .db + .userdeviceid_oauthsessioninfo + .qry(&(user_id, device_id)) + .await + .deserialized::() + .ok(); + + if let Some(session_info) = session_info { + self.db + .refreshtoken_refreshtokeninfo + .remove(&session_info.current_refresh_token); + self.db + .userdeviceid_oauthsessioninfo + .del(&(user_id, device_id)); + info!(?user_id, ?device_id, "Removed OAuth session"); + } } /// Issue a ticket for `localpart` to perform some action. pub fn issue_ticket(&self, localpart: String, ticket: OAuthTicket) { self.tickets .lock() - .expect("should be able to lock tickets") + .unwrap() .entry(localpart) .or_default() .insert(ticket, SystemTime::now()); @@ -114,7 +475,7 @@ impl Service { self.tickets .lock() - .expect("should be able to lock tickets") + .unwrap() .get_mut(localpart) .and_then(|tickets| tickets.remove(&ticket)) .is_some_and(|issued| { diff --git a/src/service/uiaa/mod.rs b/src/service/uiaa/mod.rs index cb7ced129..081cd58f7 100644 --- a/src/service/uiaa/mod.rs +++ b/src/service/uiaa/mod.rs @@ -253,17 +253,17 @@ impl Service { let mut info = assign!(UiaaInfo::new(flows), { params: Some(params), session: Some(session_id.clone()) }); let session_metadata = if let Some(initiator) = initiator { - let is_oauth = OptionFuture::from( - initiator.device_id.map(async |device_id| { - self - .services - .oauth - .get_client_id_for_device(device_id) - .await - }) - ) - .await - .is_some(); + let is_oauth = if let Some(device_id) = initiator.device_id { + self + .services + .oauth + .get_client_id_for_device(initiator.user_id, device_id) + .await + .is_some() + } else { + // Appservices never have oauth sessions + false + }; if is_oauth { if let Some(oauth_ticket) = initiator.oauth_ticket { @@ -279,7 +279,18 @@ impl Service { .unwrap(); info.flows = vec![AuthFlow::new(vec![AuthType::OAuth])]; - info.params = Some(to_raw_value(&json!({"url": ticket_url})).unwrap()); + info.params = Some( + to_raw_value(&json!({ + AuthType::OAuth.as_str(): { + "url": ticket_url, + }, + // TODO(compat): This is necessary for older versions of matrix-rust-sdk + "org.matrix.cross_signing_reset": { + "url": ticket_url, + } + })) + .unwrap(), + ); UiaaSessionMetadata::OAuth { localpart: initiator.user_id.localpart().to_owned(), diff --git a/src/service/users/dehydrated_device.rs b/src/service/users/dehydrated_device.rs index cacce5f6b..d54e0f368 100644 --- a/src/service/users/dehydrated_device.rs +++ b/src/service/users/dehydrated_device.rs @@ -54,6 +54,7 @@ pub async fn set_dehydrated_device(&self, user_id: &UserId, request: Request) -> user_id, &request.device_id, "", + None, request.initial_device_display_name.clone(), None, ) diff --git a/src/service/users/mod.rs b/src/service/users/mod.rs index 1948fbf08..609c7e02c 100644 --- a/src/service/users/mod.rs +++ b/src/service/users/mod.rs @@ -1,6 +1,12 @@ pub(super) mod dehydrated_device; -use std::{collections::BTreeMap, mem, net::IpAddr, sync::Arc}; +use std::{ + collections::BTreeMap, + mem, + net::IpAddr, + sync::Arc, + time::{Duration, SystemTime}, +}; use conduwuit::{ Err, Error, Result, Server, debug_error, debug_warn, err, trace, @@ -26,7 +32,7 @@ use ruminuwuity::invite_permission_config::{FilterLevel, InvitePermissionConfigE use serde::{Deserialize, Serialize}; use serde_json::json; -use crate::{Dep, account_data, admin, appservice, globals, rooms}; +use crate::{Dep, account_data, admin, appservice, globals, oauth, rooms}; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct UserSuspension { @@ -62,6 +68,7 @@ struct Services { admin: Dep, appservice: Dep, globals: Dep, + oauth: Dep, state_accessor: Dep, state_cache: Dep, } @@ -75,6 +82,7 @@ struct Data { logintoken_expiresatuserid: Arc, todeviceid_events: Arc, token_userdeviceid: Arc, + userdeviceid_tokenexpires: Arc, userdeviceid_metadata: Arc, userdeviceid_token: Arc, userfilterid_filter: Arc, @@ -102,6 +110,7 @@ impl crate::Service for Service { admin: args.depend::("admin"), appservice: args.depend::("appservice"), globals: args.depend::("globals"), + oauth: args.depend::("oauth"), state_accessor: args .depend::("rooms::state_accessor"), state_cache: args.depend::("rooms::state_cache"), @@ -131,6 +140,7 @@ impl crate::Service for Service { userid_selfsigningkeyid: args.db["userid_selfsigningkeyid"].clone(), userid_usersigningkeyid: args.db["userid_usersigningkeyid"].clone(), useridprofilekey_value: args.db["useridprofilekey_value"].clone(), + userdeviceid_tokenexpires: args.db["userdeviceid_tokenexpires"].clone(), }, })) } @@ -337,8 +347,37 @@ impl Service { pub async fn count(&self) -> usize { self.db.userid_password.count().await } /// Find out which user an access token belongs to. - pub async fn find_from_token(&self, token: &str) -> Result<(OwnedUserId, OwnedDeviceId)> { - self.db.token_userdeviceid.get(token).await.deserialized() + pub async fn find_from_token(&self, token: &str) -> Option<(OwnedUserId, OwnedDeviceId)> { + let user = self + .db + .token_userdeviceid + .get(token) + .await + .deserialized() + .ok(); + + // Check if the token has expired + if let Some(user) = &user { + if let Some(expires) = self + .db + .userdeviceid_tokenexpires + .qry(user) + .await + .deserialized::() + .ok() + .map(Duration::from_secs) + { + let expires_at = SystemTime::UNIX_EPOCH + .checked_add(expires) + .expect("expiry time should not overflow SystemTime"); + + if SystemTime::now() > expires_at { + return None; + } + } + } + + user } /// Returns an iterator over all users on this homeserver. @@ -434,6 +473,7 @@ impl Service { user_id: &UserId, device_id: &DeviceId, token: &str, + token_max_age: Option, initial_device_display_name: Option, client_ip: Option, ) -> Result<()> { @@ -451,7 +491,8 @@ impl Service { increment(&self.db.userid_devicelistversion, user_id.as_bytes()); self.db.userdeviceid_metadata.put(key, Json(device)); - self.set_token(user_id, device_id, token).await + self.set_token(user_id, device_id, token, token_max_age) + .await } /// Removes a device from a user. @@ -467,6 +508,7 @@ impl Service { if let Ok(old_token) = self.db.userdeviceid_token.qry(&userdeviceid).await { self.db.userdeviceid_token.del(userdeviceid); self.db.token_userdeviceid.remove(&old_token); + self.db.userdeviceid_tokenexpires.del(userdeviceid); } // Remove todevice events @@ -480,6 +522,9 @@ impl Service { // TODO: Remove onetimekeys + // Remove OAuth session information + self.services.oauth.remove_session(user_id, device_id).await; + increment(&self.db.userid_devicelistversion, user_id.as_bytes()); self.db.userdeviceid_metadata.del(userdeviceid); @@ -535,6 +580,7 @@ impl Service { user_id: &UserId, device_id: &DeviceId, token: &str, + token_max_age: Option, ) -> Result<()> { let key = (user_id, device_id); if self.db.userdeviceid_metadata.qry(&key).await.is_err() { @@ -561,6 +607,7 @@ impl Service { // Remove old token if let Ok(old_token) = self.db.userdeviceid_token.qry(&key).await { self.db.token_userdeviceid.remove(&old_token); + self.db.userdeviceid_tokenexpires.remove(&old_token); // It will be removed from userdeviceid_token by the insert later } @@ -568,6 +615,18 @@ impl Service { self.db.userdeviceid_token.put_raw(key, token); self.db.token_userdeviceid.raw_put(token, key); + if let Some(max_age) = token_max_age { + let expires = SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .expect("system time should not be before the epoch") + .saturating_add(max_age) + .as_secs(); + + self.db.userdeviceid_tokenexpires.put(key, expires); + } else { + self.db.userdeviceid_tokenexpires.del(key); + } + Ok(()) } diff --git a/src/web/Cargo.toml b/src/web/Cargo.toml index 04396e959..767ec2c06 100644 --- a/src/web/Cargo.toml +++ b/src/web/Cargo.toml @@ -43,7 +43,8 @@ validator = { version = "0.20.0", features = ["derive"] } tower-sec-fetch = { version = "0.1.2", features = ["tracing"] } tower-sessions = { version = "0.15.0", default-features = false, features = ["axum-core"] } tower-sessions-core = { version = "0.15.0", features = ["deletion-task"] } -serde_urlencoded = "0.7.1" +serde_urlencoded.workspace = true +url.workspace = true [build-dependencies] memory-serve = "2.1.0" diff --git a/src/web/mod.rs b/src/web/mod.rs index f24b2be5e..ebe4b7cb6 100644 --- a/src/web/mod.rs +++ b/src/web/mod.rs @@ -127,6 +127,7 @@ pub fn build(services: &Services) -> Router { Router::new() .nest("/account/", account::build()) .merge(debug::build()) + .nest("/oauth2/", oauth::build()) .merge(resources::build()) .merge(threepid::build()) .fallback(async || WebError::NotFound), @@ -145,7 +146,7 @@ pub fn build(services: &Services) -> Router { })) .layer(SetResponseHeaderLayer::if_not_present( header::CONTENT_SECURITY_POLICY, - HeaderValue::from_static("default-src 'self'; img-src 'self' data:;"), + HeaderValue::from_static("default-src 'self'; img-src 'self' https: data:;"), )) .layer(SecFetchLayer::new(|policy| { policy.allow_safe_methods().reject_missing_metadata(); diff --git a/src/web/pages/account/login.rs b/src/web/pages/account/login.rs index d3772aa7d..918252504 100644 --- a/src/web/pages/account/login.rs +++ b/src/web/pages/account/login.rs @@ -2,7 +2,7 @@ use std::time::SystemTime; use axum::{ Router, - extract::{Query, State}, + extract::{Query, RawQuery, State}, response::{IntoResponse, Redirect}, routing::{get, on}, }; @@ -15,11 +15,11 @@ use serde::Deserialize; use tower_sessions::Session; use crate::{ - WebError, + ROUTE_PREFIX, WebError, extract::{Expect, PostForm}, pages::{GET_POST, Result, components::UserCard}, response, - session::{LoginQuery, User, UserSession}, + session::{LoginQuery, LoginTarget, User, UserSession}, template, }; @@ -32,6 +32,7 @@ pub(crate) fn build() -> Router { template! { struct Login use "login.html.j2" { body: LoginBody, + has_next: bool, login_error: Option } } @@ -54,11 +55,12 @@ struct LoginForm { async fn route_login( State(services): State, - Expect(Query(query)): Expect>, + Expect(Query(LoginQuery { next, reauthenticate })): Expect>, session_store: Session, user: User, PostForm(form): PostForm, ) -> Result { + let next = next.unwrap_or_default(); let user_id = user.into_session().map(|session| session.user_id); let body = match &user_id { @@ -66,8 +68,8 @@ async fn route_login( server_name: services.globals.server_name().to_string(), }, | Some(user_id) => { - if !query.reauthenticate { - return response!(Redirect::to(&query.next.target_path())); + if !reauthenticate { + return response!(Redirect::to(&next.target_path())); } let user_card = UserCard::for_local_user(&services, user_id.to_owned()).await; @@ -76,7 +78,7 @@ async fn route_login( }, }; - let mut template = Login::new(&services, body, None); + let mut template = Login::new(&services, body, next != LoginTarget::Account, None); if let Some(form) = form { let login_result = match (user_id, form.identifier) { @@ -86,8 +88,6 @@ async fn route_login( }, | (None, Some(identifier)) => { // The user isn't authenticated, we need to log them in - // Yes, this does parse the email twice (handle_login does it again). I don't - // think this really needs to be optimized. let identifier = if identifier.parse::().is_ok() { UserIdentifier::Email(EmailUserIdentifier::new(identifier)) } else { @@ -123,14 +123,14 @@ async fn route_login( .await .expect("should be able to serialize user session"); - return response!(Redirect::to(&query.next.target_path())); + return response!(Redirect::to(&next.target_path())); } response!(template) } -async fn get_logout(session: Session) -> impl IntoResponse { +async fn get_logout(session: Session, RawQuery(query): RawQuery) -> impl IntoResponse { let _ = session.remove::(User::KEY).await; - Redirect::to("/_continuwuity/account/") + Redirect::to(&format!("{}/account/login?{}", ROUTE_PREFIX, query.unwrap_or_default())) } diff --git a/src/web/pages/account/mod.rs b/src/web/pages/account/mod.rs index b670e83f4..e3c3ab079 100644 --- a/src/web/pages/account/mod.rs +++ b/src/web/pages/account/mod.rs @@ -61,12 +61,14 @@ async fn get_account( let user_card = UserCard::for_local_user(&services, user_id.clone()).await; - let devices = services + let mut devices: Vec<_> = services .users .all_device_ids(&user_id) .then(async |device_id| DeviceCard::for_device(&services, &user_id, device_id).await) .collect() .await; + devices.sort_unstable_by(|a, b| a.last_seen_ts.cmp(&b.last_seen_ts).reverse()); + response!(Account::new(&services, user_card, email_requirement, email, devices)) } diff --git a/src/web/pages/components/mod.rs b/src/web/pages/components/mod.rs index 0c72c31f7..ff336aa0c 100644 --- a/src/web/pages/components/mod.rs +++ b/src/web/pages/components/mod.rs @@ -4,36 +4,26 @@ use askama::{Template, filters::HtmlSafe}; use base64::Engine; use conduwuit_core::{result::FlatOk, utils}; use conduwuit_service::{Services, media::mxc::Mxc, oauth::client_metadata::ClientMetadata}; -use ruma::{OwnedDeviceId, OwnedUserId, UserId}; +use ruma::{MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedUserId, UserId}; pub(super) mod form; #[derive(Debug)] -pub(super) enum AvatarType<'a> { +pub(super) enum AvatarType { Initial(char), - Image(&'a str), + Image(String), } #[derive(Debug, Template)] #[template(path = "_components/avatar.html.j2")] -pub(super) struct Avatar<'a> { - pub(super) avatar_type: AvatarType<'a>, +pub(super) struct Avatar { + pub(super) avatar_type: AvatarType, } -impl HtmlSafe for Avatar<'_> {} +impl HtmlSafe for Avatar {} -#[derive(Debug, Template)] -#[template(path = "_components/user_card.html.j2")] -pub(super) struct UserCard { - pub user_id: OwnedUserId, - pub display_name: Option, - pub avatar_src: Option, -} - -impl HtmlSafe for UserCard {} - -impl UserCard { - pub(super) async fn for_local_user(services: &Services, user_id: OwnedUserId) -> Self { +impl Avatar { + pub(super) async fn for_local_user(services: &Services, user_id: &UserId) -> Self { let display_name = services.users.displayname(&user_id).await.ok(); let avatar_src = async { @@ -56,33 +46,48 @@ impl UserCard { } .await; - Self { user_id, display_name, avatar_src } - } - - fn avatar(&self) -> Avatar<'_> { - let avatar_type = if let Some(ref avatar_src) = self.avatar_src { + let avatar_type = if let Some(avatar_src) = avatar_src { AvatarType::Image(avatar_src) - } else if let Some(initial) = self - .display_name + } else if let Some(initial) = display_name .as_ref() .and_then(|display_name| display_name.chars().next()) { AvatarType::Initial(initial) } else { - AvatarType::Initial(self.user_id.localpart().chars().next().unwrap()) + AvatarType::Initial(user_id.localpart().chars().next().unwrap()) }; Avatar { avatar_type } } } +#[derive(Debug, Template)] +#[template(path = "_components/user_card.html.j2")] +pub(super) struct UserCard { + pub user_id: OwnedUserId, + pub display_name: Option, + pub avatar: Avatar, +} + +impl HtmlSafe for UserCard {} + +impl UserCard { + pub(super) async fn for_local_user(services: &Services, user_id: OwnedUserId) -> Self { + let display_name = services.users.displayname(&user_id).await.ok(); + let avatar = Avatar::for_local_user(services, &user_id).await; + + Self { user_id, display_name, avatar } + } +} + #[derive(Debug, Template)] #[template(path = "_components/device_card.html.j2")] pub(super) struct DeviceCard { pub device_id: OwnedDeviceId, pub display_name: Option, - pub avatar_src: Option, + pub avatar: Avatar, pub last_active: String, + pub last_seen_ts: Option, pub oauth_metadata: Option, } @@ -101,12 +106,15 @@ impl DeviceCard { .ok(); let oauth_metadata = async { - let client_id = services.oauth.get_client_id_for_device(&device_id).await?; + let client_id = services + .oauth + .get_client_id_for_device(user_id, &device_id) + .await?; Some( services .oauth - .get_client_registration(&client_id) + .get_client_metadata(&client_id) .await .expect("client should exist"), ) @@ -122,53 +130,51 @@ impl DeviceCard { .and_then(|device| device.display_name.clone()) }); - let avatar_src = oauth_metadata - .as_ref() - .and_then(|metadata| metadata.logo_uri.as_ref()) - .map(|uri| uri.as_str().to_owned()); + let avatar = { + let avatar_src = oauth_metadata + .as_ref() + .and_then(|metadata| metadata.logo_uri.as_ref()) + .map(|uri| uri.as_str().to_owned()); - let last_active = device - .as_ref() - .and_then(|device| device.last_seen_ts) - .map_or_else( - || "unknown".to_owned(), - |active| { - active - .to_system_time() - .and_then(|t| SystemTime::now().duration_since(t).ok()) - .map_or_else( - || "now".to_owned(), - |duration| format!("{} ago", utils::time::pretty(duration)), - ) - }, - ); + let avatar_type = if let Some(avatar_src) = avatar_src { + AvatarType::Image(avatar_src) + } else if let Some(initial) = + display_name.as_ref().and_then(|name| name.chars().next()) + { + if oauth_metadata.is_some() { + AvatarType::Initial(initial) + } else { + AvatarType::Initial('❖') + } + } else { + AvatarType::Initial('?') + }; + + Avatar { avatar_type } + }; + + let last_seen_ts = device.as_ref().and_then(|device| device.last_seen_ts); + + let last_active = last_seen_ts.map_or_else( + || "unknown".to_owned(), + |last_seen_ts| { + last_seen_ts + .to_system_time() + .and_then(|t| SystemTime::now().duration_since(t).ok()) + .map_or_else( + || "now".to_owned(), + |duration| format!("{} ago", utils::time::pretty(duration)), + ) + }, + ); Self { device_id, display_name, - avatar_src, + avatar, last_active, + last_seen_ts: last_seen_ts.map(|last_seen_ts| last_seen_ts.as_secs().into()), oauth_metadata, } } - - fn avatar(&self) -> Avatar<'_> { - let avatar_type = if let Some(avatar_src) = &self.avatar_src { - AvatarType::Image(avatar_src.as_str()) - } else if let Some(initial) = self - .display_name - .as_ref() - .and_then(|name| name.chars().next()) - { - if self.oauth_metadata.is_some() { - AvatarType::Initial(initial) - } else { - AvatarType::Initial('❖') - } - } else { - AvatarType::Initial('?') - }; - - Avatar { avatar_type } - } } diff --git a/src/web/pages/mod.rs b/src/web/pages/mod.rs index fe154b8ab..39bed84be 100644 --- a/src/web/pages/mod.rs +++ b/src/web/pages/mod.rs @@ -6,6 +6,7 @@ pub(super) mod account; mod components; pub(super) mod debug; pub(super) mod index; +pub(super) mod oauth; pub(super) mod resources; pub(super) mod threepid; diff --git a/src/web/pages/oauth/grant.rs b/src/web/pages/oauth/grant.rs new file mode 100644 index 000000000..c45ea5435 --- /dev/null +++ b/src/web/pages/oauth/grant.rs @@ -0,0 +1,113 @@ +use std::collections::BTreeSet; + +use axum::{ + Router, + extract::{Query, State}, + response::{IntoResponse, Redirect}, + routing::on, +}; +use conduwuit_service::{ + oauth::{ + client_metadata::{self, ClientMetadata}, + grant::{AuthorizationCodeQuery, Scope}, + }, + rooms::user, +}; +use ruma::{OwnedDeviceId, OwnedUserId}; +use serde::Deserialize; +use url::Url; + +use crate::{ + WebError, + extract::{Expect, PostForm}, + pages::{ + GET_POST, Result, + components::{Avatar, AvatarType}, + }, + response, + session::{LoginQuery, LoginTarget, User}, + template, +}; + +pub(crate) fn build() -> Router { + Router::new().route("/authorization_code", on(GET_POST, route_authorization_code)) +} + +template! { + struct Grant use "grant.html.j2" { + logout_query: String, + user_id: OwnedUserId, + user_avatar: Avatar, + client_uri: Url, + client_name: String, + client_avatar: Avatar, + policy_uri: Option, + tos_uri: Option, + scopes: BTreeSet + } +} + +async fn route_authorization_code( + State(services): State, + user: User, + Expect(Query(query)): Expect>, + PostForm(form): PostForm<()>, +) -> Result { + let user_id = user.expect(LoginTarget::AuthorizationCode(query.clone()))?; + + if form.is_some() { + let redirect_uri = services + .oauth + .request_authorization_code(user_id, query) + .await + .map_err(WebError::BadRequest)?; + + return response!(Redirect::to(&redirect_uri)); + } + + let Some(client) = services.oauth.get_client_metadata(&query.client_id).await else { + return Err(WebError::BadRequest("Invalid client ID".to_owned())); + }; + + let scopes = query.scope.to_scopes().map_err(WebError::BadRequest)?; + + let client_name = if let Some(name) = &client.client_name { + name + } else { + "Unknown application" + } + .to_owned(); + + let client_avatar = { + let avatar_type = if let Some(logo) = &client.logo_uri { + AvatarType::Image(logo.to_string()) + } else if let Some(name) = &client.client_name + && let Some(char) = name.chars().next() + { + AvatarType::Initial(char) + } else { + AvatarType::Initial('?') + }; + + Avatar { avatar_type } + }; + + let user_avatar = Avatar::for_local_user(&services, &user_id).await; + + response!(Grant::new( + &services, + serde_urlencoded::to_string(LoginQuery { + next: Some(LoginTarget::AuthorizationCode(query)), + reauthenticate: false, + }) + .unwrap(), + user_id, + user_avatar, + client.client_uri.clone(), + client_name, + client_avatar, + client.policy_uri.clone(), + client.tos_uri.clone(), + scopes, + )) +} diff --git a/src/web/pages/oauth/mod.rs b/src/web/pages/oauth/mod.rs new file mode 100644 index 000000000..cc9b48f4f --- /dev/null +++ b/src/web/pages/oauth/mod.rs @@ -0,0 +1,10 @@ +use axum::Router; + +mod grant; + +pub(crate) fn build() -> Router { + #[allow(clippy::wildcard_imports)] + use self::*; + + Router::new().nest("/grant/", grant::build()) +} diff --git a/src/web/pages/resources/common.css b/src/web/pages/resources/common.css index 51bd59bcf..75c5432d1 100644 --- a/src/web/pages/resources/common.css +++ b/src/web/pages/resources/common.css @@ -123,8 +123,9 @@ small.error { .panel { --preferred-width: 12rem + 40dvw; --maximum-width: 48rem; + --minimum-width: 32rem; - width: min(clamp(24rem, var(--preferred-width), var(--maximum-width)), calc(100dvw - 3rem)); + width: min(clamp(var(--minimum-width), var(--preferred-width), var(--maximum-width)), calc(100dvw - 3rem)); border-radius: var(--border-radius-lg); background-color: var(--panel-bg); padding-inline: 1.5rem; @@ -184,6 +185,10 @@ a, a:visited { color: oklch(from var(--c1) var(--name-lightness) c h); } +code { + color: oklch(from var(--secondary) var(--name-lightness) c h); +} + input, button, a.button { display: inline-block; padding: 0.5em; diff --git a/src/web/pages/resources/components.css b/src/web/pages/resources/components.css index 468eeb6d6..a4cf3b649 100644 --- a/src/web/pages/resources/components.css +++ b/src/web/pages/resources/components.css @@ -17,7 +17,7 @@ background-color: var(--avatar-color); } -.green-avatar { +.red-avatar { --avatar-color: var(--c1); } diff --git a/src/web/pages/resources/grant.css b/src/web/pages/resources/grant.css new file mode 100644 index 000000000..6848756d1 --- /dev/null +++ b/src/web/pages/resources/grant.css @@ -0,0 +1,22 @@ +.avatars { + justify-content: center; + display: flex; + flex-direction: row; + + .separator { + align-self: center; + margin-inline: 1em; + color: var(--secondary); + font-size: x-large; + font-weight: bold; + user-select: none; + } +} + +.identity { + margin-block: 1em; + color: var(--secondary); + font-size: small; + font-style: italic; + text-align: center; +} diff --git a/src/web/pages/templates/_components/device_card.html.j2 b/src/web/pages/templates/_components/device_card.html.j2 index bae091039..7de1e8e2f 100644 --- a/src/web/pages/templates/_components/device_card.html.j2 +++ b/src/web/pages/templates/_components/device_card.html.j2 @@ -1,19 +1,25 @@
- {{ avatar() }} + {{ avatar }}

{% if let Some(display_name) = display_name %} - {{ display_name }} + {% if let Some(metadata) = oauth_metadata %} + {{ display_name }} + {% else %} + {{ display_name }} + {% endif %} {% else %} Unknown device {% endif %} -  {{ device_id }} + + • {{ device_id }} + {% if oauth_metadata.is_none() %} + (legacy) + {% endif %} +

Last active: {{ last_active }} - {% if let Some(metadata) = oauth_metadata %} -  • Client information - {% endif %}

diff --git a/src/web/pages/templates/_components/user_card.html.j2 b/src/web/pages/templates/_components/user_card.html.j2 index 1dbde8687..4d65685f9 100644 --- a/src/web/pages/templates/_components/user_card.html.j2 +++ b/src/web/pages/templates/_components/user_card.html.j2 @@ -1,5 +1,5 @@ -
- {{ avatar() }} +
+ {{ avatar }}
{% if let Some(display_name) = display_name %}

{{ display_name }}

diff --git a/src/web/pages/templates/grant.html.j2 b/src/web/pages/templates/grant.html.j2 new file mode 100644 index 000000000..aee111886 --- /dev/null +++ b/src/web/pages/templates/grant.html.j2 @@ -0,0 +1,64 @@ +{% extends "_layout.html.j2" %} + +{%- block head -%} + +{%- endblock -%} + +{%- block title -%} +Authorize client +{%- endblock -%} + +{%- block content -%} +
+

Authorize {{ client_name }}

+
+
+ {{ user_avatar }} +
+
+ ⇄ +
+ {{ client_avatar }} +
+
+ Signed in as {{ user_id }}. Switch accounts +
+

+ {{ client_name }} ({{ client_uri.domain().unwrap() }}) would like + your permission to: +

    + {% for scope in scopes %} + {% match scope %} + {% when Scope::ClientApi %} +
  • Interact with Matrix on your behalf
  • + {% when Scope::Device(_) %} +
  • Connect to your Matrix account
  • + {% endmatch %} + {% endfor %} +
+

+ {% match (&policy_uri, &tos_uri) %} + {% when (Some(policy_uri), Some(tos_uri)) %} +

+ {{ client_name }}'s policies + and terms of service apply. +

+ {% when (Some(policy_uri), None) %} +

+ {{ client_name }}'s policies apply. +

+ {% when (None, Some(tos_uri)) %} +

+ {{ client_name }}'s terms of service apply. +

+ {% when (None, None) %} +

+ Make sure you trust {{ client_name }} with access to your data. +

+ {% endmatch %} + +
+ +
+
+{%- endblock -%} diff --git a/src/web/pages/templates/login.html.j2 b/src/web/pages/templates/login.html.j2 index b3d3e21bd..4242dcbdb 100644 --- a/src/web/pages/templates/login.html.j2 +++ b/src/web/pages/templates/login.html.j2 @@ -13,7 +13,11 @@ Log in {% match body %} {% when LoginBody::Unauthenticated { server_name } %}

- Log in to Matrix + {% if has_next %} + Log in to continue + {% else %} + Log in to Matrix + {% endif %} Matrix logo diff --git a/src/web/session/mod.rs b/src/web/session/mod.rs index 1f8176148..21833fc06 100644 --- a/src/web/session/mod.rs +++ b/src/web/session/mod.rs @@ -1,8 +1,14 @@ -use std::time::{Duration, SystemTime}; +use std::{ + borrow::Cow, + collections::HashMap, + mem::discriminant, + time::{Duration, SystemTime}, +}; use axum::{extract::FromRequestParts, http::request::Parts}; +use conduwuit_service::oauth::grant::AuthorizationCodeQuery; use ruma::{OwnedUserId, UserId}; -use serde::{Deserialize, Serialize}; +use serde::{Deserialize, Deserializer, Serialize}; use tower_sessions::Session; use crate::{ROUTE_PREFIX, WebError}; @@ -12,7 +18,7 @@ pub(crate) mod store; #[derive(Debug, Deserialize, Serialize)] pub(crate) struct LoginQuery { #[serde(flatten)] - pub next: LoginTarget, + pub next: Option, #[serde(default, skip_serializing_if = "std::ops::Not::not")] pub reauthenticate: bool, } @@ -20,6 +26,7 @@ pub(crate) struct LoginQuery { #[derive(Debug, Default, Deserialize, Serialize)] #[serde(tag = "next", rename_all = "snake_case")] pub(crate) enum LoginTarget { + AuthorizationCode(AuthorizationCodeQuery), #[default] Account, ChangePassword, @@ -28,14 +35,23 @@ pub(crate) enum LoginTarget { Deactivate, } +impl PartialEq for LoginTarget { + fn eq(&self, other: &Self) -> bool { discriminant(self) == discriminant(other) } +} + impl LoginTarget { pub(crate) fn target_path(&self) -> String { - let path = match self { - | Self::Account => "account/", - | Self::ChangePassword => "account/password/change", - | Self::ChangeEmail => "account/email/change/", - | Self::CrossSigningReset => "account/cross_signing_reset", - | Self::Deactivate => "account/deactivate", + let path: Cow<'_, str> = match self { + | Self::AuthorizationCode(code) => format!( + "oauth2/grant/authorization_code?{}", + serde_urlencoded::to_string(code).unwrap() + ) + .into(), + | Self::Account => "account/".into(), + | Self::ChangePassword => "account/password/change".into(), + | Self::ChangeEmail => "account/email/change/".into(), + | Self::CrossSigningReset => "account/cross_signing_reset".into(), + | Self::Deactivate => "account/deactivate".into(), }; format!("{ROUTE_PREFIX}/{path}") @@ -80,7 +96,10 @@ impl User { if let Some(session) = self.0 { Ok(session.user_id) } else { - Err(WebError::LoginRequired(LoginQuery { next: or_else, reauthenticate: false })) + Err(WebError::LoginRequired(LoginQuery { + next: Some(or_else), + reauthenticate: false, + })) } } @@ -91,10 +110,16 @@ impl User { if session.is_recent() { Ok(session.user_id) } else { - Err(WebError::LoginRequired(LoginQuery { next: or_else, reauthenticate: true })) + Err(WebError::LoginRequired(LoginQuery { + next: Some(or_else), + reauthenticate: true, + })) } } else { - Err(WebError::LoginRequired(LoginQuery { next: or_else, reauthenticate: false })) + Err(WebError::LoginRequired(LoginQuery { + next: Some(or_else), + reauthenticate: false, + })) } } }