mirror of
https://forgejo.ellis.link/continuwuation/continuwuity.git
synced 2026-05-26 20:49:55 +00:00
feat: Implement oauth service and client registration
This commit is contained in:
@@ -16,6 +16,7 @@ pub(super) mod media_legacy;
|
|||||||
pub(super) mod membership;
|
pub(super) mod membership;
|
||||||
pub(super) mod message;
|
pub(super) mod message;
|
||||||
pub(super) mod mutual_rooms;
|
pub(super) mod mutual_rooms;
|
||||||
|
pub(super) mod oauth;
|
||||||
pub(super) mod openid;
|
pub(super) mod openid;
|
||||||
pub(super) mod presence;
|
pub(super) mod presence;
|
||||||
pub(super) mod profile;
|
pub(super) mod profile;
|
||||||
@@ -61,6 +62,7 @@ pub(super) use membership::*;
|
|||||||
pub use membership::{leave_all_rooms, leave_room, remote_leave_room};
|
pub use membership::{leave_all_rooms, leave_room, remote_leave_room};
|
||||||
pub(super) use message::*;
|
pub(super) use message::*;
|
||||||
pub(super) use mutual_rooms::*;
|
pub(super) use mutual_rooms::*;
|
||||||
|
pub(super) use oauth::*;
|
||||||
pub(super) use openid::*;
|
pub(super) use openid::*;
|
||||||
pub(super) use presence::*;
|
pub(super) use presence::*;
|
||||||
pub(super) use profile::*;
|
pub(super) use profile::*;
|
||||||
|
|||||||
@@ -0,0 +1,17 @@
|
|||||||
|
mod register_client;
|
||||||
|
mod server_metadata;
|
||||||
|
|
||||||
|
use axum::{
|
||||||
|
Json, Router,
|
||||||
|
routing::method_routing::{get, post},
|
||||||
|
};
|
||||||
|
use serde_json::json;
|
||||||
|
pub(crate) use server_metadata::*;
|
||||||
|
|
||||||
|
pub(crate) const BASE_PATH: &str = "/_continuwuity/oauth2/";
|
||||||
|
|
||||||
|
pub(crate) fn router() -> Router<crate::State> {
|
||||||
|
Router::new()
|
||||||
|
.route("/client/register", post(register_client::register_client_route))
|
||||||
|
.route("/client/keys.json", get(async || Json(json!({"keys": []}))))
|
||||||
|
}
|
||||||
@@ -0,0 +1,28 @@
|
|||||||
|
use axum::{
|
||||||
|
Json,
|
||||||
|
extract::State,
|
||||||
|
response::{IntoResponse, Response},
|
||||||
|
};
|
||||||
|
use http::StatusCode;
|
||||||
|
use serde::Serialize;
|
||||||
|
use service::oauth::client_metadata::ClientMetadata;
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct RegisteredClient {
|
||||||
|
client_id: String,
|
||||||
|
#[serde(flatten)]
|
||||||
|
metadata: ClientMetadata,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn register_client_route(
|
||||||
|
State(services): State<crate::State>,
|
||||||
|
Json(metadata): Json<ClientMetadata>,
|
||||||
|
) -> Result<Response, Response> {
|
||||||
|
let client_id = services
|
||||||
|
.oauth
|
||||||
|
.register_client(&metadata)
|
||||||
|
.await
|
||||||
|
.map_err(|err| (StatusCode::BAD_REQUEST, err.to_owned()).into_response())?;
|
||||||
|
|
||||||
|
Ok(Json(RegisteredClient { client_id, metadata }).into_response())
|
||||||
|
}
|
||||||
@@ -0,0 +1,34 @@
|
|||||||
|
use axum::extract::State;
|
||||||
|
use conduwuit::Result;
|
||||||
|
use ruma::{api::client::discovery::get_authorization_server_metadata, serde::Raw};
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
use crate::Ruma;
|
||||||
|
|
||||||
|
pub(crate) async fn get_authorization_server_metadata_route(
|
||||||
|
State(services): State<crate::State>,
|
||||||
|
_body: Ruma<get_authorization_server_metadata::v1::Request>,
|
||||||
|
) -> Result<get_authorization_server_metadata::v1::Response> {
|
||||||
|
let endpoint_base = services
|
||||||
|
.config
|
||||||
|
.get_client_domain()
|
||||||
|
.join(super::BASE_PATH)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let metadata = Raw::new(&json!({
|
||||||
|
"authorization_endpoint": endpoint_base.join("grant/authorization_code").unwrap(),
|
||||||
|
"code_challenge_methods_supported": ["S256"],
|
||||||
|
"grant_types_supported": ["authorization_code", "refresh_token"],
|
||||||
|
"issuer": services.config.get_client_domain(),
|
||||||
|
"jwks_uri": endpoint_base.join("client/keys.json").unwrap(),
|
||||||
|
"prompt_values_supported": ["create"],
|
||||||
|
"registration_endpoint": endpoint_base.join("client/register").unwrap(),
|
||||||
|
"response_modes_supported": ["query", "fragment"],
|
||||||
|
"response_types_supported": ["code"],
|
||||||
|
"revocation_endpoint": endpoint_base.join("client/revoke").unwrap(),
|
||||||
|
"token_endpoint": endpoint_base.join("grant/token").unwrap(),
|
||||||
|
}))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
Ok(get_authorization_server_metadata::v1::Response::new(metadata.cast_unchecked()))
|
||||||
|
}
|
||||||
@@ -186,6 +186,8 @@ pub fn build(router: Router<State>, server: &Server) -> Router<State> {
|
|||||||
.ruma_route(&client::well_known_policy_server)
|
.ruma_route(&client::well_known_policy_server)
|
||||||
.ruma_route(&client::get_rtc_transports)
|
.ruma_route(&client::get_rtc_transports)
|
||||||
.ruma_route(&client::room_initial_sync_route)
|
.ruma_route(&client::room_initial_sync_route)
|
||||||
|
.ruma_route(&client::get_authorization_server_metadata_route)
|
||||||
|
.nest(client::oauth::BASE_PATH, client::oauth::router())
|
||||||
.route("/_conduwuit/server_version", get(client::conduwuit_server_version))
|
.route("/_conduwuit/server_version", get(client::conduwuit_server_version))
|
||||||
.route("/_continuwuity/server_version", get(client::conduwuit_server_version))
|
.route("/_continuwuity/server_version", get(client::conduwuit_server_version))
|
||||||
.ruma_route(&admin::rooms::ban::ban_room)
|
.ruma_route(&admin::rooms::ban::ban_room)
|
||||||
|
|||||||
@@ -49,6 +49,10 @@ pub(super) static MAPS: &[Descriptor] = &[
|
|||||||
name: "bannedroomids",
|
name: "bannedroomids",
|
||||||
..descriptor::RANDOM_SMALL
|
..descriptor::RANDOM_SMALL
|
||||||
},
|
},
|
||||||
|
Descriptor {
|
||||||
|
name: "clientid_clientmetadata",
|
||||||
|
..descriptor::RANDOM_SMALL
|
||||||
|
},
|
||||||
Descriptor {
|
Descriptor {
|
||||||
name: "disabledroomids",
|
name: "disabledroomids",
|
||||||
..descriptor::RANDOM_SMALL
|
..descriptor::RANDOM_SMALL
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ pub mod key_backups;
|
|||||||
pub mod mailer;
|
pub mod mailer;
|
||||||
pub mod media;
|
pub mod media;
|
||||||
pub mod moderation;
|
pub mod moderation;
|
||||||
|
pub mod oauth;
|
||||||
pub mod password_reset;
|
pub mod password_reset;
|
||||||
pub mod presence;
|
pub mod presence;
|
||||||
pub mod pusher;
|
pub mod pusher;
|
||||||
|
|||||||
@@ -0,0 +1,194 @@
|
|||||||
|
use std::{collections::BTreeSet, hash::Hash};
|
||||||
|
|
||||||
|
use itertools::Itertools;
|
||||||
|
use serde::{Deserialize, Deserializer, Serialize};
|
||||||
|
use url::Url;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, PartialEq, Eq, Deserialize, Serialize)]
|
||||||
|
pub struct ClientMetadata {
|
||||||
|
#[serde(default)]
|
||||||
|
application_type: ApplicationType,
|
||||||
|
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
client_name: Option<String>,
|
||||||
|
|
||||||
|
client_uri: Url,
|
||||||
|
|
||||||
|
#[serde(default, deserialize_with = "btreeset_skip_err")]
|
||||||
|
grant_types: BTreeSet<GrantType>,
|
||||||
|
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
logo_uri: Option<Url>,
|
||||||
|
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
policy_uri: Option<Url>,
|
||||||
|
|
||||||
|
#[serde(default)]
|
||||||
|
redirect_uris: Vec<Url>,
|
||||||
|
|
||||||
|
#[serde(default, deserialize_with = "btreeset_skip_err")]
|
||||||
|
response_types: BTreeSet<ResponseType>,
|
||||||
|
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
token_endpoint_auth_method: Option<String>,
|
||||||
|
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
tos_uri: Option<Url>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ClientMetadata {
|
||||||
|
const ACCEPTABLE_LOCALHOSTS: [&str; 3] = ["localhost", "127.0.0.1", "[::1]"];
|
||||||
|
|
||||||
|
pub(super) fn validate(&self) -> Result<(), &'static str> {
|
||||||
|
let Some(client_domain) = self.client_uri.domain() else {
|
||||||
|
return Err("Client URI must have a domain.");
|
||||||
|
};
|
||||||
|
|
||||||
|
if self.client_uri.scheme() != "https" {
|
||||||
|
return Err("Client URI must be HTTPS.");
|
||||||
|
}
|
||||||
|
|
||||||
|
if !self.client_uri.username().is_empty() || self.client_uri.password().is_some() {
|
||||||
|
return Err("Client URI must not include credentials.");
|
||||||
|
}
|
||||||
|
|
||||||
|
for uri in [&self.logo_uri, &self.policy_uri, &self.tos_uri]
|
||||||
|
.iter()
|
||||||
|
.filter_map(|uri| uri.as_ref())
|
||||||
|
{
|
||||||
|
if uri.scheme() != "https" {
|
||||||
|
return Err("All metadata URIs must be HTTPS.");
|
||||||
|
}
|
||||||
|
|
||||||
|
if !uri.username().is_empty() || uri.password().is_some() {
|
||||||
|
return Err("All metadata URIs must not include credentials.");
|
||||||
|
}
|
||||||
|
|
||||||
|
if !uri
|
||||||
|
.domain()
|
||||||
|
.is_some_and(|domain| is_subdomain(domain, client_domain))
|
||||||
|
{
|
||||||
|
return Err("All metadata URIs must be subdomains of the client URI.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for uri in &self.redirect_uris {
|
||||||
|
match uri.scheme() {
|
||||||
|
| "https" => {
|
||||||
|
// HTTPS URIs are okay for native and web clients
|
||||||
|
|
||||||
|
if !uri.username().is_empty() || uri.password().is_some() {
|
||||||
|
return Err("HTTPS redirect URIs must not contain credentials.");
|
||||||
|
}
|
||||||
|
},
|
||||||
|
| "http" if self.application_type == ApplicationType::Native => {
|
||||||
|
if uri
|
||||||
|
.host_str()
|
||||||
|
.is_none_or(|host| !Self::ACCEPTABLE_LOCALHOSTS.contains(&host))
|
||||||
|
{
|
||||||
|
return Err("HTTP redirect URIs for native applications must only \
|
||||||
|
refer to localhost.");
|
||||||
|
}
|
||||||
|
|
||||||
|
if uri.port().is_some() {
|
||||||
|
return Err("HTTP redirect URIs for native applications do not need to \
|
||||||
|
specify a port. All ports will be accepted during \
|
||||||
|
authorization.");
|
||||||
|
}
|
||||||
|
},
|
||||||
|
| private_scheme if self.application_type == ApplicationType::Native => {
|
||||||
|
let rdns_client_uri = client_domain.split('.').rev().join(".");
|
||||||
|
|
||||||
|
if !private_scheme.starts_with(&rdns_client_uri) {
|
||||||
|
return Err("Private-use scheme URIs for native applications must \
|
||||||
|
begin with the application's client URI domain in \
|
||||||
|
reverse-DNS notation.");
|
||||||
|
}
|
||||||
|
|
||||||
|
if uri.has_authority() {
|
||||||
|
return Err("Private-use scheme URIs for native applications must not \
|
||||||
|
have an authority.");
|
||||||
|
}
|
||||||
|
},
|
||||||
|
| _ =>
|
||||||
|
return Err("A redirect URI's scheme is not valid for this application type."),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Default, PartialEq, Eq, Deserialize, Serialize)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
pub enum ApplicationType {
|
||||||
|
#[default]
|
||||||
|
Web,
|
||||||
|
Native,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Deserialize, Serialize)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
pub enum GrantType {
|
||||||
|
AuthorizationCode,
|
||||||
|
RefreshToken,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Deserialize, Serialize)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
pub enum ResponseType {
|
||||||
|
Code,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Deserialize a BTreeSet from a sequence, skipping items which fail to
|
||||||
|
/// deserialize. This is used as a deserialize helper for ClientMetadata to
|
||||||
|
/// ignore unknown enum variants in a few fields.
|
||||||
|
fn btreeset_skip_err<'de, D, V>(de: D) -> Result<BTreeSet<V>, D::Error>
|
||||||
|
where
|
||||||
|
D: Deserializer<'de>,
|
||||||
|
V: Deserialize<'de> + Hash + Eq + Ord,
|
||||||
|
{
|
||||||
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
|
use serde::de::{SeqAccess, Visitor};
|
||||||
|
|
||||||
|
struct BTreeSetVisitor<V> {
|
||||||
|
item: PhantomData<V>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'de, V> Visitor<'de> for BTreeSetVisitor<V>
|
||||||
|
where
|
||||||
|
V: Deserialize<'de> + Hash + Eq + Ord,
|
||||||
|
{
|
||||||
|
type Value = BTreeSet<V>;
|
||||||
|
|
||||||
|
fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(formatter, "a sequence")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
|
||||||
|
where
|
||||||
|
A: SeqAccess<'de>,
|
||||||
|
{
|
||||||
|
let mut set = BTreeSet::new();
|
||||||
|
|
||||||
|
while let Some(element) = seq.next_element().transpose() {
|
||||||
|
if let Ok(element) = element {
|
||||||
|
set.insert(element);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(set)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
de.deserialize_seq(BTreeSetVisitor { item: PhantomData })
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_subdomain(subdomain: &str, domain: &str) -> bool {
|
||||||
|
if subdomain == domain {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
subdomain.ends_with(&format!(".{domain}"))
|
||||||
|
}
|
||||||
@@ -0,0 +1,72 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use base64::Engine;
|
||||||
|
use conduwuit::{Result, utils::hash::sha256};
|
||||||
|
use database::{Deserialized, Json, Map};
|
||||||
|
|
||||||
|
use crate::{Dep, config, oauth::client_metadata::ClientMetadata};
|
||||||
|
|
||||||
|
pub mod client_metadata;
|
||||||
|
|
||||||
|
pub struct Service {
|
||||||
|
services: Services,
|
||||||
|
db: Data,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Data {
|
||||||
|
clientid_clientmetadata: Arc<Map>,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Services {
|
||||||
|
config: Dep<config::Service>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl crate::Service for Service {
|
||||||
|
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
|
||||||
|
Ok(Arc::new(Self {
|
||||||
|
services: Services {
|
||||||
|
config: args.depend::<config::Service>("config"),
|
||||||
|
},
|
||||||
|
db: Data {
|
||||||
|
clientid_clientmetadata: args.db["clientid_clientmetadata"].clone(),
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Service {
|
||||||
|
pub async fn register_client(
|
||||||
|
&self,
|
||||||
|
metadata: &ClientMetadata,
|
||||||
|
) -> Result<String, &'static str> {
|
||||||
|
metadata.validate()?;
|
||||||
|
|
||||||
|
let client_id = base64::prelude::BASE64_STANDARD
|
||||||
|
.encode(sha256::hash(serde_json::to_string(metadata).unwrap().as_bytes()));
|
||||||
|
|
||||||
|
if self
|
||||||
|
.db
|
||||||
|
.clientid_clientmetadata
|
||||||
|
.exists(&client_id)
|
||||||
|
.await
|
||||||
|
.is_err()
|
||||||
|
{
|
||||||
|
self.db
|
||||||
|
.clientid_clientmetadata
|
||||||
|
.raw_put(&client_id, Json(metadata.clone()));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(client_id)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_client_registration(&self, client_id: &str) -> Option<ClientMetadata> {
|
||||||
|
self.db
|
||||||
|
.clientid_clientmetadata
|
||||||
|
.get(client_id)
|
||||||
|
.await
|
||||||
|
.deserialized()
|
||||||
|
.ok()
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -11,8 +11,8 @@ use crate::{
|
|||||||
account_data, admin, announcements, antispam, appservice, client, config, emergency,
|
account_data, admin, announcements, antispam, appservice, client, config, emergency,
|
||||||
federation, firstrun, globals, key_backups, mailer,
|
federation, firstrun, globals, key_backups, mailer,
|
||||||
manager::Manager,
|
manager::Manager,
|
||||||
media, moderation, password_reset, presence, pusher, registration_tokens, resolver, rooms,
|
media, moderation, oauth, password_reset, presence, pusher, registration_tokens, resolver,
|
||||||
sending, server_keys,
|
rooms, sending, server_keys,
|
||||||
service::{self, Args, Map, Service},
|
service::{self, Args, Map, Service},
|
||||||
sync, threepid, transactions, uiaa, users,
|
sync, threepid, transactions, uiaa, users,
|
||||||
};
|
};
|
||||||
@@ -27,6 +27,7 @@ pub struct Services {
|
|||||||
pub globals: Arc<globals::Service>,
|
pub globals: Arc<globals::Service>,
|
||||||
pub key_backups: Arc<key_backups::Service>,
|
pub key_backups: Arc<key_backups::Service>,
|
||||||
pub media: Arc<media::Service>,
|
pub media: Arc<media::Service>,
|
||||||
|
pub oauth: Arc<oauth::Service>,
|
||||||
pub password_reset: Arc<password_reset::Service>,
|
pub password_reset: Arc<password_reset::Service>,
|
||||||
pub mailer: Arc<mailer::Service>,
|
pub mailer: Arc<mailer::Service>,
|
||||||
pub presence: Arc<presence::Service>,
|
pub presence: Arc<presence::Service>,
|
||||||
@@ -84,6 +85,7 @@ impl Services {
|
|||||||
globals: build!(globals::Service),
|
globals: build!(globals::Service),
|
||||||
key_backups: build!(key_backups::Service),
|
key_backups: build!(key_backups::Service),
|
||||||
media: build!(media::Service),
|
media: build!(media::Service),
|
||||||
|
oauth: build!(oauth::Service),
|
||||||
password_reset: build!(password_reset::Service),
|
password_reset: build!(password_reset::Service),
|
||||||
mailer: build!(mailer::Service),
|
mailer: build!(mailer::Service),
|
||||||
presence: build!(presence::Service),
|
presence: build!(presence::Service),
|
||||||
|
|||||||
Reference in New Issue
Block a user