From 02948960fa6c93d0a319225e9322238dccbd2a98 Mon Sep 17 00:00:00 2001 From: Ginger Date: Tue, 14 Apr 2026 19:14:27 -0400 Subject: [PATCH] feat: Implement oauth service and client registration --- src/api/client/mod.rs | 2 + src/api/client/oauth/mod.rs | 17 +++ src/api/client/oauth/register_client.rs | 28 ++++ src/api/client/oauth/server_metadata.rs | 34 +++++ src/api/router.rs | 2 + src/database/maps.rs | 4 + src/service/mod.rs | 1 + src/service/oauth/client_metadata.rs | 194 ++++++++++++++++++++++++ src/service/oauth/mod.rs | 72 +++++++++ src/service/services.rs | 6 +- 10 files changed, 358 insertions(+), 2 deletions(-) create mode 100644 src/api/client/oauth/mod.rs create mode 100644 src/api/client/oauth/register_client.rs create mode 100644 src/api/client/oauth/server_metadata.rs create mode 100644 src/service/oauth/client_metadata.rs create mode 100644 src/service/oauth/mod.rs diff --git a/src/api/client/mod.rs b/src/api/client/mod.rs index bc4c0413a..01f811ebc 100644 --- a/src/api/client/mod.rs +++ b/src/api/client/mod.rs @@ -16,6 +16,7 @@ pub(super) mod media_legacy; pub(super) mod membership; pub(super) mod message; pub(super) mod mutual_rooms; +pub(super) mod oauth; pub(super) mod openid; pub(super) mod presence; pub(super) mod profile; @@ -61,6 +62,7 @@ pub(super) use membership::*; pub use membership::{leave_all_rooms, leave_room, remote_leave_room}; pub(super) use message::*; pub(super) use mutual_rooms::*; +pub(super) use oauth::*; pub(super) use openid::*; pub(super) use presence::*; pub(super) use profile::*; diff --git a/src/api/client/oauth/mod.rs b/src/api/client/oauth/mod.rs new file mode 100644 index 000000000..9c564e6d3 --- /dev/null +++ b/src/api/client/oauth/mod.rs @@ -0,0 +1,17 @@ +mod register_client; +mod server_metadata; + +use axum::{ + Json, Router, + routing::method_routing::{get, post}, +}; +use serde_json::json; +pub(crate) use server_metadata::*; + +pub(crate) const BASE_PATH: &str = "/_continuwuity/oauth2/"; + +pub(crate) fn router() -> Router { + Router::new() + .route("/client/register", post(register_client::register_client_route)) + .route("/client/keys.json", get(async || Json(json!({"keys": []})))) +} diff --git a/src/api/client/oauth/register_client.rs b/src/api/client/oauth/register_client.rs new file mode 100644 index 000000000..3d43edcae --- /dev/null +++ b/src/api/client/oauth/register_client.rs @@ -0,0 +1,28 @@ +use axum::{ + Json, + extract::State, + response::{IntoResponse, Response}, +}; +use http::StatusCode; +use serde::Serialize; +use service::oauth::client_metadata::ClientMetadata; + +#[derive(Serialize)] +struct RegisteredClient { + client_id: String, + #[serde(flatten)] + metadata: ClientMetadata, +} + +pub(crate) async fn register_client_route( + State(services): State, + Json(metadata): Json, +) -> Result { + let client_id = services + .oauth + .register_client(&metadata) + .await + .map_err(|err| (StatusCode::BAD_REQUEST, err.to_owned()).into_response())?; + + Ok(Json(RegisteredClient { client_id, metadata }).into_response()) +} diff --git a/src/api/client/oauth/server_metadata.rs b/src/api/client/oauth/server_metadata.rs new file mode 100644 index 000000000..e292a4377 --- /dev/null +++ b/src/api/client/oauth/server_metadata.rs @@ -0,0 +1,34 @@ +use axum::extract::State; +use conduwuit::Result; +use ruma::{api::client::discovery::get_authorization_server_metadata, serde::Raw}; +use serde_json::json; + +use crate::Ruma; + +pub(crate) async fn get_authorization_server_metadata_route( + State(services): State, + _body: Ruma, +) -> Result { + let endpoint_base = services + .config + .get_client_domain() + .join(super::BASE_PATH) + .unwrap(); + + let metadata = Raw::new(&json!({ + "authorization_endpoint": endpoint_base.join("grant/authorization_code").unwrap(), + "code_challenge_methods_supported": ["S256"], + "grant_types_supported": ["authorization_code", "refresh_token"], + "issuer": services.config.get_client_domain(), + "jwks_uri": endpoint_base.join("client/keys.json").unwrap(), + "prompt_values_supported": ["create"], + "registration_endpoint": endpoint_base.join("client/register").unwrap(), + "response_modes_supported": ["query", "fragment"], + "response_types_supported": ["code"], + "revocation_endpoint": endpoint_base.join("client/revoke").unwrap(), + "token_endpoint": endpoint_base.join("grant/token").unwrap(), + })) + .unwrap(); + + Ok(get_authorization_server_metadata::v1::Response::new(metadata.cast_unchecked())) +} diff --git a/src/api/router.rs b/src/api/router.rs index 45da9e0c7..2b6cb7a01 100644 --- a/src/api/router.rs +++ b/src/api/router.rs @@ -186,6 +186,8 @@ pub fn build(router: Router, server: &Server) -> Router { .ruma_route(&client::well_known_policy_server) .ruma_route(&client::get_rtc_transports) .ruma_route(&client::room_initial_sync_route) + .ruma_route(&client::get_authorization_server_metadata_route) + .nest(client::oauth::BASE_PATH, client::oauth::router()) .route("/_conduwuit/server_version", get(client::conduwuit_server_version)) .route("/_continuwuity/server_version", get(client::conduwuit_server_version)) .ruma_route(&admin::rooms::ban::ban_room) diff --git a/src/database/maps.rs b/src/database/maps.rs index e8eb02331..955b30e2e 100644 --- a/src/database/maps.rs +++ b/src/database/maps.rs @@ -49,6 +49,10 @@ pub(super) static MAPS: &[Descriptor] = &[ name: "bannedroomids", ..descriptor::RANDOM_SMALL }, + Descriptor { + name: "clientid_clientmetadata", + ..descriptor::RANDOM_SMALL + }, Descriptor { name: "disabledroomids", ..descriptor::RANDOM_SMALL diff --git a/src/service/mod.rs b/src/service/mod.rs index 964916109..674cb399d 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -27,6 +27,7 @@ pub mod key_backups; pub mod mailer; pub mod media; pub mod moderation; +pub mod oauth; pub mod password_reset; pub mod presence; pub mod pusher; diff --git a/src/service/oauth/client_metadata.rs b/src/service/oauth/client_metadata.rs new file mode 100644 index 000000000..4ec03e778 --- /dev/null +++ b/src/service/oauth/client_metadata.rs @@ -0,0 +1,194 @@ +use std::{collections::BTreeSet, hash::Hash}; + +use itertools::Itertools; +use serde::{Deserialize, Deserializer, Serialize}; +use url::Url; + +#[derive(Clone, Debug, PartialEq, Eq, Deserialize, Serialize)] +pub struct ClientMetadata { + #[serde(default)] + application_type: ApplicationType, + + #[serde(default, skip_serializing_if = "Option::is_none")] + client_name: Option, + + client_uri: Url, + + #[serde(default, deserialize_with = "btreeset_skip_err")] + grant_types: BTreeSet, + + #[serde(default, skip_serializing_if = "Option::is_none")] + logo_uri: Option, + + #[serde(default, skip_serializing_if = "Option::is_none")] + policy_uri: Option, + + #[serde(default)] + redirect_uris: Vec, + + #[serde(default, deserialize_with = "btreeset_skip_err")] + response_types: BTreeSet, + + #[serde(default, skip_serializing_if = "Option::is_none")] + token_endpoint_auth_method: Option, + + #[serde(default, skip_serializing_if = "Option::is_none")] + tos_uri: Option, +} + +impl ClientMetadata { + const ACCEPTABLE_LOCALHOSTS: [&str; 3] = ["localhost", "127.0.0.1", "[::1]"]; + + pub(super) fn validate(&self) -> Result<(), &'static str> { + let Some(client_domain) = self.client_uri.domain() else { + return Err("Client URI must have a domain."); + }; + + if self.client_uri.scheme() != "https" { + return Err("Client URI must be HTTPS."); + } + + if !self.client_uri.username().is_empty() || self.client_uri.password().is_some() { + return Err("Client URI must not include credentials."); + } + + for uri in [&self.logo_uri, &self.policy_uri, &self.tos_uri] + .iter() + .filter_map(|uri| uri.as_ref()) + { + if uri.scheme() != "https" { + return Err("All metadata URIs must be HTTPS."); + } + + if !uri.username().is_empty() || uri.password().is_some() { + return Err("All metadata URIs must not include credentials."); + } + + if !uri + .domain() + .is_some_and(|domain| is_subdomain(domain, client_domain)) + { + return Err("All metadata URIs must be subdomains of the client URI."); + } + } + + for uri in &self.redirect_uris { + match uri.scheme() { + | "https" => { + // HTTPS URIs are okay for native and web clients + + if !uri.username().is_empty() || uri.password().is_some() { + return Err("HTTPS redirect URIs must not contain credentials."); + } + }, + | "http" if self.application_type == ApplicationType::Native => { + if uri + .host_str() + .is_none_or(|host| !Self::ACCEPTABLE_LOCALHOSTS.contains(&host)) + { + return Err("HTTP redirect URIs for native applications must only \ + refer to localhost."); + } + + if uri.port().is_some() { + return Err("HTTP redirect URIs for native applications do not need to \ + specify a port. All ports will be accepted during \ + authorization."); + } + }, + | private_scheme if self.application_type == ApplicationType::Native => { + let rdns_client_uri = client_domain.split('.').rev().join("."); + + if !private_scheme.starts_with(&rdns_client_uri) { + return Err("Private-use scheme URIs for native applications must \ + begin with the application's client URI domain in \ + reverse-DNS notation."); + } + + if uri.has_authority() { + return Err("Private-use scheme URIs for native applications must not \ + have an authority."); + } + }, + | _ => + return Err("A redirect URI's scheme is not valid for this application type."), + } + } + + Ok(()) + } +} + +#[derive(Clone, Debug, Default, PartialEq, Eq, Deserialize, Serialize)] +#[serde(rename_all = "snake_case")] +pub enum ApplicationType { + #[default] + Web, + Native, +} + +#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Deserialize, Serialize)] +#[serde(rename_all = "snake_case")] +pub enum GrantType { + AuthorizationCode, + RefreshToken, +} + +#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Deserialize, Serialize)] +#[serde(rename_all = "snake_case")] +pub enum ResponseType { + Code, +} + +/// Deserialize a BTreeSet from a sequence, skipping items which fail to +/// deserialize. This is used as a deserialize helper for ClientMetadata to +/// ignore unknown enum variants in a few fields. +fn btreeset_skip_err<'de, D, V>(de: D) -> Result, D::Error> +where + D: Deserializer<'de>, + V: Deserialize<'de> + Hash + Eq + Ord, +{ + use std::marker::PhantomData; + + use serde::de::{SeqAccess, Visitor}; + + struct BTreeSetVisitor { + item: PhantomData, + } + + impl<'de, V> Visitor<'de> for BTreeSetVisitor + where + V: Deserialize<'de> + Hash + Eq + Ord, + { + type Value = BTreeSet; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "a sequence") + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: SeqAccess<'de>, + { + let mut set = BTreeSet::new(); + + while let Some(element) = seq.next_element().transpose() { + if let Ok(element) = element { + set.insert(element); + } + } + + Ok(set) + } + } + + de.deserialize_seq(BTreeSetVisitor { item: PhantomData }) +} + +fn is_subdomain(subdomain: &str, domain: &str) -> bool { + if subdomain == domain { + return true; + } + + subdomain.ends_with(&format!(".{domain}")) +} diff --git a/src/service/oauth/mod.rs b/src/service/oauth/mod.rs new file mode 100644 index 000000000..a3a9d3fd0 --- /dev/null +++ b/src/service/oauth/mod.rs @@ -0,0 +1,72 @@ +use std::sync::Arc; + +use base64::Engine; +use conduwuit::{Result, utils::hash::sha256}; +use database::{Deserialized, Json, Map}; + +use crate::{Dep, config, oauth::client_metadata::ClientMetadata}; + +pub mod client_metadata; + +pub struct Service { + services: Services, + db: Data, +} + +struct Data { + clientid_clientmetadata: Arc, +} + +struct Services { + config: Dep, +} + +impl crate::Service for Service { + fn build(args: crate::Args<'_>) -> Result> { + Ok(Arc::new(Self { + services: Services { + config: args.depend::("config"), + }, + db: Data { + clientid_clientmetadata: args.db["clientid_clientmetadata"].clone(), + }, + })) + } + + fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } +} + +impl Service { + pub async fn register_client( + &self, + metadata: &ClientMetadata, + ) -> Result { + metadata.validate()?; + + let client_id = base64::prelude::BASE64_STANDARD + .encode(sha256::hash(serde_json::to_string(metadata).unwrap().as_bytes())); + + if self + .db + .clientid_clientmetadata + .exists(&client_id) + .await + .is_err() + { + self.db + .clientid_clientmetadata + .raw_put(&client_id, Json(metadata.clone())); + } + + Ok(client_id) + } + + async fn get_client_registration(&self, client_id: &str) -> Option { + self.db + .clientid_clientmetadata + .get(client_id) + .await + .deserialized() + .ok() + } +} diff --git a/src/service/services.rs b/src/service/services.rs index 7e9470cf0..a1bca1c71 100644 --- a/src/service/services.rs +++ b/src/service/services.rs @@ -11,8 +11,8 @@ use crate::{ account_data, admin, announcements, antispam, appservice, client, config, emergency, federation, firstrun, globals, key_backups, mailer, manager::Manager, - media, moderation, password_reset, presence, pusher, registration_tokens, resolver, rooms, - sending, server_keys, + media, moderation, oauth, password_reset, presence, pusher, registration_tokens, resolver, + rooms, sending, server_keys, service::{self, Args, Map, Service}, sync, threepid, transactions, uiaa, users, }; @@ -27,6 +27,7 @@ pub struct Services { pub globals: Arc, pub key_backups: Arc, pub media: Arc, + pub oauth: Arc, pub password_reset: Arc, pub mailer: Arc, pub presence: Arc, @@ -84,6 +85,7 @@ impl Services { globals: build!(globals::Service), key_backups: build!(key_backups::Service), media: build!(media::Service), + oauth: build!(oauth::Service), password_reset: build!(password_reset::Service), mailer: build!(mailer::Service), presence: build!(presence::Service),