refactor: Ruma upstreaming, half-baked edition

Co-authored-by: Jade Ellis <jade@ellis.link>
This commit is contained in:
Ginger
2026-03-29 12:25:42 -04:00
parent 1cc9dbf2a4
commit 204bc1367e
141 changed files with 2715 additions and 2279 deletions
@@ -12,7 +12,7 @@ use ruma::{
api::federation::event::get_event,
};
use super::get_room_version_id;
use super::get_room_version;
/// Find the event and auth it. Once the event is validated (steps 1 - 8)
/// it is appended to the outliers Tree.
@@ -109,15 +109,12 @@ where
match self
.services
.sending
.send_federation_request(origin, get_event::v1::Request {
event_id: (*next_id).to_owned(),
include_unredacted_content: None,
})
.send_federation_request(origin, get_event::v1::Request::new((*next_id).to_owned()))
.await
{
| Ok(res) => {
debug!("Got {next_id} over federation from {origin}");
let Ok(room_version_id) = get_room_version_id(create_event) else {
let Ok(room_version_id) = get_room_version(create_event) else {
back_off((*next_id).to_owned());
continue;
};
@@ -31,10 +31,7 @@ where
let res = self
.services
.sending
.send_federation_request(origin, get_room_state_ids::v1::Request {
room_id: room_id.to_owned(),
event_id: event_id.to_owned(),
})
.send_federation_request(origin, get_room_state_ids::v1::Request::new(event_id.to_owned(), room_id.to_owned()))
.await
.inspect_err(|e| debug_warn!("Fetching state for event failed: {e}"))?;
@@ -55,7 +55,7 @@ async fn should_rescind_invite(
// Does the target user have a pending invite?
let Ok(pending_invite_state) = services
.state_cache
.invite_state(target_user_id, room_id)
.invite_state(&target_user_id, room_id)
.await
else {
return Ok(None); // No pending invite, so nothing to rescind
@@ -146,9 +146,11 @@ pub async fn handle_incoming_pdu<'a>(
let origin_acl_check = self.acl_check(origin, room_id);
// 1.3.2 Check room ACL on sender's server name
let sender: &UserId = value
let sender: OwnedUserId = value
.get("sender")
.try_into()
.and_then(|v| v.as_str())
.ok_or_else(|| err!("No sender in object"))
.and_then(|v| Ok(UserId::parse(v)?))
.map_err(|e| err!(Request(InvalidParam("PDU does not have a valid sender key: {e}"))))?;
let sender_acl_check: OptionFuture<_> = sender
@@ -180,7 +182,7 @@ pub async fn handle_incoming_pdu<'a>(
// copied from https://github.com/element-hq/synapse/blob/7e4588a/synapse/handlers/federation_event.py#L255-L300
if value.get("type").and_then(|t| t.as_str()) == Some("m.room.member") {
if let Some(pdu) =
should_rescind_invite(&self.services, &mut value.clone(), sender, room_id).await?
should_rescind_invite(&self.services, &mut value.clone(), &sender, room_id).await?
{
debug_info!(
"Invite to {room_id} appears to have been rescinded by {sender}, marking as \
@@ -188,7 +190,7 @@ pub async fn handle_incoming_pdu<'a>(
);
self.services
.state_cache
.mark_as_left(sender, room_id, Some(pdu))
.mark_as_left(&sender, room_id, Some(pdu))
.await;
return Ok(None);
}
@@ -10,7 +10,7 @@ use ruma::{
events::StateEventType,
};
use super::{check_room_id, get_room_version_id, to_room_version};
use super::{check_room_id, get_room_version};
use crate::rooms::timeline::pdu_fits;
#[implement(super::Service)]
@@ -41,18 +41,19 @@ where
// 2. Check signatures, otherwise drop
// 3. check content hash, redact if doesn't match
let room_version_id = get_room_version_id(create_event)?;
let room_version = get_room_version(create_event)?;
let room_rules = room_version.rules().expect("room version should have defined rules");
let mut incoming_pdu = match self
.services
.server_keys
.verify_event(&value, Some(&room_version_id))
.verify_event(&value, Some(&room_version))
.await
{
| Ok(ruma::signatures::Verified::All) => value,
| Ok(ruma::signatures::Verified::Signatures) => {
// Redact
debug_info!("Calculated hash does not match (redaction): {event_id}");
let Ok(obj) = ruma::canonical_json::redact(value, &room_version_id, None) else {
let Ok(obj) = ruma::canonical_json::redact(value, &room_rules.redaction, None) else {
return Err!(Request(InvalidParam("Redaction failed")));
};
@@ -184,7 +185,7 @@ where
};
let auth_check = state_res::event_auth::auth_check(
&to_room_version(&room_version_id),
&room_rules,
&pdu_event,
None, // TODO: third party invite
state_fetch,
+3 -8
View File
@@ -14,7 +14,7 @@ mod upgrade_outlier_pdu;
use std::{collections::HashMap, fmt::Write, sync::Arc, time::Instant};
use async_trait::async_trait;
use conduwuit::{Err, Event, PduEvent, Result, RoomVersion, Server, SyncRwLock, utils::MutexMap};
use conduwuit::{Err, Event, PduEvent, Result, Server, SyncRwLock, utils::MutexMap};
use ruma::{
OwnedEventId, OwnedRoomId, RoomId, RoomVersionId,
events::room::create::RoomCreateEventContent,
@@ -114,14 +114,9 @@ fn check_room_id<Pdu: Event>(room_id: &RoomId, pdu: &Pdu) -> Result {
Ok(())
}
fn get_room_version_id<Pdu: Event>(create_event: &Pdu) -> Result<RoomVersionId> {
fn get_room_version<Pdu: Event>(create_event: &Pdu) -> Result<RoomVersionId> {
let content: RoomCreateEventContent = create_event.get_content()?;
let room_version = content.room_version;
Ok(room_version)
}
#[inline]
fn to_room_version(room_version_id: &RoomVersionId) -> RoomVersion {
RoomVersion::new(room_version_id).expect("room version is supported")
}
}
@@ -3,22 +3,24 @@
//! This module implements a check against a room-specific policy server, as
//! described in the relevant Matrix spec proposal (see: https://github.com/matrix-org/matrix-spec-proposals/pull/4284).
use std::{collections::BTreeMap, time::Duration};
use std::{collections::BTreeMap, sync::LazyLock, time::Duration};
use conduwuit::{
Err, Event, PduEvent, Result, debug, debug_error, debug_info, debug_warn, implement, trace,
warn,
};
use ruma::{
CanonicalJsonObject, CanonicalJsonValue, KeyId, RoomId, ServerName, SigningKeyId,
api::federation::room::{
policy_check::unstable::Request as PolicyCheckRequest,
policy_sign::unstable::Request as PolicySignRequest,
},
events::{StateEventType, room::policy::RoomPolicyEventContent},
CanonicalJsonObject, CanonicalJsonValue, KeyId, OwnedKeyId, RoomId, ServerName, SigningKeyId, events::StateEventType
};
use ruminuwuity::policy::{
policy_check::unstable::Request as PolicyCheckRequest,
policy_sign::unstable::Request as PolicySignRequest,
event::RoomPolicyEventContent
};
use serde_json::value::RawValue;
static POLICY_EVENT_TYPE_UNSTABLE: LazyLock<StateEventType> = LazyLock::new(|| StateEventType::from("org.matrix.msc4284.policy"));
/// Asks a remote policy server if the event is allowed.
///
/// If the event is the `org.matrix.msc4284.policy` configuration state event,
@@ -44,7 +46,7 @@ pub async fn ask_policy_server(
return Ok(true); // don't ever contact policy servers
}
if *pdu.event_type() == StateEventType::RoomPolicy.into() {
if *pdu.event_type() == POLICY_EVENT_TYPE_UNSTABLE.clone().into() {
debug!(
room_id = %room_id,
event_type = ?pdu.event_type(),
@@ -56,7 +58,7 @@ pub async fn ask_policy_server(
let Ok(policyserver) = self
.services
.state_accessor
.room_state_get_content(room_id, &StateEventType::RoomPolicy, "")
.room_state_get_content(room_id, &POLICY_EVENT_TYPE_UNSTABLE, "")
.await
.inspect_err(|e| {
if !e.is_not_found() {
@@ -86,11 +88,7 @@ pub async fn ask_policy_server(
return Ok(true);
},
};
if via.is_empty() {
trace!("Policy server is empty for room {room_id}, skipping spam check");
return Ok(true);
}
if !self.services.state_cache.server_in_room(via, room_id).await {
if !self.services.state_cache.server_in_room(&via, room_id).await {
debug!(
via = %via,
"Policy server is not in the room, skipping spam check"
@@ -110,14 +108,14 @@ pub async fn ask_policy_server(
"Getting policy server signature on event"
);
return self
.fetch_policy_server_signature(pdu, pdu_json, via, outgoing, room_id)
.fetch_policy_server_signature(pdu, pdu_json, &via, outgoing, room_id)
.await;
}
// for incoming events, is it signed by <via> with the key
// "ed25519:policy_server"?
if let Some(CanonicalJsonValue::Object(sigs)) = pdu_json.get("signatures") {
if let Some(CanonicalJsonValue::Object(server_sigs)) = sigs.get(via.as_str()) {
let wanted_key_id: &KeyId<ruma::SigningKeyAlgorithm, ruma::Base64PublicKey> =
let wanted_key_id: OwnedKeyId<ruma::SigningKeyAlgorithm, ruma::Base64PublicKey> =
SigningKeyId::parse("ed25519:policy_server")?;
if let Some(CanonicalJsonValue::String(_sig_value)) =
server_sigs.get(wanted_key_id.as_str())
@@ -134,14 +132,15 @@ pub async fn ask_policy_server(
via = %via,
"Checking event for spam with policy server via legacy check"
);
let mut request = PolicyCheckRequest::new(pdu.event_id().to_owned());
request.pdu = Some(outgoing);
let response = tokio::time::timeout(
Duration::from_secs(self.services.server.config.policy_server_request_timeout),
self.services
.sending
.send_federation_request(via, PolicyCheckRequest {
event_id: pdu.event_id().to_owned(),
pdu: Some(outgoing),
}),
.send_federation_request(&via, request),
)
.await;
let response = match response {
@@ -202,7 +201,7 @@ pub async fn fetch_policy_server_signature(
Duration::from_secs(self.services.server.config.policy_server_request_timeout),
self.services
.sending
.send_federation_request(via, PolicySignRequest { pdu: outgoing }),
.send_federation_request(via, PolicySignRequest::new(outgoing)),
)
.await;
@@ -250,7 +249,7 @@ pub async fn fetch_policy_server_signature(
}
let keypairs = sigs.get(via).unwrap();
let wanted_key_id = KeyId::parse("ed25519:policy_server")?;
if !keypairs.contains_key(wanted_key_id) {
if !keypairs.contains_key(&wanted_key_id) {
debug_warn!(
"Policy server returned signature, but did not use the key ID \
'ed25519:policy_server'."
@@ -262,7 +261,7 @@ pub async fn fetch_policy_server_signature(
.or_insert_with(|| CanonicalJsonValue::Object(BTreeMap::default()));
if let CanonicalJsonValue::Object(signatures_map) = signatures_entry {
let sig_value = keypairs.get(wanted_key_id).unwrap().to_owned();
let sig_value = keypairs.get(&wanted_key_id).unwrap().to_owned();
match signatures_map.get_mut(via.as_str()) {
| Some(CanonicalJsonValue::Object(inner_map)) => {
@@ -11,7 +11,7 @@ use conduwuit::{
utils::stream::{IterStream, ReadyExt, TryWidebandExt, WidebandExt},
};
use futures::{FutureExt, StreamExt, TryFutureExt, TryStreamExt, future::try_join};
use ruma::{OwnedEventId, RoomId, RoomVersionId};
use ruma::{OwnedEventId, RoomId, RoomVersionId, room_version_rules::RoomVersionRules};
use crate::rooms::state_compressor::CompressedState;
@@ -20,7 +20,7 @@ use crate::rooms::state_compressor::CompressedState;
pub async fn resolve_state(
&self,
room_id: &RoomId,
room_version_id: &RoomVersionId,
room_version_rules: &RoomVersionRules,
incoming_state: HashMap<u64, OwnedEventId>,
) -> Result<Arc<CompressedState>> {
trace!("Loading current room state ids");
@@ -71,7 +71,7 @@ pub async fn resolve_state(
trace!("Resolving state");
let state = self
.state_resolution(room_version_id, fork_states.iter(), &auth_chain_sets)
.state_resolution(room_version_rules, fork_states.iter(), &auth_chain_sets)
.boxed()
.await?;
@@ -103,7 +103,7 @@ pub async fn resolve_state(
#[tracing::instrument(name = "ruma", level = "debug", skip_all)]
pub async fn state_resolution<'a, StateSets>(
&'a self,
room_version: &'a RoomVersionId,
room_version_rules: &'a RoomVersionRules,
state_sets: StateSets,
auth_chain_sets: &'a [HashSet<OwnedEventId>],
) -> Result<StateMap<OwnedEventId>>
@@ -112,7 +112,7 @@ where
{
let event_fetch = |event_id| self.event_fetch(event_id);
let event_exists = |event_id| self.event_exists(event_id);
state_res::resolve(room_version, state_sets, auth_chain_sets, &event_fetch, &event_exists)
state_res::resolve(room_version_rules, state_sets, auth_chain_sets, &event_fetch, &event_exists)
.map_err(|e| err!(error!("State resolution failed: {e:?}")))
.await
}
@@ -11,7 +11,7 @@ use conduwuit::{
utils::stream::{BroadbandExt, IterStream, ReadyExt, TryBroadbandExt, TryWidebandExt},
};
use futures::{FutureExt, StreamExt, TryFutureExt, TryStreamExt, future::try_join};
use ruma::{OwnedEventId, RoomId, RoomVersionId};
use ruma::{OwnedEventId, RoomId, RoomVersionId, room_version_rules::RoomVersionRules};
use crate::rooms::short::ShortStateHash;
@@ -77,7 +77,7 @@ pub(super) async fn state_at_incoming_resolved<Pdu>(
&self,
incoming_pdu: &Pdu,
room_id: &RoomId,
room_version_id: &RoomVersionId,
room_version_rules: &RoomVersionRules,
) -> Result<Option<HashMap<u64, OwnedEventId>>>
where
Pdu: Event + Send + Sync,
@@ -118,7 +118,7 @@ where
.await?;
let Ok(new_state) = self
.state_resolution(room_version_id, fork_states.iter(), &auth_chain_sets)
.state_resolution(room_version_rules, fork_states.iter(), &auth_chain_sets)
.boxed()
.await
else {
@@ -10,7 +10,7 @@ use conduwuit::{
use futures::{FutureExt, StreamExt, future::ready};
use ruma::{CanonicalJsonValue, RoomId, ServerName, events::StateEventType};
use super::{get_room_version_id, to_room_version};
use super::get_room_version;
use crate::rooms::{
state_compressor::{CompressedState, HashSetCompressStateEvent},
timeline::RawPduId,
@@ -52,7 +52,8 @@ where
"Upgrading PDU from outlier to timeline"
);
let timer = Instant::now();
let room_version_id = get_room_version_id(create_event)?;
let room_version_id = get_room_version(create_event)?;
let room_version_rules = room_version_id.rules().expect("room version should have defined rules");
// 10. Fetch missing state and auth chain events by calling /state_ids at
// backwards extremities doing all the checks in this list starting at 1.
@@ -65,7 +66,7 @@ where
let mut state_at_incoming_event = if incoming_pdu.prev_events().count() == 1 {
self.state_at_incoming_degree_one(&incoming_pdu).await?
} else {
self.state_at_incoming_resolved(&incoming_pdu, room_id, &room_version_id)
self.state_at_incoming_resolved(&incoming_pdu, room_id, &room_version_rules)
.await?
};
@@ -78,8 +79,6 @@ where
let state_at_incoming_event =
state_at_incoming_event.expect("we always set this to some above");
let room_version = to_room_version(&room_version_id);
debug!(
event_id = %incoming_pdu.event_id,
"Performing auth check to upgrade"
@@ -98,7 +97,7 @@ where
"Running initial auth check"
);
let auth_check = state_res::event_auth::auth_check(
&room_version,
&room_version_rules,
&incoming_pdu,
None, // TODO: third party invite
|ty, sk| state_fetch(ty.clone(), sk.into()),
@@ -124,7 +123,7 @@ where
incoming_pdu.sender(),
incoming_pdu.state_key(),
incoming_pdu.content(),
&room_version,
&room_version_rules,
)
.await?;
@@ -138,7 +137,7 @@ where
"Running auth check with claimed state auth"
);
let auth_check = state_res::event_auth::auth_check(
&room_version,
&room_version_rules,
&incoming_pdu,
None, // third-party invite
state_fetch,
@@ -179,7 +178,6 @@ where
.services
.state
.get_forward_extremities(room_id)
.map(ToOwned::to_owned)
.ready_filter(|event_id| {
// Remove any that are referenced by this incoming event's prev_events
!incoming_pdu.prev_events().any(is_equal_to!(event_id))
@@ -232,7 +230,7 @@ where
}
let new_room_state = self
.resolve_state(room_id, &room_version_id, state_after)
.resolve_state(room_id, &room_version_rules, state_after)
.await?;
// Set the new room state to the resolved state