feat: Implement oauth service and client registration

This commit is contained in:
Ginger
2026-04-14 19:14:27 -04:00
parent 30c9d6d2df
commit 02948960fa
10 changed files with 358 additions and 2 deletions
+2
View File
@@ -16,6 +16,7 @@ pub(super) mod media_legacy;
pub(super) mod membership; pub(super) mod membership;
pub(super) mod message; pub(super) mod message;
pub(super) mod mutual_rooms; pub(super) mod mutual_rooms;
pub(super) mod oauth;
pub(super) mod openid; pub(super) mod openid;
pub(super) mod presence; pub(super) mod presence;
pub(super) mod profile; pub(super) mod profile;
@@ -61,6 +62,7 @@ pub(super) use membership::*;
pub use membership::{leave_all_rooms, leave_room, remote_leave_room}; pub use membership::{leave_all_rooms, leave_room, remote_leave_room};
pub(super) use message::*; pub(super) use message::*;
pub(super) use mutual_rooms::*; pub(super) use mutual_rooms::*;
pub(super) use oauth::*;
pub(super) use openid::*; pub(super) use openid::*;
pub(super) use presence::*; pub(super) use presence::*;
pub(super) use profile::*; pub(super) use profile::*;
+17
View File
@@ -0,0 +1,17 @@
mod register_client;
mod server_metadata;
use axum::{
Json, Router,
routing::method_routing::{get, post},
};
use serde_json::json;
pub(crate) use server_metadata::*;
pub(crate) const BASE_PATH: &str = "/_continuwuity/oauth2/";
pub(crate) fn router() -> Router<crate::State> {
Router::new()
.route("/client/register", post(register_client::register_client_route))
.route("/client/keys.json", get(async || Json(json!({"keys": []}))))
}
+28
View File
@@ -0,0 +1,28 @@
use axum::{
Json,
extract::State,
response::{IntoResponse, Response},
};
use http::StatusCode;
use serde::Serialize;
use service::oauth::client_metadata::ClientMetadata;
#[derive(Serialize)]
struct RegisteredClient {
client_id: String,
#[serde(flatten)]
metadata: ClientMetadata,
}
pub(crate) async fn register_client_route(
State(services): State<crate::State>,
Json(metadata): Json<ClientMetadata>,
) -> Result<Response, Response> {
let client_id = services
.oauth
.register_client(&metadata)
.await
.map_err(|err| (StatusCode::BAD_REQUEST, err.to_owned()).into_response())?;
Ok(Json(RegisteredClient { client_id, metadata }).into_response())
}
+34
View File
@@ -0,0 +1,34 @@
use axum::extract::State;
use conduwuit::Result;
use ruma::{api::client::discovery::get_authorization_server_metadata, serde::Raw};
use serde_json::json;
use crate::Ruma;
pub(crate) async fn get_authorization_server_metadata_route(
State(services): State<crate::State>,
_body: Ruma<get_authorization_server_metadata::v1::Request>,
) -> Result<get_authorization_server_metadata::v1::Response> {
let endpoint_base = services
.config
.get_client_domain()
.join(super::BASE_PATH)
.unwrap();
let metadata = Raw::new(&json!({
"authorization_endpoint": endpoint_base.join("grant/authorization_code").unwrap(),
"code_challenge_methods_supported": ["S256"],
"grant_types_supported": ["authorization_code", "refresh_token"],
"issuer": services.config.get_client_domain(),
"jwks_uri": endpoint_base.join("client/keys.json").unwrap(),
"prompt_values_supported": ["create"],
"registration_endpoint": endpoint_base.join("client/register").unwrap(),
"response_modes_supported": ["query", "fragment"],
"response_types_supported": ["code"],
"revocation_endpoint": endpoint_base.join("client/revoke").unwrap(),
"token_endpoint": endpoint_base.join("grant/token").unwrap(),
}))
.unwrap();
Ok(get_authorization_server_metadata::v1::Response::new(metadata.cast_unchecked()))
}
+2
View File
@@ -186,6 +186,8 @@ pub fn build(router: Router<State>, server: &Server) -> Router<State> {
.ruma_route(&client::well_known_policy_server) .ruma_route(&client::well_known_policy_server)
.ruma_route(&client::get_rtc_transports) .ruma_route(&client::get_rtc_transports)
.ruma_route(&client::room_initial_sync_route) .ruma_route(&client::room_initial_sync_route)
.ruma_route(&client::get_authorization_server_metadata_route)
.nest(client::oauth::BASE_PATH, client::oauth::router())
.route("/_conduwuit/server_version", get(client::conduwuit_server_version)) .route("/_conduwuit/server_version", get(client::conduwuit_server_version))
.route("/_continuwuity/server_version", get(client::conduwuit_server_version)) .route("/_continuwuity/server_version", get(client::conduwuit_server_version))
.ruma_route(&admin::rooms::ban::ban_room) .ruma_route(&admin::rooms::ban::ban_room)
+4
View File
@@ -49,6 +49,10 @@ pub(super) static MAPS: &[Descriptor] = &[
name: "bannedroomids", name: "bannedroomids",
..descriptor::RANDOM_SMALL ..descriptor::RANDOM_SMALL
}, },
Descriptor {
name: "clientid_clientmetadata",
..descriptor::RANDOM_SMALL
},
Descriptor { Descriptor {
name: "disabledroomids", name: "disabledroomids",
..descriptor::RANDOM_SMALL ..descriptor::RANDOM_SMALL
+1
View File
@@ -27,6 +27,7 @@ pub mod key_backups;
pub mod mailer; pub mod mailer;
pub mod media; pub mod media;
pub mod moderation; pub mod moderation;
pub mod oauth;
pub mod password_reset; pub mod password_reset;
pub mod presence; pub mod presence;
pub mod pusher; pub mod pusher;
+194
View File
@@ -0,0 +1,194 @@
use std::{collections::BTreeSet, hash::Hash};
use itertools::Itertools;
use serde::{Deserialize, Deserializer, Serialize};
use url::Url;
#[derive(Clone, Debug, PartialEq, Eq, Deserialize, Serialize)]
pub struct ClientMetadata {
#[serde(default)]
application_type: ApplicationType,
#[serde(default, skip_serializing_if = "Option::is_none")]
client_name: Option<String>,
client_uri: Url,
#[serde(default, deserialize_with = "btreeset_skip_err")]
grant_types: BTreeSet<GrantType>,
#[serde(default, skip_serializing_if = "Option::is_none")]
logo_uri: Option<Url>,
#[serde(default, skip_serializing_if = "Option::is_none")]
policy_uri: Option<Url>,
#[serde(default)]
redirect_uris: Vec<Url>,
#[serde(default, deserialize_with = "btreeset_skip_err")]
response_types: BTreeSet<ResponseType>,
#[serde(default, skip_serializing_if = "Option::is_none")]
token_endpoint_auth_method: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
tos_uri: Option<Url>,
}
impl ClientMetadata {
const ACCEPTABLE_LOCALHOSTS: [&str; 3] = ["localhost", "127.0.0.1", "[::1]"];
pub(super) fn validate(&self) -> Result<(), &'static str> {
let Some(client_domain) = self.client_uri.domain() else {
return Err("Client URI must have a domain.");
};
if self.client_uri.scheme() != "https" {
return Err("Client URI must be HTTPS.");
}
if !self.client_uri.username().is_empty() || self.client_uri.password().is_some() {
return Err("Client URI must not include credentials.");
}
for uri in [&self.logo_uri, &self.policy_uri, &self.tos_uri]
.iter()
.filter_map(|uri| uri.as_ref())
{
if uri.scheme() != "https" {
return Err("All metadata URIs must be HTTPS.");
}
if !uri.username().is_empty() || uri.password().is_some() {
return Err("All metadata URIs must not include credentials.");
}
if !uri
.domain()
.is_some_and(|domain| is_subdomain(domain, client_domain))
{
return Err("All metadata URIs must be subdomains of the client URI.");
}
}
for uri in &self.redirect_uris {
match uri.scheme() {
| "https" => {
// HTTPS URIs are okay for native and web clients
if !uri.username().is_empty() || uri.password().is_some() {
return Err("HTTPS redirect URIs must not contain credentials.");
}
},
| "http" if self.application_type == ApplicationType::Native => {
if uri
.host_str()
.is_none_or(|host| !Self::ACCEPTABLE_LOCALHOSTS.contains(&host))
{
return Err("HTTP redirect URIs for native applications must only \
refer to localhost.");
}
if uri.port().is_some() {
return Err("HTTP redirect URIs for native applications do not need to \
specify a port. All ports will be accepted during \
authorization.");
}
},
| private_scheme if self.application_type == ApplicationType::Native => {
let rdns_client_uri = client_domain.split('.').rev().join(".");
if !private_scheme.starts_with(&rdns_client_uri) {
return Err("Private-use scheme URIs for native applications must \
begin with the application's client URI domain in \
reverse-DNS notation.");
}
if uri.has_authority() {
return Err("Private-use scheme URIs for native applications must not \
have an authority.");
}
},
| _ =>
return Err("A redirect URI's scheme is not valid for this application type."),
}
}
Ok(())
}
}
#[derive(Clone, Debug, Default, PartialEq, Eq, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum ApplicationType {
#[default]
Web,
Native,
}
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum GrantType {
AuthorizationCode,
RefreshToken,
}
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum ResponseType {
Code,
}
/// Deserialize a BTreeSet from a sequence, skipping items which fail to
/// deserialize. This is used as a deserialize helper for ClientMetadata to
/// ignore unknown enum variants in a few fields.
fn btreeset_skip_err<'de, D, V>(de: D) -> Result<BTreeSet<V>, D::Error>
where
D: Deserializer<'de>,
V: Deserialize<'de> + Hash + Eq + Ord,
{
use std::marker::PhantomData;
use serde::de::{SeqAccess, Visitor};
struct BTreeSetVisitor<V> {
item: PhantomData<V>,
}
impl<'de, V> Visitor<'de> for BTreeSetVisitor<V>
where
V: Deserialize<'de> + Hash + Eq + Ord,
{
type Value = BTreeSet<V>;
fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(formatter, "a sequence")
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: SeqAccess<'de>,
{
let mut set = BTreeSet::new();
while let Some(element) = seq.next_element().transpose() {
if let Ok(element) = element {
set.insert(element);
}
}
Ok(set)
}
}
de.deserialize_seq(BTreeSetVisitor { item: PhantomData })
}
fn is_subdomain(subdomain: &str, domain: &str) -> bool {
if subdomain == domain {
return true;
}
subdomain.ends_with(&format!(".{domain}"))
}
+72
View File
@@ -0,0 +1,72 @@
use std::sync::Arc;
use base64::Engine;
use conduwuit::{Result, utils::hash::sha256};
use database::{Deserialized, Json, Map};
use crate::{Dep, config, oauth::client_metadata::ClientMetadata};
pub mod client_metadata;
pub struct Service {
services: Services,
db: Data,
}
struct Data {
clientid_clientmetadata: Arc<Map>,
}
struct Services {
config: Dep<config::Service>,
}
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
services: Services {
config: args.depend::<config::Service>("config"),
},
db: Data {
clientid_clientmetadata: args.db["clientid_clientmetadata"].clone(),
},
}))
}
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
}
impl Service {
pub async fn register_client(
&self,
metadata: &ClientMetadata,
) -> Result<String, &'static str> {
metadata.validate()?;
let client_id = base64::prelude::BASE64_STANDARD
.encode(sha256::hash(serde_json::to_string(metadata).unwrap().as_bytes()));
if self
.db
.clientid_clientmetadata
.exists(&client_id)
.await
.is_err()
{
self.db
.clientid_clientmetadata
.raw_put(&client_id, Json(metadata.clone()));
}
Ok(client_id)
}
async fn get_client_registration(&self, client_id: &str) -> Option<ClientMetadata> {
self.db
.clientid_clientmetadata
.get(client_id)
.await
.deserialized()
.ok()
}
}
+4 -2
View File
@@ -11,8 +11,8 @@ use crate::{
account_data, admin, announcements, antispam, appservice, client, config, emergency, account_data, admin, announcements, antispam, appservice, client, config, emergency,
federation, firstrun, globals, key_backups, mailer, federation, firstrun, globals, key_backups, mailer,
manager::Manager, manager::Manager,
media, moderation, password_reset, presence, pusher, registration_tokens, resolver, rooms, media, moderation, oauth, password_reset, presence, pusher, registration_tokens, resolver,
sending, server_keys, rooms, sending, server_keys,
service::{self, Args, Map, Service}, service::{self, Args, Map, Service},
sync, threepid, transactions, uiaa, users, sync, threepid, transactions, uiaa, users,
}; };
@@ -27,6 +27,7 @@ pub struct Services {
pub globals: Arc<globals::Service>, pub globals: Arc<globals::Service>,
pub key_backups: Arc<key_backups::Service>, pub key_backups: Arc<key_backups::Service>,
pub media: Arc<media::Service>, pub media: Arc<media::Service>,
pub oauth: Arc<oauth::Service>,
pub password_reset: Arc<password_reset::Service>, pub password_reset: Arc<password_reset::Service>,
pub mailer: Arc<mailer::Service>, pub mailer: Arc<mailer::Service>,
pub presence: Arc<presence::Service>, pub presence: Arc<presence::Service>,
@@ -84,6 +85,7 @@ impl Services {
globals: build!(globals::Service), globals: build!(globals::Service),
key_backups: build!(key_backups::Service), key_backups: build!(key_backups::Service),
media: build!(media::Service), media: build!(media::Service),
oauth: build!(oauth::Service),
password_reset: build!(password_reset::Service), password_reset: build!(password_reset::Service),
mailer: build!(mailer::Service), mailer: build!(mailer::Service),
presence: build!(presence::Service), presence: build!(presence::Service),