mirror of
https://forgejo.ellis.link/continuwuation/continuwuity.git
synced 2026-05-26 20:49:55 +00:00
refactor: Represent route auth information in the type system
This commit is contained in:
+13
-68
@@ -6,17 +6,14 @@ use axum::{
|
||||
extract::{FromRequest, Path, Query},
|
||||
};
|
||||
use conduwuit::{Error, Result, err};
|
||||
use ruma::{
|
||||
CanonicalJsonObject, DeviceId, OwnedDeviceId, OwnedServerName, OwnedUserId, ServerName,
|
||||
UserId, api::IncomingRequest,
|
||||
};
|
||||
use ruma::{CanonicalJsonObject, api::IncomingRequest};
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::{State, router::auth::CheckAuth, service::appservice::RegistrationInfo};
|
||||
use crate::{State, router::auth::CheckAuth};
|
||||
|
||||
/// Query parameters needed to authenticate requests
|
||||
#[derive(Deserialize)]
|
||||
pub(super) struct AuthQueryParams {
|
||||
pub(crate) struct AuthQueryParams {
|
||||
pub(super) user_id: Option<String>,
|
||||
/// Device ID for appservice device masquerading (MSC3202/MSC4190).
|
||||
/// Can be provided as `device_id` or `org.matrix.msc3202.device_id`.
|
||||
@@ -25,67 +22,22 @@ pub(super) struct AuthQueryParams {
|
||||
}
|
||||
|
||||
/// Extractor for Ruma request structs
|
||||
pub(crate) struct Args<T> {
|
||||
pub(crate) struct Args<R: IncomingRequest<Authentication: CheckAuth> + Send + Sync + 'static> {
|
||||
/// Request struct body
|
||||
pub(crate) body: T,
|
||||
pub(crate) body: R,
|
||||
|
||||
/// Federation server authentication: X-Matrix origin
|
||||
/// None when not a federation server.
|
||||
pub(crate) origin: Option<OwnedServerName>,
|
||||
|
||||
/// Local user authentication: user_id.
|
||||
/// None when not an authenticated local user.
|
||||
pub(crate) sender_user: Option<OwnedUserId>,
|
||||
|
||||
/// Local user authentication: device_id.
|
||||
/// None when not an authenticated local user or no device.
|
||||
pub(crate) sender_device: Option<OwnedDeviceId>,
|
||||
|
||||
/// Appservice authentication; registration info.
|
||||
/// None when not an appservice.
|
||||
pub(crate) appservice_info: Option<RegistrationInfo>,
|
||||
|
||||
/// Parsed JSON content.
|
||||
/// None when body is not a valid string
|
||||
/// Parsed JSON body. None when body is not JSON.
|
||||
pub(crate) json_body: Option<CanonicalJsonObject>,
|
||||
|
||||
/// Identity of the requesting entity
|
||||
pub(crate) identity: <R::Authentication as CheckAuth>::Identity,
|
||||
}
|
||||
|
||||
impl<T> Args<T>
|
||||
impl<R> Deref for Args<R>
|
||||
where
|
||||
T: IncomingRequest + Send + Sync + 'static,
|
||||
R: IncomingRequest<Authentication: CheckAuth> + Send + Sync + 'static,
|
||||
{
|
||||
#[inline]
|
||||
pub(crate) fn sender(&self) -> (&UserId, &DeviceId) {
|
||||
(self.sender_user(), self.sender_device())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(crate) fn sender_user(&self) -> &UserId {
|
||||
self.sender_user
|
||||
.as_deref()
|
||||
.expect("user must be authenticated for this handler")
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(crate) fn sender_device(&self) -> &DeviceId {
|
||||
self.sender_device
|
||||
.as_deref()
|
||||
.expect("user must be authenticated and device identified")
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(crate) fn origin(&self) -> &ServerName {
|
||||
self.origin
|
||||
.as_deref()
|
||||
.expect("server must be authenticated for this handler")
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Deref for Args<T>
|
||||
where
|
||||
T: IncomingRequest + Send + Sync + 'static,
|
||||
{
|
||||
type Target = T;
|
||||
type Target = R;
|
||||
|
||||
fn deref(&self) -> &Self::Target { &self.body }
|
||||
}
|
||||
@@ -145,13 +97,6 @@ where
|
||||
let body = R::try_from_http_request(request, &path)
|
||||
.map_err(|e| err!(Request(BadJson(debug_warn!("{e}")))))?;
|
||||
|
||||
Ok(Self {
|
||||
body,
|
||||
origin: auth.origin,
|
||||
sender_user: auth.sender_user,
|
||||
sender_device: auth.sender_device,
|
||||
appservice_info: auth.appservice_info,
|
||||
json_body,
|
||||
})
|
||||
Ok(Self { body, json_body, identity: auth })
|
||||
}
|
||||
}
|
||||
|
||||
+136
-98
@@ -2,7 +2,7 @@ use std::any::{Any, TypeId};
|
||||
|
||||
use conduwuit::{Err, Result, err};
|
||||
use ruma::{
|
||||
OwnedDeviceId, OwnedServerName, OwnedUserId, UserId,
|
||||
DeviceId, OwnedDeviceId, OwnedServerName, OwnedUserId, UserId,
|
||||
api::{
|
||||
IncomingRequest,
|
||||
auth_scheme::{
|
||||
@@ -20,20 +20,57 @@ use service::{
|
||||
|
||||
use crate::{router::args::AuthQueryParams, service::appservice::RegistrationInfo};
|
||||
|
||||
#[derive(Default)]
|
||||
pub(super) struct Auth {
|
||||
pub(super) origin: Option<OwnedServerName>,
|
||||
pub(super) sender_user: Option<OwnedUserId>,
|
||||
pub(super) sender_device: Option<OwnedDeviceId>,
|
||||
pub(super) appservice_info: Option<RegistrationInfo>,
|
||||
pub(crate) enum ClientIdentity {
|
||||
User {
|
||||
sender_user: OwnedUserId,
|
||||
sender_device: OwnedDeviceId,
|
||||
},
|
||||
Appservice {
|
||||
sender_user: OwnedUserId,
|
||||
sender_device: Option<OwnedDeviceId>,
|
||||
appservice_info: RegistrationInfo,
|
||||
},
|
||||
}
|
||||
|
||||
pub(super) trait CheckAuth: AuthScheme {
|
||||
impl ClientIdentity {
|
||||
pub(crate) fn sender_user(&self) -> &UserId {
|
||||
match self {
|
||||
| Self::User { sender_user, .. } | Self::Appservice { sender_user, .. } =>
|
||||
sender_user,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn sender_device(&self) -> Option<&DeviceId> {
|
||||
match self {
|
||||
| Self::User { sender_device, .. } => Some(sender_device),
|
||||
| Self::Appservice { sender_device, .. } => sender_device.as_deref(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn expect_sender_device(&self) -> Result<&DeviceId> {
|
||||
self.sender_device().ok_or_else(|| {
|
||||
err!(Request(Forbidden("Appservices must masquerade to use this endpoint")))
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn appservice_info(&self) -> Option<&RegistrationInfo> {
|
||||
match self {
|
||||
| Self::User { .. } => None,
|
||||
| Self::Appservice { appservice_info, .. } => Some(appservice_info),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn is_appservice(&self) -> bool { matches!(self, Self::Appservice { .. }) }
|
||||
}
|
||||
|
||||
pub(crate) trait CheckAuth: AuthScheme {
|
||||
type Identity: Send;
|
||||
|
||||
fn authenticate<R: IncomingRequest + Any, B: AsRef<[u8]> + Sync>(
|
||||
services: &Services,
|
||||
incoming_request: &hyper::Request<B>,
|
||||
query: AuthQueryParams,
|
||||
) -> impl Future<Output = Result<Auth>> + Send {
|
||||
) -> impl Future<Output = Result<Self::Identity>> + Send {
|
||||
async move {
|
||||
let route = TypeId::of::<R>();
|
||||
|
||||
@@ -54,17 +91,19 @@ pub(super) trait CheckAuth: AuthScheme {
|
||||
request: &hyper::Request<B>,
|
||||
query: AuthQueryParams,
|
||||
route: TypeId,
|
||||
) -> impl Future<Output = Result<Auth>> + Send;
|
||||
) -> impl Future<Output = Result<Self::Identity>> + Send;
|
||||
}
|
||||
|
||||
impl CheckAuth for ServerSignatures {
|
||||
type Identity = OwnedServerName;
|
||||
|
||||
async fn verify<B: AsRef<[u8]> + Sync>(
|
||||
services: &Services,
|
||||
output: Self::Output,
|
||||
request: &hyper::Request<B>,
|
||||
_query: AuthQueryParams,
|
||||
_route: TypeId,
|
||||
) -> Result<Auth> {
|
||||
) -> Result<Self::Identity> {
|
||||
let destination = services.globals.server_name();
|
||||
if output
|
||||
.destination
|
||||
@@ -96,10 +135,7 @@ impl CheckAuth for ServerSignatures {
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(Auth {
|
||||
origin: Some(output.origin.clone()),
|
||||
..Default::default()
|
||||
})
|
||||
Ok(output.origin)
|
||||
},
|
||||
| Err(err) =>
|
||||
Err!(Request(Unauthorized(warn!("Failed to verify X-Matrix header: {err}")))),
|
||||
@@ -108,162 +144,164 @@ impl CheckAuth for ServerSignatures {
|
||||
}
|
||||
|
||||
impl CheckAuth for AccessToken {
|
||||
type Identity = ClientIdentity;
|
||||
|
||||
async fn verify<B: AsRef<[u8]> + Sync>(
|
||||
services: &Services,
|
||||
output: Self::Output,
|
||||
_request: &hyper::Request<B>,
|
||||
query: AuthQueryParams,
|
||||
route: TypeId,
|
||||
) -> Result<Auth> {
|
||||
// Check for appservice tokens first
|
||||
|
||||
let (sender_user, sender_device, appservice_info) = {
|
||||
if let Ok((sender_user, sender_device)) =
|
||||
services.users.find_from_token(&output).await
|
||||
) -> Result<Self::Identity> {
|
||||
if let Ok((sender_user, sender_device)) = services.users.find_from_token(&output).await {
|
||||
// Locked users can only use /logout and /logout/all
|
||||
if services
|
||||
.users
|
||||
.is_locked(&sender_user)
|
||||
.await
|
||||
.is_ok_and(std::convert::identity)
|
||||
{
|
||||
// Locked users can only use /logout and /logout/all
|
||||
if services
|
||||
.users
|
||||
.is_locked(&sender_user)
|
||||
.await
|
||||
.is_ok_and(std::convert::identity)
|
||||
if !(route == TypeId::of::<client::session::logout::v3::Request>()
|
||||
|| route == TypeId::of::<client::session::logout_all::v3::Request>())
|
||||
{
|
||||
if !(route == TypeId::of::<client::session::logout::v3::Request>()
|
||||
|| route == TypeId::of::<client::session::logout_all::v3::Request>())
|
||||
{
|
||||
return Err!(Request(Unauthorized("Your account is locked.")));
|
||||
}
|
||||
return Err!(Request(Unauthorized("Your account is locked.")));
|
||||
}
|
||||
}
|
||||
|
||||
(Some(sender_user), Some(sender_device), None)
|
||||
} else if let Ok(appservice_info) = services.appservice.find_from_token(&output).await
|
||||
{
|
||||
let Ok(sender_user) = query.user_id.clone().map_or_else(
|
||||
|| {
|
||||
UserId::parse_with_server_name(
|
||||
appservice_info.registration.sender_localpart.as_str(),
|
||||
services.globals.server_name(),
|
||||
)
|
||||
},
|
||||
UserId::parse,
|
||||
) else {
|
||||
return Err!(Request(InvalidUsername("Username is invalid.")));
|
||||
Ok(ClientIdentity::User { sender_user, sender_device })
|
||||
} else if let Ok(appservice_info) = services.appservice.find_from_token(&output).await {
|
||||
let Ok(sender_user) = query.user_id.clone().map_or_else(
|
||||
|| {
|
||||
UserId::parse_with_server_name(
|
||||
appservice_info.registration.sender_localpart.as_str(),
|
||||
services.globals.server_name(),
|
||||
)
|
||||
},
|
||||
UserId::parse,
|
||||
) else {
|
||||
return Err!(Request(InvalidUsername("Username is invalid.")));
|
||||
};
|
||||
|
||||
if !appservice_info.is_user_match(&sender_user) {
|
||||
return Err!(Request(Exclusive("User is not in namespace.")));
|
||||
}
|
||||
|
||||
// MSC3202/MSC4190: Handle device_id masquerading for appservices.
|
||||
// The device_id can be provided via `device_id` or
|
||||
// `org.matrix.msc3202.device_id` query parameter.
|
||||
let sender_device =
|
||||
if let Some(device_id) = query.device_id.as_deref().map(Into::into) {
|
||||
// Verify the device exists for this user
|
||||
if services
|
||||
.users
|
||||
.get_device_metadata(&sender_user, device_id)
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
return Err!(Request(Forbidden(
|
||||
"Device does not exist for user or appservice cannot masquerade as \
|
||||
this device."
|
||||
)));
|
||||
}
|
||||
|
||||
Some(device_id.to_owned())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if !appservice_info.is_user_match(&sender_user) {
|
||||
return Err!(Request(Exclusive("User is not in namespace.")));
|
||||
}
|
||||
|
||||
// MSC3202/MSC4190: Handle device_id masquerading for appservices.
|
||||
// The device_id can be provided via `device_id` or
|
||||
// `org.matrix.msc3202.device_id` query parameter.
|
||||
let sender_device =
|
||||
if let Some(device_id) = query.device_id.as_deref().map(Into::into) {
|
||||
// Verify the device exists for this user
|
||||
if services
|
||||
.users
|
||||
.get_device_metadata(&sender_user, device_id)
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
return Err!(Request(Forbidden(
|
||||
"Device does not exist for user or appservice cannot masquerade \
|
||||
as this device."
|
||||
)));
|
||||
}
|
||||
|
||||
Some(device_id.to_owned())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
(Some(sender_user), sender_device, Some(appservice_info))
|
||||
} else {
|
||||
return Err!(Request(Unauthorized("Invalid access token.")));
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Auth {
|
||||
sender_user,
|
||||
sender_device,
|
||||
appservice_info,
|
||||
..Default::default()
|
||||
})
|
||||
Ok(ClientIdentity::Appservice {
|
||||
sender_user,
|
||||
sender_device,
|
||||
appservice_info,
|
||||
})
|
||||
} else {
|
||||
return Err!(Request(Unauthorized("Invalid access token.")));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CheckAuth for AccessTokenOptional {
|
||||
type Identity = Option<ClientIdentity>;
|
||||
|
||||
async fn verify<B: AsRef<[u8]> + Sync>(
|
||||
services: &Services,
|
||||
output: Self::Output,
|
||||
request: &hyper::Request<B>,
|
||||
query: AuthQueryParams,
|
||||
route: TypeId,
|
||||
) -> Result<Auth> {
|
||||
) -> Result<Self::Identity> {
|
||||
match output {
|
||||
| Some(token) =>
|
||||
<AccessToken as CheckAuth>::verify(services, token, request, query, route).await,
|
||||
| None => Ok(Auth::default()),
|
||||
<AccessToken as CheckAuth>::verify(services, token, request, query, route)
|
||||
.await
|
||||
.map(Some),
|
||||
| None => Ok(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CheckAuth for AppserviceToken {
|
||||
type Identity = RegistrationInfo;
|
||||
|
||||
async fn verify<B: AsRef<[u8]> + Sync>(
|
||||
services: &Services,
|
||||
output: Self::Output,
|
||||
_request: &hyper::Request<B>,
|
||||
_query: AuthQueryParams,
|
||||
_route: TypeId,
|
||||
) -> Result<Auth> {
|
||||
) -> Result<Self::Identity> {
|
||||
let Ok(appservice_info) = services.appservice.find_from_token(&output).await else {
|
||||
return Err!(Request(Unauthorized("Invalid appservice token.")));
|
||||
};
|
||||
|
||||
Ok(Auth {
|
||||
appservice_info: Some(appservice_info),
|
||||
..Default::default()
|
||||
})
|
||||
Ok(appservice_info)
|
||||
}
|
||||
}
|
||||
|
||||
impl CheckAuth for AppserviceTokenOptional {
|
||||
type Identity = Option<RegistrationInfo>;
|
||||
|
||||
async fn verify<B: AsRef<[u8]> + Sync>(
|
||||
services: &Services,
|
||||
output: Self::Output,
|
||||
request: &hyper::Request<B>,
|
||||
query: AuthQueryParams,
|
||||
route: TypeId,
|
||||
) -> Result<Auth> {
|
||||
) -> Result<Self::Identity> {
|
||||
match output {
|
||||
| Some(token) =>
|
||||
<AppserviceToken as CheckAuth>::verify(services, token, request, query, route)
|
||||
.await,
|
||||
| None => Ok(Auth::default()),
|
||||
.await
|
||||
.map(Some),
|
||||
| None => Ok(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CheckAuth for NoAuthentication {
|
||||
type Identity = ();
|
||||
|
||||
async fn verify<B: AsRef<[u8]> + Sync>(
|
||||
_services: &Services,
|
||||
_output: Self::Output,
|
||||
_request: &hyper::Request<B>,
|
||||
_query: AuthQueryParams,
|
||||
_route: TypeId,
|
||||
) -> Result<Auth> {
|
||||
Ok(Auth::default())
|
||||
) -> Result<Self::Identity> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl CheckAuth for NoAccessToken {
|
||||
type Identity = Option<ClientIdentity>;
|
||||
|
||||
async fn verify<B: AsRef<[u8]> + Sync>(
|
||||
services: &Services,
|
||||
_output: Self::Output,
|
||||
request: &hyper::Request<B>,
|
||||
query: AuthQueryParams,
|
||||
route: TypeId,
|
||||
) -> Result<Auth> {
|
||||
) -> Result<Self::Identity> {
|
||||
// We handle these the same as AccessTokenOptional
|
||||
let token = AccessTokenOptional::extract_authentication(request).map_err(|err| {
|
||||
err!(Request(Unauthorized(warn!("Failed to extract authorization: {}", err))))
|
||||
|
||||
Reference in New Issue
Block a user