refactor: Fix errors in api/router/

This commit is contained in:
Ginger
2026-04-13 16:20:47 -04:00
parent 0f64e6b49c
commit 0c7abd792d
8 changed files with 275 additions and 482 deletions
+70 -68
View File
@@ -1,15 +1,29 @@
use std::{mem, ops::Deref};
use std::ops::Deref;
use axum::{body::Body, extract::FromRequest};
use bytes::{BufMut, Bytes, BytesMut};
use conduwuit::{Error, Result, debug, debug_warn, err, trace};
use ruma::{
CanonicalJsonObject, CanonicalJsonValue, DeviceId, OwnedDeviceId, OwnedServerName,
OwnedUserId, ServerName, UserId, api::IncomingRequest,
use axum::{
RequestExt, RequestPartsExt,
body::Body,
extract::{FromRequest, Path, Query},
};
use conduwuit::{Error, Result, err};
use ruma::{
CanonicalJsonValue, DeviceId, OwnedDeviceId, OwnedServerName, OwnedUserId, ServerName,
UserId, api::IncomingRequest,
};
use serde::Deserialize;
use super::{auth, request, request::Request};
use crate::{State, service::appservice::RegistrationInfo};
use crate::{State, router::auth::CheckAuth, service::appservice::RegistrationInfo};
/// Query parameters needed to authenticate requests
#[derive(Deserialize)]
pub(super) struct AuthQueryParams {
pub(super) access_token: Option<String>,
pub(super) user_id: Option<String>,
/// Device ID for appservice device masquerading (MSC3202/MSC4190).
/// Can be provided as `device_id` or `org.matrix.msc3202.device_id`.
#[serde(alias = "org.matrix.msc3202.device_id")]
pub(super) device_id: Option<String>,
}
/// Extractor for Ruma request structs
pub(crate) struct Args<T> {
@@ -77,9 +91,9 @@ where
fn deref(&self) -> &Self::Target { &self.body }
}
impl<T> FromRequest<State, Body> for Args<T>
impl<R> FromRequest<State, Body> for Args<R>
where
T: IncomingRequest + Send + Sync + 'static,
R: IncomingRequest<Authentication: CheckAuth> + Send + Sync + 'static,
{
type Rejection = Error;
@@ -87,27 +101,53 @@ where
request: hyper::Request<Body>,
services: &State,
) -> Result<Self, Self::Rejection> {
let mut request = request::from(services, request).await?;
let mut json_body = serde_json::from_slice::<CanonicalJsonValue>(&request.body).ok();
let limited = request.with_limited_body();
let (mut parts, body) = limited.into_parts();
// Read the body
let body = {
let max_body_size = services.server.config.max_request_size;
// Check if the Content-Length header is present and valid, saves us streaming
// the response into memory
if let Some(content_length) = parts.headers.get(http::header::CONTENT_LENGTH) {
if let Ok(content_length) = content_length
.to_str()
.map(|s| s.parse::<usize>().unwrap_or_default())
{
if content_length > max_body_size {
return Err(err!(Request(TooLarge("Request body too large"))));
}
}
}
axum::body::to_bytes(body, max_body_size)
.await
.map_err(|e| err!(Request(TooLarge("Request body too large: {e}"))))?
};
// Make a JSON copy of the body for use in handlers
let json_body = serde_json::from_slice::<CanonicalJsonValue>(&body).ok();
// Extract the query parameters and path
let Path(path): Path<Vec<String>> = parts.extract().await?;
let Query(auth_query): Query<AuthQueryParams> = parts.extract().await?;
// Assemble a new request from the read body and parts
let request = hyper::Request::from_parts(parts, body);
// Check authentication
let auth =
R::Authentication::authenticate::<R, bytes::Bytes>(services, &request, auth_query)
.await?;
// Deserialize the body
let body = R::try_from_http_request(request, &path)
.map_err(|e| err!(Request(BadJson(debug_warn!("{e}")))))?;
// while very unusual and really shouldn't be recommended, Synapse accepts POST
// requests with a completely empty body. very old clients, libraries, and some
// appservices still call APIs like /join like this. so let's just default to
// empty object `{}` to copy synapse's behaviour
if json_body.is_none()
&& request.parts.method == http::Method::POST
&& !request.parts.uri.path().contains("/media/")
{
trace!("json_body from_request: {:?}", json_body.clone());
debug_warn!(
"received a POST request with an empty body, defaulting/assuming to {{}} like \
Synapse does"
);
json_body = Some(CanonicalJsonValue::Object(CanonicalJsonObject::new()));
}
let auth = auth::auth(services, &mut request, json_body.as_ref(), &T::METADATA).await?;
Ok(Self {
body: make_body::<T>(&mut request, json_body.as_mut())?,
body,
origin: auth.origin,
sender_user: auth.sender_user,
sender_device: auth.sender_device,
@@ -116,41 +156,3 @@ where
})
}
}
fn make_body<T>(request: &mut Request, json_body: Option<&mut CanonicalJsonValue>) -> Result<T>
where
T: IncomingRequest,
{
let body = take_body(request, json_body);
let http_request = into_http_request(request, body);
T::try_from_http_request(http_request, &request.path)
.map_err(|e| err!(Request(BadJson(debug_warn!("{e}")))))
}
fn into_http_request(request: &Request, body: Bytes) -> hyper::Request<Bytes> {
let mut http_request = hyper::Request::builder()
.uri(request.parts.uri.clone())
.method(request.parts.method.clone());
*http_request.headers_mut().expect("mutable http headers") = request.parts.headers.clone();
let http_request = http_request.body(body).expect("http request body");
let headers = http_request.headers();
let method = http_request.method();
let uri = http_request.uri();
debug!("{method:?} {uri:?} {headers:?}");
http_request
}
#[allow(clippy::needless_pass_by_value)]
fn take_body(request: &mut Request, json_body: Option<&mut CanonicalJsonValue>) -> Bytes {
let Some(CanonicalJsonValue::Object(json_body)) = json_body else {
return mem::take(&mut request.body);
};
let mut buf = BytesMut::new().writer();
serde_json::to_writer(&mut buf, &json_body).expect("value serialization can't fail");
buf.into_inner().freeze()
}