refactor: Ruma upstreaming, half-baked edition

Co-authored-by: Jade Ellis <jade@ellis.link>
This commit is contained in:
Ginger
2026-03-29 12:25:42 -04:00
parent 1cc9dbf2a4
commit 204bc1367e
141 changed files with 2715 additions and 2279 deletions
+1
View File
@@ -108,6 +108,7 @@ rand.workspace = true
regex.workspace = true
reqwest.workspace = true
ruma.workspace = true
ruminuwuity.workspace = true
rustyline-async.workspace = true
rustyline-async.optional = true
serde_json.workspace = true
+12 -8
View File
@@ -7,17 +7,21 @@ use conduwuit::{
use database::{Deserialized, Handle, Ignore, Json, Map};
use futures::{Stream, StreamExt, TryFutureExt};
use ruma::{
RoomId, UserId,
events::{
AnyGlobalAccountDataEvent, AnyRawAccountDataEvent, AnyRoomAccountDataEvent,
OwnedRoomId, OwnedUserId, RoomId, UserId, events::{
AnyGlobalAccountDataEvent, AnyRoomAccountDataEvent,
GlobalAccountDataEventType, RoomAccountDataEventType,
},
serde::Raw,
}, serde::Raw
};
use serde::Deserialize;
use crate::{Dep, globals};
#[derive(Debug)]
pub enum AnyRawAccountDataEvent {
Room(Raw<AnyRoomAccountDataEvent>),
Global(Raw<AnyGlobalAccountDataEvent>),
}
pub struct Service {
services: Services,
db: Data,
@@ -132,7 +136,7 @@ pub fn changes_since<'a>(
since: Option<u64>,
to: Option<u64>,
) -> impl Stream<Item = AnyRawAccountDataEvent> + Send + 'a {
type Key<'a> = (Option<&'a RoomId>, &'a UserId, u64, Ignore);
type Key = (Option<OwnedRoomId>, OwnedUserId, u64, Ignore);
// Skip the data that's exactly at since, because we sent that last time
// ...unless this is an initial sync, in which case send everything
@@ -142,8 +146,8 @@ pub fn changes_since<'a>(
.roomuserdataid_accountdata
.stream_from(&first_possible)
.ignore_err()
.ready_take_while(move |((room_id_, user_id_, count, _), _): &(Key<'_>, _)| {
room_id == *room_id_ && user_id == *user_id_ && to.is_none_or(|to| *count <= to)
.ready_take_while(move |((room_id_, user_id_, count, _), _): &(Key, _)| {
room_id == room_id_.as_deref() && user_id == user_id_ && to.is_none_or(|to| *count <= to)
})
.map(move |(_, v)| {
match room_id {
+18 -38
View File
@@ -13,7 +13,6 @@ use ruma::{
member::{MembershipState, RoomMemberEventContent},
name::RoomNameEventContent,
power_levels::RoomPowerLevelsEventContent,
preview_url::RoomPreviewUrlsEventContent,
topic::RoomTopicEventContent,
},
};
@@ -25,7 +24,7 @@ use crate::Services;
/// Users in this room are considered admins by conduwuit, and the room can be
/// used to issue admin commands by talking to the server user inside it.
pub async fn create_admin_room(services: &Services) -> Result {
let room_id = RoomId::new(services.globals.server_name());
let room_id = RoomId::new_v1(services.globals.server_name());
let room_version = &RoomVersionId::V11;
let _short_id = services
@@ -34,22 +33,24 @@ pub async fn create_admin_room(services: &Services) -> Result {
.get_or_create_shortroomid(&room_id)
.await;
let state_lock = services.rooms.state.mutex.lock(&room_id).await;
let state_lock = services.rooms.state.mutex.lock(room_id.as_str()).await;
// Create a user for the server
let server_user = services.globals.server_user.as_ref();
services.users.create(server_user, None, None).await?;
let create_content = {
let mut create_content = {
use RoomVersionId::*;
match room_version {
| V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 =>
RoomCreateEventContent::new_v1(server_user.into()),
| V11 => RoomCreateEventContent::new_v11(),
| _ => RoomCreateEventContent::new_v12(),
| _ => RoomCreateEventContent::new_v11(),
}
};
create_content.federate = true;
create_content.room_version = room_version.clone();
info!("Creating admin room {} with version {}", room_id, room_version);
// 1. The room create event
@@ -57,12 +58,7 @@ pub async fn create_admin_room(services: &Services) -> Result {
.rooms
.timeline
.build_and_append_pdu(
PduBuilder::state(String::new(), &RoomCreateEventContent {
federate: true,
predecessor: None,
room_version: room_version.clone(),
..create_content
}),
PduBuilder::state(String::new(), &create_content),
server_user,
Some(&room_id),
&state_lock,
@@ -89,14 +85,14 @@ pub async fn create_admin_room(services: &Services) -> Result {
// 3. Power levels
let users = BTreeMap::from_iter([(server_user.into(), 69420.into())]);
let mut power_levels_content = RoomPowerLevelsEventContent::new(&room_version.rules().unwrap().authorization);
power_levels_content.users = users;
services
.rooms
.timeline
.build_and_append_pdu(
PduBuilder::state(String::new(), &RoomPowerLevelsEventContent {
users,
..Default::default()
}),
PduBuilder::state(String::new(), &power_levels_content),
server_user,
Some(&room_id),
&state_lock,
@@ -163,13 +159,12 @@ pub async fn create_admin_room(services: &Services) -> Result {
.boxed()
.await?;
let room_topic = format!("Manage {} | Run commands prefixed with `!admin` | Run `!admin -h` for help | Documentation: https://continuwuity.org/", services.config.server_name);
services
.rooms
.timeline
.build_and_append_pdu(
PduBuilder::state(String::new(), &RoomTopicEventContent {
topic: format!("Manage {} | Run commands prefixed with `!admin` | Run `!admin -h` for help | Documentation: https://continuwuity.org/", services.config.server_name),
}),
PduBuilder::state(String::new(), &RoomTopicEventContent::markdown(room_topic)),
server_user,
Some(&room_id),
&state_lock,
@@ -178,16 +173,14 @@ pub async fn create_admin_room(services: &Services) -> Result {
.await?;
// 6. Room alias
let alias = &services.globals.admin_alias;
let mut alias_content = RoomCanonicalAliasEventContent::new();
alias_content.alias = Some(services.globals.admin_alias.clone());
services
.rooms
.timeline
.build_and_append_pdu(
PduBuilder::state(String::new(), &RoomCanonicalAliasEventContent {
alias: Some(alias.clone()),
alt_aliases: Vec::new(),
}),
PduBuilder::state(String::new(), &alias_content),
server_user,
Some(&room_id),
&state_lock,
@@ -198,20 +191,7 @@ pub async fn create_admin_room(services: &Services) -> Result {
services
.rooms
.alias
.set_alias(alias, &room_id, server_user)?;
// 7. (ad-hoc) Disable room URL previews for everyone by default
services
.rooms
.timeline
.build_and_append_pdu(
PduBuilder::state(String::new(), &RoomPreviewUrlsEventContent { disabled: true }),
server_user,
Some(&room_id),
&state_lock,
)
.boxed()
.await?;
.set_alias(&services.globals.admin_alias, &room_id, server_user)?;
Ok(())
}
+9 -15
View File
@@ -27,7 +27,7 @@ pub async fn make_user_admin(&self, user_id: &UserId) -> Result {
return Ok(());
};
let state_lock = self.services.state.mutex.lock(&room_id).await;
let state_lock = self.services.state.mutex.lock(room_id.as_str()).await;
if self.services.state_cache.is_joined(user_id, &room_id).await {
return Err!(debug_warn!("User is already joined in the admin room"));
@@ -100,7 +100,7 @@ pub async fn make_user_admin(&self, user_id: &UserId) -> Result {
"",
)
.await
.unwrap_or_default();
.expect("admin room should have power levels");
room_power_levels
.users
@@ -135,9 +135,7 @@ async fn set_room_tag(&self, room_id: &RoomId, user_id: &UserId, tag: &str) -> R
.account_data
.get_room(room_id, user_id, RoomAccountDataEventType::Tag)
.await
.unwrap_or_else(|_| TagEvent {
content: TagEventContent { tags: BTreeMap::new() },
});
.unwrap_or_else(|_| TagEvent::new(TagEventContent::new(BTreeMap::new())));
event
.content
@@ -177,9 +175,9 @@ pub async fn revoke_admin(&self, user_id: &UserId) -> Result {
return Err!(error!("No admin room available or created."));
};
let state_lock = self.services.state.mutex.lock(&room_id).await;
let state_lock = self.services.state.mutex.lock(room_id.as_str()).await;
let event = match self
let mut member_content = match self
.services
.state_accessor
.get_member(&room_id, user_id)
@@ -203,17 +201,13 @@ pub async fn revoke_admin(&self, user_id: &UserId) -> Result {
},
};
member_content.membership = Leave;
member_content.reason = Some("Admin Revoked".to_owned());
self.services
.timeline
.build_and_append_pdu(
PduBuilder::state(user_id.to_string(), &RoomMemberEventContent {
membership: Leave,
reason: Some("Admin Revoked".into()),
is_direct: None,
join_authorized_via_users_server: None,
third_party_invite: None,
..event
}),
PduBuilder::state(user_id.to_string(), &member_content),
self.services.globals.server_user.as_ref(),
Some(&room_id),
&state_lock,
+11 -15
View File
@@ -17,7 +17,7 @@ pub use create::create_admin_room;
use futures::{Future, FutureExt, StreamExt, TryFutureExt};
use loole::{Receiver, Sender};
use ruma::{
Mxc, OwnedEventId, OwnedMxcUri, OwnedRoomId, OwnedUserId, RoomId, UInt, UserId,
OwnedEventId, OwnedMxcUri, OwnedRoomId, OwnedUserId, RoomId, UInt, UserId,
events::{
Mentions,
room::{
@@ -30,7 +30,7 @@ use ruma::{
};
use tokio::sync::RwLock;
use crate::{Dep, account_data, globals, media::MXC_LENGTH, rooms, rooms::state::RoomMutexGuard};
use crate::{Dep, account_data, globals, media::{MXC_LENGTH, mxc::Mxc}, rooms::{self, state::RoomMutexGuard}};
pub struct Service {
services: Services,
@@ -200,19 +200,15 @@ impl Service {
.await
.expect("failed to create text file");
let size_u64: u64 = message_content.body().len().try_into().map_or(0, |n| n);
let metadata = FileInfo {
mimetype: Some("text/markdown".to_owned()),
size: Some(UInt::new_saturating(size_u64)),
thumbnail_info: None,
thumbnail_source: None,
};
let content = FileMessageEventContent {
body: "Output was too large to send as text.".to_owned(),
formatted: None,
filename: Some("output.md".to_owned()),
source: MediaSource::Plain(file),
info: Some(Box::new(metadata)),
};
let mut metadata = FileInfo::new();
metadata.mimetype = Some("text/markdown".to_owned());
metadata.size = Some(UInt::new_saturating(size_u64));
let mut content = FileMessageEventContent::plain("Output was too large to send as text.".to_owned(), file);
content.filename = Some("output.md".to_owned());
content.info = Some(Box::new(metadata));
RoomMessageEventContent::new(MessageType::File(content))
} else {
message_content
+3 -2
View File
@@ -2,7 +2,8 @@ use std::{fmt::Debug, sync::Arc};
use async_trait::async_trait;
use conduwuit::{Result, config::Antispam, debug};
use ruma::{OwnedRoomId, OwnedUserId, draupnir_antispam, meowlnir_antispam};
use ruma::{OwnedRoomId, OwnedUserId, api::{auth_scheme::AppserviceToken, path_builder::VersionHistory}};
use ruminuwuity::{draupnir_antispam, meowlnir_antispam};
use crate::{client, config, sending, service::Dep};
@@ -37,7 +38,7 @@ impl Service {
request: T,
) -> Result<T::IncomingResponse>
where
T: ruma::api::OutgoingRequest + Debug + Send,
T: ruma::api::OutgoingRequest<Authentication = AppserviceToken, PathBuilder = VersionHistory> + Debug + Send,
{
sending::antispam::send_antispam_request(
&self.services.client.appservice,
+1 -3
View File
@@ -72,9 +72,7 @@ impl Service {
None,
server_user,
GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(&GlobalAccountDataEvent {
content: PushRulesEventContent { global: ruleset },
})
&serde_json::to_value(&GlobalAccountDataEvent::new(PushRulesEventContent::new(ruleset)))
.expect("to json value always works"),
)
.await?;
+75 -119
View File
@@ -1,64 +1,85 @@
use std::{fmt::Debug, mem};
use std::{any::Any, borrow::Cow, fmt::Debug, mem, sync::LazyLock};
use bytes::Bytes;
use conduwuit::{
Err, Error, Result, debug, debug::INFO_SPAN_LEVEL, debug_error, debug_warn, err, implement,
trace, utils::response::LimitReadExt,
};
use http::{HeaderValue, header::AUTHORIZATION};
Err, Error, Result, debug, debug_error, debug_warn, err, implement, trace, utils::response::LimitReadExt, matrix::versions::{unstable_features, versions}, };
use ipaddress::IPAddress;
use reqwest::{Client, Method, Request, Response, Url};
use ruma::{
CanonicalJsonObject, CanonicalJsonValue, ServerName, ServerSigningKeyId,
api::{
EndpointError, IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken,
client::error::Error as RumaError, federation::authentication::XMatrix,
},
serde::Base64,
CanonicalJsonObject, CanonicalJsonValue, ServerName, ServerSigningKeyId, api::{
EndpointError, IncomingResponse, Metadata, OutgoingRequest, SupportedVersions, auth_scheme::{AuthScheme, NoAuthentication, SendAccessToken}, client::error::Error as RumaError, federation::authentication::{ServerSignatures, ServerSignaturesInput, XMatrix}, path_builder::{PathBuilder, SinglePath, VersionHistory}
}, serde::Base64
};
use crate::resolver::actual::ActualDest;
use crate::{SUPPORTED_VERSIONS, resolver::actual::ActualDest};
/// Sends a request to a federation server
#[implement(super::Service)]
#[tracing::instrument(skip_all, name = "request", level = "debug")]
pub async fn execute<T>(&self, dest: &ServerName, request: T) -> Result<T::IncomingResponse>
pub async fn execute<'i, T>(&self, dest: &ServerName, request: T) -> Result<T::IncomingResponse>
where
T: OutgoingRequest + Debug + Send,
T: OutgoingRequest::<Authentication = ServerSignatures, PathBuilder: PathBuilder<Input<'i>: FederationPathBuilderInput>> + Debug + Send,
{
let client = &self.services.client.federation;
self.execute_on(client, dest, request).await
self.execute_signed(client, dest, request).await
}
/// Like execute() but with a very large timeout
#[implement(super::Service)]
#[tracing::instrument(skip_all, name = "synapse", level = "debug")]
pub async fn execute_synapse<T>(
pub async fn execute_synapse<'i, T>(
&self,
dest: &ServerName,
request: T,
) -> Result<T::IncomingResponse>
where
T: OutgoingRequest + Debug + Send,
T: OutgoingRequest::<Authentication = ServerSignatures, PathBuilder: PathBuilder<Input<'i>: FederationPathBuilderInput>> + Debug + Send,
{
let client = &self.services.client.synapse;
self.execute_on(client, dest, request).await
self.execute_signed(client, dest, request).await
}
#[implement(super::Service)]
pub async fn execute_unauthenticated<'i, T>(&self, dest: &ServerName, request: T) -> Result<T::IncomingResponse>
where
T: OutgoingRequest::<Authentication = NoAuthentication, PathBuilder: PathBuilder<Input<'i>: FederationPathBuilderInput>> + Debug + Send,
{
let client = &self.services.client.federation;
let authentication = SendAccessToken::None;
self.execute_on(client, dest, request, authentication).await
}
#[implement(super::Service)]
pub async fn execute_signed<'i, T>(&self, client: &Client, dest: &ServerName, request: T) -> Result<T::IncomingResponse>
where
T: OutgoingRequest::<Authentication = ServerSignatures, PathBuilder: PathBuilder<Input<'i>: FederationPathBuilderInput>> + Send,
{
let authentication = ServerSignaturesInput::new(
self.services.server.name.clone(),
dest.to_owned(),
self.services.server_keys.keypair(),
);
self.execute_on(client, dest, request, authentication).await
}
#[implement(super::Service)]
#[tracing::instrument(
name = "fed",
level = INFO_SPAN_LEVEL,
skip(self, client, request),
level = "info",
skip(self, client, request, authentication),
)]
pub async fn execute_on<T>(
pub async fn execute_on<'i, T, PathBuilderInput>(
&self,
client: &Client,
dest: &ServerName,
request: T,
authentication: <T::Authentication as AuthScheme>::Input<'_>,
) -> Result<T::IncomingResponse>
where
T: OutgoingRequest + Send,
T: OutgoingRequest::<PathBuilder: PathBuilder<Input<'i> = PathBuilderInput>> + Send,
PathBuilderInput: FederationPathBuilderInput
{
if !self.services.server.config.allow_federation {
return Err!(Config("allow_federation", "Federation is disabled."));
@@ -69,8 +90,17 @@ where
}
let actual = self.services.resolver.get_actual_dest(dest).await?;
let request = into_http_request::<T>(&actual, request)?;
let request = self.prepare(dest, request)?;
let request = Request::try_from(
request.try_into_http_request::<Vec<u8>>(
actual.string().as_str(),
authentication,
PathBuilderInput::create(),
)?
)?;
self.validate_url(request.url())?;
self.services.server.check_running()?;
self.perform::<T>(dest, &actual, request, client).await
}
@@ -98,17 +128,6 @@ where
}
}
#[implement(super::Service)]
fn prepare(&self, dest: &ServerName, mut request: http::Request<Vec<u8>>) -> Result<Request> {
self.sign_request(&mut request, dest);
let request = Request::try_from(request)?;
self.validate_url(request.url())?;
self.services.server.check_running()?;
Ok(request)
}
#[implement(super::Service)]
fn validate_url(&self, url: &Url) -> Result<()> {
if let Some(url_host) = url.host_str() {
@@ -229,90 +248,27 @@ fn handle_error(
Err(e.into())
}
#[implement(super::Service)]
fn sign_request(&self, http_request: &mut http::Request<Vec<u8>>, dest: &ServerName) {
type Member = (String, Value);
type Value = CanonicalJsonValue;
type Object = CanonicalJsonObject;
let origin = &self.services.server.name;
let body = http_request.body();
let uri = http_request
.uri()
.path_and_query()
.expect("http::Request missing path_and_query");
let mut req: Object = if !body.is_empty() {
let content: CanonicalJsonValue =
serde_json::from_slice(body).expect("failed to serialize body");
let authorization: [Member; 5] = [
("content".into(), content),
("destination".into(), dest.as_str().into()),
("method".into(), http_request.method().as_str().into()),
("origin".into(), origin.as_str().into()),
("uri".into(), uri.to_string().into()),
];
authorization.into()
} else {
let authorization: [Member; 4] = [
("destination".into(), dest.as_str().into()),
("method".into(), http_request.method().as_str().into()),
("origin".into(), origin.as_str().into()),
("uri".into(), uri.to_string().into()),
];
authorization.into()
};
self.services
.server_keys
.sign_json(&mut req)
.expect("request signing failed");
let signatures = req["signatures"]
.as_object()
.and_then(|object| object[origin.as_str()].as_object())
.expect("origin signatures object");
let key: &ServerSigningKeyId = signatures
.keys()
.next()
.map(|k| k.as_str().try_into())
.expect("at least one signature from this origin")
.expect("keyid is json string");
let sig: Base64 = signatures
.values()
.next()
.map(|s| s.as_str().map(Base64::parse))
.expect("at least one signature from this origin")
.expect("signature is json string")
.expect("signature is valid base64");
let x_matrix = XMatrix::new(origin.into(), dest.into(), key.into(), sig);
let authorization = HeaderValue::from(&x_matrix);
let authorization = http_request
.headers_mut()
.insert(AUTHORIZATION, authorization);
debug_assert!(authorization.is_none(), "Authorization header already present");
/// A trait for the input types of acceptable path builders for outgoing federation requests.
///
/// Ruma uses Rust's type system to encode the versioning scheme of endpoints in the Matrix spec.
/// Every endpoint has a `PathBuilder` associated type, which has an `Input` associated type.
/// Endpoints with multiple versions have `VersionHistory` as their `PathBuilder`, which has `SupportedVersions`
/// as its `Input` type. Endpoints with no version have `SinglePath` as their `PathBuilder`, which has `()` as its `Input` type.
/// Both `SupportedVersions` and `()` can be created out of thin air using static data (or no data at all). This property
/// is what the `FederationPathBuilderInput` trait represents.
///
/// This trait allows the federation sender service's functions to accept requests for either versioned or unversioned endpoints,
/// by requiring that the `Input` of the `PathBuilder` of the endpoint implements `FederationPathBuilderInput`.
pub(crate) trait FederationPathBuilderInput {
fn create() -> Self;
}
fn into_http_request<T>(actual: &ActualDest, request: T) -> Result<http::Request<Vec<u8>>>
where
T: OutgoingRequest + Send,
{
const VERSIONS: [MatrixVersion; 1] = [MatrixVersion::V1_11];
let http_request = request
.try_into_http_request::<Vec<u8>>(
actual.string().as_str(),
SendAccessToken::None,
&VERSIONS,
)
.map_err(|e| err!(BadServerResponse("Invalid destination: {e:?}")))?;
Ok(http_request)
impl FederationPathBuilderInput for () {
fn create() -> Self {}
}
impl FederationPathBuilderInput for Cow<'_, SupportedVersions> {
fn create() -> Self {
Cow::Borrowed(&SUPPORTED_VERSIONS)
}
}
+1
View File
@@ -1,4 +1,5 @@
mod execute;
pub(crate) use execute::FederationPathBuilderInput;
use std::sync::Arc;
+1 -1
View File
@@ -59,7 +59,7 @@ impl crate::Service for Service {
let (ber_count, ber_bytes) = self.bad_event_ratelimiter.read().iter().fold(
(0_usize, 0_usize),
|(mut count, mut bytes), (event_id, _)| {
bytes = bytes.saturating_add(event_id.capacity());
bytes = bytes.saturating_add(event_id.as_bytes().len());
bytes = bytes.saturating_add(size_of::<RateLimitState>());
count = count.saturating_add(1);
(count, bytes)
+5 -7
View File
@@ -7,9 +7,7 @@ use conduwuit::{
use database::{Deserialized, Ignore, Interfix, Json, Map};
use futures::StreamExt;
use ruma::{
OwnedRoomId, RoomId, UserId,
api::client::backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup},
serde::Raw,
OwnedRoomId, OwnedUserId, RoomId, UserId, api::client::backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup}, serde::Raw
};
use crate::{Dep, globals};
@@ -103,7 +101,7 @@ pub async fn update_backup<'a>(
#[implement(Service)]
pub async fn get_latest_backup_version(&self, user_id: &UserId) -> Result<String> {
type Key<'a> = (&'a UserId, &'a str);
type Key<'a> = (OwnedUserId, &'a str);
let last_possible_key = (user_id, u64::MAX);
self.db
@@ -122,7 +120,7 @@ pub async fn get_latest_backup(
&self,
user_id: &UserId,
) -> Result<(String, Raw<BackupAlgorithm>)> {
type Key<'a> = (&'a UserId, &'a str);
type Key<'a> = (OwnedUserId, &'a str);
type KeyVal<'a> = (Key<'a>, Raw<BackupAlgorithm>);
let last_possible_key = (user_id, u64::MAX);
@@ -197,11 +195,11 @@ pub async fn get_all(
user_id: &UserId,
version: &str,
) -> BTreeMap<OwnedRoomId, RoomKeyBackup> {
type Key<'a> = (Ignore, Ignore, &'a RoomId, &'a str);
type Key<'a> = (Ignore, Ignore, OwnedRoomId, &'a str);
type KeyVal<'a> = (Key<'a>, Raw<KeyBackupData>);
let mut rooms = BTreeMap::<OwnedRoomId, RoomKeyBackup>::new();
let default = || RoomKeyBackup { sessions: BTreeMap::new() };
let default = || RoomKeyBackup::new(BTreeMap::new());
let prefix = (user_id, version, Interfix);
self.db
+5 -3
View File
@@ -6,7 +6,9 @@ use conduwuit::{
};
use database::{Database, Interfix, Map};
use futures::StreamExt;
use ruma::{Mxc, OwnedMxcUri, UserId, http_headers::ContentDisposition};
use ruma::{OwnedMxcUri, OwnedUserId, UserId, http_headers::ContentDisposition};
use crate::media::mxc::Mxc;
use super::{preview::UrlPreviewData, thumbnail::Dim};
@@ -41,7 +43,7 @@ impl Data {
content_type: Option<&str>,
) -> Result<Vec<u8>> {
let dim: &[u32] = &[dim.width, dim.height];
let key = (mxc, dim, content_disposition, content_type);
let key = (mxc, dim, content_disposition.map(ToString::to_string), content_type);
let key = database::serialize_key(key)?;
self.mediaid_file.insert(&key, []);
if let Some(user) = user {
@@ -146,7 +148,7 @@ impl Data {
self.mediaid_user
.stream()
.ignore_err()
.ready_filter_map(|(key, user): (&str, &UserId)| {
.ready_filter_map(|(key, user): (&str, OwnedUserId)| {
(user == user_id).then(|| key.into())
})
.collect()
+3 -2
View File
@@ -1,4 +1,5 @@
pub mod blurhash;
pub mod mxc;
mod data;
pub(super) mod migrations;
mod preview;
@@ -17,7 +18,7 @@ use conduwuit::{
},
warn,
};
use ruma::{Mxc, OwnedMxcUri, UserId, http_headers::ContentDisposition};
use ruma::{OwnedMxcUri, UserId, http_headers::ContentDisposition};
use tokio::{
fs,
io::{AsyncReadExt, AsyncWriteExt, BufReader},
@@ -25,7 +26,7 @@ use tokio::{
use self::data::{Data, Metadata};
pub use self::thumbnail::Dim;
use crate::{Dep, client, globals, moderation, sending};
use crate::{Dep, client, globals, media::mxc::Mxc, moderation, sending};
#[derive(Debug)]
pub struct FileMeta {
+54
View File
@@ -0,0 +1,54 @@
use std::fmt;
use ruma::{MxcUri, MxcUriError, OwnedMxcUri, ServerName};
use serde::{Serialize, Serializer};
/// A structured, valid MXC URI
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Mxc<'a> {
/// ServerName part of the MXC URI
pub server_name: &'a ServerName,
/// MediaId part of the MXC URI
pub media_id: &'a str,
}
impl fmt::Display for Mxc<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "mxc://{}/{}", self.server_name, self.media_id)
}
}
impl<'a> TryFrom<&'a MxcUri> for Mxc<'a> {
type Error = MxcUriError;
fn try_from(s: &'a MxcUri) -> Result<Self, Self::Error> {
let (server_name, media_id) = s.parts()?;
Ok(Self { server_name, media_id })
}
}
impl<'a> TryFrom<&'a str> for Mxc<'a> {
type Error = MxcUriError;
fn try_from(s: &'a str) -> Result<Self, Self::Error> {
let s: &MxcUri = s.into();
s.try_into()
}
}
impl<'a> TryFrom<&'a OwnedMxcUri> for Mxc<'a> {
type Error = MxcUriError;
fn try_from(s: &'a OwnedMxcUri) -> Result<Self, Self::Error> {
let s: &MxcUri = s.as_ref();
s.try_into()
}
}
impl Serialize for Mxc<'_> {
fn serialize<S: Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
s.serialize_str(self.to_string().as_str())
}
}
+2 -1
View File
@@ -130,7 +130,8 @@ pub async fn download_image(
) -> Result<UrlPreviewData> {
use conduwuit::utils::random_string;
use image::ImageReader;
use ruma::Mxc;
use crate::media::mxc::Mxc;
let mut preview_data = preview_data.unwrap_or_default();
+60 -62
View File
@@ -6,18 +6,16 @@ use conduwuit::{
};
use http::header::{CONTENT_DISPOSITION, CONTENT_TYPE, HeaderValue};
use ruma::{
Mxc, ServerName, UserId,
api::{
OutgoingRequest,
client::{
ServerName, UserId, api::{
Metadata, OutgoingRequest, auth_scheme::NoAuthentication, client::{
error::ErrorKind::{NotFound, Unrecognized},
media,
},
federation,
federation::authenticated_media::{Content, FileOrLocation},
},
}, federation::{self, authenticated_media::{Content, FileOrLocation}, authentication::ServerSignatures}, path_builder::PathBuilder
}
};
use crate::{federation::FederationPathBuilderInput, media::mxc::Mxc};
use super::{Dim, FileMeta};
#[implement(super::Service)]
@@ -87,14 +85,10 @@ async fn fetch_thumbnail_authenticated(
) -> Result<FileMeta> {
use federation::authenticated_media::get_content_thumbnail::v1::{Request, Response};
let request = Request {
media_id: mxc.media_id.into(),
method: dim.method.clone().into(),
width: dim.width.into(),
height: dim.height.into(),
animated: true.into(),
timeout_ms,
};
let mut request = Request::new(mxc.media_id.into(), dim.width.into(), dim.height.into());
request.method = Some(dim.method.clone());
request.animated = Some(true);
request.timeout_ms = timeout_ms;
let Response { content, .. } = self.federation_request(mxc, server, request).await?;
@@ -102,6 +96,7 @@ async fn fetch_thumbnail_authenticated(
| FileOrLocation::File(content) =>
self.handle_thumbnail_file(mxc, user, dim, content).await,
| FileOrLocation::Location(location) => self.handle_location(mxc, user, &location).await,
| _ => Err!("Unknown content in response"),
}
}
@@ -115,16 +110,15 @@ async fn fetch_content_authenticated(
) -> Result<FileMeta> {
use federation::authenticated_media::get_content::v1::{Request, Response};
let request = Request {
media_id: mxc.media_id.into(),
timeout_ms,
};
let mut request = Request::new(mxc.media_id.into());
request.timeout_ms = timeout_ms;
let Response { content, .. } = self.federation_request(mxc, server, request).await?;
match content {
| FileOrLocation::File(content) => self.handle_content_file(mxc, user, content).await,
| FileOrLocation::Location(location) => self.handle_location(mxc, user, &location).await,
| _ => Err!("Unknown content in response"),
}
}
@@ -140,23 +134,18 @@ async fn fetch_thumbnail_unauthenticated(
) -> Result<FileMeta> {
use media::get_content_thumbnail::v3::{Request, Response};
let request = Request {
allow_remote: true,
allow_redirect: true,
animated: true.into(),
method: dim.method.clone().into(),
width: dim.width.into(),
height: dim.height.into(),
server_name: mxc.server_name.into(),
media_id: mxc.media_id.into(),
timeout_ms,
};
let mut request = Request::new(mxc.media_id.into(), mxc.server_name.into(), dim.width.into(), dim.height.into());
request.allow_redirect = true;
request.allow_remote = true;
request.animated = Some(true);
request.method = Some(dim.method.clone());
request.timeout_ms = timeout_ms;
let Response {
file, content_type, content_disposition, ..
} = self.federation_request(mxc, server, request).await?;
} = self.federation_request_unauthenticated(mxc, server, request).await?;
let content = Content { file, content_type, content_disposition };
let content = Content::new(file, content_type.unwrap(), content_disposition.unwrap());
self.handle_thumbnail_file(mxc, user, dim, content).await
}
@@ -172,19 +161,16 @@ async fn fetch_content_unauthenticated(
) -> Result<FileMeta> {
use media::get_content::v3::{Request, Response};
let request = Request {
allow_remote: true,
allow_redirect: true,
server_name: mxc.server_name.into(),
media_id: mxc.media_id.into(),
timeout_ms,
};
let mut request = Request::new(mxc.media_id.into(), mxc.server_name.into());
request.allow_remote = true;
request.allow_redirect = true;
request.timeout_ms = timeout_ms;
let Response {
file, content_type, content_disposition, ..
} = self.federation_request(mxc, server, request).await?;
} = self.federation_request_unauthenticated(mxc, server, request).await?;
let content = Content { file, content_type, content_disposition };
let content = Content::new(file, content_type.unwrap(), content_disposition.unwrap());
self.handle_content_file(mxc, user, content).await
}
@@ -307,14 +293,14 @@ async fn location_request(&self, location: &str) -> Result<FileMeta> {
}
#[implement(super::Service)]
async fn federation_request<Request>(
async fn federation_request<'i, Request>(
&self,
mxc: &Mxc<'_>,
server: Option<&ServerName>,
request: Request,
) -> Result<Request::IncomingResponse>
where
Request: OutgoingRequest + Send + Debug,
Request: OutgoingRequest::<Authentication = ServerSignatures, PathBuilder: PathBuilder<Input<'i>: FederationPathBuilderInput>> + Debug + Send,
{
self.services
.sending
@@ -322,6 +308,22 @@ where
.await
}
#[implement(super::Service)]
async fn federation_request_unauthenticated<'i, Request>(
&self,
mxc: &Mxc<'_>,
server: Option<&ServerName>,
request: Request,
) -> Result<Request::IncomingResponse>
where
Request: OutgoingRequest::<Authentication = NoAuthentication, PathBuilder: PathBuilder<Input<'i>: FederationPathBuilderInput>> + Debug + Send,
{
self.services
.sending
.send_unauthenticated_request(server.unwrap_or(mxc.server_name), request)
.await
}
#[implement(super::Service)]
#[allow(deprecated)]
pub async fn fetch_remote_thumbnail_legacy(
@@ -333,22 +335,19 @@ pub async fn fetch_remote_thumbnail_legacy(
media_id: &body.media_id,
};
let mut request = media::get_content_thumbnail::v3::Request::new(body.media_id.clone(), body.server_name.clone(), body.width, body.height);
request.method = body.method.clone();
request.allow_remote = body.allow_remote;
request.allow_redirect = body.allow_redirect;
request.animated = body.animated;
request.timeout_ms = body.timeout_ms;
self.check_legacy_freeze()?;
self.check_fetch_authorized(&mxc)?;
let response = self
.services
.sending
.send_federation_request(mxc.server_name, media::get_content_thumbnail::v3::Request {
allow_remote: body.allow_remote,
height: body.height,
width: body.width,
method: body.method.clone(),
server_name: body.server_name.clone(),
media_id: body.media_id.clone(),
timeout_ms: body.timeout_ms,
allow_redirect: body.allow_redirect,
animated: body.animated,
})
.send_unauthenticated_request(mxc.server_name, request)
.await?;
let dim = Dim::from_ruma(body.width, body.height, body.method.clone())?;
@@ -373,18 +372,17 @@ pub async fn fetch_remote_content_legacy(
allow_redirect: bool,
timeout_ms: Duration,
) -> Result<media::get_content::v3::Response, Error> {
let mut request = media::get_content::v3::Request::new(mxc.media_id.into(), mxc.server_name.into());
request.allow_remote = true;
request.allow_redirect = allow_redirect;
request.timeout_ms = timeout_ms;
self.check_legacy_freeze()?;
self.check_fetch_authorized(mxc)?;
let response = self
.services
.sending
.send_federation_request(mxc.server_name, media::get_content::v3::Request {
allow_remote: true,
server_name: mxc.server_name.into(),
media_id: mxc.media_id.into(),
timeout_ms,
allow_redirect,
})
.send_unauthenticated_request(mxc.server_name, request)
.await?;
let content_disposition = make_content_disposition(
+3 -1
View File
@@ -8,12 +8,14 @@
use std::{cmp, num::Saturating as Sat};
use conduwuit::{Result, checked, err, implement};
use ruma::{Mxc, UInt, UserId, http_headers::ContentDisposition, media::Method};
use ruma::{UInt, UserId, http_headers::ContentDisposition, media::Method};
use tokio::{
fs,
io::{AsyncReadExt, AsyncWriteExt},
};
use crate::media::mxc::Mxc;
use super::{FileMeta, data::Metadata};
/// Dimension specification for a thumbnail.
+23 -25
View File
@@ -243,7 +243,13 @@ async fn migrate(services: &Services) -> Result<()> {
services
.users
.stream()
.filter(|user_id| services.users.is_active_local(user_id))
.filter_map(async |user_id| {
if services.users.is_active_local(&user_id).await {
Some(user_id)
} else {
None
}
})
.ready_for_each(|user_id| {
let matches = patterns.matches(user_id.localpart());
if matches.matched_any() {
@@ -268,7 +274,6 @@ async fn migrate(services: &Services) -> Result<()> {
.rooms
.metadata
.iter_ids()
.map(ToOwned::to_owned)
.collect::<Vec<_>>()
.await
{
@@ -305,7 +310,6 @@ async fn db_lt_12(services: &Services) -> Result<()> {
for username in &services
.users
.list_local_users()
.map(ToOwned::to_owned)
.collect::<Vec<OwnedUserId>>()
.await
{
@@ -385,7 +389,6 @@ async fn db_lt_13(services: &Services) -> Result<()> {
for username in &services
.users
.list_local_users()
.map(ToOwned::to_owned)
.collect::<Vec<OwnedUserId>>()
.await
{
@@ -480,7 +483,6 @@ async fn retroactively_fix_bad_data_from_roomuserid_joined(services: &Services)
.rooms
.metadata
.iter_ids()
.map(ToOwned::to_owned)
.collect::<Vec<_>>()
.await;
@@ -491,7 +493,6 @@ async fn retroactively_fix_bad_data_from_roomuserid_joined(services: &Services)
.rooms
.state_cache
.room_members(room_id)
.map(ToOwned::to_owned)
.collect()
.await;
@@ -603,11 +604,8 @@ async fn fix_referencedevents_missing_sep(services: &Services) -> Result {
}
async fn fix_readreceiptid_readreceipt_duplicates(services: &Services) -> Result {
use conduwuit::arrayvec::ArrayString;
use ruma::identifiers_validation::MAX_BYTES;
type ArrayId = ArrayString<MAX_BYTES>;
type Key<'a> = (&'a RoomId, u64, &'a UserId);
type ArrayId = String;
type Key = (OwnedRoomId, u64, OwnedUserId);
info!("Fixing undeleted entries in readreceiptid_readreceipt...");
@@ -621,8 +619,8 @@ async fn fix_readreceiptid_readreceipt_duplicates(services: &Services) -> Result
readreceiptid_readreceipt
.keys()
.expect_ok()
.ready_for_each(|key: Key<'_>| {
let (room_id, _, user_id) = key;
.ready_for_each(|key: Key| {
let (ref room_id, _, ref user_id) = key;
let last_room = cur_room.replace(
room_id
.as_str()
@@ -715,8 +713,8 @@ async fn fix_corrupt_msc4133_fields(services: &Services) -> Result {
const POPULATED_USERROOMID_LEFTSTATE_TABLE_MARKER: &str = "populate_userroomid_leftstate_table";
async fn populate_userroomid_leftstate_table(services: &Services) -> Result {
type KeyVal<'a> = (Key<'a>, Raw<Option<Pdu>>);
type Key<'a> = (&'a UserId, &'a RoomId);
type KeyVal = (Key, Raw<Option<Pdu>>);
type Key = (OwnedUserId, OwnedRoomId);
let db = &services.db;
let cork = db.cork_and_sync();
@@ -731,16 +729,16 @@ async fn populate_userroomid_leftstate_table(services: &Services) -> Result {
usize,
HashMap<_, _>,
),
((user_id, room_id), state): KeyVal<'_>|
((user_id, room_id), state): KeyVal|
-> Result<(usize, usize, HashMap<_, _>)> {
if state.deserialize().is_err() {
let latest_shortstatehash =
if let Some(shortstatehash) = shortstatehash_cache.get(room_id) {
if let Some(shortstatehash) = shortstatehash_cache.get(&room_id) {
*shortstatehash
} else if let Ok(shortstatehash) =
services.rooms.state.get_room_shortstatehash(room_id).await
services.rooms.state.get_room_shortstatehash(&room_id).await
{
shortstatehash_cache.insert(room_id.to_owned(), shortstatehash);
shortstatehash_cache.insert(room_id.clone(), shortstatehash);
shortstatehash
} else {
warn!(%room_id, %user_id, "room has no shortstatehash");
@@ -792,8 +790,8 @@ const FIXED_LOCAL_INVITE_STATE_MARKER: &str = "fix_local_invite_state";
async fn fix_local_invite_state(services: &Services) -> Result {
// Clean up the effects of !1249 by caching stripped state for invites
type KeyVal<'a> = (Key<'a>, Raw<Vec<AnyStrippedStateEvent>>);
type Key<'a> = (&'a UserId, &'a RoomId);
type KeyVal = (Key, Raw<Vec<AnyStrippedStateEvent>>);
type Key = (OwnedUserId, OwnedRoomId);
let db = &services.db;
let cork = db.cork_and_sync();
@@ -802,9 +800,9 @@ async fn fix_local_invite_state(services: &Services) -> Result {
// for each user invited to a room
let fixed = userroomid_invitestate.stream()
// if they're a local user on this homeserver
.try_filter(|((user_id, _), _): &KeyVal<'_>| ready(services.globals.user_is_local(user_id)))
.and_then(async |((user_id, room_id), stripped_state): KeyVal<'_>| Ok::<_,
conduwuit::Error>((user_id.to_owned(), room_id.to_owned(), stripped_state.deserialize
.try_filter(|((user_id, _), _): &KeyVal| ready(services.globals.user_is_local(user_id)))
.and_then(async |((user_id, room_id), stripped_state): KeyVal| Ok::<_,
conduwuit::Error>((user_id.clone(), room_id.clone(), stripped_state.deserialize
().unwrap_or_else(|e| {
trace!("Failed to deserialize: {:?}", stripped_state.json());
warn!(
@@ -812,7 +810,7 @@ async fn fix_local_invite_state(services: &Services) -> Result {
%room_id,
"Failed to deserialize stripped state for invite, removing from db: {e}"
);
userroomid_invitestate.del((user_id, room_id));
userroomid_invitestate.del((&user_id, &room_id));
vec![]
}))))
.try_fold(0_usize, async |mut fixed, (user_id, room_id, stripped_state)| {
+8
View File
@@ -47,3 +47,11 @@ pub use crate::services::Services;
conduwuit::mod_ctor! {}
conduwuit::mod_dtor! {}
use std::sync::LazyLock;
use conduwuit::matrix::versions::{unstable_features, versions};
use ruma::api::SupportedVersions;
pub static SUPPORTED_VERSIONS: LazyLock<SupportedVersions> = LazyLock::new(|| {
SupportedVersions::from_parts(&versions(), &unstable_features())
});
+2 -3
View File
@@ -183,7 +183,6 @@ impl Service {
.services
.users
.list_local_users()
.map(ToOwned::to_owned)
.collect::<Vec<OwnedUserId>>()
.await
{
@@ -194,9 +193,9 @@ impl Service {
| _ => continue,
};
if !matches!(
if matches!(
presence.presence,
PresenceState::Unavailable | PresenceState::Online | PresenceState::Busy
PresenceState::Offline
) {
trace!(%user_id, ?presence, "Skipping user");
continue;
+7 -8
View File
@@ -47,17 +47,16 @@ impl Presence {
) -> PresenceEvent {
let now = utils::millis_since_unix_epoch();
let last_active_ago = Some(UInt::new_saturating(now.saturating_sub(self.last_active_ts)));
let mut content = PresenceEventContent::new(self.state.clone());
content.status_msg = self.status_msg.clone();
content.currently_active = Some(self.currently_active);
content.last_active_ago = last_active_ago;
content.displayname = users.displayname(user_id).await.ok();
content.avatar_url = users.avatar_url(user_id).await.ok();
PresenceEvent {
sender: user_id.to_owned(),
content: PresenceEventContent {
presence: self.state.clone(),
status_msg: self.status_msg.clone(),
currently_active: Some(self.currently_active),
last_active_ago,
displayname: users.displayname(user_id).await.ok(),
avatar_url: users.avatar_url(user_id).await.ok(),
},
content,
}
}
}
+23 -39
View File
@@ -11,24 +11,17 @@ use conduwuit_database::{Deserialized, Ignore, Interfix, Json, Map};
use futures::{Stream, StreamExt};
use ipaddress::IPAddress;
use ruma::{
DeviceId, OwnedDeviceId, RoomId, UInt, UserId,
api::{
IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken,
client::push::{Pusher, PusherKind, set_pusher},
push_gateway::send_event_notification::{
DeviceId, OwnedDeviceId, RoomId, UInt, UserId, api::{
IncomingResponse, MatrixVersion, OutgoingRequest, auth_scheme::{NoAuthentication, SendAccessToken}, client::push::{Pusher, PusherKind, set_pusher}, path_builder::SinglePath, push_gateway::send_event_notification::{
self,
v1::{Device, Notification, NotificationCounts, NotificationPriority},
},
},
events::{
}
}, events::{
AnySyncTimelineEvent, StateEventType, TimelineEventType,
room::power_levels::RoomPowerLevelsEventContent,
},
push::{
room::{create::RoomCreateEventContent, power_levels::{RoomPowerLevels, RoomPowerLevelsEventContent}},
}, push::{
Action, PushConditionPowerLevelsCtx, PushConditionRoomCtx, PushFormat, Ruleset, Tweak,
},
serde::Raw,
uint,
}, room_version_rules::{AuthorizationRules, RoomPowerLevelsRules, RoomVersionRules}, serde::Raw, uint
};
use crate::{Dep, client, config, globals, rooms, sending, users};
@@ -42,6 +35,7 @@ struct Services {
globals: Dep<globals::Service>,
config: Dep<config::Service>,
client: Dep<client::Service>,
state: Dep<rooms::state::Service>,
state_accessor: Dep<rooms::state_accessor::Service>,
state_cache: Dep<rooms::state_cache::Service>,
users: Dep<users::Service>,
@@ -64,6 +58,7 @@ impl crate::Service for Service {
globals: args.depend::<globals::Service>("globals"),
client: args.depend::<client::Service>("client"),
config: args.depend::<config::Service>("config"),
state: args.depend::<rooms::state::Service>("rooms::state"),
state_accessor: args
.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"),
@@ -137,6 +132,7 @@ impl Service {
| set_pusher::v3::PusherAction::Delete(ids) => {
self.delete_pusher(sender, ids.pushkey.as_str()).await;
},
| _ => return Err!(Request(InvalidParam("Unknown pusher action"))),
}
Ok(())
@@ -193,7 +189,7 @@ impl Service {
#[tracing::instrument(skip(self, dest, request))]
pub async fn send_request<T>(&self, dest: &str, request: T) -> Result<T::IncomingResponse>
where
T: OutgoingRequest + Debug + Send,
T: OutgoingRequest<Authentication = NoAuthentication, PathBuilder = SinglePath> + Debug + Send,
{
const VERSIONS: [MatrixVersion; 1] = [MatrixVersion::V1_0];
@@ -201,7 +197,7 @@ impl Service {
trace!("Push gateway destination: {dest}");
let http_request = request
.try_into_http_request::<BytesMut>(&dest, SendAccessToken::None, &VERSIONS)
.try_into_http_request::<BytesMut>(&dest, SendAccessToken::None, ())
.map_err(|e| {
err!(BadServerResponse(warn!(
"Failed to find destination {dest} for push gateway: {e}"
@@ -298,22 +294,20 @@ impl Service {
{
let mut notify = None;
let mut tweaks = Vec::new();
if event.room_id().is_none() {
// This only affects v12+ create events
let Some(room_id) = event.room_id() else {
// Only v12+ create events have no room ID
return Ok(());
}
};
let power_levels: RoomPowerLevelsEventContent = self
let power_levels = self
.services
.state_accessor
.room_state_get(event.room_id().unwrap(), &StateEventType::RoomPowerLevels, "")
.await
.and_then(|event| event.get_content())
.unwrap_or_default();
.get_room_power_levels(room_id)
.await;
let serialized = event.to_format();
for action in self
.get_actions(user, &ruleset, &power_levels, &serialized, event.room_id().unwrap())
.get_actions(user, &ruleset, power_levels.clone(), &serialized, event.room_id().unwrap())
.await
{
let n = match action {
@@ -347,15 +341,11 @@ impl Service {
&self,
user: &UserId,
ruleset: &'a Ruleset,
power_levels: &RoomPowerLevelsEventContent,
power_levels: RoomPowerLevels,
pdu: &Raw<AnySyncTimelineEvent>,
room_id: &RoomId,
) -> &'a [Action] {
let power_levels = PushConditionPowerLevelsCtx {
users: power_levels.users.clone(),
users_default: power_levels.users_default,
notifications: power_levels.notifications.clone(),
};
let power_levels = PushConditionPowerLevelsCtx::from(power_levels);
let room_joined_count = self
.services
@@ -373,15 +363,9 @@ impl Service {
.await
.unwrap_or_else(|_| user.localpart().to_owned());
let ctx = PushConditionRoomCtx {
room_id: room_id.to_owned(),
member_count: room_joined_count,
user_id: user.to_owned(),
user_display_name,
power_levels: Some(power_levels),
};
let ctx = PushConditionRoomCtx::new(room_id.to_owned(), room_joined_count, user.to_owned(), user_display_name).with_power_levels(power_levels);
ruleset.get_actions(pdu, &ctx)
ruleset.get_actions(pdu, &ctx).await
}
#[tracing::instrument(skip(self, unread, pusher, tweaks, event))]
+5 -5
View File
@@ -8,7 +8,7 @@ use conduwuit::{
};
use database::{Cbor, Deserialized, Map};
use futures::{Stream, StreamExt, future::join};
use ruma::ServerName;
use ruma::{OwnedServerName, ServerName};
use serde::{Deserialize, Serialize};
use super::fed::FedDest;
@@ -107,19 +107,19 @@ pub async fn get_override(&self, name: &str) -> Result<CachedOverride> {
}
#[implement(Cache)]
pub fn destinations(&self) -> impl Stream<Item = (&ServerName, CachedDest)> + Send + '_ {
pub fn destinations(&self) -> impl Stream<Item = (OwnedServerName, CachedDest)> + Send + '_ {
self.destinations
.stream()
.ignore_err()
.map(|item: (&ServerName, Cbor<_>)| (item.0, item.1.0))
.map(|item: (OwnedServerName, Cbor<_>)| (item.0, item.1.0))
}
#[implement(Cache)]
pub fn overrides(&self) -> impl Stream<Item = (&ServerName, CachedOverride)> + Send + '_ {
pub fn overrides(&self) -> impl Stream<Item = (OwnedServerName, CachedOverride)> + Send + '_ {
self.overrides
.stream()
.ignore_err()
.map(|item: (&ServerName, Cbor<_>)| (item.0, item.1.0))
.map(|item: (OwnedServerName, Cbor<_>)| (item.0, item.1.0))
}
impl CachedDest {
+4 -1
View File
@@ -1,6 +1,9 @@
use std::str::FromStr;
use conduwuit::{
Result, debug, debug_error, debug_info, implement, trace, utils::response::LimitReadExt,
};
use ruma::{OwnedServerName, ServerName};
#[implement(super::Service)]
#[tracing::instrument(name = "well-known", level = "debug", skip(self, dest))]
@@ -40,7 +43,7 @@ pub(super) async fn request_well_known(&self, dest: &str) -> Result<Option<Strin
.as_str()
.unwrap_or_default();
if ruma::identifiers_validation::server_name::validate(m_server).is_err() {
if ServerName::parse(m_server).is_err() {
debug_error!("response content missing or invalid");
return Ok(None);
}
+11 -33
View File
@@ -9,11 +9,10 @@ use conduwuit::{
use database::{Deserialized, Ignore, Interfix, Map};
use futures::{Stream, StreamExt, TryFutureExt};
use ruma::{
OwnedRoomId, OwnedServerName, OwnedUserId, RoomAliasId, RoomId, RoomOrAliasId, UserId,
events::{
OwnedRoomAliasId, OwnedRoomId, OwnedServerName, OwnedUserId, RoomAliasId, RoomId, RoomOrAliasId, UserId, events::{
StateEventType,
room::power_levels::{RoomPowerLevels, RoomPowerLevelsEventContent},
},
}
};
use crate::{Dep, admin, appservice, appservice::RegistrationInfo, globals, rooms, sending};
@@ -179,7 +178,6 @@ impl Service {
.services
.state_cache
.room_servers(&room_id)
.map(ToOwned::to_owned)
.collect()
.await;
return Ok((room_id, servers));
@@ -197,22 +195,22 @@ impl Service {
pub fn local_aliases_for_room<'a>(
&'a self,
room_id: &'a RoomId,
) -> impl Stream<Item = &'a RoomAliasId> + Send + 'a {
) -> impl Stream<Item = OwnedRoomAliasId> + Send + 'a {
let prefix = (room_id, Interfix);
self.db
.aliasid_alias
.stream_prefix(&prefix)
.ignore_err()
.map(|(_, alias): (Ignore, &RoomAliasId)| alias)
.map(|(_, alias): (Ignore, OwnedRoomAliasId)| alias)
}
#[tracing::instrument(skip(self), level = "debug")]
pub fn all_local_aliases(&self) -> impl Stream<Item = (&RoomId, &str)> + Send + '_ {
pub fn all_local_aliases(&self) -> impl Stream<Item = (OwnedRoomId, &str)> + Send + '_ {
self.db
.alias_roomid
.stream()
.ignore_err()
.map(|(alias_localpart, room_id): (&str, &RoomId)| (room_id, alias_localpart))
.map(|(alias_localpart, room_id): (&str, OwnedRoomId)| (room_id, alias_localpart))
}
async fn user_can_remove_alias(&self, alias: &RoomAliasId, user_id: &UserId) -> Result<bool> {
@@ -236,34 +234,14 @@ impl Service {
}
// Checking whether the user is able to change canonical aliases of the room
if let Ok(power_levels) = self
let can_change_canonical_alias = self
.services
.state_accessor
.room_state_get_content::<RoomPowerLevelsEventContent>(
&room_id,
&StateEventType::RoomPowerLevels,
"",
)
.map_ok(RoomPowerLevels::from)
.get_room_power_levels(&room_id)
.await
{
return Ok(
power_levels.user_can_send_state(user_id, StateEventType::RoomCanonicalAlias)
);
}
.user_can_send_state(user_id, StateEventType::RoomCanonicalAlias);
// If there is no power levels event, only the room creator can change
// canonical aliases
if let Ok(event) = self
.services
.state_accessor
.room_state_get(&room_id, &StateEventType::RoomCreate, "")
.await
{
return Ok(event.sender() == user_id);
}
Err!(Database("Room has no m.room.create event"))
Ok(can_change_canonical_alias)
}
async fn who_created_alias(&self, alias: &RoomAliasId) -> Result<OwnedUserId> {
@@ -299,7 +277,7 @@ impl Service {
.sending
.send_appservice_request(
appservice.registration.clone(),
query_room_alias::v1::Request { room_alias: room_alias.to_owned() },
query_room_alias::v1::Request::new(room_alias.to_owned()),
)
.await,
Ok(Some(_opt_result))
+2 -2
View File
@@ -16,7 +16,7 @@ pub(super) async fn remote_resolve(
error!("Unable to resolve remote room alias {}: {e}", room_alias);
Err(e)
},
| Ok(Response { room_id, servers }) => {
| Ok(Response { room_id, servers, .. }) => {
debug!("Remote resolved {room_alias:?} to {room_id:?} with servers {servers:?}");
Ok((room_id, servers))
},
@@ -31,7 +31,7 @@ async fn remote_request(
) -> Result<Response> {
use federation::query::get_room_information::v1::Request;
let request = Request { room_alias: room_alias.to_owned() };
let request = Request::new(room_alias.to_owned());
self.services
.sending
+2 -2
View File
@@ -3,7 +3,7 @@ use std::sync::Arc;
use conduwuit::{Result, implement, utils::stream::TryIgnore};
use database::Map;
use futures::Stream;
use ruma::{RoomId, api::client::room::Visibility};
use ruma::{OwnedRoomId, RoomId, api::client::room::Visibility};
pub struct Service {
db: Data,
@@ -32,7 +32,7 @@ pub fn set_public(&self, room_id: &RoomId) { self.db.publicroomids.insert(room_i
pub fn set_not_public(&self, room_id: &RoomId) { self.db.publicroomids.remove(room_id); }
#[implement(Service)]
pub fn public_rooms(&self) -> impl Stream<Item = &RoomId> + Send {
pub fn public_rooms(&self) -> impl Stream<Item = OwnedRoomId> + Send {
self.db.publicroomids.keys().ignore_err()
}
@@ -12,7 +12,7 @@ use ruma::{
api::federation::event::get_event,
};
use super::get_room_version_id;
use super::get_room_version;
/// Find the event and auth it. Once the event is validated (steps 1 - 8)
/// it is appended to the outliers Tree.
@@ -109,15 +109,12 @@ where
match self
.services
.sending
.send_federation_request(origin, get_event::v1::Request {
event_id: (*next_id).to_owned(),
include_unredacted_content: None,
})
.send_federation_request(origin, get_event::v1::Request::new((*next_id).to_owned()))
.await
{
| Ok(res) => {
debug!("Got {next_id} over federation from {origin}");
let Ok(room_version_id) = get_room_version_id(create_event) else {
let Ok(room_version_id) = get_room_version(create_event) else {
back_off((*next_id).to_owned());
continue;
};
@@ -31,10 +31,7 @@ where
let res = self
.services
.sending
.send_federation_request(origin, get_room_state_ids::v1::Request {
room_id: room_id.to_owned(),
event_id: event_id.to_owned(),
})
.send_federation_request(origin, get_room_state_ids::v1::Request::new(event_id.to_owned(), room_id.to_owned()))
.await
.inspect_err(|e| debug_warn!("Fetching state for event failed: {e}"))?;
@@ -55,7 +55,7 @@ async fn should_rescind_invite(
// Does the target user have a pending invite?
let Ok(pending_invite_state) = services
.state_cache
.invite_state(target_user_id, room_id)
.invite_state(&target_user_id, room_id)
.await
else {
return Ok(None); // No pending invite, so nothing to rescind
@@ -146,9 +146,11 @@ pub async fn handle_incoming_pdu<'a>(
let origin_acl_check = self.acl_check(origin, room_id);
// 1.3.2 Check room ACL on sender's server name
let sender: &UserId = value
let sender: OwnedUserId = value
.get("sender")
.try_into()
.and_then(|v| v.as_str())
.ok_or_else(|| err!("No sender in object"))
.and_then(|v| Ok(UserId::parse(v)?))
.map_err(|e| err!(Request(InvalidParam("PDU does not have a valid sender key: {e}"))))?;
let sender_acl_check: OptionFuture<_> = sender
@@ -180,7 +182,7 @@ pub async fn handle_incoming_pdu<'a>(
// copied from https://github.com/element-hq/synapse/blob/7e4588a/synapse/handlers/federation_event.py#L255-L300
if value.get("type").and_then(|t| t.as_str()) == Some("m.room.member") {
if let Some(pdu) =
should_rescind_invite(&self.services, &mut value.clone(), sender, room_id).await?
should_rescind_invite(&self.services, &mut value.clone(), &sender, room_id).await?
{
debug_info!(
"Invite to {room_id} appears to have been rescinded by {sender}, marking as \
@@ -188,7 +190,7 @@ pub async fn handle_incoming_pdu<'a>(
);
self.services
.state_cache
.mark_as_left(sender, room_id, Some(pdu))
.mark_as_left(&sender, room_id, Some(pdu))
.await;
return Ok(None);
}
@@ -10,7 +10,7 @@ use ruma::{
events::StateEventType,
};
use super::{check_room_id, get_room_version_id, to_room_version};
use super::{check_room_id, get_room_version};
use crate::rooms::timeline::pdu_fits;
#[implement(super::Service)]
@@ -41,18 +41,19 @@ where
// 2. Check signatures, otherwise drop
// 3. check content hash, redact if doesn't match
let room_version_id = get_room_version_id(create_event)?;
let room_version = get_room_version(create_event)?;
let room_rules = room_version.rules().expect("room version should have defined rules");
let mut incoming_pdu = match self
.services
.server_keys
.verify_event(&value, Some(&room_version_id))
.verify_event(&value, Some(&room_version))
.await
{
| Ok(ruma::signatures::Verified::All) => value,
| Ok(ruma::signatures::Verified::Signatures) => {
// Redact
debug_info!("Calculated hash does not match (redaction): {event_id}");
let Ok(obj) = ruma::canonical_json::redact(value, &room_version_id, None) else {
let Ok(obj) = ruma::canonical_json::redact(value, &room_rules.redaction, None) else {
return Err!(Request(InvalidParam("Redaction failed")));
};
@@ -184,7 +185,7 @@ where
};
let auth_check = state_res::event_auth::auth_check(
&to_room_version(&room_version_id),
&room_rules,
&pdu_event,
None, // TODO: third party invite
state_fetch,
+3 -8
View File
@@ -14,7 +14,7 @@ mod upgrade_outlier_pdu;
use std::{collections::HashMap, fmt::Write, sync::Arc, time::Instant};
use async_trait::async_trait;
use conduwuit::{Err, Event, PduEvent, Result, RoomVersion, Server, SyncRwLock, utils::MutexMap};
use conduwuit::{Err, Event, PduEvent, Result, Server, SyncRwLock, utils::MutexMap};
use ruma::{
OwnedEventId, OwnedRoomId, RoomId, RoomVersionId,
events::room::create::RoomCreateEventContent,
@@ -114,14 +114,9 @@ fn check_room_id<Pdu: Event>(room_id: &RoomId, pdu: &Pdu) -> Result {
Ok(())
}
fn get_room_version_id<Pdu: Event>(create_event: &Pdu) -> Result<RoomVersionId> {
fn get_room_version<Pdu: Event>(create_event: &Pdu) -> Result<RoomVersionId> {
let content: RoomCreateEventContent = create_event.get_content()?;
let room_version = content.room_version;
Ok(room_version)
}
#[inline]
fn to_room_version(room_version_id: &RoomVersionId) -> RoomVersion {
RoomVersion::new(room_version_id).expect("room version is supported")
}
}
@@ -3,22 +3,24 @@
//! This module implements a check against a room-specific policy server, as
//! described in the relevant Matrix spec proposal (see: https://github.com/matrix-org/matrix-spec-proposals/pull/4284).
use std::{collections::BTreeMap, time::Duration};
use std::{collections::BTreeMap, sync::LazyLock, time::Duration};
use conduwuit::{
Err, Event, PduEvent, Result, debug, debug_error, debug_info, debug_warn, implement, trace,
warn,
};
use ruma::{
CanonicalJsonObject, CanonicalJsonValue, KeyId, RoomId, ServerName, SigningKeyId,
api::federation::room::{
policy_check::unstable::Request as PolicyCheckRequest,
policy_sign::unstable::Request as PolicySignRequest,
},
events::{StateEventType, room::policy::RoomPolicyEventContent},
CanonicalJsonObject, CanonicalJsonValue, KeyId, OwnedKeyId, RoomId, ServerName, SigningKeyId, events::StateEventType
};
use ruminuwuity::policy::{
policy_check::unstable::Request as PolicyCheckRequest,
policy_sign::unstable::Request as PolicySignRequest,
event::RoomPolicyEventContent
};
use serde_json::value::RawValue;
static POLICY_EVENT_TYPE_UNSTABLE: LazyLock<StateEventType> = LazyLock::new(|| StateEventType::from("org.matrix.msc4284.policy"));
/// Asks a remote policy server if the event is allowed.
///
/// If the event is the `org.matrix.msc4284.policy` configuration state event,
@@ -44,7 +46,7 @@ pub async fn ask_policy_server(
return Ok(true); // don't ever contact policy servers
}
if *pdu.event_type() == StateEventType::RoomPolicy.into() {
if *pdu.event_type() == POLICY_EVENT_TYPE_UNSTABLE.clone().into() {
debug!(
room_id = %room_id,
event_type = ?pdu.event_type(),
@@ -56,7 +58,7 @@ pub async fn ask_policy_server(
let Ok(policyserver) = self
.services
.state_accessor
.room_state_get_content(room_id, &StateEventType::RoomPolicy, "")
.room_state_get_content(room_id, &POLICY_EVENT_TYPE_UNSTABLE, "")
.await
.inspect_err(|e| {
if !e.is_not_found() {
@@ -86,11 +88,7 @@ pub async fn ask_policy_server(
return Ok(true);
},
};
if via.is_empty() {
trace!("Policy server is empty for room {room_id}, skipping spam check");
return Ok(true);
}
if !self.services.state_cache.server_in_room(via, room_id).await {
if !self.services.state_cache.server_in_room(&via, room_id).await {
debug!(
via = %via,
"Policy server is not in the room, skipping spam check"
@@ -110,14 +108,14 @@ pub async fn ask_policy_server(
"Getting policy server signature on event"
);
return self
.fetch_policy_server_signature(pdu, pdu_json, via, outgoing, room_id)
.fetch_policy_server_signature(pdu, pdu_json, &via, outgoing, room_id)
.await;
}
// for incoming events, is it signed by <via> with the key
// "ed25519:policy_server"?
if let Some(CanonicalJsonValue::Object(sigs)) = pdu_json.get("signatures") {
if let Some(CanonicalJsonValue::Object(server_sigs)) = sigs.get(via.as_str()) {
let wanted_key_id: &KeyId<ruma::SigningKeyAlgorithm, ruma::Base64PublicKey> =
let wanted_key_id: OwnedKeyId<ruma::SigningKeyAlgorithm, ruma::Base64PublicKey> =
SigningKeyId::parse("ed25519:policy_server")?;
if let Some(CanonicalJsonValue::String(_sig_value)) =
server_sigs.get(wanted_key_id.as_str())
@@ -134,14 +132,15 @@ pub async fn ask_policy_server(
via = %via,
"Checking event for spam with policy server via legacy check"
);
let mut request = PolicyCheckRequest::new(pdu.event_id().to_owned());
request.pdu = Some(outgoing);
let response = tokio::time::timeout(
Duration::from_secs(self.services.server.config.policy_server_request_timeout),
self.services
.sending
.send_federation_request(via, PolicyCheckRequest {
event_id: pdu.event_id().to_owned(),
pdu: Some(outgoing),
}),
.send_federation_request(&via, request),
)
.await;
let response = match response {
@@ -202,7 +201,7 @@ pub async fn fetch_policy_server_signature(
Duration::from_secs(self.services.server.config.policy_server_request_timeout),
self.services
.sending
.send_federation_request(via, PolicySignRequest { pdu: outgoing }),
.send_federation_request(via, PolicySignRequest::new(outgoing)),
)
.await;
@@ -250,7 +249,7 @@ pub async fn fetch_policy_server_signature(
}
let keypairs = sigs.get(via).unwrap();
let wanted_key_id = KeyId::parse("ed25519:policy_server")?;
if !keypairs.contains_key(wanted_key_id) {
if !keypairs.contains_key(&wanted_key_id) {
debug_warn!(
"Policy server returned signature, but did not use the key ID \
'ed25519:policy_server'."
@@ -262,7 +261,7 @@ pub async fn fetch_policy_server_signature(
.or_insert_with(|| CanonicalJsonValue::Object(BTreeMap::default()));
if let CanonicalJsonValue::Object(signatures_map) = signatures_entry {
let sig_value = keypairs.get(wanted_key_id).unwrap().to_owned();
let sig_value = keypairs.get(&wanted_key_id).unwrap().to_owned();
match signatures_map.get_mut(via.as_str()) {
| Some(CanonicalJsonValue::Object(inner_map)) => {
@@ -11,7 +11,7 @@ use conduwuit::{
utils::stream::{IterStream, ReadyExt, TryWidebandExt, WidebandExt},
};
use futures::{FutureExt, StreamExt, TryFutureExt, TryStreamExt, future::try_join};
use ruma::{OwnedEventId, RoomId, RoomVersionId};
use ruma::{OwnedEventId, RoomId, RoomVersionId, room_version_rules::RoomVersionRules};
use crate::rooms::state_compressor::CompressedState;
@@ -20,7 +20,7 @@ use crate::rooms::state_compressor::CompressedState;
pub async fn resolve_state(
&self,
room_id: &RoomId,
room_version_id: &RoomVersionId,
room_version_rules: &RoomVersionRules,
incoming_state: HashMap<u64, OwnedEventId>,
) -> Result<Arc<CompressedState>> {
trace!("Loading current room state ids");
@@ -71,7 +71,7 @@ pub async fn resolve_state(
trace!("Resolving state");
let state = self
.state_resolution(room_version_id, fork_states.iter(), &auth_chain_sets)
.state_resolution(room_version_rules, fork_states.iter(), &auth_chain_sets)
.boxed()
.await?;
@@ -103,7 +103,7 @@ pub async fn resolve_state(
#[tracing::instrument(name = "ruma", level = "debug", skip_all)]
pub async fn state_resolution<'a, StateSets>(
&'a self,
room_version: &'a RoomVersionId,
room_version_rules: &'a RoomVersionRules,
state_sets: StateSets,
auth_chain_sets: &'a [HashSet<OwnedEventId>],
) -> Result<StateMap<OwnedEventId>>
@@ -112,7 +112,7 @@ where
{
let event_fetch = |event_id| self.event_fetch(event_id);
let event_exists = |event_id| self.event_exists(event_id);
state_res::resolve(room_version, state_sets, auth_chain_sets, &event_fetch, &event_exists)
state_res::resolve(room_version_rules, state_sets, auth_chain_sets, &event_fetch, &event_exists)
.map_err(|e| err!(error!("State resolution failed: {e:?}")))
.await
}
@@ -11,7 +11,7 @@ use conduwuit::{
utils::stream::{BroadbandExt, IterStream, ReadyExt, TryBroadbandExt, TryWidebandExt},
};
use futures::{FutureExt, StreamExt, TryFutureExt, TryStreamExt, future::try_join};
use ruma::{OwnedEventId, RoomId, RoomVersionId};
use ruma::{OwnedEventId, RoomId, RoomVersionId, room_version_rules::RoomVersionRules};
use crate::rooms::short::ShortStateHash;
@@ -77,7 +77,7 @@ pub(super) async fn state_at_incoming_resolved<Pdu>(
&self,
incoming_pdu: &Pdu,
room_id: &RoomId,
room_version_id: &RoomVersionId,
room_version_rules: &RoomVersionRules,
) -> Result<Option<HashMap<u64, OwnedEventId>>>
where
Pdu: Event + Send + Sync,
@@ -118,7 +118,7 @@ where
.await?;
let Ok(new_state) = self
.state_resolution(room_version_id, fork_states.iter(), &auth_chain_sets)
.state_resolution(room_version_rules, fork_states.iter(), &auth_chain_sets)
.boxed()
.await
else {
@@ -10,7 +10,7 @@ use conduwuit::{
use futures::{FutureExt, StreamExt, future::ready};
use ruma::{CanonicalJsonValue, RoomId, ServerName, events::StateEventType};
use super::{get_room_version_id, to_room_version};
use super::get_room_version;
use crate::rooms::{
state_compressor::{CompressedState, HashSetCompressStateEvent},
timeline::RawPduId,
@@ -52,7 +52,8 @@ where
"Upgrading PDU from outlier to timeline"
);
let timer = Instant::now();
let room_version_id = get_room_version_id(create_event)?;
let room_version_id = get_room_version(create_event)?;
let room_version_rules = room_version_id.rules().expect("room version should have defined rules");
// 10. Fetch missing state and auth chain events by calling /state_ids at
// backwards extremities doing all the checks in this list starting at 1.
@@ -65,7 +66,7 @@ where
let mut state_at_incoming_event = if incoming_pdu.prev_events().count() == 1 {
self.state_at_incoming_degree_one(&incoming_pdu).await?
} else {
self.state_at_incoming_resolved(&incoming_pdu, room_id, &room_version_id)
self.state_at_incoming_resolved(&incoming_pdu, room_id, &room_version_rules)
.await?
};
@@ -78,8 +79,6 @@ where
let state_at_incoming_event =
state_at_incoming_event.expect("we always set this to some above");
let room_version = to_room_version(&room_version_id);
debug!(
event_id = %incoming_pdu.event_id,
"Performing auth check to upgrade"
@@ -98,7 +97,7 @@ where
"Running initial auth check"
);
let auth_check = state_res::event_auth::auth_check(
&room_version,
&room_version_rules,
&incoming_pdu,
None, // TODO: third party invite
|ty, sk| state_fetch(ty.clone(), sk.into()),
@@ -124,7 +123,7 @@ where
incoming_pdu.sender(),
incoming_pdu.state_key(),
incoming_pdu.content(),
&room_version,
&room_version_rules,
)
.await?;
@@ -138,7 +137,7 @@ where
"Running auth check with claimed state auth"
);
let auth_check = state_res::event_auth::auth_check(
&room_version,
&room_version_rules,
&incoming_pdu,
None, // third-party invite
state_fetch,
@@ -179,7 +178,6 @@ where
.services
.state
.get_forward_extremities(room_id)
.map(ToOwned::to_owned)
.ready_filter(|event_id| {
// Remove any that are referenced by this incoming event's prev_events
!incoming_pdu.prev_events().any(is_equal_to!(event_id))
@@ -232,7 +230,7 @@ where
}
let new_room_state = self
.resolve_state(room_id, &room_version_id, state_after)
.resolve_state(room_id, &room_version_rules, state_after)
.await?;
// Set the new room state to the resolved state
+2 -2
View File
@@ -89,13 +89,13 @@ pub async fn retain_lazy_members(&self, senders: MemberSet, ctx: &Context<'_>) -
let mut senders = MemberSet::with_capacity(senders.len());
while let Some((status, sender)) = witness.next().await {
if include_redundant || status == Status::Unseen {
senders.insert(sender.into());
senders.insert(sender.clone());
continue;
}
if let Status::Seen(seen) = status {
if seen == 0 || ctx.token == Some(seen) {
senders.insert(sender.into());
senders.insert(sender.clone());
continue;
}
}
+3 -3
View File
@@ -3,7 +3,7 @@ use std::sync::Arc;
use conduwuit::{Result, implement, utils::stream::TryIgnore};
use database::Map;
use futures::{Stream, StreamExt};
use ruma::RoomId;
use ruma::{OwnedRoomId, RoomId};
use crate::{Dep, rooms};
@@ -58,7 +58,7 @@ pub async fn exists(&self, room_id: &RoomId) -> bool {
}
#[implement(Service)]
pub fn iter_ids(&self) -> impl Stream<Item = &RoomId> + Send + '_ {
pub fn iter_ids(&self) -> impl Stream<Item = OwnedRoomId> + Send + '_ {
self.db.roomid_shortroomid.keys().ignore_err()
}
@@ -83,7 +83,7 @@ pub fn ban_room(&self, room_id: &RoomId, banned: bool) {
}
#[implement(Service)]
pub fn list_banned_rooms(&self) -> impl Stream<Item = &RoomId> + Send + '_ {
pub fn list_banned_rooms(&self) -> impl Stream<Item = OwnedRoomId> + Send + '_ {
self.db.bannedroomids.keys().ignore_err()
}
+6 -8
View File
@@ -7,9 +7,7 @@ use conduwuit::{
use database::{Deserialized, Json, Map};
use futures::{Stream, StreamExt};
use ruma::{
CanonicalJsonObject, OwnedUserId, RoomId, UserId,
events::{AnySyncEphemeralRoomEvent, receipt::ReceiptEvent},
serde::Raw,
CanonicalJsonObject, OwnedRoomId, OwnedUserId, RoomId, UserId, events::{AnySyncEphemeralRoomEvent, receipt::ReceiptEvent}, serde::Raw
};
use crate::{Dep, globals};
@@ -66,8 +64,8 @@ impl Data {
room_id: &'a RoomId,
since: u64,
) -> impl Stream<Item = ReceiptItem> + Send + 'a {
type Key<'a> = (&'a RoomId, u64, &'a UserId);
type KeyVal<'a> = (Key<'a>, CanonicalJsonObject);
type Key = (OwnedRoomId, u64, OwnedUserId);
type KeyVal = (Key, CanonicalJsonObject);
let after_since = since.saturating_add(1); // +1 so we don't send the event at since
let first_possible_edu = (room_id, after_since);
@@ -75,13 +73,13 @@ impl Data {
self.readreceiptid_readreceipt
.stream_from(&first_possible_edu)
.ignore_err()
.ready_take_while(move |((r, ..), _): &KeyVal<'_>| *r == room_id)
.map(move |((_, count, user_id), mut json): KeyVal<'_>| {
.ready_take_while(move |((r, ..), _): &KeyVal| *r == room_id)
.map(move |((_, count, user_id), mut json): KeyVal| {
json.remove("room_id");
let event = serde_json::value::to_raw_value(&json)?;
Ok((user_id.to_owned(), count, Raw::from_json(event)))
Ok((user_id, count, Raw::from_json(event)))
})
.ignore_err()
}
+9 -6
View File
@@ -85,18 +85,21 @@ impl Service {
let event_id: OwnedEventId = pdu.event_id().to_owned();
let user_id: OwnedUserId = user_id.to_owned();
let mut receipt = ruma::events::receipt::Receipt::default();
// TODO: start storing the timestamp so we can return one
receipt.ts = None;
receipt.thread = ruma::events::receipt::ReceiptThread::Unthreaded;
let content: BTreeMap<OwnedEventId, Receipts> = BTreeMap::from_iter([(
event_id,
BTreeMap::from_iter([(
ruma::events::receipt::ReceiptType::ReadPrivate,
BTreeMap::from_iter([(user_id, ruma::events::receipt::Receipt {
ts: None, // TODO: start storing the timestamp so we can return one
thread: ruma::events::receipt::ReceiptThread::Unthreaded,
})]),
BTreeMap::from_iter([(user_id, receipt)]),
)]),
)]);
let receipt_event_content = ReceiptEventContent(content);
let receipt_sync_event = SyncEphemeralRoomEvent { content: receipt_event_content };
let receipt_sync_event = SyncEphemeralRoomEvent::new(receipt_event_content);
let event = serde_json::value::to_raw_value(&receipt_sync_event)
.expect("receipt created manually");
@@ -165,7 +168,7 @@ where
conduwuit::trace!(?content);
Raw::from_json(
serde_json::value::to_raw_value(&SyncEphemeralRoomEvent { content })
serde_json::value::to_raw_value(&SyncEphemeralRoomEvent::new(content))
.expect("received valid json"),
)
}
+45 -160
View File
@@ -17,20 +17,16 @@ use conduwuit_core::{
use futures::{FutureExt, Stream, StreamExt, TryFutureExt, pin_mut, stream::FuturesUnordered};
use lru_cache::LruCache;
use ruma::{
OwnedEventId, OwnedRoomId, OwnedServerName, RoomId, ServerName, UserId,
api::{
OwnedEventId, OwnedRoomId, OwnedServerName, RoomId, ServerName, UserId, api::{
client::space::SpaceHierarchyRoomsChunk,
federation::{
self,
space::{SpaceHierarchyChildSummary, SpaceHierarchyParentSummary},
space::SpaceHierarchyParentSummary,
},
},
events::{
}, events::{
StateEventType,
space::child::{HierarchySpaceChildEvent, SpaceChildEventContent},
},
serde::Raw,
space::SpaceRoomJoinRule,
}, room::{JoinRuleSummary, RoomSummary}, serde::Raw,
};
use tokio::sync::{Mutex, MutexGuard};
@@ -121,13 +117,10 @@ pub async fn get_summary_and_children_local(
| None => (), // cache miss
| Some(None) => return Ok(None),
| Some(Some(cached)) => {
let allowed_rooms = cached.summary.allowed_room_ids.iter().map(AsRef::as_ref);
let is_accessible_child = self.is_accessible_child(
current_room,
&cached.summary.join_rule,
&cached.summary.summary.join_rule,
identifier,
allowed_rooms,
);
let accessibility = if is_accessible_child.await {
@@ -172,10 +165,8 @@ async fn get_summary_and_children_federation(
user_id: &UserId,
via: &[OwnedServerName],
) -> Result<Option<SummaryAccessibility>> {
let request = federation::space::get_hierarchy::v1::Request {
room_id: current_room.to_owned(),
suggested_only,
};
let mut request = federation::space::get_hierarchy::v1::Request::new(current_room.to_owned());
request.suggested_only = suggested_only;
let mut requests: FuturesUnordered<_> = via
.iter()
@@ -213,14 +204,13 @@ async fn get_summary_and_children_federation(
.ready_filter_map(|(child, mut cache)| {
(!cache.contains_key(current_room)).then_some((child, cache))
})
.for_each(|(child, cache)| self.cache_insert(cache, current_room, child))
.for_each(|(summary, cache)| self.cache_insert(cache, current_room, summary))
.await;
let identifier = Identifier::UserId(user_id);
let allowed_room_ids = summary.allowed_room_ids.iter().map(AsRef::as_ref);
let is_accessible_child = self
.is_accessible_child(current_room, &summary.join_rule, &identifier, allowed_room_ids)
.is_accessible_child(current_room, &summary.summary.join_rule, &identifier)
.await;
let accessibility = if is_accessible_child {
@@ -313,7 +303,6 @@ async fn get_room_summary(
room_id,
&join_rule.clone().into(),
identifier,
join_rule.allowed_rooms(),
)
.await;
@@ -381,38 +370,34 @@ async fn get_room_summary(
encryption,
);
let summary = SpaceHierarchyParentSummary {
canonical_alias,
name,
topic,
world_readable,
let mut summary = RoomSummary::new(
room_id.to_owned(),
join_rule.clone().into(),
guest_can_join,
avatar_url,
room_type,
children_state,
encryption,
room_version,
room_id: room_id.to_owned(),
num_joined_members: num_joined_members.try_into().unwrap_or_default(),
allowed_room_ids: join_rule.allowed_rooms().map(Into::into).collect(),
join_rule: join_rule.clone().into(),
};
num_joined_members.try_into().unwrap_or_default(),
world_readable
);
summary.canonical_alias = canonical_alias;
summary.name = name;
summary.topic = topic;
summary.avatar_url = avatar_url;
summary.encryption = encryption;
summary.room_type = room_type;
summary.room_version = room_version;
let summary = SpaceHierarchyParentSummary::new(summary, children_state);
Ok(summary)
}
/// With the given identifier, checks if a room is accessible
#[implement(Service)]
async fn is_accessible_child<'a, I>(
async fn is_accessible_child(
&self,
current_room: &RoomId,
join_rule: &SpaceRoomJoinRule,
join_rule: &JoinRuleSummary,
identifier: &Identifier<'_>,
allowed_rooms: I,
) -> bool
where
I: Iterator<Item = &'a RoomId> + Send,
{
) -> bool {
if let Identifier::ServerName(server_name) = identifier {
// Checks if ACLs allow for the server to participate
if self
@@ -437,23 +422,17 @@ where
}
}
match *join_rule {
| SpaceRoomJoinRule::Public
| SpaceRoomJoinRule::Knock
| SpaceRoomJoinRule::KnockRestricted => true,
| SpaceRoomJoinRule::Restricted =>
allowed_rooms
.stream()
.any(async |room| match identifier {
| Identifier::UserId(user) =>
self.services.state_cache.is_joined(user, room).await,
| Identifier::ServerName(server) =>
self.services.state_cache.server_in_room(server, room).await,
})
.await,
// Invite only, Private, or Custom join rule
| _ => false,
match join_rule {
| JoinRuleSummary::Public
| JoinRuleSummary::Knock
| JoinRuleSummary::KnockRestricted(_) => true,
| JoinRuleSummary::Restricted(restricted_summary) => {
(&restricted_summary.allowed_room_ids).stream().any(async |room| match identifier {
| Identifier::UserId(user) => self.services.state_cache.is_joined(user, room).await,
| Identifier::ServerName(server) => self.services.state_cache.server_in_room(server, room).await,
}).await
},
_ => false
}
}
@@ -481,44 +460,14 @@ async fn cache_insert(
&self,
mut cache: MutexGuard<'_, Cache>,
current_room: &RoomId,
child: SpaceHierarchyChildSummary,
summary: RoomSummary,
) {
let SpaceHierarchyChildSummary {
canonical_alias,
name,
num_joined_members,
room_id,
topic,
world_readable,
guest_can_join,
avatar_url,
join_rule,
room_type,
allowed_room_ids,
encryption,
room_version,
} = child;
let summary = SpaceHierarchyParentSummary {
canonical_alias,
name,
num_joined_members,
topic,
world_readable,
guest_can_join,
avatar_url,
join_rule,
room_type,
allowed_room_ids,
room_id: room_id.clone(),
children_state: self
.get_space_child_events(&room_id)
let children_state = self
.get_space_child_events(&summary.room_id)
.map(Event::into_format)
.collect()
.await,
encryption,
room_version,
};
.await;
let summary = SpaceHierarchyParentSummary::new(summary, children_state);
cache.insert(current_room.to_owned(), Some(CachedSpaceHierarchySummary { summary }));
}
@@ -527,39 +476,7 @@ async fn cache_insert(
// ruma-client-api types
impl From<CachedSpaceHierarchySummary> for SpaceHierarchyRoomsChunk {
fn from(value: CachedSpaceHierarchySummary) -> Self {
let SpaceHierarchyParentSummary {
canonical_alias,
name,
num_joined_members,
room_id,
topic,
world_readable,
guest_can_join,
avatar_url,
join_rule,
room_type,
children_state,
allowed_room_ids,
encryption,
room_version,
} = value.summary;
Self {
canonical_alias,
name,
num_joined_members,
room_id,
topic,
world_readable,
guest_can_join,
avatar_url,
join_rule,
room_type,
children_state,
encryption,
room_version,
allowed_room_ids,
}
Self::new(value.summary.summary, value.summary.children_state)
}
}
@@ -567,37 +484,5 @@ impl From<CachedSpaceHierarchySummary> for SpaceHierarchyRoomsChunk {
/// ruma-client-api types
#[must_use]
pub fn summary_to_chunk(summary: SpaceHierarchyParentSummary) -> SpaceHierarchyRoomsChunk {
let SpaceHierarchyParentSummary {
canonical_alias,
name,
num_joined_members,
room_id,
topic,
world_readable,
guest_can_join,
avatar_url,
join_rule,
room_type,
children_state,
allowed_room_ids,
encryption,
room_version,
} = summary;
SpaceHierarchyRoomsChunk {
canonical_alias,
name,
num_joined_members,
room_id,
topic,
world_readable,
guest_can_join,
avatar_url,
join_rule,
room_type,
children_state,
encryption,
room_version,
allowed_room_ids,
}
SpaceHierarchyRoomsChunk::new(summary.summary, summary.children_state)
}
+6 -15
View File
@@ -1,23 +1,16 @@
use std::str::FromStr;
use ruma::{
UInt,
api::federation::space::{SpaceHierarchyParentSummary, SpaceHierarchyParentSummaryInit},
owned_room_id, owned_server_name,
space::SpaceRoomJoinRule,
UInt, api::federation::space::SpaceHierarchyParentSummary, owned_room_id, owned_server_name, room::{JoinRuleSummary, RoomSummary},
};
use crate::rooms::spaces::{PaginationToken, get_parent_children_via};
#[test]
fn get_summary_children() {
let summary: SpaceHierarchyParentSummary = SpaceHierarchyParentSummaryInit {
num_joined_members: UInt::from(1_u32),
room_id: owned_room_id!("!root:example.org"),
world_readable: true,
guest_can_join: true,
join_rule: SpaceRoomJoinRule::Public,
children_state: vec![
let summary = SpaceHierarchyParentSummary::new(
RoomSummary::new(owned_room_id!("!root:example.org"), JoinRuleSummary::Public, true, UInt::from(1_u32), true),
vec![
serde_json::from_str(
r#"{
"content": {
@@ -62,10 +55,8 @@ fn get_summary_children() {
}"#,
)
.unwrap(),
],
allowed_room_ids: vec![],
}
.into();
]
);
assert_eq!(
get_parent_children_via(&summary, false)
+9 -11
View File
@@ -1,7 +1,7 @@
use std::{collections::HashMap, fmt::Write, iter::once, sync::Arc};
use async_trait::async_trait;
use conduwuit::{RoomVersion, debug};
use conduwuit::debug;
use conduwuit_core::{
Event, PduEvent, Result, err,
result::FlatOk,
@@ -17,12 +17,10 @@ use futures::{
FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt, future::join_all, pin_mut,
};
use ruma::{
EventId, OwnedEventId, OwnedRoomId, RoomId, RoomVersionId, UserId,
events::{
EventId, OwnedEventId, OwnedRoomId, RoomId, RoomVersionId, UserId, events::{
AnyStrippedStateEvent, StateEventType, TimelineEventType,
room::create::RoomCreateEventContent,
},
serde::Raw,
}, room_version_rules::RoomVersionRules, serde::Raw
};
use crate::{
@@ -128,7 +126,7 @@ impl Service {
self.services
.state_cache
.update_membership(room_id, user_id, &pdu, false)
.update_membership(room_id, &user_id, &pdu, false)
.await?;
},
| TimelineEventType::SpaceChild => {
@@ -381,13 +379,13 @@ impl Service {
pub fn get_forward_extremities<'a>(
&'a self,
room_id: &'a RoomId,
) -> impl Stream<Item = &'a EventId> + Send + 'a {
) -> impl Stream<Item = OwnedEventId> + Send + 'a {
let prefix = (room_id, Interfix);
self.db
.roomid_pduleaves
.keys_prefix(&prefix)
.map_ok(|(_, event_id): (Ignore, &EventId)| event_id)
.map_ok(|(_, event_id): (Ignore, OwnedEventId)| event_id)
.ignore_err()
}
@@ -414,7 +412,7 @@ impl Service {
}
/// This fetches auth events from the current state.
#[tracing::instrument(skip(self, content, room_version), level = "trace")]
#[tracing::instrument(skip(self, content, room_version_rules), level = "trace")]
pub async fn get_auth_events(
&self,
room_id: &RoomId,
@@ -422,14 +420,14 @@ impl Service {
sender: &UserId,
state_key: Option<&str>,
content: &serde_json::value::RawValue,
room_version: &RoomVersion,
room_version_rules: &RoomVersionRules,
) -> Result<StateMap<PduEvent>> {
let Ok(shortstatehash) = self.get_room_shortstatehash(room_id).await else {
return Ok(HashMap::new());
};
let auth_types =
state_res::auth_types_for_event(kind, sender, state_key, content, room_version)?;
state_res::auth_types_for_event(kind, sender, state_key, content, room_version_rules)?;
debug!(?auth_types, "Auth types for event");
let sauthevents: HashMap<_, _> = auth_types
.iter()
+39 -16
View File
@@ -3,29 +3,18 @@ mod server_can;
mod state;
mod user_can;
use std::sync::Arc;
use std::{collections::HashSet, sync::Arc};
use async_trait::async_trait;
use conduwuit::{Result, err};
use conduwuit::{Event, Result, err};
use database::Map;
use ruma::{
EventEncryptionAlgorithm, JsOption, OwnedRoomAliasId, RoomId, UserId,
events::{
EventEncryptionAlgorithm, JsOption, OwnedRoomAliasId, OwnedUserId, RoomId, UserId, events::{
StateEventType,
room::{
avatar::RoomAvatarEventContent,
canonical_alias::RoomCanonicalAliasEventContent,
create::RoomCreateEventContent,
encryption::RoomEncryptionEventContent,
guest_access::{GuestAccess, RoomGuestAccessEventContent},
history_visibility::{HistoryVisibility, RoomHistoryVisibilityEventContent},
join_rules::{JoinRule, RoomJoinRulesEventContent},
member::RoomMemberEventContent,
name::RoomNameEventContent,
topic::RoomTopicEventContent,
avatar::RoomAvatarEventContent, canonical_alias::RoomCanonicalAliasEventContent, create::{RoomCreateEvent, RoomCreateEventContent}, encryption::RoomEncryptionEventContent, guest_access::{GuestAccess, RoomGuestAccessEventContent}, history_visibility::{HistoryVisibility, RoomHistoryVisibilityEventContent}, join_rules::{JoinRule, RoomJoinRulesEventContent}, member::RoomMemberEventContent, name::RoomNameEventContent, pinned_events::RoomPinnedEventsEventContent, power_levels::{RoomPowerLevels, RoomPowerLevelsEventContent}, topic::RoomTopicEventContent
},
},
room::RoomType,
}, room::RoomType
};
use crate::{Dep, rooms};
@@ -162,4 +151,38 @@ impl Service {
.await
.is_ok()
}
/// Get a set of the room's creators. This will always contain a single user for room versions 11 and earlier.
pub async fn get_room_creators(&self, room_id: &RoomId) -> HashSet<OwnedUserId> {
let room_version_rules = self.services.state.get_room_version(room_id).await.expect("room should have a version").rules().expect("room version should be known");
let create_event = self.room_state_get(room_id, &StateEventType::RoomCreate, "").await.expect("room should have a create event");
let create_content: RoomCreateEventContent = create_event.get_content().expect("create event content should be valid");
let mut creators = HashSet::new();
if room_version_rules.authorization.use_room_create_sender {
creators.insert(create_event.sender);
} else {
#[allow(deprecated)]
creators.insert(create_content.creator.unwrap());
}
if room_version_rules.authorization.additional_room_creators {
creators.extend(create_content.additional_creators);
}
creators
}
/// Get the room's power levels. This will never fail -- if the room has no power level state event,
/// the default power levels for the room's version will be returned.
pub async fn get_room_power_levels(&self, room_id: &RoomId) -> RoomPowerLevels {
let room_version_rules = self.services.state.get_room_version(room_id).await.expect("room should have a version").rules().expect("room version should be known");
let creators = self.get_room_creators(room_id).await;
let power_levels_event: RoomPowerLevelsEventContent = self.room_state_get_content(room_id, &StateEventType::RoomPowerLevels, "")
.await
.unwrap_or_else(|_| RoomPowerLevelsEventContent::new(&room_version_rules.authorization));
RoomPowerLevels::new(power_levels_event.into(), &room_version_rules.authorization, creators)
}
}
@@ -44,13 +44,17 @@ pub async fn server_can_see_event(
| HistoryVisibility::Invited => {
// Allow if any member on requesting server was AT LEAST invited, else deny
current_server_members
.any(|member| self.user_was_invited(shortstatehash, member))
.any(async |member| {
self.user_was_invited(shortstatehash, &member).await
})
.await
},
| HistoryVisibility::Joined => {
// Allow if any member on requested server was joined, else deny
current_server_members
.any(|member| self.user_was_joined(shortstatehash, member))
.any(async |member| {
self.user_was_joined(shortstatehash, &member).await
})
.await
},
| HistoryVisibility::WorldReadable | HistoryVisibility::Shared | _ => true,
+1 -1
View File
@@ -316,7 +316,7 @@ pub fn state_full(
shortstatehash: ShortStateHash,
) -> impl Stream<Item = ((StateEventType, StateKey), impl Event)> + Send + '_ {
self.state_full_pdus(shortstatehash)
.ready_filter_map(|pdu| Some(((pdu.kind().clone().into(), pdu.state_key()?.into()), pdu)))
.ready_filter_map(|pdu| Some(((pdu.kind().to_string().into(), pdu.state_key()?.into()), pdu)))
}
#[implement(super::Service)]
+32 -41
View File
@@ -1,4 +1,4 @@
use conduwuit::{Err, Result, RoomVersion, implement, matrix::Event, pdu::PduBuilder};
use conduwuit::{Err, Result, implement, matrix::Event, pdu::PduBuilder};
use ruma::{
EventId, RoomId, UserId,
events::{
@@ -7,7 +7,6 @@ use ruma::{
create::RoomCreateEventContent,
history_visibility::{HistoryVisibility, RoomHistoryVisibilityEventContent},
member::{MembershipState, RoomMemberEventContent},
power_levels::{RoomPowerLevels, RoomPowerLevelsEventContent},
},
},
};
@@ -45,53 +44,45 @@ pub async fn user_can_redact(
)));
}
let room_create = self
.room_state_get(room_id, &StateEventType::RoomCreate, "")
.await?;
let create_content: RoomCreateEventContent =
serde_json::from_str(room_create.content().get())?;
let room_features = RoomVersion::new(&create_content.room_version)?;
if room_features.explicitly_privilege_room_creators {
let create_event = self.room_state_get(room_id, &StateEventType::RoomCreate, "").await?;
let create_event_content: RoomCreateEventContent = create_event.get_content().unwrap();
let room_version_rules = create_event_content.room_version.rules().expect("room version should have defined rules");
if room_version_rules.authorization.explicitly_privilege_room_creators {
let sender_owned = sender.to_owned();
if sender == room_create.sender()
|| create_content
// NOTE: we don't check the pre-v11 `creator` field because no room version has
// `explicitly_privilege_room_creators` and `use_room_create_sender` set at the same time
if sender == create_event.sender()
|| create_event_content
.additional_creators
.is_some_and(|cs| cs.contains(&sender_owned))
.contains(&sender_owned)
{
return Ok(true);
}
}
match self
.room_state_get_content::<RoomPowerLevelsEventContent>(
room_id,
&StateEventType::RoomPowerLevels,
"",
)
.await
{
| Ok(pl_event_content) => {
let pl_event: RoomPowerLevels = pl_event_content.into();
Ok(pl_event.user_can_redact_event_of_other(sender)
|| pl_event.user_can_redact_own_event(sender)
&& match redacting_event {
| Ok(redacting_event) =>
if federation {
redacting_event.sender().server_name() == sender.server_name()
} else {
redacting_event.sender() == sender
},
| _ => false,
})
},
| _ => {
// Falling back on m.room.create to judge power level
Ok(room_create.sender() == sender
|| redacting_event
.as_ref()
.is_ok_and(|redacting_event| redacting_event.sender() == sender))
},
let power_levels = self.get_room_power_levels(room_id).await;
if power_levels.user_can_redact_event_of_other(sender) {
return Ok(true);
}
if power_levels.user_can_redact_own_event(sender) {
let is_own_event = match redacting_event {
Ok(redacting_event) => {
if federation {
redacting_event.sender().server_name() == sender.server_name()
} else {
redacting_event.sender() == sender
}
},
_ => false
};
return Ok(is_own_event);
}
return Ok(false);
}
/// Whether a user is allowed to see an event, based on
+39 -35
View File
@@ -12,9 +12,7 @@ use conduwuit::{
use database::{Deserialized, Ignore, Interfix, Map};
use futures::{Stream, StreamExt, future::join5, pin_mut};
use ruma::{
OwnedRoomId, OwnedUserId, RoomId, ServerName, UserId,
events::{AnyStrippedStateEvent, room::member::MembershipState},
serde::Raw,
OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, events::{AnyStrippedStateEvent, room::member::MembershipState}, serde::Raw
};
use crate::{Dep, account_data, appservice::RegistrationInfo, config, globals, rooms, users};
@@ -147,13 +145,13 @@ pub fn clear_appservice_in_room_cache(&self) { self.appservice_in_room_cache.wri
pub fn room_servers<'a>(
&'a self,
room_id: &'a RoomId,
) -> impl Stream<Item = &'a ServerName> + Send + 'a {
) -> impl Stream<Item = OwnedServerName> + Send + 'a {
let prefix = (room_id, Interfix);
self.db
.roomserverids
.keys_prefix(&prefix)
.ignore_err()
.map(|(_, server): (Ignore, &ServerName)| server)
.map(|(_, server): (Ignore, OwnedServerName)| server)
}
#[implement(Service)]
@@ -170,13 +168,13 @@ pub async fn server_in_room<'a>(&'a self, server: &'a ServerName, room_id: &'a R
pub fn server_rooms<'a>(
&'a self,
server: &'a ServerName,
) -> impl Stream<Item = &'a RoomId> + Send + 'a {
) -> impl Stream<Item = OwnedRoomId> + Send + 'a {
let prefix = (server, Interfix);
self.db
.serverroomids
.keys_prefix(&prefix)
.ignore_err()
.map(|(_, room_id): (Ignore, &RoomId)| room_id)
.map(|(_, room_id): (Ignore, OwnedRoomId)| room_id)
}
/// Returns true if server can see user by sharing at least one room.
@@ -184,7 +182,7 @@ pub fn server_rooms<'a>(
#[tracing::instrument(skip(self), level = "trace")]
pub async fn server_sees_user(&self, server: &ServerName, user_id: &UserId) -> bool {
self.server_rooms(server)
.any(|room_id| self.is_joined(user_id, room_id))
.any(async |room_id| self.is_joined(user_id, &room_id).await)
.await
}
@@ -205,7 +203,7 @@ pub fn get_shared_rooms<'a>(
&'a self,
user_a: &'a UserId,
user_b: &'a UserId,
) -> impl Stream<Item = &'a RoomId> + Send + 'a {
) -> impl Stream<Item = OwnedRoomId> + Send + 'a {
use conduwuit::utils::set;
let a = self.rooms_joined(user_a);
@@ -219,13 +217,13 @@ pub fn get_shared_rooms<'a>(
pub fn room_members<'a>(
&'a self,
room_id: &'a RoomId,
) -> impl Stream<Item = &'a UserId> + Send + 'a {
) -> impl Stream<Item = OwnedUserId> + Send + 'a {
let prefix = (room_id, Interfix);
self.db
.roomuserid_joined
.keys_prefix(&prefix)
.ignore_err()
.map(|(_, user_id): (Ignore, &UserId)| user_id)
.map(|(_, user_id): (Ignore, OwnedUserId)| user_id)
}
/// Returns the number of users which are currently in a room
@@ -242,7 +240,7 @@ pub async fn room_joined_count(&self, room_id: &RoomId) -> Result<u64> {
pub fn local_users_in_room<'a>(
&'a self,
room_id: &'a RoomId,
) -> impl Stream<Item = &'a UserId> + Send + 'a {
) -> impl Stream<Item = OwnedUserId> + Send + 'a {
self.room_members(room_id)
.ready_filter(|user| self.services.globals.user_is_local(user))
}
@@ -254,9 +252,15 @@ pub fn local_users_in_room<'a>(
pub fn active_local_users_in_room<'a>(
&'a self,
room_id: &'a RoomId,
) -> impl Stream<Item = &'a UserId> + Send + 'a {
) -> impl Stream<Item = OwnedUserId> + Send + 'a {
self.local_users_in_room(room_id)
.filter(|user| self.services.users.is_active(user))
.filter_map(async |user_id| {
if self.services.users.is_active(&user_id).await {
Some(user_id)
} else {
None
}
})
}
/// Returns the number of users which are currently invited to a room
@@ -276,13 +280,13 @@ pub async fn room_invited_count(&self, room_id: &RoomId) -> Result<u64> {
pub fn room_useroncejoined<'a>(
&'a self,
room_id: &'a RoomId,
) -> impl Stream<Item = &'a UserId> + Send + 'a {
) -> impl Stream<Item = OwnedUserId> + Send + 'a {
let prefix = (room_id, Interfix);
self.db
.roomuseroncejoinedids
.keys_prefix(&prefix)
.ignore_err()
.map(|(_, user_id): (Ignore, &UserId)| user_id)
.map(|(_, user_id): (Ignore, OwnedUserId)| user_id)
}
/// Returns an iterator over all invited members of a room.
@@ -291,13 +295,13 @@ pub fn room_useroncejoined<'a>(
pub fn room_members_invited<'a>(
&'a self,
room_id: &'a RoomId,
) -> impl Stream<Item = &'a UserId> + Send + 'a {
) -> impl Stream<Item = OwnedUserId> + Send + 'a {
let prefix = (room_id, Interfix);
self.db
.roomuserid_invitecount
.keys_prefix(&prefix)
.ignore_err()
.map(|(_, user_id): (Ignore, &UserId)| user_id)
.map(|(_, user_id): (Ignore, OwnedUserId)| user_id)
}
/// Returns an iterator over all knocked members of a room.
@@ -306,13 +310,13 @@ pub fn room_members_invited<'a>(
pub fn room_members_knocked<'a>(
&'a self,
room_id: &'a RoomId,
) -> impl Stream<Item = &'a UserId> + Send + 'a {
) -> impl Stream<Item = OwnedUserId> + Send + 'a {
let prefix = (room_id, Interfix);
self.db
.roomuserid_knockedcount
.keys_prefix(&prefix)
.ignore_err()
.map(|(_, user_id): (Ignore, &UserId)| user_id)
.map(|(_, user_id): (Ignore, OwnedUserId)| user_id)
}
#[implement(Service)]
@@ -350,12 +354,12 @@ pub async fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result
pub fn rooms_joined<'a>(
&'a self,
user_id: &'a UserId,
) -> impl Stream<Item = &'a RoomId> + Send + 'a {
) -> impl Stream<Item = OwnedRoomId> + Send + 'a {
self.db
.userroomid_joined
.keys_raw_prefix(user_id)
.ignore_err()
.map(|(_, room_id): (Ignore, &RoomId)| room_id)
.map(|(_, room_id): (Ignore, OwnedRoomId)| room_id)
}
/// Returns an iterator over all rooms a user was invited to.
@@ -365,16 +369,16 @@ pub fn rooms_invited<'a>(
&'a self,
user_id: &'a UserId,
) -> impl Stream<Item = StrippedStateEventItem> + Send + 'a {
type KeyVal<'a> = (Key<'a>, Raw<Vec<AnyStrippedStateEvent>>);
type Key<'a> = (&'a UserId, &'a RoomId);
type KeyVal = (Key, Raw<Vec<AnyStrippedStateEvent>>);
type Key = (OwnedUserId, OwnedRoomId);
let prefix = (user_id, Interfix);
self.db
.userroomid_invitestate
.stream_prefix(&prefix)
.ignore_err()
.map(|((_, room_id), state): KeyVal<'_>| (room_id.to_owned(), state))
.map(|(room_id, state)| Ok((room_id, state.deserialize_as()?)))
.map(|((_, room_id), state): KeyVal| (room_id, state))
.map(|(room_id, state)| Ok((room_id, state.deserialize_as_unchecked()?)))
.ignore_err()
}
@@ -385,16 +389,16 @@ pub fn rooms_knocked<'a>(
&'a self,
user_id: &'a UserId,
) -> impl Stream<Item = StrippedStateEventItem> + Send + 'a {
type KeyVal<'a> = (Key<'a>, Raw<Vec<AnyStrippedStateEvent>>);
type Key<'a> = (&'a UserId, &'a RoomId);
type KeyVal = (Key, Raw<Vec<AnyStrippedStateEvent>>);
type Key = (OwnedUserId, OwnedRoomId);
let prefix = (user_id, Interfix);
self.db
.userroomid_knockedstate
.stream_prefix(&prefix)
.ignore_err()
.map(|((_, room_id), state): KeyVal<'_>| (room_id.to_owned(), state))
.map(|(room_id, state)| Ok((room_id, state.deserialize_as()?)))
.map(|((_, room_id), state): KeyVal| (room_id, state))
.map(|(room_id, state)| Ok((room_id, state.deserialize_as_unchecked()?)))
.ignore_err()
}
@@ -411,7 +415,7 @@ pub async fn invite_state(
.qry(&key)
.await
.deserialized()
.and_then(|val: Raw<Vec<AnyStrippedStateEvent>>| val.deserialize_as().map_err(Into::into))
.and_then(|val: Raw<Vec<AnyStrippedStateEvent>>| val.deserialize_as_unchecked().map_err(Into::into))
}
#[implement(Service)]
@@ -427,7 +431,7 @@ pub async fn knock_state(
.qry(&key)
.await
.deserialized()
.and_then(|val: Raw<Vec<AnyStrippedStateEvent>>| val.deserialize_as().map_err(Into::into))
.and_then(|val: Raw<Vec<AnyStrippedStateEvent>>| val.deserialize_as_unchecked().map_err(Into::into))
}
#[implement(Service)]
@@ -444,15 +448,15 @@ pub fn rooms_left<'a>(
&'a self,
user_id: &'a UserId,
) -> impl Stream<Item = (OwnedRoomId, Option<Pdu>)> + Send + 'a {
type KeyVal<'a> = (Key<'a>, Raw<Option<Pdu>>);
type Key<'a> = (&'a UserId, &'a RoomId);
type KeyVal = (Key, Raw<Option<Pdu>>);
type Key = (OwnedUserId, OwnedRoomId);
let prefix = (user_id, Interfix);
self.db
.userroomid_leftstate
.stream_prefix(&prefix)
.ignore_err()
.map(|((_, room_id), state): KeyVal<'_>| (room_id.to_owned(), state))
.map(|((_, room_id), state): KeyVal| (room_id, state))
.map(|(room_id, state)| Ok((room_id, state.deserialize()?)))
.ignore_err()
}
+3 -3
View File
@@ -9,7 +9,6 @@ use ruma::{
AnyStrippedStateEvent, GlobalAccountDataEventType, RoomAccountDataEventType,
StateEventType,
direct::DirectEvent,
invite_permission_config::FilterLevel,
room::{
create::RoomCreateEventContent,
member::{MembershipState, RoomMemberEventContent},
@@ -17,6 +16,7 @@ use ruma::{
},
serde::Raw,
};
use ruminuwuity::invite_permission_config::FilterLevel;
/// Update current membership data.
#[implement(super::Service)]
@@ -174,12 +174,12 @@ pub async fn update_joined_count(&self, room_id: &RoomId) {
self.room_servers(room_id)
.ready_for_each(|old_joined_server| {
if joined_servers.remove(old_joined_server) {
if joined_servers.remove(&old_joined_server) {
return;
}
// Server not in room anymore
let roomserver_id = (room_id, old_joined_server);
let roomserver_id = (room_id, old_joined_server.clone());
let serverroom_id = (old_joined_server, room_id);
self.db.roomserverids.del(roomserver_id);
+3 -4
View File
@@ -17,7 +17,6 @@ use ruma::{
pub async fn add_servers_invite_via(&self, room_id: &RoomId, servers: Vec<OwnedServerName>) {
let mut servers: Vec<_> = self
.servers_invite_via(room_id)
.map(ToOwned::to_owned)
.chain(iter(servers.into_iter()))
.collect()
.await;
@@ -81,12 +80,12 @@ pub async fn servers_route_via(&self, room_id: &RoomId) -> Result<Vec<OwnedServe
pub fn servers_invite_via<'a>(
&'a self,
room_id: &'a RoomId,
) -> impl Stream<Item = &'a ServerName> + Send + 'a {
type KeyVal<'a> = (Ignore, Vec<&'a ServerName>);
) -> impl Stream<Item = OwnedServerName> + Send + 'a {
type KeyVal = (Ignore, Vec<OwnedServerName>);
self.db
.roomid_inviteviaservers
.stream_raw_prefix(room_id)
.ignore_err()
.map(|(_, servers): KeyVal<'_>| *servers.last().expect("at least one server"))
.map(|(_, mut servers): KeyVal| servers.pop().expect("at least one server"))
}
+3 -8
View File
@@ -11,8 +11,7 @@ use conduwuit_core::{
use conduwuit_database::{Deserialized, Map};
use futures::{Stream, StreamExt};
use ruma::{
CanonicalJsonValue, EventId, OwnedUserId, RoomId, UserId,
api::client::threads::get_threads::v1::IncludeThreads, events::relation::BundledThread, uint,
CanonicalJsonValue, EventId, OwnedUserId, RoomId, UserId, api::client::threads::get_threads::v1::IncludeThreads, events::relation::BundledThread, serde::Raw, uint
};
use serde_json::json;
@@ -89,7 +88,7 @@ impl Service {
}) {
// Thread already existed
relations.count = relations.count.saturating_add(uint!(1));
relations.latest_event = event.to_format();
relations.latest_event = Raw::from_json(event.content().to_owned());
let content = serde_json::to_value(relations).expect("to_value always works");
@@ -101,11 +100,7 @@ impl Service {
);
} else {
// New thread
let relations = BundledThread {
latest_event: event.to_format(),
count: uint!(1),
current_user_participated: true,
};
let relations = BundledThread::new(Raw::from_json(event.content().to_owned()), uint!(1), true);
let content = serde_json::to_value(relations).expect("to_value always works");
+6 -13
View File
@@ -16,10 +16,10 @@ use futures::StreamExt;
use ruma::{
CanonicalJsonObject, CanonicalJsonValue, EventId, RoomVersionId, UserId,
events::{
GlobalAccountDataEventType, StateEventType, TimelineEventType,
GlobalAccountDataEventType, TimelineEventType,
push_rules::PushRulesEvent,
room::{
encrypted::Relation, power_levels::RoomPowerLevelsEventContent,
encrypted::Relation,
redaction::RoomRedactionEventContent,
},
},
@@ -204,18 +204,11 @@ where
drop(insert_lock);
// See if the event matches any known pushers via power level
let power_levels: RoomPowerLevelsEventContent = self
.services
.state_accessor
.room_state_get_content(room_id, &StateEventType::RoomPowerLevels, "")
.await
.unwrap_or_default();
let power_levels = self.services.state_accessor.get_room_power_levels(room_id).await;
let mut push_target: HashSet<_> = self
.services
.state_cache
.active_local_users_in_room(room_id)
.map(ToOwned::to_owned)
// Don't notify the sender of their own events, and dont send from ignored users
.ready_filter(|user| *user != pdu.sender())
.filter_map(|recipient_user| async move { (!self.services.users.user_is_ignored(pdu.sender(), &recipient_user).await).then_some(recipient_user) })
@@ -229,7 +222,7 @@ where
if let Some(state_key) = pdu.state_key() {
let target_user_id = UserId::parse(state_key)?;
if self.services.users.is_active_local(target_user_id).await {
if self.services.users.is_active_local(&target_user_id).await {
push_target.insert(target_user_id.to_owned());
}
}
@@ -253,7 +246,7 @@ where
for action in self
.services
.pusher
.get_actions(user, &rules_for_user, &power_levels, &serialized, room_id)
.get_actions(user, &rules_for_user, power_levels.clone(), &serialized, room_id)
.await
{
match action {
@@ -346,7 +339,7 @@ where
// knock event for auth
self.services
.state_cache
.update_membership(room_id, target_user_id, pdu, true)
.update_membership(room_id, &target_user_id, pdu, true)
.await?;
}
},
+60 -155
View File
@@ -1,6 +1,6 @@
use std::iter::once;
use std::{collections::HashSet, iter::once};
use conduwuit::{Err, PduEvent, RoomVersion};
use conduwuit::{Err, PduEvent};
use conduwuit_core::{
Result, debug, debug_warn, err, implement, info,
matrix::{
@@ -10,15 +10,12 @@ use conduwuit_core::{
utils::{IterStream, ReadyExt},
validated, warn,
};
use futures::{FutureExt, StreamExt};
use futures::{FutureExt, Stream, StreamExt};
use ruma::{
CanonicalJsonObject, EventId, Int, RoomId, ServerName,
api::federation,
events::{
CanonicalJsonObject, EventId, Int, OwnedServerName, RoomId, ServerName, api::federation, events::{
StateEventType, TimelineEventType,
room::{create::RoomCreateEventContent, power_levels::RoomPowerLevelsEventContent},
},
uint,
room::{create::RoomCreateEventContent, power_levels::{RoomPowerLevelsEventContent, UserPowerLevel}},
}, uint
};
use serde_json::value::RawValue as RawJsonValue;
@@ -55,94 +52,12 @@ pub async fn backfill_if_required(&self, room_id: &RoomId, from: PduCount) -> Re
return Ok(());
}
let power_levels: RoomPowerLevelsEventContent = self
.services
.state_accessor
.room_state_get_content(room_id, &StateEventType::RoomPowerLevels, "")
.await
.unwrap_or_default();
let create_event_content: RoomCreateEventContent = self
.services
.state_accessor
.room_state_get_content(room_id, &StateEventType::RoomCreate, "")
.await?;
let create_event = self
.services
.state_accessor
.room_state_get(room_id, &StateEventType::RoomCreate, "")
.await?;
let room_version =
RoomVersion::new(&create_event_content.room_version).expect("supported room version");
let mut users = power_levels.users.clone();
if room_version.explicitly_privilege_room_creators {
users.insert(create_event.sender().to_owned(), Int::MAX);
if let Some(additional_creators) = &create_event_content.additional_creators {
for user_id in additional_creators {
users.insert(user_id.to_owned(), Int::MAX);
}
}
}
let room_mods = users.iter().filter_map(|(user_id, level)| {
let remote_powered =
level > &power_levels.users_default && !self.services.globals.user_is_local(user_id);
let creator = if room_version.explicitly_privilege_room_creators {
create_event.sender() == user_id
|| create_event_content
.additional_creators
.as_ref()
.is_some_and(|c| c.contains(user_id))
} else {
false
};
if remote_powered || creator {
debug!(%remote_powered, %creator, "User {user_id} can backfill in room {room_id}");
Some(user_id.server_name())
} else {
debug!(%remote_powered, %creator, "User {user_id} cannot backfill in room {room_id}");
None
}
});
let canonical_room_alias_server = once(
self.services
.state_accessor
.get_canonical_alias(room_id)
.await,
)
.filter_map(Result::ok)
.map(|alias| alias.server_name().to_owned())
.stream();
let mut servers = room_mods
.stream()
.map(ToOwned::to_owned)
.chain(canonical_room_alias_server)
.chain(
self.services
.server
.config
.trusted_servers
.iter()
.map(ToOwned::to_owned)
.stream(),
)
.ready_filter(|server_name| !self.services.globals.server_is_ours(server_name))
.filter_map(|server_name| async move {
self.services
.state_cache
.server_in_room(&server_name, room_id)
.await
.then_some(server_name)
})
.boxed();
let mut servers = self.candidate_backfill_servers(room_id).await;
let mut federated_room = false;
while let Some(ref backfill_server) = servers.next().await {
if !self.services.globals.server_is_ours(backfill_server) {
for backfill_server in servers {
if !self.services.globals.server_is_ours(&backfill_server) {
federated_room = true;
}
info!("Asking {backfill_server} for backfill in {room_id}");
@@ -150,18 +65,14 @@ pub async fn backfill_if_required(&self, room_id: &RoomId, from: PduCount) -> Re
.services
.sending
.send_federation_request(
backfill_server,
federation::backfill::get_backfill::v1::Request {
room_id: room_id.to_owned(),
v: vec![first_pdu.1.event_id().to_owned()],
limit: uint!(100),
},
&backfill_server,
federation::backfill::get_backfill::v1::Request::new(room_id.to_owned(), vec![first_pdu.1.event_id().to_owned()], uint!(100))
)
.await;
match response {
| Ok(response) => {
for pdu in response.pdus {
if let Err(e) = self.backfill_pdu(backfill_server, pdu).boxed().await {
if let Err(e) = self.backfill_pdu(&backfill_server, pdu).boxed().await {
debug_warn!("Failed to add backfilled pdu in room {room_id}: {e}");
}
}
@@ -207,62 +118,14 @@ pub async fn get_remote_pdu(&self, room_id: &RoomId, event_id: &EventId) -> Resu
return Err!(Request(NotFound("No one can backfill this PDU, room is empty.")));
}
let power_levels: RoomPowerLevelsEventContent = self
.services
.state_accessor
.room_state_get_content(room_id, &StateEventType::RoomPowerLevels, "")
.await
.unwrap_or_default();
let mut servers = self.candidate_backfill_servers(room_id).await;
let room_mods = power_levels.users.iter().filter_map(|(user_id, level)| {
if level > &power_levels.users_default && !self.services.globals.user_is_local(user_id) {
Some(user_id.server_name())
} else {
None
}
});
let canonical_room_alias_server = once(
self.services
.state_accessor
.get_canonical_alias(room_id)
.await,
)
.filter_map(Result::ok)
.map(|alias| alias.server_name().to_owned())
.stream();
let mut servers = room_mods
.stream()
.map(ToOwned::to_owned)
.chain(canonical_room_alias_server)
.chain(
self.services
.server
.config
.trusted_servers
.iter()
.map(ToOwned::to_owned)
.stream(),
)
.ready_filter(|server_name| !self.services.globals.server_is_ours(server_name))
.filter_map(|server_name| async move {
self.services
.state_cache
.server_in_room(&server_name, room_id)
.await
.then_some(server_name)
})
.boxed();
while let Some(ref backfill_server) = servers.next().await {
for backfill_server in servers {
info!("Asking {backfill_server} for event {}", event_id);
let value = self
.services
.sending
.send_federation_request(backfill_server, federation::event::get_event::v1::Request {
event_id: event_id.to_owned(),
include_unredacted_content: Some(false),
})
.send_federation_request(&backfill_server, federation::event::get_event::v1::Request::new(event_id.to_owned()))
.await
.and_then(|response| {
serde_json::from_str::<CanonicalJsonObject>(response.pdu.get()).map_err(|e| {
@@ -275,7 +138,7 @@ pub async fn get_remote_pdu(&self, room_id: &RoomId, event_id: &EventId) -> Resu
| Ok(value) => {
self.services
.event_handler
.handle_incoming_pdu(backfill_server, room_id, event_id, value, false)
.handle_incoming_pdu(&backfill_server, room_id, event_id, value, false)
.boxed()
.await?;
debug!("Successfully backfilled {event_id} from {backfill_server}");
@@ -305,7 +168,7 @@ pub async fn backfill_pdu(&self, origin: &ServerName, pdu: Box<RawJsonValue>) ->
.services
.event_handler
.mutex_federation
.lock(&room_id)
.lock(room_id.as_str())
.await;
// Skip the PDU if we already have it as a timeline event
@@ -326,7 +189,7 @@ pub async fn backfill_pdu(&self, origin: &ServerName, pdu: Box<RawJsonValue>) ->
let shortroomid = self.services.short.get_shortroomid(&room_id).await?;
let insert_lock = self.mutex_insert.lock(&room_id).await;
let insert_lock = self.mutex_insert.lock(room_id.as_str()).await;
let count: i64 = self.services.globals.next_count().unwrap().try_into()?;
@@ -352,3 +215,45 @@ pub async fn backfill_pdu(&self, origin: &ServerName, pdu: Box<RawJsonValue>) ->
debug!("Prepended backfill pdu");
Ok(())
}
#[implement(super::Service)]
async fn candidate_backfill_servers(&self, room_id: &RoomId) -> HashSet<OwnedServerName> {
let mut candidate_backfill_servers = HashSet::new();
let power_levels = self.services.state_accessor.get_room_power_levels(room_id).await;
// Insert servers of room creators
if let Some(creators) = &power_levels.rules.privileged_creators {
for creator in creators {
candidate_backfill_servers.insert(creator.server_name().to_owned());
}
}
// Insert servers of remote users with higher-than-default PL
for (user_id, level) in &power_levels.users {
if !self.services.globals.user_is_local(user_id) && *level > power_levels.users_default {
candidate_backfill_servers.insert(user_id.server_name().to_owned());
}
}
// Insert the canonical room alias server
if let Ok(canonical_alias) = self.services.state_accessor.get_canonical_alias(room_id).await {
candidate_backfill_servers.insert(canonical_alias.server_name().to_owned());
}
// Insert all trusted servers in the config
candidate_backfill_servers.extend(self.services.server.config.trusted_servers.iter().cloned());
// Remove our own name, we can't request backfill from ourselves
candidate_backfill_servers.remove(self.services.globals.server_name());
// Remove all servers that aren't in the room
for server in candidate_backfill_servers.clone() {
if !self.services.state_cache.server_in_room(&server, room_id).await {
candidate_backfill_servers.remove(&server);
}
}
debug!(?candidate_backfill_servers, "Found candidate servers for backfill");
candidate_backfill_servers
}
+1 -2
View File
@@ -157,7 +157,6 @@ pub async fn build_and_append_pdu(
.services
.state_cache
.room_servers(&room_id)
.map(ToOwned::to_owned)
.collect()
.await;
@@ -180,7 +179,7 @@ pub async fn build_and_append_pdu(
trace!("Sending PDU {} to {} servers", pdu.event_id(), servers.len());
self.services
.sending
.send_pdu_servers(servers.iter().map(AsRef::as_ref).stream(), &pdu_id)
.send_pdu_servers(servers.stream(), &pdu_id)
.await?;
trace!("Event {} in room {:?} has been appended", pdu.event_id(), room_id);
+9 -7
View File
@@ -6,7 +6,7 @@ use conduwuit_core::{
matrix::{
event::{Event, gen_event_id},
pdu::{EventHash, PduBuilder, PduEvent},
state_res::{self, RoomVersion},
state_res,
},
utils::{self, IterStream, ReadyExt, stream::TryIgnore},
warn,
@@ -90,7 +90,7 @@ pub async fn create_event(
redacts,
timestamp,
} = pdu_builder;
// If there was no create event yet, assume we are creating a room
trace!(
"Creating event of type {} in room {}",
event_type,
@@ -121,7 +121,9 @@ pub async fn create_event(
},
};
let room_version = RoomVersion::new(&room_version_id).expect("room version is supported");
let Some(room_version_rules) = room_version.rules() else {
return Err!(Request(UnsupportedRoomVersion("Unsupported room version")));
};
let prev_events: Vec<OwnedEventId> = match room_id {
| Some(room_id) =>
@@ -145,7 +147,7 @@ pub async fn create_event(
sender,
state_key.as_deref(),
&content,
&room_version,
&room_version_rules,
)
.await?,
| None => HashMap::new(),
@@ -242,7 +244,7 @@ pub async fn create_event(
};
let auth_check = state_res::auth_check(
&room_version,
&room_version_rules,
&pdu,
None, // TODO: third_party_invite
auth_fetch,
@@ -287,7 +289,7 @@ pub async fn create_hash_and_sign_event(
if let Err(e) = self
.services
.server_keys
.hash_and_sign_event(&mut pdu_json, &room_version_id)
.hash_and_sign_event(&mut pdu_json, &room_version)
{
return match e {
| Error::Signatures(ruma::signatures::Error::PduSize) => {
@@ -297,7 +299,7 @@ pub async fn create_hash_and_sign_event(
};
}
// Generate event id
pdu.event_id = gen_event_id(&pdu_json, &room_version_id)?;
pdu.event_id = gen_event_id(&pdu_json, &room_version)?;
pdu_json.insert("event_id".into(), CanonicalJsonValue::String(pdu.event_id.clone().into()));
// Verify that the *full* PDU isn't over 64KiB.
// Ruma only validates that it's under 64KiB before signing and hashing.
+4 -7
View File
@@ -8,7 +8,7 @@ use futures::StreamExt;
use ruma::{
OwnedRoomId, OwnedUserId, RoomId, UserId,
api::federation::transactions::edu::{Edu, TypingContent},
events::SyncEphemeralRoomEvent,
events::{SyncEphemeralRoomEvent, typing::TypingEventContent},
};
use tokio::sync::{RwLock, broadcast};
@@ -212,12 +212,9 @@ impl Service {
&self,
room_id: &RoomId,
sender_user: &UserId,
) -> Result<SyncEphemeralRoomEvent<ruma::events::typing::TypingEventContent>> {
Ok(SyncEphemeralRoomEvent {
content: ruma::events::typing::TypingEventContent {
user_ids: self.typing_users_for_user(room_id, sender_user).await?,
},
})
) -> Result<SyncEphemeralRoomEvent<TypingEventContent>> {
let user_ids = self.typing_users_for_user(room_id, sender_user).await?;
Ok(SyncEphemeralRoomEvent::new(TypingEventContent::new(user_ids)))
}
async fn federation_send(
+10 -4
View File
@@ -1,9 +1,11 @@
use std::{fmt::Debug, mem};
use std::{borrow::Cow, fmt::Debug, mem};
use bytes::BytesMut;
use conduwuit::{Err, Result, debug_error, err, utils, utils::response::LimitReadExt, warn};
use reqwest::Client;
use ruma::api::{IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken};
use ruma::api::{IncomingResponse, MatrixVersion, OutgoingRequest, auth_scheme::{AppserviceToken, SendAccessToken}, path_builder::VersionHistory};
use crate::SUPPORTED_VERSIONS;
/// Sends a request to an antispam service
pub(crate) async fn send_antispam_request<T>(
@@ -13,11 +15,15 @@ pub(crate) async fn send_antispam_request<T>(
request: T,
) -> Result<T::IncomingResponse>
where
T: OutgoingRequest + Debug + Send,
T: OutgoingRequest::<Authentication = AppserviceToken, PathBuilder = VersionHistory> + Debug + Send,
{
const VERSIONS: [MatrixVersion; 1] = [MatrixVersion::V1_15];
let http_request = request
.try_into_http_request::<BytesMut>(base_url, SendAccessToken::Always(secret), &VERSIONS)?
.try_into_http_request::<BytesMut>(
base_url,
SendAccessToken::Always(secret),
Cow::Borrowed(&SUPPORTED_VERSIONS),
)?
.map(BytesMut::freeze);
let reqwest_request = reqwest::Request::try_from(http_request)?;
+3 -5
View File
@@ -5,7 +5,7 @@ use conduwuit::{
Err, Result, debug_error, err, implement, trace, utils, utils::response::LimitReadExt, warn,
};
use ruma::api::{
IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken, appservice::Registration,
IncomingResponse, MatrixVersion, OutgoingRequest, appservice::Registration, auth_scheme::{AccessToken, SendAccessToken}, path_builder::SinglePath,
};
/// Sends a request to an appservice
@@ -19,10 +19,8 @@ pub async fn send_appservice_request<T>(
request: T,
) -> Result<Option<T::IncomingResponse>>
where
T: OutgoingRequest + Debug + Send,
T: OutgoingRequest<Authentication = AccessToken, PathBuilder = SinglePath> + Debug + Send,
{
const VERSIONS: [MatrixVersion; 1] = [MatrixVersion::V1_7];
let Some(dest) = registration.url else {
return Ok(None);
};
@@ -38,7 +36,7 @@ where
.try_into_http_request::<BytesMut>(
&dest,
SendAccessToken::Appservice(hs_token),
&VERSIONS,
(),
)
.map_err(|e| {
err!(BadServerResponse(
+1 -1
View File
@@ -245,7 +245,7 @@ fn parse_servercurrentevent(key: &[u8], value: &[u8]) -> Result<(Destination, Se
})?;
(
Destination::Federation(OwnedServerName::parse(&server).map_err(|_| {
Destination::Federation(ServerName::parse(&server).map_err(|_| {
Error::bad_database("Invalid server string in server_currenttransaction")
})?),
if value.is_empty() {
+30 -15
View File
@@ -19,7 +19,7 @@ use conduwuit::{
warn,
};
use futures::{FutureExt, Stream, StreamExt};
use ruma::{RoomId, ServerName, UserId, api::OutgoingRequest};
use ruma::{OwnedServerName, RoomId, ServerName, UserId, api::{OutgoingRequest, auth_scheme::NoAuthentication, federation::authentication::ServerSignatures, path_builder::PathBuilder}};
use tokio::{task, task::JoinSet};
use self::data::Data;
@@ -28,8 +28,8 @@ pub use self::{
sender::{EDU_LIMIT, PDU_LIMIT},
};
use crate::{
Dep, account_data, client, federation, globals, presence, pusher, rooms,
rooms::timeline::RawPduId, users,
Dep, account_data, client, federation::{self, FederationPathBuilderInput}, globals, presence, pusher, rooms::{self, timeline::RawPduId},
users,
};
pub struct Service {
@@ -187,9 +187,9 @@ impl Service {
}
#[tracing::instrument(skip(self, servers, pdu_id), level = "debug")]
pub async fn send_pdu_servers<'a, S>(&self, servers: S, pdu_id: &RawPduId) -> Result
pub async fn send_pdu_servers<S>(&self, servers: S, pdu_id: &RawPduId) -> Result
where
S: Stream<Item = &'a ServerName> + Send + 'a,
S: Stream<Item = OwnedServerName> + Send,
{
let requests = servers
.map(|server| {
@@ -233,14 +233,14 @@ impl Service {
}
#[tracing::instrument(skip(self, servers, serialized), level = "debug")]
pub async fn send_edu_servers<'a, S>(&self, servers: S, serialized: EduBuf) -> Result
pub async fn send_edu_servers<S>(&self, servers: S, serialized: EduBuf) -> Result
where
S: Stream<Item = &'a ServerName> + Send + 'a,
S: Stream<Item = OwnedServerName> + Send,
{
let requests = servers
.map(|server| {
(
Destination::Federation(server.to_owned()),
Destination::Federation(server),
SendingEvent::Edu(serialized.clone()),
)
})
@@ -269,12 +269,11 @@ impl Service {
}
#[tracing::instrument(skip(self, servers), level = "debug")]
pub async fn flush_servers<'a, S>(&self, servers: S) -> Result<()>
pub async fn flush_servers<S>(&self, servers: S) -> Result<()>
where
S: Stream<Item = &'a ServerName> + Send + 'a,
S: Stream<Item = OwnedServerName> + Send,
{
servers
.map(ToOwned::to_owned)
.map(Destination::Federation)
.map(Ok)
.ready_try_for_each(|dest| {
@@ -289,26 +288,26 @@ impl Service {
/// Sends a request to a federation server
#[inline]
pub async fn send_federation_request<T>(
pub async fn send_federation_request<'i, T>(
&self,
dest: &ServerName,
request: T,
) -> Result<T::IncomingResponse>
where
T: OutgoingRequest + Debug + Send,
T: OutgoingRequest::<Authentication = ServerSignatures, PathBuilder: PathBuilder<Input<'i>: FederationPathBuilderInput>> + Debug + Send,
{
self.services.federation.execute(dest, request).await
}
/// Like send_federation_request() but with a very large timeout
#[inline]
pub async fn send_synapse_request<T>(
pub async fn send_synapse_request<'i, T>(
&self,
dest: &ServerName,
request: T,
) -> Result<T::IncomingResponse>
where
T: OutgoingRequest + Debug + Send,
T: OutgoingRequest::<Authentication = ServerSignatures, PathBuilder: PathBuilder<Input<'i>: FederationPathBuilderInput>> + Debug + Send,
{
self.services
.federation
@@ -316,6 +315,22 @@ impl Service {
.await
}
/// Send an unauthenticated federation request with no X-Matrix header.
#[inline]
pub async fn send_unauthenticated_request<'i, T>(
&self,
dest: &ServerName,
request: T,
) -> Result<T::IncomingResponse>
where
T: OutgoingRequest::<Authentication = NoAuthentication, PathBuilder: PathBuilder<Input<'i>: FederationPathBuilderInput>> + Debug + Send,
{
self.services
.federation
.execute_unauthenticated(dest, request)
.await
}
/// Clean up queued sending event data
///
/// Used after we remove an appservice registration or a user deletes a push
+22 -48
View File
@@ -422,7 +422,7 @@ impl Service {
let keys_changed = self
.services
.users
.room_keys_changed(room_id, Some(since.0), None)
.room_keys_changed(&room_id, Some(since.0), None)
.ready_filter(|(user_id, _)| self.services.globals.user_is_local(user_id));
pin_mut!(keys_changed);
@@ -432,21 +432,13 @@ impl Service {
}
max_edu_count.fetch_max(count, Ordering::Relaxed);
if !device_list_changes.insert(user_id.into()) {
if !device_list_changes.insert(user_id.clone()) {
continue;
}
// Empty prev id forces synapse to resync; because synapse resyncs,
// we can just insert placeholder data
let edu = Edu::DeviceListUpdate(DeviceListUpdateContent {
user_id: user_id.into(),
device_id: device_id!("placeholder").to_owned(),
device_display_name: Some("Placeholder".to_owned()),
stream_id: uint!(1),
prev_id: Vec::new(),
deleted: None,
keys: None,
});
let edu = Edu::DeviceListUpdate(DeviceListUpdateContent::new(user_id, device_id!("placeholder").to_owned(), uint!(1)));
let mut buf = EduBuf::new();
serde_json::to_writer(&mut buf, &edu)
@@ -479,7 +471,6 @@ impl Service {
.services
.state_cache
.server_rooms(server_name)
.map(ToOwned::to_owned)
.broad_filter_map(|room_id| async move {
let receipt_map = self
.select_edus_receipts_room(&room_id, since, max_edu_count, &mut num)
@@ -498,7 +489,7 @@ impl Service {
return None;
}
let receipt_content = Edu::Receipt(ReceiptContent { receipts });
let receipt_content = Edu::Receipt(ReceiptContent::new(receipts));
let mut buf = EduBuf::new();
serde_json::to_writer(&mut buf, &receipt_content)
@@ -556,10 +547,7 @@ impl Service {
.remove(&user_id)
.expect("our read receipts always have the user here");
let receipt_data = ReceiptData {
data: receipt,
event_ids: vec![event_id.clone()],
};
let receipt_data = ReceiptData::new(receipt, vec![event_id.clone()]);
if read.insert(user_id, receipt_data).is_none() {
*num = num.saturating_add(1);
@@ -569,7 +557,7 @@ impl Service {
}
}
ReceiptMap { read }
ReceiptMap::new(read)
}
/// Look for presence
@@ -617,16 +605,9 @@ impl Service {
continue;
};
let update = PresenceUpdate {
user_id: user_id.into(),
presence: presence_event.content.presence,
currently_active: presence_event.content.currently_active.unwrap_or(false),
status_msg: presence_event.content.status_msg,
last_active_ago: presence_event
.content
.last_active_ago
.unwrap_or_else(|| uint!(0)),
};
let mut update = PresenceUpdate::new(user_id.to_owned(), presence_event.content.presence, presence_event.content.last_active_ago.unwrap_or_else(|| uint!(0)));
update.currently_active = presence_event.content.currently_active.unwrap_or_default();
update.status_msg = presence_event.content.status_msg;
presence_updates.insert(user_id.into(), update);
if presence_updates.len() >= SELECT_PRESENCE_LIMIT {
@@ -638,9 +619,7 @@ impl Service {
return None;
}
let presence_content = Edu::Presence(PresenceContent {
push: presence_updates.into_values().collect(),
});
let presence_content = Edu::Presence(PresenceContent::new(presence_updates.into_values().collect()));
let mut buf = EduBuf::new();
serde_json::to_writer(&mut buf, &presence_content)
@@ -686,7 +665,7 @@ impl Service {
.filter(|event| matches!(event, SendingEvent::Pdu(_)))
.count(),
);
let mut edu_jsons: Vec<EphemeralData> = Vec::with_capacity(
let mut edu_jsons: Vec<Raw<EphemeralData>> = Vec::with_capacity(
events
.iter()
.filter(|event| matches!(event, SendingEvent::Edu(_)))
@@ -702,7 +681,7 @@ impl Service {
| SendingEvent::Edu(edu) =>
if appservice.receive_ephemeral {
if let Ok(edu) = serde_json::from_slice(edu) {
edu_jsons.push(edu);
edu_jsons.push(Raw::from_json(edu));
}
},
| SendingEvent::Flush => {}, // flush only; no new content
@@ -720,15 +699,14 @@ impl Service {
//debug_assert!(pdu_jsons.len() + edu_jsons.len() > 0, "sending empty
// transaction");
let mut request = ruma::api::appservice::event::push_events::v1::Request::new(txn_id.into(), pdu_jsons);
request.ephemeral = edu_jsons;
request.to_device = Vec::new(); // TODO
match self
.send_appservice_request(
appservice,
ruma::api::appservice::event::push_events::v1::Request {
events: pdu_jsons,
txn_id: txn_id.into(),
ephemeral: edu_jsons,
to_device: Vec::new(), // TODO
},
request,
)
.await
{
@@ -851,18 +829,14 @@ impl Service {
let txn_hash = calculate_hash(preimage);
let txn_id = &*URL_SAFE_NO_PAD.encode(txn_hash);
let request = send_transaction_message::v1::Request {
transaction_id: txn_id.into(),
origin: self.server.name.clone(),
origin_server_ts: MilliSecondsSinceUnixEpoch::now(),
pdus,
edus,
};
let mut request = send_transaction_message::v1::Request::new(txn_id.into(), self.server.name.clone(), MilliSecondsSinceUnixEpoch::now());
request.pdus = pdus;
request.edus = edus;
let result = self
.services
.federation
.execute_on(&self.services.client.sender, &server, request)
.execute_signed(&self.services.client.sender, &server, request)
.await;
for (event_id, result) in result.iter().flat_map(|resp| resp.pdus.iter()) {
@@ -900,7 +874,7 @@ impl Service {
.get("room_id")
.and_then(|val| RoomId::parse(val.as_str()?).ok())
{
match self.services.state.get_room_version(room_id).await {
match self.services.state.get_room_version(&room_id).await {
| Ok(room_version_id) => match room_version_id {
| RoomVersionId::V1 | RoomVersionId::V2 => {},
| _ => _ = pdu_json.remove("event_id"),
+2 -2
View File
@@ -6,6 +6,8 @@ use ruma::{
api::federation::discovery::VerifyKey,
};
use crate::server_keys::util::required_keys;
use super::{PubKeyMap, PubKeys, extract_key};
#[implement(super::Service)]
@@ -14,8 +16,6 @@ pub async fn get_event_keys(
object: &CanonicalJsonObject,
version: &RoomVersionId,
) -> Result<PubKeyMap> {
use ruma::signatures::required_keys;
let required = match required_keys(object, version) {
| Ok(required) => required,
| Err(e) => {
+1 -3
View File
@@ -12,9 +12,7 @@ pub(super) fn init(db: &Arc<Database>) -> Result<(Box<Ed25519KeyPair>, VerifyKey
remove(db);
})?;
let verify_key = VerifyKey {
key: Base64::new(keypair.public_key().to_vec()),
};
let verify_key = VerifyKey::new(Base64::new(keypair.public_key().to_vec()));
let id = format!("ed25519:{}", keypair.version());
let verify_keys: VerifyKeys = [(id.try_into()?, verify_key)].into();
+5 -6
View File
@@ -4,6 +4,7 @@ mod keypair;
mod request;
mod sign;
mod verify;
mod util;
use std::{collections::BTreeMap, sync::Arc, time::Duration};
@@ -22,7 +23,7 @@ use ruma::{
};
use serde_json::value::RawValue as RawJsonValue;
use crate::{Dep, globals, sending};
use crate::{Dep, globals, sending, server_keys::util::required_keys};
pub struct Service {
keypair: Box<Ed25519KeyPair>,
@@ -118,8 +119,6 @@ pub async fn required_keys_exist(
object: &CanonicalJsonObject,
version: &RoomVersionId,
) -> bool {
use ruma::signatures::required_keys;
trace!(?object, "Checking required keys exist");
let Ok(required_keys) = required_keys(object, version) else {
debug_error!("Failed to determine required keys");
@@ -137,7 +136,7 @@ pub async fn required_keys_exist(
#[implement(Service)]
#[tracing::instrument(skip(self), level = "debug")]
pub async fn verify_key_exists(&self, origin: &ServerName, key_id: &ServerSigningKeyId) -> bool {
type KeysMap<'a> = BTreeMap<&'a ServerSigningKeyId, &'a RawJsonValue>;
type KeysMap = BTreeMap<OwnedServerSigningKeyId, Box<RawJsonValue>>;
let Ok(keys) = self
.db
@@ -150,13 +149,13 @@ pub async fn verify_key_exists(&self, origin: &ServerName, key_id: &ServerSignin
return false;
};
if let Ok(Some(verify_keys)) = keys.get_field::<KeysMap<'_>>("verify_keys") {
if let Ok(Some(verify_keys)) = keys.get_field::<KeysMap>("verify_keys") {
if verify_keys.contains_key(key_id) {
return true;
}
}
if let Ok(Some(old_verify_keys)) = keys.get_field::<KeysMap<'_>>("old_verify_keys") {
if let Ok(Some(old_verify_keys)) = keys.get_field::<KeysMap>("old_verify_keys") {
if old_verify_keys.contains_key(key_id) {
return true;
}
+7 -13
View File
@@ -23,9 +23,8 @@ where
use get_remote_server_keys_batch::v2::Request;
type RumaBatch = BTreeMap<OwnedServerName, BTreeMap<OwnedServerSigningKeyId, QueryCriteria>>;
let criteria = QueryCriteria {
minimum_valid_until_ts: Some(self.minimum_valid_ts()),
};
let mut criteria = QueryCriteria::new();
criteria.minimum_valid_until_ts = Some(self.minimum_valid_ts());
let mut server_keys = batch.fold(RumaBatch::new(), |mut batch, (server, key_ids)| {
batch
@@ -46,9 +45,7 @@ where
.next_back()
.cloned()
{
let request = Request {
server_keys: server_keys.split_off(&batch),
};
let request = Request::new(server_keys.split_off(&batch));
debug!(
?notary,
@@ -61,7 +58,7 @@ where
let response = self
.services
.sending
.send_synapse_request(notary, request)
.send_unauthenticated_request(notary, request)
.await?
.server_keys
.into_iter()
@@ -82,15 +79,12 @@ pub async fn notary_request(
) -> Result<impl Iterator<Item = ServerSigningKeys> + Clone + Debug + Send + use<>> {
use get_remote_server_keys::v2::Request;
let request = Request {
server_name: target.into(),
minimum_valid_until_ts: self.minimum_valid_ts(),
};
let request = Request::new(target.into(), self.minimum_valid_ts());
let response = self
.services
.sending
.send_federation_request(notary, request)
.send_unauthenticated_request(notary, request)
.await?
.server_keys
.into_iter()
@@ -107,7 +101,7 @@ pub async fn server_request(&self, target: &ServerName) -> Result<ServerSigningK
let server_signing_key = self
.services
.sending
.send_federation_request(target, Request::new())
.send_unauthenticated_request(target, Request::new())
.await
.map(|response| response.server_key)
.and_then(|key| key.deserialize().map_err(Into::into))?;
+1 -1
View File
@@ -18,5 +18,5 @@ pub fn hash_and_sign_event(
use ruma::signatures::hash_and_sign_event;
let server_name = self.services.globals.server_name().as_str();
hash_and_sign_event(server_name, self.keypair(), object, room_version).map_err(Into::into)
hash_and_sign_event(server_name, self.keypair(), object, &room_version.rules().unwrap().redaction).map_err(Into::into)
}
+135
View File
@@ -0,0 +1,135 @@
use std::collections::{BTreeMap, BTreeSet};
use ruma::{CanonicalJsonObject, CanonicalJsonValue, OwnedEventId, OwnedServerName, OwnedServerSigningKeyId, RoomVersionId, UserId, canonical_json::JsonType, signatures::{Error, JsonError, ParseError}};
/// Whether the given event is an `m.room.member` invite that was created as the result of a
/// third-party invite.
///
/// Returns an error if the object has not the expected format of an `m.room.member` event.
pub(super) fn is_invite_via_third_party_id(object: &CanonicalJsonObject) -> Result<bool, Error> {
let Some(CanonicalJsonValue::String(raw_type)) = object.get("type") else {
return Err(JsonError::NotOfType { target: "type".to_owned(), of_type: JsonType::String }.into());
};
if raw_type != "m.room.member" {
return Ok(false);
}
let Some(CanonicalJsonValue::Object(content)) = object.get("content") else {
return Err(JsonError::NotOfType { target: "content".to_owned(), of_type: JsonType::Object }.into());
};
let Some(CanonicalJsonValue::String(membership)) = content.get("membership") else {
return Err(JsonError::NotOfType { target: "membership".to_owned(), of_type: JsonType::String }.into());
};
if membership != "invite" {
return Ok(false);
}
match content.get("third_party_invite") {
Some(CanonicalJsonValue::Object(_)) => Ok(true),
None => Ok(false),
_ => Err(JsonError::NotOfType { target: "third_party_invite".to_owned(), of_type: JsonType::Object }.into()),
}
}
/// Extracts the server names to check signatures for given event.
///
/// Respects the rules for [validating signatures on received events] for populating the result:
///
/// - Add the server of the sender, except if it's an invite event that results from a third-party
/// invite.
/// - For room versions 1 and 2, add the server of the `event_id`.
/// - For room versions that support restricted join rules, if it's a join event with a
/// `join_authorised_via_users_server`, add the server of that user.
///
/// [validating signatures on received events]: https://spec.matrix.org/latest/server-server-api/#validating-hashes-and-signatures-on-received-events
pub fn servers_to_check_signatures(
object: &CanonicalJsonObject,
version: &RoomVersionId,
) -> Result<BTreeSet<OwnedServerName>, Error> {
let mut servers_to_check = BTreeSet::new();
if !is_invite_via_third_party_id(object)? {
match object.get("sender") {
Some(CanonicalJsonValue::String(raw_sender)) => {
let user_id = <&UserId>::try_from(raw_sender.as_str())
.map_err(|e| Error::from(ParseError::UserId(e)))?;
servers_to_check.insert(user_id.server_name().to_owned());
}
_ => return Err(JsonError::NotOfType { target: "sender".to_owned(), of_type: JsonType::String }.into()),
};
}
match version {
RoomVersionId::V1 | RoomVersionId::V2 => match object.get("event_id") {
Some(CanonicalJsonValue::String(raw_event_id)) => {
let event_id: OwnedEventId =
raw_event_id.parse().map_err(|e| Error::from(ParseError::EventId(e)))?;
let server_name = event_id
.server_name()
.ok_or_else(|| ParseError::ServerNameFromEventId(event_id.to_owned()))?
.to_owned();
servers_to_check.insert(server_name);
}
_ => {
return Err(JsonError::JsonFieldMissingFromObject("event_id".to_owned()).into());
}
},
RoomVersionId::V3
| RoomVersionId::V4
| RoomVersionId::V5
| RoomVersionId::V6
| RoomVersionId::V7 => {}
// TODO: And for all future versions that have join_authorised_via_users_server
RoomVersionId::V8 | RoomVersionId::V9 | RoomVersionId::V10 | RoomVersionId::V11 | RoomVersionId::V12 => {
if let Some(authorized_user) = object
.get("content")
.and_then(|c| c.as_object())
.and_then(|c| c.get("join_authorised_via_users_server"))
{
let authorized_user = authorized_user.as_str().ok_or_else(|| -> Error {
JsonError::NotOfType { target: "join_authorised_via_users_server".to_owned(), of_type: JsonType::String }.into()
})?;
let authorized_user = <&UserId>::try_from(authorized_user)
.map_err(|e| Error::from(ParseError::UserId(e)))?;
servers_to_check.insert(authorized_user.server_name().to_owned());
}
}
_ => unimplemented!(),
}
Ok(servers_to_check)
}
/// Extracts the server names and key ids to check signatures for given event.
pub fn required_keys(
object: &CanonicalJsonObject,
version: &RoomVersionId,
) -> Result<BTreeMap<OwnedServerName, Vec<OwnedServerSigningKeyId>>, Error> {
use CanonicalJsonValue::Object;
let mut map = BTreeMap::<OwnedServerName, Vec<OwnedServerSigningKeyId>>::new();
let Some(Object(signatures)) = object.get("signatures") else {
return Ok(map);
};
for server in servers_to_check_signatures(object, version)? {
let Some(Object(set)) = signatures.get(server.as_str()) else {
continue;
};
let entry = map.entry(server.clone()).or_default();
set.iter()
.map(|(k, _)| k.clone())
.map(TryInto::try_into)
.filter_map(Result::ok)
.for_each(|key_id| entry.push(key_id));
}
Ok(map)
}
+2 -2
View File
@@ -63,7 +63,7 @@ pub async fn verify_event(
) -> Result<Verified> {
let room_version = room_version.unwrap_or(&RoomVersionId::V12);
let keys = self.get_event_keys(event, room_version).await?;
ruma::signatures::verify_event(&keys, event, room_version).map_err(Into::into)
ruma::signatures::verify_event(&keys, event, &room_version.rules().unwrap()).map_err(Into::into)
}
#[implement(super::Service)]
@@ -74,5 +74,5 @@ pub async fn verify_json(
) -> Result {
let room_version = room_version.unwrap_or(&RoomVersionId::V12);
let keys = self.get_event_keys(event, room_version).await?;
ruma::signatures::verify_json(&keys, event.clone()).map_err(Into::into)
ruma::signatures::verify_json(&keys, event).map_err(Into::into)
}
+3 -162
View File
@@ -10,8 +10,6 @@ use database::Map;
use ruma::{
OwnedDeviceId, OwnedRoomId, OwnedUserId,
api::client::sync::sync_events::{
self,
v4::{ExtensionsConfig, SyncRequestList},
v5,
},
};
@@ -47,11 +45,11 @@ struct Services {
}
struct SlidingSyncCache {
lists: BTreeMap<String, SyncRequestList>,
subscriptions: BTreeMap<OwnedRoomId, sync_events::v4::RoomSubscription>,
lists: BTreeMap<String, v5::request::List>,
subscriptions: BTreeMap<OwnedRoomId, v5::request::RoomSubscription>,
// For every room, the roomsince number
known_rooms: BTreeMap<String, BTreeMap<OwnedRoomId, u64>>,
extensions: ExtensionsConfig,
extensions: v5::request::Extensions,
}
#[derive(Default)]
@@ -136,7 +134,6 @@ impl Service {
&mut list.room_details.required_state,
&cached_list.room_details.required_state,
);
some_or_sticky(&mut list.include_heroes, cached_list.include_heroes);
match (&mut list.filters, cached_list.filters.clone()) {
| (Some(filters), Some(cached_filters)) => {
@@ -219,162 +216,6 @@ impl Service {
cached.known_rooms.clone()
}
pub fn update_sync_request_with_cache(
&self,
key: &SnakeConnectionsKey,
request: &mut sync_events::v4::Request,
) -> BTreeMap<String, BTreeMap<OwnedRoomId, u64>> {
let Some(conn_id) = request.conn_id.clone() else {
return BTreeMap::new();
};
let key = into_db_key(key.0.clone(), key.1.clone(), conn_id);
let mut cache = self.connections.lock();
let cached = Arc::clone(cache.entry(key).or_insert_with(|| {
Arc::new(SyncMutex::new(SlidingSyncCache {
lists: BTreeMap::new(),
subscriptions: BTreeMap::new(),
known_rooms: BTreeMap::new(),
extensions: ExtensionsConfig::default(),
}))
}));
let cached = &mut cached.lock();
drop(cache);
for (list_id, list) in &mut request.lists {
if let Some(cached_list) = cached.lists.get(list_id) {
list_or_sticky(&mut list.sort, &cached_list.sort);
list_or_sticky(
&mut list.room_details.required_state,
&cached_list.room_details.required_state,
);
some_or_sticky(
&mut list.room_details.timeline_limit,
cached_list.room_details.timeline_limit,
);
some_or_sticky(
&mut list.include_old_rooms,
cached_list.include_old_rooms.clone(),
);
match (&mut list.filters, cached_list.filters.clone()) {
| (Some(filter), Some(cached_filter)) => {
some_or_sticky(&mut filter.is_dm, cached_filter.is_dm);
list_or_sticky(&mut filter.spaces, &cached_filter.spaces);
some_or_sticky(&mut filter.is_encrypted, cached_filter.is_encrypted);
some_or_sticky(&mut filter.is_invite, cached_filter.is_invite);
list_or_sticky(&mut filter.room_types, &cached_filter.room_types);
// Should be made possible to change
list_or_sticky(&mut filter.not_room_types, &cached_filter.not_room_types);
some_or_sticky(&mut filter.room_name_like, cached_filter.room_name_like);
list_or_sticky(&mut filter.tags, &cached_filter.tags);
list_or_sticky(&mut filter.not_tags, &cached_filter.not_tags);
},
| (_, Some(cached_filters)) => list.filters = Some(cached_filters),
| (Some(list_filters), _) => list.filters = Some(list_filters.clone()),
| (..) => {},
}
list_or_sticky(&mut list.bump_event_types, &cached_list.bump_event_types);
}
cached.lists.insert(list_id.clone(), list.clone());
}
cached
.subscriptions
.extend(request.room_subscriptions.clone());
request
.room_subscriptions
.extend(cached.subscriptions.clone());
request.extensions.e2ee.enabled = request
.extensions
.e2ee
.enabled
.or(cached.extensions.e2ee.enabled);
request.extensions.to_device.enabled = request
.extensions
.to_device
.enabled
.or(cached.extensions.to_device.enabled);
request.extensions.account_data.enabled = request
.extensions
.account_data
.enabled
.or(cached.extensions.account_data.enabled);
request.extensions.account_data.lists = request
.extensions
.account_data
.lists
.clone()
.or_else(|| cached.extensions.account_data.lists.clone());
request.extensions.account_data.rooms = request
.extensions
.account_data
.rooms
.clone()
.or_else(|| cached.extensions.account_data.rooms.clone());
cached.extensions = request.extensions.clone();
cached.known_rooms.clone()
}
pub fn update_sync_subscriptions(
&self,
key: &DbConnectionsKey,
subscriptions: BTreeMap<OwnedRoomId, sync_events::v4::RoomSubscription>,
) {
let mut cache = self.connections.lock();
let cached = Arc::clone(cache.entry(key.clone()).or_insert_with(|| {
Arc::new(SyncMutex::new(SlidingSyncCache {
lists: BTreeMap::new(),
subscriptions: BTreeMap::new(),
known_rooms: BTreeMap::new(),
extensions: ExtensionsConfig::default(),
}))
}));
let cached = &mut cached.lock();
drop(cache);
cached.subscriptions = subscriptions;
}
pub fn update_sync_known_rooms(
&self,
key: &DbConnectionsKey,
list_id: String,
new_cached_rooms: BTreeSet<OwnedRoomId>,
globalsince: u64,
) {
let mut cache = self.connections.lock();
let cached = Arc::clone(cache.entry(key.clone()).or_insert_with(|| {
Arc::new(SyncMutex::new(SlidingSyncCache {
lists: BTreeMap::new(),
subscriptions: BTreeMap::new(),
known_rooms: BTreeMap::new(),
extensions: ExtensionsConfig::default(),
}))
}));
let cached = &mut cached.lock();
drop(cache);
for (room_id, lastsince) in cached
.known_rooms
.entry(list_id.clone())
.or_default()
.iter_mut()
{
if !new_cached_rooms.contains(room_id) {
*lastsince = 0;
}
}
let list = cached.known_rooms.entry(list_id).or_default();
for room_id in new_cached_rooms {
list.insert(room_id, globalsince);
}
}
pub fn update_snake_sync_known_rooms(
&self,
key: &SnakeConnectionsKey,
+1 -1
View File
@@ -38,7 +38,7 @@ pub async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result {
pin_mut!(rooms_joined);
while let Some(room_id) = rooms_joined.next().await {
let Ok(short_roomid) = self.services.short.get_shortroomid(room_id).await else {
let Ok(short_roomid) = self.services.short.get_shortroomid(&room_id).await else {
continue;
};
+28 -31
View File
@@ -24,11 +24,11 @@ use ruma::{
events::{
AnyToDeviceEvent, GlobalAccountDataEventType,
ignored_user_list::IgnoredUserListEvent,
invite_permission_config::{FilterLevel, InvitePermissionConfigEvent},
},
serde::Raw,
uint,
};
use ruminuwuity::invite_permission_config::{FilterLevel, InvitePermissionConfigEvent};
use serde::{Deserialize, Serialize};
use serde_json::json;
@@ -206,7 +206,7 @@ impl Service {
pub async fn deactivate_account(&self, user_id: &UserId) -> Result<()> {
// Remove all associated devices
self.all_device_ids(user_id)
.for_each(|device_id| self.remove_device(user_id, device_id))
.for_each(async |device_id| self.remove_device(user_id, &device_id).await)
.await;
// Set the password to "" to indicate a deactivated account. Hashes will never
@@ -342,15 +342,8 @@ impl Service {
self.db.token_userdeviceid.get(token).await.deserialized()
}
/// Returns an iterator over all users on this homeserver (offered for
/// compatibility)
#[allow(clippy::iter_without_into_iter, clippy::iter_not_returning_iterator)]
pub fn iter(&self) -> impl Stream<Item = OwnedUserId> + Send + '_ {
self.stream().map(ToOwned::to_owned)
}
/// Returns an iterator over all users on this homeserver.
pub fn stream(&self) -> impl Stream<Item = &UserId> + Send {
pub fn stream(&self) -> impl Stream<Item = OwnedUserId> + Send {
self.db.userid_password.keys().ignore_err()
}
@@ -358,12 +351,12 @@ impl Service {
///
/// A user account is considered `local` if the length of it's password is
/// greater then zero.
pub fn list_local_users(&self) -> impl Stream<Item = &UserId> + Send + '_ {
pub fn list_local_users(&self) -> impl Stream<Item = OwnedUserId> + Send + '_ {
self.db
.userid_password
.stream()
.ignore_err()
.ready_filter_map(|(u, p): (&UserId, &[u8])| (!p.is_empty()).then_some(u))
.ready_filter_map(|(u, p): (OwnedUserId, &[u8])| (!p.is_empty()).then_some(u))
}
/// Returns the origin of the user (password/LDAP/...).
@@ -470,15 +463,13 @@ impl Service {
}
let key = (user_id, device_id);
let val = Device {
device_id: device_id.into(),
display_name: initial_device_display_name,
last_seen_ip: client_ip,
last_seen_ts: Some(MilliSecondsSinceUnixEpoch::now()),
};
let mut device = Device::new(device_id.into());
device.display_name = initial_device_display_name;
device.last_seen_ip = client_ip;
device.last_seen_ts = Some(MilliSecondsSinceUnixEpoch::now());
increment(&self.db.userid_devicelistversion, user_id.as_bytes());
self.db.userdeviceid_metadata.put(key, Json(val));
self.db.userdeviceid_metadata.put(key, Json(device));
self.set_token(user_id, device_id, token).await
}
@@ -518,13 +509,13 @@ impl Service {
pub fn all_device_ids<'a>(
&'a self,
user_id: &'a UserId,
) -> impl Stream<Item = &'a DeviceId> + Send + 'a {
) -> impl Stream<Item = OwnedDeviceId> + Send + 'a {
let prefix = (user_id, Interfix);
self.db
.userdeviceid_metadata
.keys_prefix(&prefix)
.ignore_err()
.map(|(_, device_id): (Ignore, &DeviceId)| device_id)
.map(|(_, device_id): (Ignore, OwnedDeviceId)| device_id)
}
pub async fn get_token(&self, user_id: &UserId, device_id: &DeviceId) -> Result<String> {
@@ -866,7 +857,7 @@ impl Service {
user_id: &'a UserId,
from: Option<u64>,
to: Option<u64>,
) -> impl Stream<Item = &'a UserId> + Send + 'a {
) -> impl Stream<Item = OwnedUserId> + Send + 'a {
self.keys_changed_user_or_room(user_id.as_str(), from, to)
.map(|(user_id, ..)| user_id)
}
@@ -877,7 +868,7 @@ impl Service {
room_id: &'a RoomId,
from: Option<u64>,
to: Option<u64>,
) -> impl Stream<Item = (&'a UserId, u64)> + Send + 'a {
) -> impl Stream<Item = (OwnedUserId, u64)> + Send + 'a {
self.keys_changed_user_or_room(room_id.as_str(), from, to)
}
@@ -886,8 +877,8 @@ impl Service {
user_or_room_id: &'a str,
from: Option<u64>,
to: Option<u64>,
) -> impl Stream<Item = (&'a UserId, u64)> + Send + 'a {
type KeyVal<'a> = ((&'a str, u64), &'a UserId);
) -> impl Stream<Item = (OwnedUserId, u64)> + Send + 'a {
type KeyVal<'a> = ((&'a str, u64), OwnedUserId);
let from = from.unwrap_or(0);
let to = to.unwrap_or(u64::MAX);
@@ -909,7 +900,13 @@ impl Service {
.state_cache
.rooms_joined(user_id)
// Don't send key updates to unencrypted rooms
.filter(|room_id| self.services.state_accessor.is_encrypted_room(room_id))
.filter_map(async |room_id| {
if self.services.state_accessor.is_encrypted_room(&room_id).await {
Some(room_id)
} else {
None
}
})
.ready_for_each(|room_id| {
let key = (room_id, count);
self.db.keychangeid_userid.put_raw(key, user_id);
@@ -1013,7 +1010,7 @@ impl Service {
since: Option<u64>,
to: Option<u64>,
) -> impl Stream<Item = (u64, Raw<AnyToDeviceEvent>)> + Send + 'a {
type Key<'a> = (&'a UserId, &'a DeviceId, u64);
type Key = (OwnedUserId, OwnedDeviceId, u64);
let from = (user_id, device_id, since.map_or(0, |since| since.saturating_add(1)));
@@ -1021,7 +1018,7 @@ impl Service {
.todeviceid_events
.stream_from(&from)
.ignore_err()
.ready_take_while(move |((user_id_, device_id_, count), _): &(Key<'_>, _)| {
.ready_take_while(move |((user_id_, device_id_, count), _): &(Key, _)| {
user_id == *user_id_
&& device_id == *device_id_
&& to.is_none_or(|to| *count <= to)
@@ -1037,7 +1034,7 @@ impl Service {
) where
Until: Into<Option<u64>> + Send,
{
type Key<'a> = (&'a UserId, &'a DeviceId, u64);
type Key = (OwnedUserId, OwnedDeviceId, u64);
let until = until.into().unwrap_or(u64::MAX);
let from = (user_id, device_id, until);
@@ -1045,10 +1042,10 @@ impl Service {
.todeviceid_events
.rev_keys_from(&from)
.ignore_err()
.ready_take_while(move |(user_id_, device_id_, _): &Key<'_>| {
.ready_take_while(move |(user_id_, device_id_, _): &Key| {
user_id == *user_id_ && device_id == *device_id_
})
.ready_for_each(|key: Key<'_>| {
.ready_for_each(|key: Key| {
self.db.todeviceid_events.del(key);
})
.await;