Compare commits

..

49 Commits

Author SHA1 Message Date
timedout 69a08ffbf6 style: Make "empty timeline" warning only warn in debug mode 2026-02-23 21:36:13 +00:00
timedout 123aa2456c perf: Delayed drop is probably unnecessary 2026-02-23 21:18:46 +00:00
timedout c3f6c73ac4 feat: Add compile flag to instantly drop mutex for A/B testing 2026-02-23 20:58:22 +00:00
timedout 1730cc9ff3 fix: Write-lock individual rooms when building sync for them 2026-02-23 20:45:44 +00:00
Niklas Wojtkowiak cb9786466b chore(changelog): add news fragment for #1441 2026-02-23 17:59:13 +00:00
Niklas Wojtkowiak 18d2662b01 fix(config): remove allow_public_room_directory_without_auth 2026-02-23 17:59:13 +00:00
timedout 558262dd1f chore: Refactor transaction_ids -> transactions 2026-02-23 17:44:35 +00:00
timedout d311b87579 chore: Fix incorrect capitalisation
I didn't realise I agreed to take an English class with @ginger while
working on this server lol
2026-02-23 17:25:12 +00:00
timedout 8702f55cf5 fix: Don't panic if nobody's listening 2026-02-23 17:22:37 +00:00
timedout d4481b07ac chore: Add news frag 2026-02-23 16:54:54 +00:00
Jade Ellis 92351df925 refactor: Make federation transaction handle errors correctly
We have a dedicated error type that's then matched.
Event sorting is now infallible.
Could probably be cleaned up in a bit.
2026-02-23 16:36:46 +00:00
Jade Ellis 47e2733ea1 refactor: Make stream utils generic over the error type 2026-02-23 16:36:46 +00:00
Jade Ellis 6637e4c6a7 fix: Clean up cache, prevent several race conditions
We use one map which is only ever held for a short time.
2026-02-23 16:36:46 +00:00
nexy7574 35e441452f feat: Attempt to build localised DAG before processing PDUs 2026-02-23 16:36:46 +00:00
nexy7574 66bbb655bf feat: Warn when server is overloaded 2026-02-23 16:36:45 +00:00
nexy7574 81b202ce51 chore: Decrease transaction log verbosity 2026-02-23 16:36:45 +00:00
nexy7574 4657844d46 feat: Show active transaction handle count in !admin federation incoming-federation 2026-02-23 16:36:45 +00:00
nexy7574 9016cd11a6 chore: Run pre-commit and clippy to fix inherited CI errs 2026-02-23 16:36:45 +00:00
nexy7574 dd70094719 feat: Make max_active_txns actually configurable 2026-02-23 16:36:45 +00:00
nexy7574 fcd49b7ab3 fix: Remove duplicate fields from logs 2026-02-23 16:36:45 +00:00
nexy7574 470c9b52dd feat: Instrument process_inbound_transaction 2026-02-23 16:36:45 +00:00
nexy7574 0d8cafc329 feat: Support casting transaction processing to the background 2026-02-23 16:36:44 +00:00
nexy7574 2f9956ddca feat: Add helper functions for federation channels 2026-02-23 16:36:44 +00:00
nexy7574 21a97cdd0b chore: Refactor existing references to transaction service 2026-02-23 16:36:44 +00:00
nexy7574 e986cd4536 feat(federation): Restructure transaction_ids service
Adds two new in-memory maps to the service in to prepare for better handlers
2026-02-23 16:36:40 +00:00
Shane Jaroch 526d862296 fix: more aggressive user agent for URL preview
adding "facebookexternalhit" alongside "embedbot" fixes many errors, such as YouTube Music's:
    "Your browser is deprecated. Please upgrade."

add admin command to clear URL stuck and broken data (per URL currently)

    add command to clear all saved URL previews.
    sync resolver docs.
2026-02-23 15:24:14 +00:00
Ben Botwin fbeb5bf186 report permission denied errors 2026-02-23 15:22:18 +00:00
Ben Botwin a336f2df44 fixed formatting 2026-02-23 15:22:18 +00:00
Ben Botwin 19b78ec73e made error handling more concise 2026-02-23 15:22:18 +00:00
Ben Botwin 27ff2d9363 added more granular error handling for other file fetch function 2026-02-23 15:22:18 +00:00
Ben Botwin 50fa8c3abf ran format 2026-02-23 15:22:18 +00:00
Ben Botwin 18c4be869f added handling for other potential errors 2026-02-23 15:22:18 +00:00
Ben Botwin fc00b96d8b Added proper 404 for not found media and fixed devshell for running tests 2026-02-23 15:22:18 +00:00
Jade Ellis fa4156d8a6 docs: Changelog 2026-02-22 21:19:20 +00:00
Jade Ellis 23638cd714 feat(appservices): MSC3202 Device masquerading for appservices 2026-02-22 21:19:20 +00:00
Raven 9f1a483e76 docs: Add information about partnered homeservers to the introduction page & update README.md
Includes step-by-step directions to ease the lift for those who have ended up
here and who have never created a matrix account or used matrix before in the
past.

Also updates the information in README.md to match, as these should generally be identical.
2026-02-21 18:51:56 -08:00
Renovate Bot 688ef727e5 chore(deps): update rust crate nix to 0.31.0 2026-02-21 16:33:05 +00:00
Shannon Sterz 3de026160e docs: express forbidden_remote_server_names as valid regex
this field expects a regex not a glob, so the correct value should be
".*" if one wants to block all remote server names. otherwise, setting
"*" as documented results in an error on start because the configuration
could not be properly parsed.
2026-02-21 16:27:59 +00:00
Ginger 9fe761513d chore: Clippy & prek fixes 2026-02-21 11:27:39 -05:00
Renovate Bot abf1e1195a chore(deps): update rust crate libloading to 0.9.0 2026-02-21 01:55:48 +00:00
Ginger d9537e9b55 fix: Forbid registering users with a non-local localpart 2026-02-20 20:54:19 -05:00
Jade Ellis 0d1de70d8f fix(deps): Update lockfile 2026-02-21 00:22:42 +00:00
Ben Botwin 4aa03a71eb fix(nix): Added unstable flag to buildDeps 2026-02-21 00:15:53 +00:00
aviac f847918575 fix(nix): Fix all-features build
The build was broken since we started using an unstable reqwest version
which requires setting an extra feature flag
2026-02-21 00:15:53 +00:00
Renovate Bot 7569a0545b chore(deps): update dependency lddtree to 0.5.0 2026-02-20 22:59:34 +00:00
Jade Ellis b6c5991e1f chore(deps): Update rand
A couple indirect deps are still on rand_core 0.6 but we can deal
2026-02-20 22:57:45 +00:00
Katie Kloss efd879fcd8 docs: Add news fragment 2026-02-20 10:13:54 +00:00
Katie Kloss 92a848f74d fix: Crash before starting on OpenBSD
core_affinity doesn't return any cores on OpenBSD, so we try to
clamp(1, 0). This is Less Good than fixing that crate, but at
least allows the server to start up.
2026-02-20 10:13:54 +00:00
Renovate Bot 776b5865ba chore(deps): update sentry-rust monorepo to 0.46.0 2026-02-19 14:56:25 +00:00
59 changed files with 3952 additions and 2539 deletions
Generated
+532 -480
View File
File diff suppressed because it is too large Load Diff
+8 -8
View File
@@ -68,7 +68,7 @@ default-features = false
version = "0.1.3"
[workspace.dependencies.rand]
version = "0.8.5"
version = "0.10.0"
# Used for the http request / response body type for Ruma endpoints used with reqwest
[workspace.dependencies.bytes]
@@ -253,7 +253,7 @@ features = [
version = "0.4.0"
[workspace.dependencies.libloading]
version = "0.8.6"
version = "0.9.0"
# Validating urls in config, was already a transitive dependency
[workspace.dependencies.url]
@@ -298,7 +298,7 @@ default-features = false
features = ["env", "toml"]
[workspace.dependencies.hickory-resolver]
version = "0.25.1"
version = "0.25.2"
default-features = false
features = [
"serde",
@@ -343,7 +343,7 @@ version = "0.1.2"
[workspace.dependencies.ruma]
git = "https://forgejo.ellis.link/continuwuation/ruwuma"
#branch = "conduwuit-changes"
rev = "3126cb5eea991ec40590e54d8c9d75637650641a"
rev = "e087ff15888156942ca2ffe6097d1b4c3fd27628"
features = [
"compat",
"rand",
@@ -425,7 +425,7 @@ features = ["http", "grpc-tonic", "trace", "logs", "metrics"]
# optional sentry metrics for crash/panic reporting
[workspace.dependencies.sentry]
version = "0.45.0"
version = "0.46.0"
default-features = false
features = [
"backtrace",
@@ -441,9 +441,9 @@ features = [
]
[workspace.dependencies.sentry-tracing]
version = "0.45.0"
version = "0.46.0"
[workspace.dependencies.sentry-tower]
version = "0.45.0"
version = "0.46.0"
# jemalloc usage
[workspace.dependencies.tikv-jemalloc-sys]
@@ -472,7 +472,7 @@ features = ["use_std"]
version = "0.5"
[workspace.dependencies.nix]
version = "0.30.1"
version = "0.31.0"
default-features = false
features = ["resource"]
+8 -3
View File
@@ -57,10 +57,15 @@ Continuwuity aims to:
### Can I try it out?
Check out the [documentation](https://continuwuity.org) for installation instructions, or join one of these vetted public homeservers running Continuwuity to get a feel for things!
Check out the [documentation](https://continuwuity.org) for installation instructions.
- https://continuwuity.rocks -- A public demo server operated by the Continuwuity Team.
- https://federated.nexus -- Federated Nexus is a community resource hosting multiple FOSS (especially federated) services, including Matrix and Forgejo.
If you want to try it out as a user, we have some partnered homeservers you can use:
* You can head over to [https://federated.nexus](https://federated.nexus/) in your browser.
* Hit the `Apply to Join` button. Once your request has been accepted, you will receive an email with your username and password.
* Head over to [https://app.federated.nexus](https://app.federated.nexus/) and you can sign in there, or use any other matrix chat client you wish elsewhere.
* Your username for matrix will be in the form of `@username:federated.nexus`, however you can simply use the `username` part to log in. Your password is your password.
* There's also [https://continuwuity.rocks/](https://continuwuity.rocks/). You can register a new account using Cinny via [this convenient link](https://app.cinny.in/register/continuwuity.rocks), or you can use Element or another matrix client *that supports registration*.
### What are we working on?
+1
View File
@@ -0,0 +1 @@
Fixed a startup crash in the sender service if we can't detect the number of CPU cores, even if the `sender_workers' config option is set correctly. Contributed by @katie.
+1
View File
@@ -0,0 +1 @@
Improved the concurrency handling of federation transactions, vastly improving performance and reliability by more accurately handling inbound transactions and reducing the amount of repeated wasted work. Contributed by @nex and @Jade.
+1
View File
@@ -0,0 +1 @@
Added MSC3202 Device masquerading (not all of MSC3202). This should fix issues with enabling MSC4190 for some Mautrix bridges. Contributed by @Jade
+1
View File
@@ -0,0 +1 @@
Removed the `allow_public_room_directory_without_auth` config option. Contributed by @0xnim.
+1
View File
@@ -0,0 +1 @@
Improved URL preview fetching with a more compatible user agent for sites like YouTube Music. Added `!admin media delete-url-preview <url>` command to clear cached URL previews that were stuck and broken.
-1
View File
@@ -9,7 +9,6 @@ address = "0.0.0.0"
allow_device_name_federation = true
allow_guest_registration = true
allow_public_room_directory_over_federation = true
allow_public_room_directory_without_auth = true
allow_registration = true
database_path = "/database"
log = "trace,h2=debug,hyper=debug"
+20 -7
View File
@@ -290,6 +290,25 @@
#
#max_fetch_prev_events = 192
# How many incoming federation transactions the server is willing to be
# processing at any given time before it becomes overloaded and starts
# rejecting further transactions until some slots become available.
#
# Setting this value too low or too high may result in unstable
# federation, and setting it too high may cause runaway resource usage.
#
#max_concurrent_inbound_transactions = 150
# Maximum age (in seconds) for cached federation transaction responses.
# Entries older than this will be removed during cleanup.
#
#transaction_id_cache_max_age_secs = 7200 (2 hours)
# Maximum number of cached federation transaction responses.
# When the cache exceeds this limit, older entries will be removed.
#
#transaction_id_cache_max_entries = 8192
# Default/base connection timeout (seconds). This is used only by URL
# previews and update/news endpoint checks.
#
@@ -527,12 +546,6 @@
#
#allow_public_room_directory_over_federation = false
# Set this to true to allow your server's public room directory to be
# queried without client authentication (access token) through the Client
# APIs. Set this to false to protect against /publicRooms spiders.
#
#allow_public_room_directory_without_auth = false
# Allow guests/unauthenticated users to access TURN credentials.
#
# This is the equivalent of Synapse's `turn_allow_guests` config option.
@@ -1325,7 +1338,7 @@
# sender user's server name, inbound federation X-Matrix origin, and
# outbound federation handler.
#
# You can set this to ["*"] to block all servers by default, and then
# You can set this to [".*"] to block all servers by default, and then
# use `allowed_remote_server_names` to allow only specific servers.
#
# example: ["badserver\\.tld$", "badphrase", "19dollarfortnitecards"]
+1 -1
View File
@@ -52,7 +52,7 @@ ENV BINSTALL_VERSION=1.17.5
# renovate: datasource=github-releases depName=psastras/sbom-rs
ENV CARGO_SBOM_VERSION=0.9.1
# renovate: datasource=crate depName=lddtree
ENV LDDTREE_VERSION=0.4.0
ENV LDDTREE_VERSION=0.5.0
# renovate: datasource=crate depName=timelord-cli
ENV TIMELORD_VERSION=3.0.1
+1 -1
View File
@@ -22,7 +22,7 @@ ENV BINSTALL_VERSION=1.17.5
# renovate: datasource=github-releases depName=psastras/sbom-rs
ENV CARGO_SBOM_VERSION=0.9.1
# renovate: datasource=crate depName=lddtree
ENV LDDTREE_VERSION=0.4.0
ENV LDDTREE_VERSION=0.5.0
# Install unpackaged tools
RUN <<EOF
+7 -1
View File
@@ -51,7 +51,13 @@ continuwuity aims to:
Check out the [documentation](https://continuwuity.org) for installation instructions.
There are currently no open registration continuwuity instances available.
If you want to try it out as a user, we have some partnered homeservers you can use:
* You can head over to [https://federated.nexus](https://federated.nexus/) in your browser.
* Hit the `Apply to Join` button. Once your request has been accepted, you will receive an email with your username and password.
* Head over to [https://app.federated.nexus](https://app.federated.nexus/) and you can sign in there, or use any other matrix chat client you wish elsewhere.
* Your username for matrix will be in the form of `@username:federated.nexus`, however you can simply use the `username` part to log in. Your password is your password.
* There's also [https://continuwuity.rocks/](https://continuwuity.rocks/). You can register a new account using Cinny via [this convenient link](https://app.cinny.in/register/continuwuity.rocks), or you can use Element or another matrix client *that supports registration*.
## What are we working on?
+4
View File
@@ -36,3 +36,7 @@ Deletes all the local media from a local user on our server. This will always ig
## `!admin media delete-all-from-server`
Deletes all remote media from the specified remote server. This will always ignore errors by default
## `!admin media delete-url-preview`
Deletes a cached URL preview, forcing it to be re-fetched. Use --all to purge all cached URL previews
+13 -2
View File
@@ -77,7 +77,12 @@ rec {
craneLib.buildDepsOnly (
(commonAttrs commonAttrsArgs)
// {
env = uwuenv.buildDepsOnlyEnv // (makeRocksDBEnv { inherit rocksdb; });
env = uwuenv.buildDepsOnlyEnv
// (makeRocksDBEnv { inherit rocksdb; })
// {
# required since we started using unstable reqwest apparently ... otherwise the all-features build will fail
RUSTFLAGS = "--cfg reqwest_unstable";
};
inherit (features) cargoExtraArgs;
}
@@ -102,7 +107,13 @@ rec {
'';
cargoArtifacts = deps;
doCheck = true;
env = uwuenv.buildPackageEnv // rocksdbEnv;
env =
uwuenv.buildPackageEnv
// rocksdbEnv
// {
# required since we started using unstable reqwest apparently ... otherwise the all-features build will fail
RUSTFLAGS = "--cfg reqwest_unstable";
};
passthru.env = uwuenv.buildPackageEnv // rocksdbEnv;
meta.mainProgram = crateInfo.pname;
inherit (features) cargoExtraArgs;
+5 -2
View File
@@ -30,12 +30,15 @@ pub(super) async fn incoming_federation(&self) -> Result {
.federation_handletime
.read();
let mut msg = format!("Handling {} incoming pdus:\n", map.len());
let mut msg = format!(
"Handling {} incoming PDUs across {} active transactions:\n",
map.len(),
self.services.transactions.txn_active_handle_count()
);
for (r, (e, i)) in map.iter() {
let elapsed = i.elapsed();
writeln!(msg, "{} {}: {}m{}s", r, e, elapsed.as_secs() / 60, elapsed.as_secs() % 60)?;
}
msg
};
+16
View File
@@ -388,3 +388,19 @@ pub(super) async fn get_remote_thumbnail(
self.write_str(&format!("```\n{result:#?}\nreceived {len} bytes for file content.\n```"))
.await
}
#[admin_command]
pub(super) async fn delete_url_preview(&self, url: Option<String>, all: bool) -> Result {
if all {
self.services.media.clear_url_previews().await;
return self.write_str("Deleted all cached URL previews.").await;
}
let url = url.expect("clap enforces url is required unless --all");
self.services.media.remove_url_preview(&url).await?;
self.write_str(&format!("Deleted cached URL preview for: {url}"))
.await
}
+12
View File
@@ -108,4 +108,16 @@ pub enum MediaCommand {
#[arg(long, default_value("800"))]
height: u32,
},
/// Deletes a cached URL preview, forcing it to be re-fetched.
/// Use --all to purge all cached URL previews.
DeleteUrlPreview {
/// The URL to clear from the saved preview data
#[arg(required_unless_present = "all")]
url: Option<String>,
/// Purge all cached URL previews
#[arg(long, conflicts_with = "url")]
all: bool,
},
}
+1 -1
View File
@@ -209,7 +209,7 @@ pub(super) async fn compact(
let parallelism = parallelism.unwrap_or(1);
let results = maps
.into_iter()
.try_stream()
.try_stream::<conduwuit::Error>()
.paralleln_and_then(runtime, parallelism, move |map| {
map.compact_blocking(options.clone())?;
Ok(map.name().to_owned())
+11 -1
View File
@@ -20,7 +20,17 @@ pub enum ResolverCommand {
name: Option<String>,
},
/// Flush a specific server from the resolver caches or everything
/// Flush a given server from the resolver caches or flush them completely
///
/// * Examples:
/// * Flush a specific server:
///
/// `!admin query resolver flush-cache matrix.example.com`
///
/// * Flush all resolver caches completely:
///
/// `!admin query resolver flush-cache --all`
#[command(verbatim_doc_comment)]
FlushCache {
name: Option<OwnedServerName>,
+7
View File
@@ -252,6 +252,13 @@ pub(crate) async fn register_route(
}
}
// Don't allow registration with user IDs that aren't local
if !services.globals.user_is_local(&user_id) {
return Err!(Request(InvalidUsername(
"Username {body_username} is not local to this server"
)));
}
user_id
},
| Err(e) => {
+25 -3
View File
@@ -6,6 +6,7 @@ use conduwuit::{
Err, Result, err,
utils::{self, content_disposition::make_content_disposition, math::ruma_from_usize},
};
use conduwuit_core::error;
use conduwuit_service::{
Services,
media::{CACHE_CONTROL_IMMUTABLE, CORP_CROSS_ORIGIN, Dim, FileMeta, MXC_LENGTH},
@@ -144,12 +145,22 @@ pub(crate) async fn get_content_route(
server_name: &body.server_name,
media_id: &body.media_id,
};
let FileMeta {
content,
content_type,
content_disposition,
} = fetch_file(&services, &mxc, user, body.timeout_ms, None).await?;
} = match fetch_file(&services, &mxc, user, body.timeout_ms, None).await {
| Ok(meta) => meta,
| Err(conduwuit::Error::Io(e)) => match e.kind() {
| std::io::ErrorKind::NotFound => return Err!(Request(NotFound("Media not found."))),
| std::io::ErrorKind::PermissionDenied => {
error!("Permission denied when trying to read file: {e:?}");
return Err!(Request(Unknown("Unknown error when fetching file.")));
},
| _ => return Err!(Request(Unknown("Unknown error when fetching file."))),
},
| Err(_) => return Err!(Request(Unknown("Unknown error when fetching file."))),
};
Ok(get_content::v1::Response {
file: content.expect("entire file contents"),
@@ -185,7 +196,18 @@ pub(crate) async fn get_content_as_filename_route(
content,
content_type,
content_disposition,
} = fetch_file(&services, &mxc, user, body.timeout_ms, Some(&body.filename)).await?;
} = match fetch_file(&services, &mxc, user, body.timeout_ms, None).await {
| Ok(meta) => meta,
| Err(conduwuit::Error::Io(e)) => match e.kind() {
| std::io::ErrorKind::NotFound => return Err!(Request(NotFound("Media not found."))),
| std::io::ErrorKind::PermissionDenied => {
error!("Permission denied when trying to read file: {e:?}");
return Err!(Request(Unknown("Unknown error when fetching file.")));
},
| _ => return Err!(Request(Unknown("Unknown error when fetching file."))),
},
| Err(_) => return Err!(Request(Unknown("Unknown error when fetching file."))),
};
Ok(get_content_as_filename::v1::Response {
file: content.expect("entire file contents"),
+1 -2
View File
@@ -4,7 +4,6 @@ use axum::extract::State;
use axum_client_ip::InsecureClientIp;
use conduwuit::{Err, Event, Result, debug_info, info, matrix::pdu::PduEvent, utils::ReadyExt};
use conduwuit_service::Services;
use rand::Rng;
use ruma::{
EventId, OwnedEventId, OwnedRoomId, OwnedUserId, RoomId, UserId,
api::client::{
@@ -244,7 +243,7 @@ fn build_report(report: Report) -> RoomMessageEventContent {
/// random delay sending a response per spec suggestion regarding
/// enumerating for potential events existing in our server.
async fn delay_response() {
let time_to_wait = rand::thread_rng().gen_range(2..5);
let time_to_wait = rand::random_range(2..5);
debug_info!(
"Got successful /report request, waiting {time_to_wait} seconds before sending \
successful response."
+3 -3
View File
@@ -50,8 +50,8 @@ pub(crate) async fn send_message_event_route(
// Check if this is a new transaction id
if let Ok(response) = services
.transaction_ids
.existing_txnid(sender_user, sender_device, &body.txn_id)
.transactions
.get_client_txn(sender_user, sender_device, &body.txn_id)
.await
{
// The client might have sent a txnid of the /sendToDevice endpoint
@@ -92,7 +92,7 @@ pub(crate) async fn send_message_event_route(
)
.await?;
services.transaction_ids.add_txnid(
services.transactions.add_client_txnid(
sender_user,
sender_device,
&body.txn_id,
+3 -1
View File
@@ -65,6 +65,8 @@ pub(super) async fn load_joined_room(
and `join*` functions are used to perform steps in parallel which do not depend on each other.
*/
let insert_lock = services.rooms.timeline.mutex_insert.lock(room_id).await;
drop(insert_lock);
let (
account_data,
ephemeral,
@@ -270,7 +272,7 @@ async fn build_state_and_timeline(
// joined since the last sync, that being the syncing user's join event. if
// it's empty something is wrong.
if joined_since_last_sync && timeline.pdus.is_empty() {
warn!("timeline for newly joined room is empty");
debug_warn!("timeline for newly joined room is empty");
}
let (summary, device_list_updates) = try_join(
+4 -4
View File
@@ -26,8 +26,8 @@ pub(crate) async fn send_event_to_device_route(
// Check if this is a new transaction id
if services
.transaction_ids
.existing_txnid(sender_user, sender_device, &body.txn_id)
.transactions
.get_client_txn(sender_user, sender_device, &body.txn_id)
.await
.is_ok()
{
@@ -104,8 +104,8 @@ pub(crate) async fn send_event_to_device_route(
// Save transaction id with empty data
services
.transaction_ids
.add_txnid(sender_user, sender_device, &body.txn_id, &[]);
.transactions
.add_client_txnid(sender_user, sender_device, &body.txn_id, &[]);
Ok(send_event_to_device::v3::Response {})
}
+37 -19
View File
@@ -14,7 +14,8 @@ use futures::{
pin_mut,
};
use ruma::{
CanonicalJsonObject, CanonicalJsonValue, OwnedDeviceId, OwnedServerName, OwnedUserId, UserId,
CanonicalJsonObject, CanonicalJsonValue, DeviceId, OwnedDeviceId, OwnedServerName,
OwnedUserId, UserId,
api::{
AuthScheme, IncomingRequest, Metadata,
client::{
@@ -66,23 +67,17 @@ pub(super) async fn auth(
if metadata.authentication == AuthScheme::None {
match metadata {
| &get_public_rooms::v3::Request::METADATA => {
if !services
.server
.config
.allow_public_room_directory_without_auth
{
match token {
| Token::Appservice(_) | Token::User(_) => {
// we should have validated the token above
// already
},
| Token::None | Token::Invalid => {
return Err(Error::BadRequest(
ErrorKind::MissingToken,
"Missing or invalid access token.",
));
},
}
match token {
| Token::Appservice(_) | Token::User(_) => {
// we should have validated the token above
// already
},
| Token::None | Token::Invalid => {
return Err(Error::BadRequest(
ErrorKind::MissingToken,
"Missing or invalid access token.",
));
},
}
},
| &get_profile::v3::Request::METADATA
@@ -234,10 +229,33 @@ async fn auth_appservice(
return Err!(Request(Exclusive("User is not in namespace.")));
}
// MSC3202/MSC4190: Handle device_id masquerading for appservices.
// The device_id can be provided via `device_id` or
// `org.matrix.msc3202.device_id` query parameter.
let sender_device = if let Some(ref device_id_str) = request.query.device_id {
let device_id: &DeviceId = device_id_str.as_str().into();
// Verify the device exists for this user
if services
.users
.get_device_metadata(&user_id, device_id)
.await
.is_err()
{
return Err!(Request(Forbidden(
"Device does not exist for user or appservice cannot masquerade as this device."
)));
}
Some(device_id.to_owned())
} else {
None
};
Ok(Auth {
origin: None,
sender_user: Some(user_id),
sender_device: None,
sender_device,
appservice_info: Some(*info),
})
}
+4
View File
@@ -11,6 +11,10 @@ use service::Services;
pub(super) struct QueryParams {
pub(super) access_token: Option<String>,
pub(super) user_id: Option<String>,
/// Device ID for appservice device masquerading (MSC3202/MSC4190).
/// Can be provided as `device_id` or `org.matrix.msc3202.device_id`.
#[serde(alias = "org.matrix.msc3202.device_id")]
pub(super) device_id: Option<String>,
}
pub(super) struct Request {
+1 -1
View File
@@ -40,7 +40,7 @@ pub(crate) async fn get_room_information_route(
servers.sort_unstable();
servers.dedup();
servers.shuffle(&mut rand::thread_rng());
servers.shuffle(&mut rand::rng());
// insert our server as the very first choice if in list
if let Some(server_index) = servers
+214 -64
View File
@@ -1,27 +1,33 @@
use std::{collections::BTreeMap, net::IpAddr, time::Instant};
use std::{
collections::{BTreeMap, HashMap, HashSet},
net::IpAddr,
time::{Duration, Instant},
};
use axum::extract::State;
use axum_client_ip::InsecureClientIp;
use conduwuit::{
Err, Error, Result, debug, debug_warn, err, error,
result::LogErr,
state_res::lexicographical_topological_sort,
trace,
utils::{
IterStream, ReadyExt, millis_since_unix_epoch,
stream::{BroadbandExt, TryBroadbandExt, automatic_width},
},
warn,
};
use conduwuit_service::{
Services,
sending::{EDU_LIMIT, PDU_LIMIT},
};
use futures::{FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt};
use http::StatusCode;
use itertools::Itertools;
use ruma::{
CanonicalJsonObject, OwnedEventId, OwnedRoomId, OwnedUserId, RoomId, ServerName, UserId,
CanonicalJsonObject, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedRoomId, OwnedUserId,
RoomId, ServerName, UserId,
api::{
client::error::ErrorKind,
client::error::{ErrorKind, ErrorKind::LimitExceeded},
federation::transactions::{
edu::{
DeviceListUpdateContent, DirectDeviceContent, Edu, PresenceContent,
@@ -32,9 +38,16 @@ use ruma::{
},
},
events::receipt::{ReceiptEvent, ReceiptEventContent, ReceiptType},
int,
serde::Raw,
to_device::DeviceIdOrAllDevices,
uint,
};
use service::transactions::{
FederationTxnState, TransactionError, TxnKey, WrappedTransactionResponse,
};
use tokio::sync::watch::{Receiver, Sender};
use tracing::instrument;
use crate::Ruma;
@@ -44,15 +57,6 @@ type Pdu = (OwnedRoomId, OwnedEventId, CanonicalJsonObject);
/// # `PUT /_matrix/federation/v1/send/{txnId}`
///
/// Push EDUs and PDUs to this server.
#[tracing::instrument(
name = "txn",
level = "debug",
skip_all,
fields(
%client,
origin = body.origin().as_str()
),
)]
pub(crate) async fn send_transaction_message_route(
State(services): State<crate::State>,
InsecureClientIp(client): InsecureClientIp,
@@ -76,16 +80,73 @@ pub(crate) async fn send_transaction_message_route(
)));
}
let txn_start_time = Instant::now();
trace!(
pdus = body.pdus.len(),
edus = body.edus.len(),
elapsed = ?txn_start_time.elapsed(),
id = %body.transaction_id,
origin = %body.origin(),
"Starting txn",
);
let txn_key = (body.origin().to_owned(), body.transaction_id.clone());
// Atomically check cache, join active, or start new transaction
match services
.transactions
.get_or_start_federation_txn(txn_key.clone())?
{
| FederationTxnState::Cached(response) => {
// Already responded
Ok(response)
},
| FederationTxnState::Active(receiver) => {
// Another thread is processing
wait_for_result(receiver).await
},
| FederationTxnState::Started { receiver, sender } => {
// We're the first, spawn the processing task
services
.server
.runtime()
.spawn(process_inbound_transaction(services, body, client, txn_key, sender));
// and wait for it
wait_for_result(receiver).await
},
}
}
async fn wait_for_result(
mut recv: Receiver<WrappedTransactionResponse>,
) -> Result<send_transaction_message::v1::Response> {
if tokio::time::timeout(Duration::from_secs(50), recv.changed())
.await
.is_err()
{
// Took too long, return 429 to encourage the sender to try again
return Err(Error::BadRequest(
LimitExceeded { retry_after: None },
"Transaction is being still being processed. Please try again later.",
));
}
let value = recv.borrow_and_update();
match value.clone() {
| Some(Ok(response)) => Ok(response),
| Some(Err(err)) => Err(transaction_error_to_response(&err)),
| None => Err(Error::Request(
ErrorKind::Unknown,
"Transaction processing failed unexpectedly".into(),
StatusCode::INTERNAL_SERVER_ERROR,
)),
}
}
#[instrument(
skip_all,
fields(
id = ?body.transaction_id.as_str(),
origin = ?body.origin()
)
)]
async fn process_inbound_transaction(
services: crate::State,
body: Ruma<send_transaction_message::v1::Request>,
client: IpAddr,
txn_key: TxnKey,
sender: Sender<WrappedTransactionResponse>,
) {
let txn_start_time = Instant::now();
let pdus = body
.pdus
.iter()
@@ -102,40 +163,79 @@ pub(crate) async fn send_transaction_message_route(
.filter_map(Result::ok)
.stream();
let results = handle(&services, &client, body.origin(), txn_start_time, pdus, edus).await?;
debug!(pdus = body.pdus.len(), edus = body.edus.len(), "Processing transaction",);
let results = match handle(&services, &client, body.origin(), pdus, edus).await {
| Ok(results) => results,
| Err(err) => {
fail_federation_txn(services, &txn_key, &sender, err);
return;
},
};
for (id, result) in &results {
if let Err(e) = result {
if matches!(e, Error::BadRequest(ErrorKind::NotFound, _)) {
debug_warn!("Incoming PDU failed {id}: {e:?}");
}
}
}
debug!(
pdus = body.pdus.len(),
edus = body.edus.len(),
elapsed = ?txn_start_time.elapsed(),
id = %body.transaction_id,
origin = %body.origin(),
"Finished txn",
"Finished processing transaction"
);
for (id, result) in &results {
if let Err(e) = result {
if matches!(e, Error::BadRequest(ErrorKind::NotFound, _)) {
warn!("Incoming PDU failed {id}: {e:?}");
}
}
}
Ok(send_transaction_message::v1::Response {
let response = send_transaction_message::v1::Response {
pdus: results
.into_iter()
.map(|(e, r)| (e, r.map_err(error::sanitized_message)))
.collect(),
})
};
services
.transactions
.finish_federation_txn(txn_key, sender, response);
}
/// Handles a failed federation transaction by sending the error through
/// the channel and cleaning up the transaction state. This allows waiters to
/// receive an appropriate error response.
fn fail_federation_txn(
services: crate::State,
txn_key: &TxnKey,
sender: &Sender<WrappedTransactionResponse>,
err: TransactionError,
) {
debug!("Transaction failed: {err}");
// Remove from active state so the transaction can be retried
services.transactions.remove_federation_txn(txn_key);
// Send the error to any waiters
if let Err(e) = sender.send(Some(Err(err))) {
debug_warn!("Failed to send transaction error to receivers: {e}");
}
}
/// Converts a TransactionError into an appropriate HTTP error response.
fn transaction_error_to_response(err: &TransactionError) -> Error {
match err {
| TransactionError::ShuttingDown => Error::Request(
ErrorKind::Unknown,
"Server is shutting down, please retry later".into(),
StatusCode::SERVICE_UNAVAILABLE,
),
}
}
async fn handle(
services: &Services,
client: &IpAddr,
origin: &ServerName,
started: Instant,
pdus: impl Stream<Item = Pdu> + Send,
edus: impl Stream<Item = Edu> + Send,
) -> Result<ResolvedMap> {
) -> std::result::Result<ResolvedMap, TransactionError> {
// group pdus by room
let pdus = pdus
.collect()
@@ -152,7 +252,7 @@ async fn handle(
.into_iter()
.try_stream()
.broad_and_then(|(room_id, pdus): (_, Vec<_>)| {
handle_room(services, client, origin, started, room_id, pdus.into_iter())
handle_room(services, client, origin, room_id, pdus.into_iter())
.map_ok(Vec::into_iter)
.map_ok(IterStream::try_stream)
})
@@ -169,14 +269,51 @@ async fn handle(
Ok(results)
}
/// Attempts to build a localised directed acyclic graph out of the given PDUs,
/// returning them in a topologically sorted order.
///
/// This is used to attempt to process PDUs in an order that respects their
/// dependencies, however it is ultimately the sender's responsibility to send
/// them in a processable order, so this is just a best effort attempt. It does
/// not account for power levels or other tie breaks.
async fn build_local_dag(
pdu_map: &HashMap<OwnedEventId, CanonicalJsonObject>,
) -> Result<Vec<OwnedEventId>> {
debug_assert!(pdu_map.len() >= 2, "needless call to build_local_dag with less than 2 PDUs");
let mut dag: HashMap<OwnedEventId, HashSet<OwnedEventId>> = HashMap::new();
for (event_id, value) in pdu_map {
let prev_events = value
.get("prev_events")
.expect("pdu must have prev_events")
.as_array()
.expect("prev_events must be an array")
.iter()
.map(|v| {
OwnedEventId::parse(v.as_str().expect("prev_events values must be strings"))
.expect("prev_events must be valid event IDs")
})
.collect::<HashSet<OwnedEventId>>();
dag.insert(event_id.clone(), prev_events);
}
lexicographical_topological_sort(&dag, &|_| async {
// Note: we don't bother fetching power levels because that would massively slow
// this function down. This is a best-effort attempt to order events correctly
// for processing, however ultimately that should be the sender's job.
Ok((int!(0), MilliSecondsSinceUnixEpoch(uint!(0))))
})
.await
.map_err(|e| err!("failed to resolve local graph: {e}"))
}
async fn handle_room(
services: &Services,
_client: &IpAddr,
origin: &ServerName,
txn_start_time: Instant,
room_id: OwnedRoomId,
pdus: impl Iterator<Item = Pdu> + Send,
) -> Result<Vec<(OwnedEventId, Result)>> {
) -> std::result::Result<Vec<(OwnedEventId, Result)>, TransactionError> {
let _room_lock = services
.rooms
.event_handler
@@ -185,27 +322,40 @@ async fn handle_room(
.await;
let room_id = &room_id;
pdus.try_stream()
.and_then(|(_, event_id, value)| async move {
services.server.check_running()?;
let pdu_start_time = Instant::now();
let result = services
.rooms
.event_handler
.handle_incoming_pdu(origin, room_id, &event_id, value, true)
.await
.map(|_| ());
debug!(
pdu_elapsed = ?pdu_start_time.elapsed(),
txn_elapsed = ?txn_start_time.elapsed(),
"Finished PDU {event_id}",
);
Ok((event_id, result))
let pdu_map: HashMap<OwnedEventId, CanonicalJsonObject> = pdus
.into_iter()
.map(|(_, event_id, value)| (event_id, value))
.collect();
// Try to sort PDUs by their dependencies, but fall back to arbitrary order on
// failure (e.g., cycles). This is best-effort; proper ordering is the sender's
// responsibility.
let sorted_event_ids = if pdu_map.len() >= 2 {
build_local_dag(&pdu_map).await.unwrap_or_else(|e| {
debug_warn!("Failed to build local DAG for room {room_id}: {e}");
pdu_map.keys().cloned().collect()
})
.try_collect()
.await
} else {
pdu_map.keys().cloned().collect()
};
let mut results = Vec::with_capacity(sorted_event_ids.len());
for event_id in sorted_event_ids {
let value = pdu_map
.get(&event_id)
.expect("sorted event IDs must be from the original map")
.clone();
services
.server
.check_running()
.map_err(|_| TransactionError::ShuttingDown)?;
let result = services
.rooms
.event_handler
.handle_incoming_pdu(origin, room_id, &event_id, value, true)
.await
.map(|_| ());
results.push((event_id, result));
}
Ok(results)
}
async fn handle_edu(services: &Services, client: &IpAddr, origin: &ServerName, edu: Edu) {
@@ -478,8 +628,8 @@ async fn handle_edu_direct_to_device(
// Check if this is a new transaction id
if services
.transaction_ids
.existing_txnid(sender, None, message_id)
.transactions
.get_client_txn(sender, None, message_id)
.await
.is_ok()
{
@@ -498,8 +648,8 @@ async fn handle_edu_direct_to_device(
// Save transaction id with empty data
services
.transaction_ids
.add_txnid(sender, None, message_id, &[]);
.transactions
.add_client_txnid(sender, None, message_id, &[]);
}
async fn handle_edu_direct_to_device_user<Event: Send + Sync>(
+1
View File
@@ -86,6 +86,7 @@ libloading.optional = true
log.workspace = true
num-traits.workspace = true
rand.workspace = true
rand_core = { version = "0.6.4", features = ["getrandom"] }
regex.workspace = true
reqwest.workspace = true
ring.workspace = true
+32 -7
View File
@@ -368,6 +368,31 @@ pub struct Config {
#[serde(default = "default_max_fetch_prev_events")]
pub max_fetch_prev_events: u16,
/// How many incoming federation transactions the server is willing to be
/// processing at any given time before it becomes overloaded and starts
/// rejecting further transactions until some slots become available.
///
/// Setting this value too low or too high may result in unstable
/// federation, and setting it too high may cause runaway resource usage.
///
/// default: 150
#[serde(default = "default_max_concurrent_inbound_transactions")]
pub max_concurrent_inbound_transactions: usize,
/// Maximum age (in seconds) for cached federation transaction responses.
/// Entries older than this will be removed during cleanup.
///
/// default: 7200 (2 hours)
#[serde(default = "default_transaction_id_cache_max_age_secs")]
pub transaction_id_cache_max_age_secs: u64,
/// Maximum number of cached federation transaction responses.
/// When the cache exceeds this limit, older entries will be removed.
///
/// default: 8192
#[serde(default = "default_transaction_id_cache_max_entries")]
pub transaction_id_cache_max_entries: usize,
/// Default/base connection timeout (seconds). This is used only by URL
/// previews and update/news endpoint checks.
///
@@ -653,12 +678,6 @@ pub struct Config {
#[serde(default)]
pub allow_public_room_directory_over_federation: bool,
/// Set this to true to allow your server's public room directory to be
/// queried without client authentication (access token) through the Client
/// APIs. Set this to false to protect against /publicRooms spiders.
#[serde(default)]
pub allow_public_room_directory_without_auth: bool,
/// Allow guests/unauthenticated users to access TURN credentials.
///
/// This is the equivalent of Synapse's `turn_allow_guests` config option.
@@ -1525,7 +1544,7 @@ pub struct Config {
/// sender user's server name, inbound federation X-Matrix origin, and
/// outbound federation handler.
///
/// You can set this to ["*"] to block all servers by default, and then
/// You can set this to [".*"] to block all servers by default, and then
/// use `allowed_remote_server_names` to allow only specific servers.
///
/// example: ["badserver\\.tld$", "badphrase", "19dollarfortnitecards"]
@@ -2540,6 +2559,12 @@ fn default_pusher_idle_timeout() -> u64 { 15 }
fn default_max_fetch_prev_events() -> u16 { 192_u16 }
fn default_max_concurrent_inbound_transactions() -> usize { 150 }
fn default_transaction_id_cache_max_age_secs() -> u64 { 60 * 60 * 2 }
fn default_transaction_id_cache_max_entries() -> usize { 8192 }
fn default_tracing_flame_filter() -> String {
cfg!(debug_assertions)
.then_some("trace,h2=off")
+9
View File
@@ -14,6 +14,7 @@ static SEMANTIC: &str = env!("CARGO_PKG_VERSION");
static VERSION: OnceLock<String> = OnceLock::new();
static VERSION_UA: OnceLock<String> = OnceLock::new();
static USER_AGENT: OnceLock<String> = OnceLock::new();
static USER_AGENT_MEDIA: OnceLock<String> = OnceLock::new();
#[inline]
#[must_use]
@@ -21,14 +22,22 @@ pub fn name() -> &'static str { BRANDING }
#[inline]
pub fn version() -> &'static str { VERSION.get_or_init(init_version) }
#[inline]
pub fn version_ua() -> &'static str { VERSION_UA.get_or_init(init_version_ua) }
#[inline]
pub fn user_agent() -> &'static str { USER_AGENT.get_or_init(init_user_agent) }
#[inline]
pub fn user_agent_media() -> &'static str { USER_AGENT_MEDIA.get_or_init(init_user_agent_media) }
fn init_user_agent() -> String { format!("{}/{} (bot; +{WEBSITE})", name(), version_ua()) }
fn init_user_agent_media() -> String {
format!("{}/{} (embedbot; facebookexternalhit/1.1; +{WEBSITE})", name(), version_ua())
}
fn init_version_ua() -> String {
conduwuit_build_metadata::version_tag()
.map_or_else(|| SEMANTIC.to_owned(), |extra| format!("{SEMANTIC}+{extra}"))
+552
View File
@@ -0,0 +1,552 @@
#[cfg(conduwuit_bench)]
extern crate test;
use std::{
borrow::Borrow,
collections::{HashMap, HashSet},
sync::atomic::{AtomicU64, Ordering::SeqCst},
};
use futures::{future, future::ready};
use maplit::{btreemap, hashmap, hashset};
use ruma::{
EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, RoomId, RoomVersionId, Signatures, UserId,
events::{
StateEventType, TimelineEventType,
room::{
join_rules::{JoinRule, RoomJoinRulesEventContent},
member::{MembershipState, RoomMemberEventContent},
},
},
int, room_id, uint, user_id,
};
use serde_json::{
json,
value::{RawValue as RawJsonValue, to_raw_value as to_raw_json_value},
};
use crate::{
matrix::{Event, Pdu, pdu::EventHash},
state_res::{self as state_res, Error, Result, StateMap},
};
static SERVER_TIMESTAMP: AtomicU64 = AtomicU64::new(0);
#[cfg(conduwuit_bench)]
#[cfg_attr(conduwuit_bench, bench)]
fn lexico_topo_sort(c: &mut test::Bencher) {
let graph = hashmap! {
event_id("l") => hashset![event_id("o")],
event_id("m") => hashset![event_id("n"), event_id("o")],
event_id("n") => hashset![event_id("o")],
event_id("o") => hashset![], // "o" has zero outgoing edges but 4 incoming edges
event_id("p") => hashset![event_id("o")],
};
c.iter(|| {
let _ = state_res::lexicographical_topological_sort(&graph, &|_| {
future::ok((int!(0), MilliSecondsSinceUnixEpoch(uint!(0))))
});
});
}
#[cfg(conduwuit_bench)]
#[cfg_attr(conduwuit_bench, bench)]
fn resolution_shallow_auth_chain(c: &mut test::Bencher) {
let mut store = TestStore(hashmap! {});
// build up the DAG
let (state_at_bob, state_at_charlie, _) = store.set_up();
c.iter(|| async {
let ev_map = store.0.clone();
let state_sets = [&state_at_bob, &state_at_charlie];
let fetch = |id: OwnedEventId| ready(ev_map.get(&id).map(ToOwned::to_owned));
let exists = |id: OwnedEventId| ready(ev_map.get(&id).is_some());
let auth_chain_sets: Vec<HashSet<_>> = state_sets
.iter()
.map(|map| {
store
.auth_event_ids(room_id(), map.values().cloned().collect())
.unwrap()
})
.collect();
let _ = match state_res::resolve(
&RoomVersionId::V6,
state_sets.into_iter(),
&auth_chain_sets,
&fetch,
&exists,
)
.await
{
| Ok(state) => state,
| Err(e) => panic!("{e}"),
};
});
}
#[cfg(conduwuit_bench)]
#[cfg_attr(conduwuit_bench, bench)]
fn resolve_deeper_event_set(c: &mut test::Bencher) {
let mut inner = INITIAL_EVENTS();
let ban = BAN_STATE_SET();
inner.extend(ban);
let store = TestStore(inner.clone());
let state_set_a = [
inner.get(&event_id("CREATE")).unwrap(),
inner.get(&event_id("IJR")).unwrap(),
inner.get(&event_id("IMA")).unwrap(),
inner.get(&event_id("IMB")).unwrap(),
inner.get(&event_id("IMC")).unwrap(),
inner.get(&event_id("MB")).unwrap(),
inner.get(&event_id("PA")).unwrap(),
]
.iter()
.map(|ev| {
(
(ev.event_type().clone().into(), ev.state_key().unwrap().into()),
ev.event_id().to_owned(),
)
})
.collect::<StateMap<_>>();
let state_set_b = [
inner.get(&event_id("CREATE")).unwrap(),
inner.get(&event_id("IJR")).unwrap(),
inner.get(&event_id("IMA")).unwrap(),
inner.get(&event_id("IMB")).unwrap(),
inner.get(&event_id("IMC")).unwrap(),
inner.get(&event_id("IME")).unwrap(),
inner.get(&event_id("PA")).unwrap(),
]
.iter()
.map(|ev| {
(
(ev.event_type().clone().into(), ev.state_key().unwrap().into()),
ev.event_id().to_owned(),
)
})
.collect::<StateMap<_>>();
c.iter(|| async {
let state_sets = [&state_set_a, &state_set_b];
let auth_chain_sets: Vec<HashSet<_>> = state_sets
.iter()
.map(|map| {
store
.auth_event_ids(room_id(), map.values().cloned().collect())
.unwrap()
})
.collect();
let fetch = |id: OwnedEventId| ready(inner.get(&id).map(ToOwned::to_owned));
let exists = |id: OwnedEventId| ready(inner.get(&id).is_some());
let _ = match state_res::resolve(
&RoomVersionId::V6,
state_sets.into_iter(),
&auth_chain_sets,
&fetch,
&exists,
)
.await
{
| Ok(state) => state,
| Err(_) => panic!("resolution failed during benchmarking"),
};
});
}
//*/////////////////////////////////////////////////////////////////////
//
// IMPLEMENTATION DETAILS AHEAD
//
/////////////////////////////////////////////////////////////////////*/
struct TestStore<E: Event>(HashMap<OwnedEventId, E>);
#[allow(unused)]
impl<E: Event + Clone> TestStore<E> {
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<E> {
self.0
.get(event_id)
.cloned()
.ok_or_else(|| Error::NotFound(format!("{} not found", event_id)))
}
/// Returns the events that correspond to the `event_ids` sorted in the same
/// order.
fn get_events(&self, room_id: &RoomId, event_ids: &[OwnedEventId]) -> Result<Vec<E>> {
let mut events = vec![];
for id in event_ids {
events.push(self.get_event(room_id, id)?);
}
Ok(events)
}
/// Returns a Vec of the related auth events to the given `event`.
fn auth_event_ids(
&self,
room_id: &RoomId,
event_ids: Vec<OwnedEventId>,
) -> Result<HashSet<OwnedEventId>> {
let mut result = HashSet::new();
let mut stack = event_ids;
// DFS for auth event chain
while !stack.is_empty() {
let ev_id = stack.pop().unwrap();
if result.contains(&ev_id) {
continue;
}
result.insert(ev_id.clone());
let event = self.get_event(room_id, ev_id.borrow())?;
stack.extend(event.auth_events().map(ToOwned::to_owned));
}
Ok(result)
}
/// Returns a vector representing the difference in auth chains of the given
/// `events`.
fn auth_chain_diff(
&self,
room_id: &RoomId,
event_ids: Vec<Vec<OwnedEventId>>,
) -> Result<Vec<OwnedEventId>> {
let mut auth_chain_sets = vec![];
for ids in event_ids {
// TODO state store `auth_event_ids` returns self in the event ids list
// when an event returns `auth_event_ids` self is not contained
let chain = self
.auth_event_ids(room_id, ids)?
.into_iter()
.collect::<HashSet<_>>();
auth_chain_sets.push(chain);
}
if let Some(first) = auth_chain_sets.first().cloned() {
let common = auth_chain_sets
.iter()
.skip(1)
.fold(first, |a, b| a.intersection(b).cloned().collect::<HashSet<_>>());
Ok(auth_chain_sets
.into_iter()
.flatten()
.filter(|id| !common.contains(id))
.collect())
} else {
Ok(vec![])
}
}
}
impl TestStore<Pdu> {
#[allow(clippy::type_complexity)]
fn set_up(
&mut self,
) -> (StateMap<OwnedEventId>, StateMap<OwnedEventId>, StateMap<OwnedEventId>) {
let create_event = to_pdu_event::<&EventId>(
"CREATE",
alice(),
TimelineEventType::RoomCreate,
Some(""),
to_raw_json_value(&json!({ "creator": alice() })).unwrap(),
&[],
&[],
);
let cre = create_event.event_id().to_owned();
self.0.insert(cre.clone(), create_event.clone());
let alice_mem = to_pdu_event(
"IMA",
alice(),
TimelineEventType::RoomMember,
Some(alice().to_string().as_str()),
member_content_join(),
&[cre.clone()],
&[cre.clone()],
);
self.0
.insert(alice_mem.event_id().to_owned(), alice_mem.clone());
let join_rules = to_pdu_event(
"IJR",
alice(),
TimelineEventType::RoomJoinRules,
Some(""),
to_raw_json_value(&RoomJoinRulesEventContent::new(JoinRule::Public)).unwrap(),
&[cre.clone(), alice_mem.event_id().to_owned()],
&[alice_mem.event_id().to_owned()],
);
self.0
.insert(join_rules.event_id().to_owned(), join_rules.clone());
// Bob and Charlie join at the same time, so there is a fork
// this will be represented in the state_sets when we resolve
let bob_mem = to_pdu_event(
"IMB",
bob(),
TimelineEventType::RoomMember,
Some(bob().to_string().as_str()),
member_content_join(),
&[cre.clone(), join_rules.event_id().to_owned()],
&[join_rules.event_id().to_owned()],
);
self.0
.insert(bob_mem.event_id().to_owned(), bob_mem.clone());
let charlie_mem = to_pdu_event(
"IMC",
charlie(),
TimelineEventType::RoomMember,
Some(charlie().to_string().as_str()),
member_content_join(),
&[cre, join_rules.event_id().to_owned()],
&[join_rules.event_id().to_owned()],
);
self.0
.insert(charlie_mem.event_id().to_owned(), charlie_mem.clone());
let state_at_bob = [&create_event, &alice_mem, &join_rules, &bob_mem]
.iter()
.map(|ev| {
(
(ev.event_type().clone().into(), ev.state_key().unwrap().into()),
ev.event_id().to_owned(),
)
})
.collect::<StateMap<_>>();
let state_at_charlie = [&create_event, &alice_mem, &join_rules, &charlie_mem]
.iter()
.map(|ev| {
(
(ev.event_type().clone().into(), ev.state_key().unwrap().into()),
ev.event_id().to_owned(),
)
})
.collect::<StateMap<_>>();
let expected = [&create_event, &alice_mem, &join_rules, &bob_mem, &charlie_mem]
.iter()
.map(|ev| {
(
(ev.event_type().clone().into(), ev.state_key().unwrap().into()),
ev.event_id().to_owned(),
)
})
.collect::<StateMap<_>>();
(state_at_bob, state_at_charlie, expected)
}
}
fn event_id(id: &str) -> OwnedEventId {
if id.contains('$') {
return id.try_into().unwrap();
}
format!("${}:foo", id).try_into().unwrap()
}
fn alice() -> &'static UserId { user_id!("@alice:foo") }
fn bob() -> &'static UserId { user_id!("@bob:foo") }
fn charlie() -> &'static UserId { user_id!("@charlie:foo") }
fn ella() -> &'static UserId { user_id!("@ella:foo") }
fn room_id() -> &'static RoomId { room_id!("!test:foo") }
fn member_content_ban() -> Box<RawJsonValue> {
to_raw_json_value(&RoomMemberEventContent::new(MembershipState::Ban)).unwrap()
}
fn member_content_join() -> Box<RawJsonValue> {
to_raw_json_value(&RoomMemberEventContent::new(MembershipState::Join)).unwrap()
}
fn to_pdu_event<S>(
id: &str,
sender: &UserId,
ev_type: TimelineEventType,
state_key: Option<&str>,
content: Box<RawJsonValue>,
auth_events: &[S],
prev_events: &[S],
) -> Pdu
where
S: AsRef<str>,
{
// We don't care if the addition happens in order just that it is atomic
// (each event has its own value)
let ts = SERVER_TIMESTAMP.fetch_add(1, SeqCst);
let id = if id.contains('$') {
id.to_owned()
} else {
format!("${}:foo", id)
};
let auth_events = auth_events
.iter()
.map(AsRef::as_ref)
.map(event_id)
.collect::<Vec<_>>();
let prev_events = prev_events
.iter()
.map(AsRef::as_ref)
.map(event_id)
.collect::<Vec<_>>();
Pdu {
event_id: id.try_into().unwrap(),
room_id: Some(room_id().to_owned()),
sender: sender.to_owned(),
origin_server_ts: ts.try_into().unwrap(),
state_key: state_key.map(Into::into),
kind: ev_type,
content,
origin: None,
redacts: None,
unsigned: None,
auth_events,
prev_events,
depth: uint!(0),
hashes: EventHash { sha256: String::new() },
signatures: None,
}
}
// all graphs start with these input events
#[allow(non_snake_case)]
fn INITIAL_EVENTS() -> HashMap<OwnedEventId, Pdu> {
vec![
to_pdu_event::<&EventId>(
"CREATE",
alice(),
TimelineEventType::RoomCreate,
Some(""),
to_raw_json_value(&json!({ "creator": alice() })).unwrap(),
&[],
&[],
),
to_pdu_event(
"IMA",
alice(),
TimelineEventType::RoomMember,
Some(alice().as_str()),
member_content_join(),
&["CREATE"],
&["CREATE"],
),
to_pdu_event(
"IPOWER",
alice(),
TimelineEventType::RoomPowerLevels,
Some(""),
to_raw_json_value(&json!({ "users": { alice(): 100 } })).unwrap(),
&["CREATE", "IMA"],
&["IMA"],
),
to_pdu_event(
"IJR",
alice(),
TimelineEventType::RoomJoinRules,
Some(""),
to_raw_json_value(&RoomJoinRulesEventContent::new(JoinRule::Public)).unwrap(),
&["CREATE", "IMA", "IPOWER"],
&["IPOWER"],
),
to_pdu_event(
"IMB",
bob(),
TimelineEventType::RoomMember,
Some(bob().to_string().as_str()),
member_content_join(),
&["CREATE", "IJR", "IPOWER"],
&["IJR"],
),
to_pdu_event(
"IMC",
charlie(),
TimelineEventType::RoomMember,
Some(charlie().to_string().as_str()),
member_content_join(),
&["CREATE", "IJR", "IPOWER"],
&["IMB"],
),
to_pdu_event::<&EventId>(
"START",
charlie(),
TimelineEventType::RoomTopic,
Some(""),
to_raw_json_value(&json!({})).unwrap(),
&[],
&[],
),
to_pdu_event::<&EventId>(
"END",
charlie(),
TimelineEventType::RoomTopic,
Some(""),
to_raw_json_value(&json!({})).unwrap(),
&[],
&[],
),
]
.into_iter()
.map(|ev| (ev.event_id().to_owned(), ev))
.collect()
}
// all graphs start with these input events
#[allow(non_snake_case)]
fn BAN_STATE_SET() -> HashMap<OwnedEventId, Pdu> {
vec![
to_pdu_event(
"PA",
alice(),
TimelineEventType::RoomPowerLevels,
Some(""),
to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(),
&["CREATE", "IMA", "IPOWER"], // auth_events
&["START"], // prev_events
),
to_pdu_event(
"PB",
alice(),
TimelineEventType::RoomPowerLevels,
Some(""),
to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(),
&["CREATE", "IMA", "IPOWER"],
&["END"],
),
to_pdu_event(
"MB",
alice(),
TimelineEventType::RoomMember,
Some(ella().as_str()),
member_content_ban(),
&["CREATE", "IMA", "PB"],
&["PA"],
),
to_pdu_event(
"IME",
ella(),
TimelineEventType::RoomMember,
Some(ella().as_str()),
member_content_join(),
&["CREATE", "IJR", "PA"],
&["MB"],
),
]
.into_iter()
.map(|ev| (ev.event_id().to_owned(), ev))
.collect()
}
+1 -20
View File
@@ -1,4 +1,3 @@
use ruma::OwnedEventId;
use serde_json::Error as JsonError;
use thiserror::Error;
@@ -15,28 +14,10 @@ pub enum Error {
Unsupported(String),
/// The given event was not found.
#[error("Event not found: {0}")]
#[error("Not found error: {0}")]
NotFound(String),
/// A required event this event depended on could not be fetched,
/// either as it was missing, or because it was invalid
#[error("Failed to fetch required {0} event: {1}")]
DependencyFailed(OwnedEventId, String),
/// Invalid fields in the given PDU.
#[error("Invalid PDU: {0}")]
InvalidPdu(String),
/// This event failed an authorization condition.
#[error("Auth check failed: {0}")]
AuthConditionFailed(String),
/// This event contained multiple auth events of the same type and state
/// key.
#[error("Duplicate auth events: {0}")]
DuplicateAuthEvents(String),
/// This event contains unnecessary auth events.
#[error("Unknown or unnecessary auth events present: {0}")]
UnselectedAuthEvents(String),
}
File diff suppressed because it is too large Load Diff
@@ -1,238 +0,0 @@
//! Auth checks relevant to any event's `auth_events`.
//!
//! See: https://spec.matrix.org/v1.16/rooms/v12/#authorization-rules
use std::collections::{HashMap, HashSet};
use ruma::{
EventId, OwnedEventId, RoomId, UserId,
events::{
StateEventType, TimelineEventType,
room::member::{MembershipState, RoomMemberEventContent, ThirdPartyInvite},
},
};
use crate::{Event, EventTypeExt, Pdu, RoomVersion, matrix::StateKey, state_res::Error, warn};
/// For the given event `kind` what are the relevant auth events that are needed
/// to authenticate this `content`.
///
/// # Errors
///
/// This function will return an error if the supplied `content` is not a JSON
/// object.
pub fn auth_types_for_event(
room_version: &RoomVersion,
event_type: &TimelineEventType,
state_key: Option<&StateKey>,
sender: &UserId,
member_content: Option<RoomMemberEventContent>,
) -> serde_json::Result<Vec<(StateEventType, StateKey)>> {
if event_type == &TimelineEventType::RoomCreate {
// Create events never have auth events
return Ok(vec![]);
}
let mut auth_types = if room_version.room_ids_as_hashes {
vec![
StateEventType::RoomMember.with_state_key(sender.as_str()),
StateEventType::RoomPowerLevels.with_state_key(""),
]
} else {
// For room versions that do not use room IDs as hashes, include the
// RoomCreate event as an auth event.
vec![
StateEventType::RoomMember.with_state_key(sender.as_str()),
StateEventType::RoomPowerLevels.with_state_key(""),
StateEventType::RoomCreate.with_state_key(""),
]
};
if event_type == &TimelineEventType::RoomMember {
let member_content =
member_content.expect("member_content must be provided for RoomMember events");
// Include the target's membership (if available)
auth_types.push((
StateEventType::RoomMember,
state_key
.expect("state_key must be provided for RoomMember events")
.to_owned(),
));
if matches!(
member_content.membership,
MembershipState::Join | MembershipState::Invite | MembershipState::Knock
) {
// Include the join rules
auth_types.push(StateEventType::RoomJoinRules.with_state_key(""));
}
if matches!(member_content.membership, MembershipState::Invite) {
// If this is an invite, include the third party invite if it exists
if let Some(ThirdPartyInvite { signed, .. }) = member_content.third_party_invite {
auth_types
.push(StateEventType::RoomThirdPartyInvite.with_state_key(signed.token));
}
}
if matches!(member_content.membership, MembershipState::Join)
&& room_version.restricted_join_rules
{
// If this is a restricted join, include the authorizing user's membership
if let Some(authorizing_user) = member_content.join_authorized_via_users_server {
auth_types
.push(StateEventType::RoomMember.with_state_key(authorizing_user.as_str()));
}
}
}
Ok(auth_types)
}
/// Checks for duplicate auth events in the `auth_events` field of an event.
/// Note: the caller should already have all of the auth events fetched.
///
/// If there are multiple auth events of the same type and state key, this
/// returns an error. Otherwise, it returns a map of (type, state_key) to the
/// corresponding auth event.
pub async fn check_duplicate_auth_events<FE>(
auth_events: &[OwnedEventId],
fetch_event: FE,
) -> Result<HashMap<(StateEventType, StateKey), Pdu>, Error>
where
FE: AsyncFn(&EventId) -> Result<Option<Pdu>, Error>,
{
let mut seen: HashMap<(StateEventType, StateKey), Pdu> = HashMap::new();
// Considering all of the event's auth events:
for auth_event_id in auth_events {
if let Ok(Some(auth_event)) = fetch_event(auth_event_id).await {
let event_type = auth_event.kind();
// If this is not a state event, reject it.
let Some(state_key) = &auth_event.state_key() else {
return Err(Error::InvalidPdu(format!(
"Auth event {:?} is not a state event",
auth_event_id
)));
};
let type_key_pair: (StateEventType, StateKey) =
event_type.clone().with_state_key(state_key.clone());
// If there are duplicate entries for a given type and state_key pair, reject.
if seen.contains_key(&type_key_pair) {
return Err(Error::DuplicateAuthEvents(format!(
"({:?},\"{:?}\")",
event_type, state_key
)));
}
seen.insert(type_key_pair, auth_event);
} else {
return Err(Error::NotFound(auth_event_id.as_str().to_owned()));
}
}
Ok(seen)
}
// Checks that the event does not refer to any auth events that it does not need
// to.
pub fn check_unnecessary_auth_events(
auth_events: &HashSet<(StateEventType, StateKey)>,
expected: &Vec<(StateEventType, StateKey)>,
) -> Result<(), Error> {
// If there are entries whose type and state_key don't match those specified by
// the auth events selection algorithm described in the server specification,
// reject.
let remaining = auth_events
.iter()
.filter(|key| !expected.contains(key))
.collect::<HashSet<_>>();
if !remaining.is_empty() {
return Err(Error::UnselectedAuthEvents(format!("{:?}", remaining)));
}
Ok(())
}
// Checks that all provided auth events were not rejected previously.
//
// TODO: this is currently a no-op and always returns Ok(()).
pub fn check_all_auth_events_accepted(
_auth_events: &HashMap<(StateEventType, StateKey), Pdu>,
) -> Result<(), Error> {
Ok(())
}
// Checks that all auth events are from the same room as the event being
// validated.
pub fn check_auth_same_room(auth_events: &Vec<Pdu>, room_id: &RoomId) -> bool {
for auth_event in auth_events {
if let Some(auth_room_id) = &auth_event.room_id() {
if auth_room_id.as_str() != room_id.as_str() {
warn!(
auth_event_id=%auth_event.event_id(),
"Auth event room id {} does not match expected room id {}",
auth_room_id,
room_id
);
return false;
}
} else {
warn!(auth_event_id=%auth_event.event_id(), "Auth event has no room_id");
return false;
}
}
true
}
/// Performs all auth event checks for the given event.
pub async fn check_auth_events<FE>(
event: &Pdu,
room_id: &RoomId,
room_version: &RoomVersion,
fetch_event: &FE,
) -> Result<HashMap<(StateEventType, StateKey), Pdu>, Error>
where
FE: AsyncFn(&EventId) -> Result<Option<Pdu>, Error>,
{
// If there are duplicate entries for a given type and state_key pair, reject.
let auth_events_map = check_duplicate_auth_events(&event.auth_events, fetch_event).await?;
let auth_events_set: HashSet<(StateEventType, StateKey)> =
auth_events_map.keys().cloned().collect();
// If there are entries whose type and state_key dont match those specified by
// the auth events selection algorithm described in the server specification,
// reject.
let member_event_content = match event.kind() {
| TimelineEventType::RoomMember =>
Some(event.get_content::<RoomMemberEventContent>().map_err(|e| {
Error::InvalidPdu(format!("Failed to parse m.room.member content: {}", e))
})?),
| _ => None,
};
let expected_auth_events = auth_types_for_event(
room_version,
event.kind(),
event.state_key.as_ref(),
event.sender(),
member_event_content,
)?;
if let Err(e) = check_unnecessary_auth_events(&auth_events_set, &expected_auth_events) {
return Err(e);
}
// If there are entries which were themselves rejected under the checks
// performed on receipt of a PDU, reject.
if let Err(e) = check_all_auth_events_accepted(&auth_events_map) {
return Err(e);
}
// If any event in auth_events has a room_id which does not match that of the
// event being authorised, reject.
let auth_event_refs: Vec<Pdu> = auth_events_map.values().cloned().collect();
if !check_auth_same_room(&auth_event_refs, room_id) {
return Err(Error::InvalidPdu(
"One or more auth events are from a different room".to_owned(),
));
}
Ok(auth_events_map)
}
@@ -1,113 +0,0 @@
//! Context for event authorisation checks
use ruma::{
Int, OwnedUserId, UserId,
events::{
StateEventType,
room::{create::RoomCreateEventContent, power_levels::RoomPowerLevelsEventContent},
},
};
use crate::{Event, EventTypeExt, Pdu, RoomVersion, matrix::StateKey, state_res::Error};
pub enum UserPower {
/// Creator indicates this user should be granted a power level above all.
Creator,
/// Standard indicates power levels should be used to determine rank.
Standard,
}
impl PartialEq for UserPower {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
| (UserPower::Creator, UserPower::Creator) => true,
| (UserPower::Standard, UserPower::Standard) => true,
| _ => false,
}
}
}
/// Get the creators of the room.
/// If this room only supports one creator, a vec of one will be returned.
/// If multiple creators are supported, all will be returned, with the
/// m.room.create sender first.
pub async fn calculate_creators<FS>(
room_version: &RoomVersion,
fetch_state: FS,
) -> Result<Vec<OwnedUserId>, Error>
where
FS: AsyncFn((StateEventType, StateKey)) -> Result<Option<Pdu>, Error>,
{
let create_event = fetch_state(StateEventType::RoomCreate.with_state_key(""))
.await?
.ok_or_else(|| Error::InvalidPdu("Room create event not found".to_owned()))?;
let content = create_event
.get_content::<RoomCreateEventContent>()
.map_err(|e| {
Error::InvalidPdu(format!("Room create event has invalid content: {}", e))
})?;
if room_version.explicitly_privilege_room_creators {
let mut creators = vec![create_event.sender().to_owned()];
if let Some(additional) = content.additional_creators {
for user_id in additional {
if !creators.contains(&user_id) {
creators.push(user_id);
}
}
}
Ok(creators)
} else if room_version.use_room_create_sender {
Ok(vec![create_event.sender().to_owned()])
} else {
// Have to check the event content
#[allow(deprecated)]
if let Some(creator) = content.creator {
Ok(vec![creator])
} else {
Err(Error::InvalidPdu("Room create event missing creator field".to_owned()))
}
}
}
/// Rank fetches the creatorship and power level of the target user
///
/// Returns (UserPower, power_level, Option<RoomPowerLevelsEventContent>)
/// If UserPower::Creator is returned, the power_level and
/// RoomPowerLevelsEventContent will be meaningless and can be ignored.
pub async fn get_rank<FS>(
room_version: &RoomVersion,
fetch_state: &FS,
user_id: &UserId,
) -> Result<(UserPower, Int, Option<RoomPowerLevelsEventContent>), Error>
where
FS: AsyncFn((StateEventType, StateKey)) -> Result<Option<Pdu>, Error>,
{
let creators = calculate_creators(room_version, &fetch_state).await?;
if creators.contains(&user_id.to_owned()) && room_version.explicitly_privilege_room_creators {
return Ok((UserPower::Creator, Int::MAX, None));
}
let power_levels = fetch_state(StateEventType::RoomPowerLevels.with_state_key("")).await?;
if let Some(power_levels) = power_levels {
let power_levels = power_levels
.get_content::<RoomPowerLevelsEventContent>()
.map_err(|e| {
Error::InvalidPdu(format!("m.room.power_levels event has invalid content: {}", e))
})?;
Ok((
UserPower::Standard,
*power_levels
.users
.get(user_id)
.unwrap_or(&power_levels.users_default),
Some(power_levels),
))
} else {
// No power levels event, use defaults
if creators[0] == user_id {
return Ok((UserPower::Creator, Int::MAX, None));
}
Ok((UserPower::Standard, Int::from(0), None))
}
}
@@ -1,97 +0,0 @@
//! Auth checks relevant to the `m.room.create` event specifically.
//!
//! See: https://spec.matrix.org/v1.16/rooms/v12/#authorization-rules
use ruma::{OwnedUserId, RoomVersionId, events::room::create::RoomCreateEventContent};
use serde::Deserialize;
use serde_json::from_str;
use crate::{Event, Pdu, RoomVersion, state_res::Error, trace};
// A raw representation of the create event content, for initial parsing.
// This allows us to extract fields without fully validating the event first.
#[derive(Deserialize)]
struct RawCreateContent {
creator: Option<String>,
room_version: Option<String>,
additional_creators: Option<Vec<String>>,
}
// Check whether an `m.room.create` event is valid.
// This ensures that:
//
// 1. The event has no `prev_events`
// 2. If the version disallows it, the event has no `room_id` present.
// 3. If the room version is present and recognised, otherwise assume invalid.
// 4. If the room version supports it, `additional_creators` is populated with
// valid user IDs.
// 5. If the room version supports it, `creator` is populated AND is a valid
// user ID.
// 6. Otherwise, this event is valid.
//
// The fully deserialized `RoomCreateEventContent` is returned for further calls
// to other checks.
pub fn check_room_create(event: &Pdu) -> Result<RoomCreateEventContent, Error> {
// Check 1: The event has no `prev_events`
if !event.prev_events.is_empty() {
return Err(Error::InvalidPdu("m.room.create event has prev_events".to_owned()));
}
let create_content = from_str::<RawCreateContent>(event.content().get())?;
// Note: Here we attempt to both load the raw room version string and validate
// it, and then cast it to the room features. If either step fails, we return
// an unsupported error. If the room version is missing, it defaults to "1",
// which we also do not support.
//
// This performs check 3, which then allows us to perform check 2.
let room_version = if let Some(raw_room_version) = create_content.room_version {
trace!("Parsing and interpreting room version: {}", raw_room_version);
let room_version_id = RoomVersionId::try_from(raw_room_version.as_str())
.map_err(|_| Error::Unsupported(raw_room_version))?;
RoomVersion::new(&room_version_id)
.map_err(|_| Error::Unsupported(room_version_id.as_str().to_owned()))?
} else {
return Err(Error::Unsupported("1".to_owned()));
};
// Check 2: If the version disallows it, the event has no `room_id` present.
if room_version.room_ids_as_hashes && event.room_id.is_some() {
return Err(Error::InvalidPdu(
"m.room.create event has room_id but room version disallows it".to_owned(),
));
}
// Check 4: If the room version supports it, `additional_creators` is populated
// with valid user IDs.
if room_version.explicitly_privilege_room_creators {
if let Some(additional_creators) = create_content.additional_creators {
for creator in additional_creators {
trace!("Validating additional creator user ID: {}", creator);
if OwnedUserId::parse(&creator).is_err() {
return Err(Error::InvalidPdu(format!(
"Invalid user ID in additional_creators: {creator}"
)));
}
}
}
}
// Check 5: If the room version supports it, `creator` is populated AND is a
// valid user ID.
if !room_version.use_room_create_sender {
if let Some(creator) = create_content.creator {
trace!("Validating creator user ID: {}", creator);
if OwnedUserId::parse(&creator).is_err() {
return Err(Error::InvalidPdu(format!("Invalid user ID in creator: {creator}")));
}
} else {
return Err(Error::InvalidPdu(
"m.room.create event missing creator field".to_owned(),
));
}
}
// Deserialise into the full create event for future checks.
Ok(from_str::<RoomCreateEventContent>(event.content().get())?)
}
@@ -1,650 +0,0 @@
use ruma::{
EventId, OwnedUserId, RoomVersionId,
events::{
StateEventType, TimelineEventType,
room::{create::RoomCreateEventContent, member::MembershipState},
},
int,
serde::Raw,
};
use serde::{Deserialize, de::IgnoredAny};
use serde_json::from_str as from_json_str;
use crate::{
Event, EventTypeExt, Pdu, RoomVersion, debug, error,
matrix::StateKey,
state_res::{
error::Error,
event_auth::{
auth_events::check_auth_events,
context::{UserPower, calculate_creators, get_rank},
create_event::check_room_create,
member_event::check_member_event,
power_levels::check_power_levels,
},
},
trace, warn,
};
// FIXME: field extracting could be bundled for `content`
#[derive(Deserialize)]
struct GetMembership {
membership: MembershipState,
}
#[derive(Deserialize, Debug)]
struct RoomMemberContentFields {
membership: Option<Raw<MembershipState>>,
join_authorised_via_users_server: Option<Raw<OwnedUserId>>,
}
#[derive(Deserialize)]
struct RoomCreateContentFields {
room_version: Option<Raw<RoomVersionId>>,
creator: Option<Raw<IgnoredAny>>,
additional_creators: Option<Vec<Raw<OwnedUserId>>>,
#[serde(rename = "m.federate", default = "ruma::serde::default_true")]
federate: bool,
}
/// Authenticate the incoming `event`.
///
/// The steps of authentication are:
///
/// * check that the event is being authenticated for the correct room
/// * then there are checks for specific event types
///
/// The `fetch_state` closure should gather state from a state snapshot. We need
/// to know if the event passes auth against some state not a recursive
/// collection of auth_events fields.
#[tracing::instrument(
skip_all,
fields(
event_id = incoming_event.event_id().as_str(),
event_type = ?incoming_event.event_type().to_string()
)
)]
#[allow(clippy::suspicious_operation_groupings)]
pub async fn auth_check<FE, FS>(
room_version: &RoomVersion,
incoming_event: &Pdu,
fetch_event: &FE,
fetch_state: &FS,
create_event: Option<&Pdu>,
) -> Result<bool, Error>
where
FE: AsyncFn(&EventId) -> Result<Option<Pdu>, Error>,
FS: AsyncFn((StateEventType, StateKey)) -> Result<Option<Pdu>, Error>,
{
debug!("auth_check beginning");
let sender = incoming_event.sender();
// Since v1, If type is m.room.create:
if *incoming_event.event_type() == TimelineEventType::RoomCreate {
debug!("start m.room.create check");
if let Err(e) = check_room_create(incoming_event) {
warn!("m.room.create event has been rejected: {}", e);
return Ok(false);
}
debug!("m.room.create event was allowed");
return Ok(true);
}
let Some(create_event) = create_event else {
error!("no create event provided for auth check");
return Err(Error::InvalidPdu("missing create event".to_owned()));
};
// TODO: we need to know if events have previously been rejected or soft failed
// For now, we'll just assume the create_event is valid.
let create_content = from_json_str::<RoomCreateEventContent>(create_event.content().get())
.expect("provided create event must be valid");
// Since v12, If the events room_id is not an event ID for an accepted (not
// rejected) m.room.create event, with the sigil ! instead of $, reject.
if room_version.room_ids_as_hashes {
let calculated_room_id = create_event.event_id().as_str().replace('$', "!");
if let Some(claimed_room_id) = create_event.room_id() {
if claimed_room_id.as_str() != calculated_room_id {
warn!(
expected = %calculated_room_id,
received = %claimed_room_id,
"event's room ID does not match the hash of the m.room.create event ID"
);
return Ok(false);
}
} else {
warn!("event is missing a room ID");
return Ok(false);
}
}
let room_id = incoming_event.room_id().expect("event must have a room ID");
let auth_map =
match check_auth_events(incoming_event, room_id, &room_version, fetch_event).await {
| Ok(map) => map,
| Err(e) => {
warn!("event's auth events are invalid: {}", e);
return Ok(false);
},
};
// Considering the event's auth_events
// Since v1, If the content of the m.room.create event in the room state has the
// property m.federate set to false, and the sender domain of the event does
// not match the sender domain of the create event, reject.
if !create_content.federate {
if create_event.sender().server_name() != incoming_event.sender().server_name() {
warn!(
sender = %incoming_event.sender(),
create_sender = %create_event.sender(),
"room is not federated and event's sender domain does not match create event's sender domain"
);
return Ok(false);
}
}
// From v1 to v5, If type is m.room.aliases
if room_version.special_case_aliases_auth
&& *incoming_event.event_type() == TimelineEventType::RoomAliases
{
if let Some(state_key) = incoming_event.state_key() {
// If sender's domain doesn't matches state_key, reject
if state_key != sender.server_name().as_str() {
warn!("state_key does not match sender");
return Ok(false);
}
// Otherwise, allow
return Ok(true);
}
// If event has no state_key, reject.
warn!("m.room.alias event has no state key");
return Ok(false);
}
// From v1, If type is m.room.member
if *incoming_event.event_type() == TimelineEventType::RoomMember {
if let Err(e) =
check_member_event(&room_version, incoming_event, fetch_event, fetch_state).await
{
warn!("m.room.member event has been rejected: {}", e);
return Ok(false);
}
}
// From v1, If the sender's current membership state is not join, reject
let sender_member_event =
match auth_map.get(&StateEventType::RoomMember.with_state_key(sender.as_str())) {
| Some(ev) => ev,
| None => {
warn!(
%sender,
"sender is not joined - no membership event found for sender in auth events"
);
return Ok(false);
},
};
let sender_membership_event_content: RoomMemberContentFields =
from_json_str(sender_member_event.content().get())?;
let Some(membership_state) = sender_membership_event_content.membership else {
warn!(
?sender_membership_event_content,
"Sender membership event content missing membership field"
);
return Err(Error::InvalidPdu("Missing membership field".to_owned()));
};
let membership_state = membership_state.deserialize()?;
if membership_state != MembershipState::Join {
warn!(
%sender,
?membership_state,
"sender cannot send events without being joined to the room"
);
return Ok(false);
}
// From v1, If type is m.room.third_party_invite
let (rank, sender_pl, pl_evt) = get_rank(&room_version, fetch_state, sender).await?;
// Allow if and only if sender's current power level is greater than
// or equal to the invite level
if *incoming_event.event_type() == TimelineEventType::RoomThirdPartyInvite {
if rank == UserPower::Creator {
trace!("sender is room creator, allowing m.room.third_party_invite");
return Ok(true);
}
let invite_level = match &pl_evt {
| Some(power_levels) => power_levels.invite,
| None => int!(0),
};
if sender_pl < invite_level {
warn!(
%sender,
has=%sender_pl,
required=%invite_level,
"sender cannot send invites in this room"
);
return Ok(false);
}
debug!("m.room.third_party_invite event was allowed");
return Ok(true);
}
// Since v1, if the event types required power level is greater than the
// senders power level, reject.
let required_level = match &pl_evt {
| Some(power_levels) => power_levels
.events
.get(incoming_event.kind())
.unwrap_or_else(|| {
if incoming_event.state_key.is_some() {
&power_levels.state_default
} else {
&power_levels.events_default
}
}),
| None => &int!(0),
};
if rank != UserPower::Creator && sender_pl < *required_level {
warn!(
%sender,
has=%sender_pl,
required=%required_level,
"sender does not have enough power level to send this event"
);
return Ok(false);
}
// Since v1, If the event has a state_key that starts with an @ and does not
// match the sender, reject.
if let Some(state_key) = incoming_event.state_key() {
if state_key.starts_with('@') && state_key != sender.as_str() {
warn!(
%sender,
%state_key,
"event's state key starts with @ and does not match sender"
);
return Ok(false);
}
}
// Since v1, If type is m.room.power_levels
if *incoming_event.event_type() == TimelineEventType::RoomPowerLevels {
let creators = calculate_creators(&room_version, fetch_state).await?;
if let Err(e) =
check_power_levels(&room_version, incoming_event, pl_evt.as_ref(), creators).await
{
warn!(
%sender,
"m.room.power_levels event has been rejected: {}", e
);
return Ok(false);
}
}
// From v1 to v2: If type is m.room.redaction:
// If the senders power level is greater than or equal to the redact level,
// allow.
// If the domain of the event_id of the event being redacted is the same as the
// domain of the event_id of the m.room.redaction, allow.
// Otherwise, reject.
if room_version.extra_redaction_checks {
// We'll panic here, since while we don't theoretically support the room
// versions that require this, we don't want to incorrectly permit an event
// that should be rejected in this theoretically impossible scenario.
unreachable!(
"continuwuity does not support room versions that require extra redaction checks"
);
}
debug!("allowing event passed all checks");
Ok(true)
}
#[cfg(test)]
mod tests {
use ruma::events::{
StateEventType, TimelineEventType,
room::{
join_rules::{
AllowRule, JoinRule, Restricted, RoomJoinRulesEventContent, RoomMembership,
},
member::{MembershipState, RoomMemberEventContent},
},
};
use serde_json::value::to_raw_value as to_raw_json_value;
use crate::{
matrix::{Event, EventTypeExt, Pdu as PduEvent},
state_res::{
RoomVersion, StateMap,
event_auth::{
iterative_auth_checks::valid_membership_change, valid_membership_change,
},
test_utils::{
INITIAL_EVENTS, INITIAL_EVENTS_CREATE_ROOM, alice, charlie, ella, event_id,
member_content_ban, member_content_join, room_id, to_pdu_event,
},
},
};
#[test]
fn test_ban_pass() {
let _ = tracing::subscriber::set_default(
tracing_subscriber::fmt().with_test_writer().finish(),
);
let events = INITIAL_EVENTS();
let auth_events = events
.values()
.map(|ev| (ev.event_type().with_state_key(ev.state_key().unwrap()), ev.clone()))
.collect::<StateMap<_>>();
let requester = to_pdu_event(
"HELLO",
alice(),
TimelineEventType::RoomMember,
Some(charlie().as_str()),
member_content_ban(),
&[],
&["IMC"],
);
let fetch_state = |ty, key| auth_events.get(&(ty, key)).cloned();
let target_user = charlie();
let sender = alice();
assert!(
valid_membership_change(
&RoomVersion::V6,
target_user,
fetch_state(StateEventType::RoomMember, target_user.as_str().into()).as_ref(),
sender,
fetch_state(StateEventType::RoomMember, sender.as_str().into()).as_ref(),
&requester,
None::<&PduEvent>,
fetch_state(StateEventType::RoomPowerLevels, "".into()).as_ref(),
fetch_state(StateEventType::RoomJoinRules, "".into()).as_ref(),
None,
&MembershipState::Leave,
&fetch_state(StateEventType::RoomCreate, "".into()).unwrap(),
)
.unwrap()
);
}
#[test]
fn test_join_non_creator() {
let _ = tracing::subscriber::set_default(
tracing_subscriber::fmt().with_test_writer().finish(),
);
let events = INITIAL_EVENTS_CREATE_ROOM();
let auth_events = events
.values()
.map(|ev| (ev.event_type().with_state_key(ev.state_key().unwrap()), ev.clone()))
.collect::<StateMap<_>>();
let requester = to_pdu_event(
"HELLO",
charlie(),
TimelineEventType::RoomMember,
Some(charlie().as_str()),
member_content_join(),
&["CREATE"],
&["CREATE"],
);
let fetch_state = |ty, key| auth_events.get(&(ty, key)).cloned();
let target_user = charlie();
let sender = charlie();
assert!(
!valid_membership_change(
&RoomVersion::V6,
target_user,
fetch_state(StateEventType::RoomMember, target_user.as_str().into()).as_ref(),
sender,
fetch_state(StateEventType::RoomMember, sender.as_str().into()).as_ref(),
&requester,
None::<&PduEvent>,
fetch_state(StateEventType::RoomPowerLevels, "".into()).as_ref(),
fetch_state(StateEventType::RoomJoinRules, "".into()).as_ref(),
None,
&MembershipState::Leave,
&fetch_state(StateEventType::RoomCreate, "".into()).unwrap(),
)
.unwrap()
);
}
#[test]
fn test_join_creator() {
let _ = tracing::subscriber::set_default(
tracing_subscriber::fmt().with_test_writer().finish(),
);
let events = INITIAL_EVENTS_CREATE_ROOM();
let auth_events = events
.values()
.map(|ev| (ev.event_type().with_state_key(ev.state_key().unwrap()), ev.clone()))
.collect::<StateMap<_>>();
let requester = to_pdu_event(
"HELLO",
alice(),
TimelineEventType::RoomMember,
Some(alice().as_str()),
member_content_join(),
&["CREATE"],
&["CREATE"],
);
let fetch_state = |ty, key| auth_events.get(&(ty, key)).cloned();
let target_user = alice();
let sender = alice();
assert!(
valid_membership_change(
&RoomVersion::V6,
target_user,
fetch_state(StateEventType::RoomMember, target_user.as_str().into()).as_ref(),
sender,
fetch_state(StateEventType::RoomMember, sender.as_str().into()).as_ref(),
&requester,
None::<&PduEvent>,
fetch_state(StateEventType::RoomPowerLevels, "".into()).as_ref(),
fetch_state(StateEventType::RoomJoinRules, "".into()).as_ref(),
None,
&MembershipState::Leave,
&fetch_state(StateEventType::RoomCreate, "".into()).unwrap(),
)
.unwrap()
);
}
#[test]
fn test_ban_fail() {
let _ = tracing::subscriber::set_default(
tracing_subscriber::fmt().with_test_writer().finish(),
);
let events = INITIAL_EVENTS();
let auth_events = events
.values()
.map(|ev| (ev.event_type().with_state_key(ev.state_key().unwrap()), ev.clone()))
.collect::<StateMap<_>>();
let requester = to_pdu_event(
"HELLO",
charlie(),
TimelineEventType::RoomMember,
Some(alice().as_str()),
member_content_ban(),
&[],
&["IMC"],
);
let fetch_state = |ty, key| auth_events.get(&(ty, key)).cloned();
let target_user = alice();
let sender = charlie();
assert!(
!valid_membership_change(
&RoomVersion::V6,
target_user,
fetch_state(StateEventType::RoomMember, target_user.as_str().into()).as_ref(),
sender,
fetch_state(StateEventType::RoomMember, sender.as_str().into()).as_ref(),
&requester,
None::<&PduEvent>,
fetch_state(StateEventType::RoomPowerLevels, "".into()).as_ref(),
fetch_state(StateEventType::RoomJoinRules, "".into()).as_ref(),
None,
&MembershipState::Leave,
&fetch_state(StateEventType::RoomCreate, "".into()).unwrap(),
)
.unwrap()
);
}
#[test]
fn test_restricted_join_rule() {
let _ = tracing::subscriber::set_default(
tracing_subscriber::fmt().with_test_writer().finish(),
);
let mut events = INITIAL_EVENTS();
*events.get_mut(&event_id("IJR")).unwrap() = to_pdu_event(
"IJR",
alice(),
TimelineEventType::RoomJoinRules,
Some(""),
to_raw_json_value(&RoomJoinRulesEventContent::new(JoinRule::Restricted(
Restricted::new(vec![AllowRule::RoomMembership(RoomMembership::new(
room_id().to_owned(),
))]),
)))
.unwrap(),
&["CREATE", "IMA", "IPOWER"],
&["IPOWER"],
);
let mut member = RoomMemberEventContent::new(MembershipState::Join);
member.join_authorized_via_users_server = Some(alice().to_owned());
let auth_events = events
.values()
.map(|ev| (ev.event_type().with_state_key(ev.state_key().unwrap()), ev.clone()))
.collect::<StateMap<_>>();
let requester = to_pdu_event(
"HELLO",
ella(),
TimelineEventType::RoomMember,
Some(ella().as_str()),
to_raw_json_value(&RoomMemberEventContent::new(MembershipState::Join)).unwrap(),
&["CREATE", "IJR", "IPOWER", "new"],
&["new"],
);
let fetch_state = |ty, key| auth_events.get(&(ty, key)).cloned();
let target_user = ella();
let sender = ella();
assert!(
valid_membership_change(
&RoomVersion::V9,
target_user,
fetch_state(StateEventType::RoomMember, target_user.as_str().into()).as_ref(),
sender,
fetch_state(StateEventType::RoomMember, sender.as_str().into()).as_ref(),
&requester,
None::<&PduEvent>,
fetch_state(StateEventType::RoomPowerLevels, "".into()).as_ref(),
fetch_state(StateEventType::RoomJoinRules, "".into()).as_ref(),
Some(alice()),
&MembershipState::Join,
&fetch_state(StateEventType::RoomCreate, "".into()).unwrap(),
)
.unwrap()
);
assert!(
!valid_membership_change(
&RoomVersion::V9,
target_user,
fetch_state(StateEventType::RoomMember, target_user.as_str().into()).as_ref(),
sender,
fetch_state(StateEventType::RoomMember, sender.as_str().into()).as_ref(),
&requester,
None::<&PduEvent>,
fetch_state(StateEventType::RoomPowerLevels, "".into()).as_ref(),
fetch_state(StateEventType::RoomJoinRules, "".into()).as_ref(),
Some(ella()),
&MembershipState::Leave,
&fetch_state(StateEventType::RoomCreate, "".into()).unwrap(),
)
.unwrap()
);
}
#[test]
fn test_knock() {
let _ = tracing::subscriber::set_default(
tracing_subscriber::fmt().with_test_writer().finish(),
);
let mut events = INITIAL_EVENTS();
*events.get_mut(&event_id("IJR")).unwrap() = to_pdu_event(
"IJR",
alice(),
TimelineEventType::RoomJoinRules,
Some(""),
to_raw_json_value(&RoomJoinRulesEventContent::new(JoinRule::Knock)).unwrap(),
&["CREATE", "IMA", "IPOWER"],
&["IPOWER"],
);
let auth_events = events
.values()
.map(|ev| (ev.event_type().with_state_key(ev.state_key().unwrap()), ev.clone()))
.collect::<StateMap<_>>();
let requester = to_pdu_event(
"HELLO",
ella(),
TimelineEventType::RoomMember,
Some(ella().as_str()),
to_raw_json_value(&RoomMemberEventContent::new(MembershipState::Knock)).unwrap(),
&[],
&["IMC"],
);
let fetch_state = |ty, key| auth_events.get(&(ty, key)).cloned();
let target_user = ella();
let sender = ella();
assert!(
valid_membership_change(
&RoomVersion::V7,
target_user,
fetch_state(StateEventType::RoomMember, target_user.as_str().into()).as_ref(),
sender,
fetch_state(StateEventType::RoomMember, sender.as_str().into()).as_ref(),
&requester,
None::<&PduEvent>,
fetch_state(StateEventType::RoomPowerLevels, "".into()).as_ref(),
fetch_state(StateEventType::RoomJoinRules, "".into()).as_ref(),
None,
&MembershipState::Leave,
&fetch_state(StateEventType::RoomCreate, "".into()).unwrap(),
)
.unwrap()
);
}
}
@@ -1,422 +0,0 @@
//! Auth checks relevant to the `m.room.member` event specifically.
//!
//! See: https://spec.matrix.org/v1.16/rooms/v12/#authorization-rules
use ruma::{
EventId, OwnedUserId, UserId,
events::{
StateEventType,
room::{
join_rules::{JoinRule, RoomJoinRulesEventContent},
third_party_invite::{PublicKey, RoomThirdPartyInviteEventContent},
},
},
serde::Base64,
signatures::{PublicKeyMap, PublicKeySet, verify_json},
};
use crate::{
Event, EventTypeExt, Pdu, RoomVersion,
matrix::StateKey,
state_res::{
Error,
event_auth::context::{UserPower, get_rank},
},
utils::to_canonical_object,
};
#[derive(serde::Deserialize, Default)]
struct PartialMembershipObject {
membership: Option<String>,
join_authorized_via_users_server: Option<OwnedUserId>,
third_party_invite: Option<serde_json::Value>,
}
/// Fetches the membership *content* of the target.
/// If there is not one, an empty leave membership is returned.
async fn fetch_membership<FS>(
fetch_state: &FS,
target: &UserId,
) -> Result<PartialMembershipObject, Error>
where
FS: AsyncFn((StateEventType, StateKey)) -> Result<Option<Pdu>, Error>,
{
fetch_state(StateEventType::RoomMember.with_state_key(target.as_str()))
.await
.map(|pdu| {
if let Some(ev) = pdu {
ev.get_content::<PartialMembershipObject>().map_err(|e| {
Error::InvalidPdu(format!("m.room.member event has invalid content: {}", e))
})
} else {
Ok(PartialMembershipObject {
membership: Some("leave".to_owned()),
..Default::default()
})
}
})?
}
async fn check_join_event<FE, FS>(
room_version: &RoomVersion,
event: &Pdu,
membership: &PartialMembershipObject,
target: &UserId,
fetch_event: &FE,
fetch_state: &FS,
) -> Result<(), Error>
where
FE: AsyncFn(&EventId) -> Result<Option<Pdu>, Error>,
FS: AsyncFn((StateEventType, StateKey)) -> Result<Option<Pdu>, Error>,
{
// 3.1: If the only previous event is an m.room.create and the state_key is the
// sender of the m.room.create, allow.
if event.prev_events.len() == 1 {
let only_prev = fetch_event(&event.prev_events[0]).await?;
if let Some(prev_event) = only_prev {
let k = prev_event.event_type().with_state_key("");
if k.0 == StateEventType::RoomCreate && k.1.as_str() == event.sender().as_str() {
return Ok(());
}
} else {
return Err(Error::DependencyFailed(
event.prev_events[0].to_owned(),
"Previous event not found when checking join event".to_owned(),
));
}
}
// 3.2: If the sender does not match state_key, reject.
if event.sender() != target {
return Err(Error::AuthConditionFailed(
"m.room.member join event sender does not match state_key".to_owned(),
));
}
let prev_membership = if let Some(ev) =
fetch_state(StateEventType::RoomMember.with_state_key(target.as_str())).await?
{
Some(ev.get_content::<PartialMembershipObject>().map_err(|e| {
Error::InvalidPdu(format!("Previous m.room.member event has invalid content: {}", e))
})?)
} else {
None
};
let join_rule_content =
if let Some(jr) = fetch_state(StateEventType::RoomJoinRules.with_state_key("")).await? {
jr.get_content::<RoomJoinRulesEventContent>().map_err(|e| {
Error::InvalidPdu(format!("m.room.join_rules event has invalid content: {}", e))
})?
} else {
// Default to invite if no join rules event is present.
RoomJoinRulesEventContent { join_rule: JoinRule::Private }
};
// 3.3: If the sender is banned, reject.
let prev_member = if let Some(prev_content) = &prev_membership {
if let Some(membership) = &prev_content.membership {
if membership == "ban" {
return Err(Error::AuthConditionFailed(
"m.room.member join event sender is banned".to_owned(),
));
}
membership
} else {
"leave"
}
} else {
"leave"
};
// 3.4: If the join_rule is invite or knock then allow if membership
// state is invite or join.
// 3.5: If the join_rule is restricted or knock_restricted:
// 3.5.1: If membership state is join or invite, allow.
match join_rule_content.join_rule {
| JoinRule::Invite | JoinRule::Knock => {
if prev_member == "invite" || prev_member == "join" {
return Ok(());
}
Err(Error::AuthConditionFailed(
"m.room.member join event not invited under invite/knock join rule".to_owned(),
))
},
| JoinRule::Restricted(_) | JoinRule::KnockRestricted(_) => {
// 3.5.2: If the join_authorised_via_users_server key in content is not a user
// with sufficient permission to invite other users or is not a joined
// member of the room, reject.
if prev_member == "invite" || prev_member == "join" {
return Ok(());
}
let join_authed_by = membership.join_authorized_via_users_server.as_ref();
if let Some(user_id) = join_authed_by {
let rank = get_rank(&room_version, fetch_state, user_id).await?;
if rank.0 == UserPower::Standard {
// This user is not a creator, check that they have
// sufficient power level
if rank.1 < rank.2.unwrap().invite {
return Err(Error::InvalidPdu(
"m.room.member join event join_authorised_via_users_server does not \
have sufficient power level to invite"
.to_owned(),
));
}
}
// Check that the user is a joined member of the room
if let Some(state_event) =
fetch_state(StateEventType::RoomMember.with_state_key(user_id.as_str()))
.await?
{
let state_content = state_event
.get_content::<PartialMembershipObject>()
.map_err(|e| {
Error::InvalidPdu(format!(
"m.room.member event has invalid content: {}",
e
))
})?;
if let Some(state_membership) = &state_content.membership {
if state_membership == "join" {
return Ok(());
}
}
}
} else {
return Err(Error::AuthConditionFailed(
"m.room.member join event missing join_authorised_via_users_server"
.to_owned(),
));
}
// 3.5.3: Otherwise, allow
return Ok(());
},
| JoinRule::Public => return Ok(()),
| _ => Err(Error::AuthConditionFailed(format!(
"unknown join rule: {:?}",
join_rule_content.join_rule
)))?,
}
}
/// Checks a third-party invite is valid.
async fn check_third_party_invite(
target_current_membership: PartialMembershipObject,
raw_third_party_invite: &serde_json::Value,
target: &UserId,
event: &Pdu,
fetch_state: impl AsyncFn((StateEventType, StateKey)) -> Result<Option<Pdu>, Error>,
) -> Result<(), Error> {
// 4.1.1: If target user is banned, reject.
if target_current_membership
.membership
.is_some_and(|m| m == "ban")
{
return Err(Error::AuthConditionFailed("invite target is banned".to_owned()));
}
// 4.1.2: If content.third_party_invite does not have a signed property, reject.
let signed = raw_third_party_invite.get("signed").ok_or_else(|| {
Error::AuthConditionFailed(
"invite event third_party_invite missing signed property".to_owned(),
)
})?;
// 4.2.3: If signed does not have mxid and token properties, reject.
let mxid = signed.get("mxid").and_then(|v| v.as_str()).ok_or_else(|| {
Error::AuthConditionFailed(
"invite event third_party_invite signed missing/invalid mxid property".to_owned(),
)
})?;
let token = signed
.get("token")
.and_then(|v| v.as_str())
.ok_or_else(|| {
Error::AuthConditionFailed(
"invite event third_party_invite signed missing token property".to_owned(),
)
})?;
// 4.2.4: If mxid does not match state_key, reject.
if mxid != target.as_str() {
return Err(Error::AuthConditionFailed(
"invite event third_party_invite signed mxid does not match state_key".to_owned(),
));
}
// 4.2.5: If there is no m.room.third_party_invite event in the room
// state matching the token, reject.
let Some(third_party_invite_event) =
fetch_state(StateEventType::RoomThirdPartyInvite.with_state_key(token)).await?
else {
return Err(Error::AuthConditionFailed(
"invite event third_party_invite token has no matching m.room.third_party_invite"
.to_owned(),
));
};
// 4.2.6: If sender does not match sender of the m.room.third_party_invite,
// reject.
if third_party_invite_event.sender() != event.sender() {
return Err(Error::AuthConditionFailed(
"invite event sender does not match m.room.third_party_invite sender".to_owned(),
));
}
// 4.2.7: If any signature in signed matches any public key in the
// m.room.third_party_invite event, allow. The public keys are in
// content of m.room.third_party_invite as:
// 1. A single public key in the public_key property.
// 2. A list of public keys in the public_keys property.
let tpi_content = third_party_invite_event
.get_content::<RoomThirdPartyInviteEventContent>()
.or_else(|_| {
Err(Error::InvalidPdu(
"m.room.third_party_invite event has invalid content".to_owned(),
))
})?;
let mut public_keys = tpi_content.public_keys.unwrap_or_default();
public_keys.push(PublicKey {
public_key: tpi_content.public_key,
key_validity_url: None,
});
let signatures = signed
.get("signatures")
.and_then(|v| v.as_object())
.ok_or_else(|| {
Error::InvalidPdu(
"invite event third_party_invite signed missing/invalid signatures".to_owned(),
)
})?;
let mut public_key_map = PublicKeyMap::new();
for (server_name, sig_map) in signatures {
let mut pk_set = PublicKeySet::new();
if let Some(sig_map) = sig_map.as_object() {
for (key_id, sig) in sig_map {
let sig_b64 = Base64::parse(sig.as_str().ok_or(Error::InvalidPdu(
"invite event third_party_invite signature is not a string".to_owned(),
))?)
.map_err(|_| {
Error::InvalidPdu(
"invite event third_party_invite signature is not valid Base64"
.to_owned(),
)
})?;
pk_set.insert(key_id.clone(), sig_b64);
}
}
public_key_map.insert(server_name.clone(), pk_set);
}
verify_json(
&public_key_map,
to_canonical_object(signed).expect("signed was already validated"),
)
.map_err(|e| {
Error::AuthConditionFailed(format!(
"invite event third_party_invite signature verification failed: {e}"
))
})?;
// If there was no error, there was a valid signature, so allow.
Ok(())
}
async fn check_invite_event<FS>(
room_version: &RoomVersion,
event: &Pdu,
membership: &PartialMembershipObject,
target: &UserId,
fetch_state: &FS,
) -> Result<(), Error>
where
FS: AsyncFn((StateEventType, StateKey)) -> Result<Option<Pdu>, Error>,
{
let target_current_membership = fetch_membership(fetch_state, target).await?;
// 4.1: If content has a third_party_invite property:
if let Some(raw_third_party_invite) = &membership.third_party_invite {
return check_third_party_invite(
target_current_membership,
raw_third_party_invite,
target,
event,
fetch_state,
)
.await;
}
// 4.2: If the senders current membership state is not join, reject.
let sender_membership = fetch_membership(fetch_state, event.sender()).await?;
if sender_membership.membership.is_none_or(|m| m != "join") {
return Err(Error::AuthConditionFailed("invite sender is not joined".to_owned()));
}
// 4.3: If target users current membership state is join or ban, reject.
if target_current_membership
.membership
.is_some_and(|m| m == "join" || m == "ban")
{
return Err(Error::AuthConditionFailed(
"invite target is already joined or banned".to_owned(),
));
}
// 4.4: If the senders power level is greater than or equal to the invite
// level, allow.
let (rank, pl, pl_evt) = get_rank(&room_version, fetch_state, event.sender()).await?;
if rank == UserPower::Creator || pl >= pl_evt.unwrap_or_default().invite {
return Ok(());
}
// 4.5: Otherwise, reject.
Err(Error::AuthConditionFailed(
"invite sender does not have sufficient power level to invite".to_owned(),
))
}
pub async fn check_member_event<FE, FS>(
room_version: &RoomVersion,
event: &Pdu,
fetch_event: FE,
fetch_state: FS,
) -> Result<(), Error>
where
FE: AsyncFn(&EventId) -> Result<Option<Pdu>, Error>,
FS: AsyncFn((StateEventType, StateKey)) -> Result<Option<Pdu>, Error>,
{
// 1. If there is no state_key property, or no membership property in content,
// reject.
if event.state_key.is_none() {
return Err(Error::InvalidPdu("m.room.member event missing state_key".to_owned()));
}
let target = UserId::parse(event.state_key().unwrap())
.map_err(|_| Error::InvalidPdu("m.room.member event has invalid state_key".to_owned()))?
.to_owned();
let content = event
.get_content::<PartialMembershipObject>()
.map_err(|e| {
Error::InvalidPdu(format!("m.room.member event has invalid content: {}", e))
})?;
if content.membership.is_none() {
return Err(Error::InvalidPdu(
"m.room.member event missing membership in content".to_owned(),
));
}
// 2: If content has a join_authorised_via_users_server key
//
// 2.1: If the event is not validly signed by the homeserver of the user ID
// denoted by the key, reject.
if let Some(_join_auth) = &content.join_authorized_via_users_server {
// We need to check the signature here, but don't have the means to do so yet.
todo!("Implement join_authorised_via_users_server check");
}
match content.membership.as_deref().unwrap() {
| "join" =>
check_join_event(room_version, event, &content, &target, &fetch_event, &fetch_state)
.await?,
| "invite" =>
check_invite_event(room_version, event, &content, &target, &fetch_state).await?,
| _ => {
todo!()
},
};
Ok(())
}
@@ -1,6 +0,0 @@
pub mod auth_events;
mod context;
pub mod create_event;
pub mod iterative_auth_checks;
pub mod member_event;
mod power_levels;
@@ -1,157 +0,0 @@
use ruma::{OwnedUserId, events::room::power_levels::RoomPowerLevelsEventContent};
use crate::{
Event, Pdu, RoomVersion,
state_res::{Error, event_auth::context::UserPower},
};
/// Verifies that a m.room.power_levels event is well-formed according to the
/// Matrix specification.
///
/// Creators must contain the m.room.create sender and any additional creators.
pub async fn check_power_levels(
room_version: &RoomVersion,
event: &Pdu,
current_power_levels: Option<&RoomPowerLevelsEventContent>,
creators: Vec<OwnedUserId>,
) -> Result<(), Error> {
let content = event
.get_content::<RoomPowerLevelsEventContent>()
.map_err(|e| {
Error::InvalidPdu(format!("m.room.power_levels event has invalid content: {}", e))
})?;
// If any of the properties users_default, events_default, state_default, ban,
// redact, kick, or invite in content are present and not an integer, reject.
//
// If either of the properties events or notifications in content are present
// and not an object with values that are integers, reject.
//
// NOTE: Deserialisation fails if this is not the case, so we don't need to
// check these here.
// If the users property in content is not an object with keys that are valid
// user IDs with values that are integers (or a string that is an integer),
// reject.
while let Some(user_id) = content.users.keys().next() {
// NOTE: Deserialisation fails if the power level is not an integer, so we don't
// need to check that here.
if let Err(e) = user_id.validate_historical() {
return Err(Error::InvalidPdu(format!(
"m.room.power_levels event has invalid user ID in users map: {}",
e
)));
}
// Since v12, If the users property in content contains the sender of the
// m.room.create event or any of the additional_creators array (if present)
// from the content of the m.room.create event, reject.
if room_version.explicitly_privilege_room_creators && creators.contains(user_id) {
return Err(Error::InvalidPdu(
"m.room.power_levels event users map contains a room creator".to_string(),
));
}
}
// If there is no previous m.room.power_levels event in the room, allow.
if current_power_levels.is_none() {
return Ok(());
}
let current_power_levels = current_power_levels.unwrap();
// For the properties users_default, events_default, state_default, ban, redact,
// kick, invite check if they were added, changed or removed. For each found
// alteration:
// If the current value is higher than the senders current power level, reject.
// If the new value is higher than the senders current power level, reject.
let sender = event.sender();
let rank = if room_version.explicitly_privilege_room_creators {
if creators.contains(&sender.to_owned()) {
UserPower::Creator
} else {
UserPower::Standard
}
} else {
UserPower::Standard
};
let sender_pl = current_power_levels
.users
.get(sender)
.unwrap_or(&current_power_levels.users_default);
if rank != UserPower::Creator {
let checks = [
("users_default", current_power_levels.users_default, content.users_default),
("events_default", current_power_levels.events_default, content.events_default),
("state_default", current_power_levels.state_default, content.state_default),
("ban", current_power_levels.ban, content.ban),
("redact", current_power_levels.redact, content.redact),
("kick", current_power_levels.kick, content.kick),
("invite", current_power_levels.invite, content.invite),
];
for (name, old_value, new_value) in checks.iter() {
if old_value != new_value {
if *old_value > *sender_pl {
return Err(Error::AuthConditionFailed(format!(
"sender cannot change level for {}",
name
)));
}
if *new_value > *sender_pl {
return Err(Error::AuthConditionFailed(format!(
"sender cannot raise level for {} to {}",
name, new_value
)));
}
}
}
// For each entry being changed in, or removed from, the events
// property:
// If the current value is greater than the senders current power level,
// reject.
for (event_type, new_value) in content.events.iter() {
let old_value = current_power_levels.events.get(event_type);
if old_value != Some(new_value) {
let old_pl = old_value.unwrap_or(&current_power_levels.events_default);
if *old_pl > *sender_pl {
return Err(Error::AuthConditionFailed(format!(
"sender cannot change event level for {}",
event_type
)));
}
if *new_value > *sender_pl {
return Err(Error::AuthConditionFailed(format!(
"sender cannot raise event level for {} to {}",
event_type, new_value
)));
}
}
}
// For each entry being changed in, or removed from, the events or
// notifications properties:
// If the current value is greater than the senders current power
// level, reject.
// If the new value is greater than the senders current power level,
// reject.
// TODO after making ruwuma's notifications value a BTreeMap
// For each entry being added to, or changed in, the users property:
// If the new value is greater than the senders current power level, reject.
for (user_id, new_value) in content.users.iter() {
let old_value = current_power_levels.users.get(user_id);
if old_value != Some(new_value) {
if *new_value > *sender_pl {
return Err(Error::AuthConditionFailed(format!(
"sender cannot raise user level for {} to {}",
user_id, new_value
)));
}
}
}
}
Ok(())
}
+149 -135
View File
@@ -8,6 +8,9 @@ mod room_version;
#[cfg(test)]
mod test_utils;
#[cfg(test)]
mod benches;
use std::{
borrow::Borrow,
cmp::{Ordering, Reverse},
@@ -15,31 +18,30 @@ use std::{
hash::{BuildHasher, Hash},
};
use futures::{Future, FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt};
use itertools::Itertools;
use futures::{Future, FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt, future};
use ruma::{
EventId, Int, MilliSecondsSinceUnixEpoch, OwnedEventId, RoomVersionId,
events::{
room::member::{MembershipState, RoomMemberEventContent}, StateEventType,
TimelineEventType,
}, int, EventId, Int, MilliSecondsSinceUnixEpoch,
OwnedEventId,
RoomVersionId,
StateEventType, TimelineEventType,
room::member::{MembershipState, RoomMemberEventContent},
},
int,
};
use serde_json::from_str as from_json_str;
pub(crate) use self::error::Error;
use self::power_levels::PowerLevelsContentFields;
pub use self::{event_auth::iterative_auth_checks::auth_check, room_version::RoomVersion};
use crate::utils::TryFutureExtExt;
pub use self::{
event_auth::{auth_check, auth_types_for_event},
room_version::RoomVersion,
};
use crate::{
debug, err, error as log_error, matrix::{Event, StateKey},
state_res::{
event_auth::auth_events::auth_types_for_event, room_version::StateResolutionVersion,
},
debug, debug_error, err,
matrix::{Event, StateKey},
state_res::room_version::StateResolutionVersion,
trace,
utils::stream::{BroadbandExt, IterStream, ReadyExt, TryBroadbandExt, WidebandExt},
warn,
Pdu,
};
/// A mapping of event type and state_key to some value `T`, usually an
@@ -73,20 +75,23 @@ type Result<T, E = Error> = crate::Result<T, E>;
/// event is part of the same room.
//#[tracing::instrument(level = "debug", skip(state_sets, auth_chain_sets,
//#[tracing::instrument(level event_fetch))]
pub async fn resolve<'a, Sets, SetIter, Hasher, FE, FR, Exists>(
pub async fn resolve<'a, Pdu, Sets, SetIter, Hasher, Fetch, FetchFut, Exists, ExistsFut>(
room_version: &RoomVersionId,
state_sets: Sets,
auth_chain_sets: &'a [HashSet<OwnedEventId, Hasher>],
event_fetch: &FE,
event_fetch: &Fetch,
event_exists: &Exists,
) -> Result<StateMap<OwnedEventId>>
where
FE: Fn(&EventId) -> FR + Sync,
FR: Future<Output = Result<Option<Pdu>, Error>> + Send,
Exists: AsyncFn(OwnedEventId) -> bool + Sync,
Fetch: Fn(OwnedEventId) -> FetchFut + Sync,
FetchFut: Future<Output = Option<Pdu>> + Send,
Exists: Fn(OwnedEventId) -> ExistsFut + Sync,
ExistsFut: Future<Output = bool> + Send,
Sets: IntoIterator<IntoIter = SetIter> + Send,
SetIter: Iterator<Item = &'a StateMap<OwnedEventId>> + Clone + Send,
Hasher: BuildHasher + Send + Sync,
Pdu: Event + Clone + Send + Sync,
for<'b> &'b Pdu: Event + Send,
{
use RoomVersionId::*;
let stateres_version = match room_version {
@@ -164,7 +169,7 @@ where
// Sequentially auth check each control event.
let resolved_control = iterative_auth_check(
&room_version,
sorted_control_levels.iter().stream().map(ToOwned::to_owned),
sorted_control_levels.iter().stream().map(AsRef::as_ref),
initial_state,
&event_fetch,
)
@@ -204,7 +209,7 @@ where
let mut resolved_state = iterative_auth_check(
&room_version,
sorted_left_events.iter().stream(),
sorted_left_events.iter().stream().map(AsRef::as_ref),
resolved_control, // The control events are added to the final resolved state
&event_fetch,
)
@@ -268,12 +273,14 @@ where
}
/// Calculate the conflicted subgraph
async fn calculate_conflicted_subgraph<FE>(
async fn calculate_conflicted_subgraph<F, Fut, E>(
conflicted: &StateMap<Vec<OwnedEventId>>,
fetch_event: &FE,
fetch_event: &F,
) -> Option<HashSet<OwnedEventId>>
where
FE: AsyncFn(OwnedEventId) -> Result<Option<Pdu>> + Sync,
F: Fn(OwnedEventId) -> Fut + Sync,
Fut: Future<Output = Option<E>> + Send,
E: Event + Send + Sync,
{
let conflicted_events: HashSet<_> = conflicted.values().flatten().cloned().collect();
let mut subgraph: HashSet<OwnedEventId> = HashSet::new();
@@ -305,17 +312,7 @@ where
continue;
}
trace!(event_id = event_id.as_str(), "fetching event for its auth events");
let evt = fetch_event(event_id.clone())
.await
.inspect_err(|e| {
log_error!(
"error fetching event {} for conflicted state subgraph: {}",
event_id,
e
)
})
.ok()
.flatten();
let evt = fetch_event(event_id.clone()).await;
if evt.is_none() {
err!("could not fetch event {} to calculate conflicted subgraph", event_id);
path.pop();
@@ -362,14 +359,15 @@ where
/// The power level is negative because a higher power level is equated to an
/// earlier (further back in time) origin server timestamp.
#[tracing::instrument(level = "debug", skip_all)]
async fn reverse_topological_power_sort<FE, FR>(
async fn reverse_topological_power_sort<E, F, Fut>(
events_to_sort: Vec<OwnedEventId>,
auth_diff: &HashSet<OwnedEventId>,
fetch_event: &FE,
fetch_event: &F,
) -> Result<Vec<OwnedEventId>>
where
FE: Fn(&EventId) -> FR + Sync,
FR: Future<Output = Result<Option<Pdu>, Error>> + Send,
F: Fn(OwnedEventId) -> Fut + Sync,
Fut: Future<Output = Option<E>> + Send,
E: Event + Send + Sync,
{
debug!("reverse topological sort of power events");
@@ -406,8 +404,8 @@ where
.get(&event_id)
.ok_or_else(|| Error::NotFound(String::new()))?;
let ev = fetch_event(&event_id)
.await?
let ev = fetch_event(event_id)
.await
.ok_or_else(|| Error::NotFound(String::new()))?;
Ok((pl, ev.origin_server_ts()))
@@ -546,14 +544,18 @@ where
/// Do NOT use this any where but topological sort, we find the power level for
/// the eventId at the eventId's generation (we walk backwards to `EventId`s
/// most recent previous power level event).
async fn get_power_level_for_sender<FE, FR>(event_id: &EventId, fetch_event: &FE) -> Result<Int>
async fn get_power_level_for_sender<E, F, Fut>(
event_id: &EventId,
fetch_event: &F,
) -> serde_json::Result<Int>
where
FE: Fn(&EventId) -> FR + Sync,
FR: Future<Output = Result<Option<Pdu>, Error>> + Send,
F: Fn(OwnedEventId) -> Fut + Sync,
Fut: Future<Output = Option<E>> + Send,
E: Event + Send,
{
debug!("fetch event ({event_id}) senders power level");
let event = fetch_event(event_id).await?;
let event = fetch_event(event_id.to_owned()).await;
let auth_events = event.as_ref().map(Event::auth_events);
@@ -561,7 +563,7 @@ where
.into_iter()
.flatten()
.stream()
.broad_filter_map(|aid| fetch_event(aid).unwrap_or_default())
.broadn_filter_map(5, |aid| fetch_event(aid.to_owned()))
.ready_find(|aev| is_type_and_key(aev, &TimelineEventType::RoomPowerLevels, ""))
.await;
@@ -592,24 +594,27 @@ where
/// the the `fetch_event` closure and verify each event using the
/// `event_auth::auth_check` function.
#[tracing::instrument(level = "trace", skip_all)]
async fn iterative_auth_check<FE, FR, S>(
async fn iterative_auth_check<'a, E, F, Fut, S>(
room_version: &RoomVersion,
events_to_check: S,
unconflicted_state: StateMap<OwnedEventId>,
fetch_event: &FE,
fetch_event: &F,
) -> Result<StateMap<OwnedEventId>>
where
FE: Fn(&EventId) -> FR,
FR: Future<Output = Result<Option<Pdu>, Error>> + Send + Sync,
S: Stream<Item = OwnedEventId> + Send,
F: Fn(OwnedEventId) -> Fut + Sync,
Fut: Future<Output = Option<E>> + Send,
S: Stream<Item = &'a EventId> + Send + 'a,
E: Event + Clone + Send + Sync,
for<'b> &'b E: Event + Send,
{
debug!("starting iterative auth check");
let events_to_check: Vec<_> = events_to_check
.map(Ok::<OwnedEventId, Error>)
.broad_and_then(async |event_id| match fetch_event(&event_id).await {
| Ok(Some(e)) => Ok(e),
| _ => Err(Error::NotFound(format!("could not find {event_id}")))?,
.map(Result::Ok)
.broad_and_then(async |event_id| {
fetch_event(event_id.to_owned())
.await
.ok_or_else(|| Error::NotFound(format!("Failed to find {event_id}")))
})
.try_collect()
.boxed()
@@ -622,20 +627,16 @@ where
let auth_event_ids: HashSet<OwnedEventId> = events_to_check
.iter()
.flat_map(|event: &Pdu| event.auth_events().map(ToOwned::to_owned))
.flat_map(|event: &E| event.auth_events().map(ToOwned::to_owned))
.collect();
trace!(set = ?auth_event_ids, "auth event IDs to fetch");
let auth_events: HashMap<OwnedEventId, Pdu> = auth_event_ids
let auth_events: HashMap<OwnedEventId, E> = auth_event_ids
.into_iter()
.stream()
.broad_filter_map(async |event_id| {
fetch_event(&event_id)
.await
.map(|ev_opt| ev_opt.map(|ev| (event_id.clone(), ev)))
.unwrap_or_default()
})
.broad_filter_map(fetch_event)
.map(|auth_event| (auth_event.event_id().to_owned(), auth_event))
.collect()
.boxed()
.await;
@@ -654,23 +655,29 @@ where
.state_key()
.ok_or_else(|| Error::InvalidPdu("State event had no state key".to_owned()))?;
let member_event_content = match event.kind() {
| TimelineEventType::RoomMember =>
Some(event.get_content::<RoomMemberEventContent>().map_err(|e| {
Error::InvalidPdu(format!("Failed to parse m.room.member content: {}", e))
})?),
| _ => None,
};
let auth_types = auth_types_for_event(
room_version,
event.kind(),
event.state_key().map(StateKey::from_str).as_ref(),
event.event_type(),
event.sender(),
member_event_content,
Some(state_key),
event.content(),
room_version,
)?;
trace!(list = ?auth_types, event_id = event.event_id().as_str(), "auth types for event");
let mut auth_state = StateMap::with_capacity(event.auth_events.len());
let mut auth_state = StateMap::new();
if room_version.room_ids_as_hashes {
trace!("room version uses hashed IDs, manually fetching create event");
let create_event_id_raw = event.room_id_or_hash().as_str().replace('!', "$");
let create_event_id = EventId::parse(&create_event_id_raw).map_err(|e| {
Error::InvalidPdu(format!(
"Failed to parse create event ID from room ID/hash: {e}"
))
})?;
let create_event = fetch_event(create_event_id.into())
.await
.ok_or_else(|| Error::NotFound("Failed to find create event".into()))?;
auth_state.insert(create_event.event_type().with_state_key(""), create_event);
}
for aid in event.auth_events() {
if let Some(ev) = auth_events.get(aid) {
//TODO: synapse checks "rejected_reason" which is most likely related to
@@ -696,13 +703,7 @@ where
if let Some(event) = auth_events.get(ev_id) {
Some((key, event.clone()))
} else {
match fetch_event(ev_id).await {
| Ok(Some(event)) => Some((key, event)),
| _ => {
warn!(event_id = ev_id.as_str(), "unable to fetch auth event");
None
},
}
Some((key, fetch_event(ev_id.clone()).await?))
}
})
.ready_for_each(|(key, event)| {
@@ -714,16 +715,30 @@ where
debug!(event_id = event.event_id().as_str(), "Running auth checks");
let fetch_state = async |t: (StateEventType, StateKey)| {
Ok(auth_state
.get(&t.0.with_state_key(t.1.as_str()))
.map(ToOwned::to_owned))
// The key for this is (eventType + a state_key of the signed token not sender)
// so search for it
let current_third_party = auth_state.iter().find_map(|(_, pdu)| {
(*pdu.event_type() == TimelineEventType::RoomThirdPartyInvite).then_some(pdu)
});
let fetch_state = |ty: &StateEventType, key: &str| {
future::ready(
auth_state
.get(&ty.with_state_key(key))
.map(ToOwned::to_owned),
)
};
let create_event = fetch_state((StateEventType::RoomCreate, StateKey::new())).await?;
let auth_result =
auth_check(room_version, &event, fetch_event, &fetch_state, create_event.as_ref())
.await;
let auth_result = auth_check(
room_version,
&event,
current_third_party,
fetch_state,
&fetch_state(&StateEventType::RoomCreate, "")
.await
.expect("create event must exist"),
)
.await;
match auth_result {
| Ok(true) => {
@@ -743,7 +758,7 @@ where
warn!("event {} failed the authentication check", event.event_id());
},
| Err(e) => {
log_error!("event {} failed the authentication check: {e}", event.event_id());
debug_error!("event {} failed the authentication check: {e}", event.event_id());
return Err(e);
},
}
@@ -762,14 +777,15 @@ where
/// after the most recent are depth 0, the events before (with the first power
/// level as a parent) will be marked as depth 1. depth 1 is "older" than depth
/// 0.
async fn mainline_sort<FE, FR>(
async fn mainline_sort<E, F, Fut>(
to_sort: &[OwnedEventId],
resolved_power_level: Option<OwnedEventId>,
fetch_event: &FE,
fetch_event: &F,
) -> Result<Vec<OwnedEventId>>
where
FE: Fn(&EventId) -> FR + Sync,
FR: Future<Output = Result<Option<Pdu>, Error>> + Send,
F: Fn(OwnedEventId) -> Fut + Sync,
Fut: Future<Output = Option<E>> + Send,
E: Event + Clone + Send + Sync,
{
debug!("mainline sort of events");
@@ -783,14 +799,14 @@ where
while let Some(p) = pl {
mainline.push(p.clone());
let event = fetch_event(&p)
.await?
let event = fetch_event(p.clone())
.await
.ok_or_else(|| Error::NotFound(format!("Failed to find {p}")))?;
pl = None;
for aid in event.auth_events() {
let ev = fetch_event(aid)
.await?
let ev = fetch_event(aid.to_owned())
.await
.ok_or_else(|| Error::NotFound(format!("Failed to find {aid}")))?;
if is_type_and_key(&ev, &TimelineEventType::RoomPowerLevels, "") {
@@ -811,11 +827,7 @@ where
.iter()
.stream()
.broad_filter_map(async |ev_id| {
fetch_event(ev_id)
.await
.ok()
.flatten()
.map(|event| (event, ev_id))
fetch_event(ev_id.clone()).await.map(|event| (event, ev_id))
})
.broad_filter_map(|(event, ev_id)| {
get_mainline_depth(Some(event.clone()), &mainline_map, fetch_event)
@@ -837,14 +849,15 @@ where
/// Get the mainline depth from the `mainline_map` or finds a power_level event
/// that has an associated mainline depth.
async fn get_mainline_depth<FE, FR>(
mut event: Option<Pdu>,
async fn get_mainline_depth<E, F, Fut>(
mut event: Option<E>,
mainline_map: &HashMap<OwnedEventId, usize>,
fetch_event: &FE,
fetch_event: &F,
) -> Result<usize>
where
FE: Fn(&EventId) -> FR + Sync,
FR: Future<Output = Result<Option<Pdu>, Error>> + Send,
F: Fn(OwnedEventId) -> Fut + Sync,
Fut: Future<Output = Option<E>> + Send,
E: Event + Send + Sync,
{
while let Some(sort_ev) = event {
debug!(event_id = sort_ev.event_id().as_str(), "mainline");
@@ -856,8 +869,8 @@ where
event = None;
for aid in sort_ev.auth_events() {
let aev = fetch_event(aid)
.await?
let aev = fetch_event(aid.to_owned())
.await
.ok_or_else(|| Error::NotFound(format!("Failed to find {aid}")))?;
if is_type_and_key(&aev, &TimelineEventType::RoomPowerLevels, "") {
@@ -870,19 +883,20 @@ where
Ok(0)
}
async fn add_event_and_auth_chain_to_graph<FE, FR>(
async fn add_event_and_auth_chain_to_graph<E, F, Fut>(
graph: &mut HashMap<OwnedEventId, HashSet<OwnedEventId>>,
event_id: OwnedEventId,
auth_diff: &HashSet<OwnedEventId>,
fetch_event: &FE,
fetch_event: &F,
) where
FE: Fn(&EventId) -> FR + Sync,
FR: Future<Output = Result<Option<Pdu>, Error>> + Send,
F: Fn(OwnedEventId) -> Fut + Sync,
Fut: Future<Output = Option<E>> + Send,
E: Event + Send + Sync,
{
let mut state = vec![event_id];
while let Some(eid) = state.pop() {
graph.entry(eid.clone()).or_default();
let event = fetch_event(&eid).await.ok().flatten();
let event = fetch_event(eid.clone()).await;
let auth_events = event.as_ref().map(Event::auth_events).into_iter().flatten();
// Prefer the store to event as the store filters dedups the events
@@ -901,13 +915,14 @@ async fn add_event_and_auth_chain_to_graph<FE, FR>(
}
}
async fn is_power_event_id<FE, FR>(event_id: &EventId, fetch: &FE) -> bool
async fn is_power_event_id<E, F, Fut>(event_id: &EventId, fetch: &F) -> bool
where
FE: Fn(&EventId) -> FR + Sync,
FR: Future<Output = Result<Option<Pdu>, Error>> + Send,
F: Fn(OwnedEventId) -> Fut + Sync,
Fut: Future<Output = Option<E>> + Send,
E: Event + Send,
{
match fetch(event_id).await.as_ref() {
| Ok(Some(state)) => is_power_event(state),
match fetch(event_id.to_owned()).await.as_ref() {
| Some(state) => is_power_event(state),
| _ => false,
}
}
@@ -964,27 +979,26 @@ where
mod tests {
use std::collections::{HashMap, HashSet};
use itertools::Itertools;
use maplit::{hashmap, hashset};
use rand::seq::SliceRandom;
use ruma::{
MilliSecondsSinceUnixEpoch, OwnedEventId, RoomVersionId,
events::{
room::join_rules::{JoinRule, RoomJoinRulesEventContent}, StateEventType,
TimelineEventType,
}, int, uint,
MilliSecondsSinceUnixEpoch,
OwnedEventId, RoomVersionId,
StateEventType, TimelineEventType,
room::join_rules::{JoinRule, RoomJoinRulesEventContent},
},
int, uint,
};
use serde_json::{json, value::to_raw_value as to_raw_json_value};
use super::{
is_power_event, room_version::RoomVersion,
StateMap, is_power_event,
room_version::RoomVersion,
test_utils::{
alice, bob, charlie, do_check, ella, event_id, member_content_ban, member_content_join,
room_id, to_init_pdu_event, to_pdu_event, zara, TestStore,
INITIAL_EVENTS,
INITIAL_EVENTS, TestStore, alice, bob, charlie, do_check, ella, event_id,
member_content_ban, member_content_join, room_id, to_init_pdu_event, to_pdu_event,
zara,
},
StateMap,
};
use crate::{
debug,
@@ -1014,13 +1028,13 @@ mod tests {
.map(|pdu| pdu.event_id.clone())
.collect::<Vec<_>>();
let fetcher = |id| ready(Ok(events.get(id).cloned()));
let fetcher = |id| ready(events.get(&id).cloned());
let sorted_power_events =
super::reverse_topological_power_sort(power_events, &auth_chain, &fetcher)
.await
.unwrap();
let resolved_power = super::auth_check(
let resolved_power = super::iterative_auth_check(
&RoomVersion::V6,
sorted_power_events.iter().map(AsRef::as_ref).stream(),
HashMap::new(), // unconflicted events
@@ -1032,7 +1046,7 @@ mod tests {
// don't remove any events so we know it sorts them all correctly
let mut events_to_sort = events.keys().cloned().collect::<Vec<_>>();
events_to_sort.shuffle(&mut rand::thread_rng());
events_to_sort.shuffle(&mut rand::rng());
let power_level = resolved_power
.get(&(StateEventType::RoomPowerLevels, "".into()))
+1 -1
View File
@@ -28,7 +28,7 @@ fn init_argon() -> Argon2<'static> {
}
pub(super) fn password(password: &str) -> Result<String> {
let salt = SaltString::generate(rand::thread_rng());
let salt = SaltString::generate(rand_core::OsRng);
ARGON
.get_or_init(init_argon)
.hash_password(password.as_bytes(), &salt)
+7 -10
View File
@@ -4,16 +4,16 @@ use std::{
};
use arrayvec::ArrayString;
use rand::{Rng, seq::SliceRandom, thread_rng};
use rand::{RngExt, seq::SliceRandom};
pub fn shuffle<T>(vec: &mut [T]) {
let mut rng = thread_rng();
let mut rng = rand::rng();
vec.shuffle(&mut rng);
}
pub fn string(length: usize) -> String {
thread_rng()
.sample_iter(&rand::distributions::Alphanumeric)
rand::rng()
.sample_iter(&rand::distr::Alphanumeric)
.take(length)
.map(char::from)
.collect()
@@ -22,8 +22,8 @@ pub fn string(length: usize) -> String {
#[inline]
pub fn string_array<const LENGTH: usize>() -> ArrayString<LENGTH> {
let mut ret = ArrayString::<LENGTH>::new();
thread_rng()
.sample_iter(&rand::distributions::Alphanumeric)
rand::rng()
.sample_iter(&rand::distr::Alphanumeric)
.take(LENGTH)
.map(char::from)
.for_each(|c| ret.push(c));
@@ -40,7 +40,4 @@ pub fn time_from_now_secs(range: Range<u64>) -> SystemTime {
}
#[must_use]
pub fn secs(range: Range<u64>) -> Duration {
let mut rng = thread_rng();
Duration::from_secs(rng.gen_range(range))
}
pub fn secs(range: Range<u64>) -> Duration { Duration::from_secs(rand::random_range(range)) }
+7 -9
View File
@@ -3,19 +3,17 @@ use futures::{
stream::{Stream, TryStream},
};
use crate::{Error, Result};
pub trait IterStream<I: IntoIterator + Send> {
/// Convert an Iterator into a Stream
fn stream(self) -> impl Stream<Item = <I as IntoIterator>::Item> + Send;
/// Convert an Iterator into a TryStream
fn try_stream(
/// Convert an Iterator into a TryStream with a generic error type
fn try_stream<E>(
self,
) -> impl TryStream<
Ok = <I as IntoIterator>::Item,
Error = Error,
Item = Result<<I as IntoIterator>::Item, Error>,
Error = E,
Item = Result<<I as IntoIterator>::Item, E>,
> + Send;
}
@@ -28,12 +26,12 @@ where
fn stream(self) -> impl Stream<Item = <I as IntoIterator>::Item> + Send { stream::iter(self) }
#[inline]
fn try_stream(
fn try_stream<E>(
self,
) -> impl TryStream<
Ok = <I as IntoIterator>::Item,
Error = Error,
Item = Result<<I as IntoIterator>::Item, Error>,
Error = E,
Item = Result<<I as IntoIterator>::Item, E>,
> + Send {
self.stream().map(Ok)
}
+2 -1
View File
@@ -1,9 +1,10 @@
//! Synchronous combinator extensions to futures::TryStream
use std::result::Result;
use futures::{TryFuture, TryStream, TryStreamExt};
use super::automatic_width;
use crate::Result;
/// Concurrency extensions to augment futures::TryStreamExt. broad_ combinators
/// produce out-of-order
+1 -3
View File
@@ -20,7 +20,6 @@ use std::{sync::Arc, time::Duration};
use async_trait::async_trait;
use conduwuit::{Result, Server, debug, error, warn};
use database::{Deserialized, Map};
use rand::Rng;
use ruma::events::{Mentions, room::message::RoomMessageEventContent};
use serde::Deserialize;
use tokio::{
@@ -100,8 +99,7 @@ impl crate::Service for Service {
}
let first_check_jitter = {
let mut rng = rand::thread_rng();
let jitter_percent = rng.gen_range(-50.0..=10.0);
let jitter_percent = rand::random_range(-50.0..=10.0);
self.interval.mul_f64(1.0 + jitter_percent / 100.0)
};
+1 -1
View File
@@ -39,7 +39,7 @@ impl crate::Service for Service {
let url_preview_user_agent = config
.url_preview_user_agent
.clone()
.unwrap_or_else(|| conduwuit::version::user_agent().to_owned());
.unwrap_or_else(|| conduwuit::version::user_agent_media().to_owned());
Ok(Arc::new(Self {
default: base(config)?
+2
View File
@@ -170,6 +170,8 @@ impl Data {
Ok(())
}
pub(super) async fn clear_url_previews(&self) { self.url_previews.clear().await; }
pub(super) fn set_url_preview(
&self,
url: &str,
+3
View File
@@ -37,6 +37,9 @@ pub async fn remove_url_preview(&self, url: &str) -> Result<()> {
self.db.remove_url_preview(url)
}
#[implement(Service)]
pub async fn clear_url_previews(&self) { self.db.clear_url_previews().await; }
#[implement(Service)]
pub async fn set_url_preview(&self, url: &str, data: &UrlPreviewData) -> Result<()> {
let now = SystemTime::now()
+1 -1
View File
@@ -31,7 +31,7 @@ pub mod rooms;
pub mod sending;
pub mod server_keys;
pub mod sync;
pub mod transaction_ids;
pub mod transactions;
pub mod uiaa;
pub mod users;
+1 -1
View File
@@ -142,7 +142,7 @@ async fn get_auth_chain_outer(
let chunk_cache: Vec<_> = chunk
.into_iter()
.try_stream()
.try_stream::<conduwuit::Error>()
.broad_and_then(|(shortid, event_id)| async move {
if let Ok(cached) = self.get_cached_eventid_authchain(&[shortid]).await {
return Ok(cached.to_vec());
+7 -5
View File
@@ -385,11 +385,13 @@ fn num_senders(args: &crate::Args<'_>) -> usize {
const MIN_SENDERS: usize = 1;
// Limit the number of senders to the number of workers threads or number of
// cores, conservatively.
let max_senders = args
.server
.metrics
.num_workers()
.min(available_parallelism());
let mut max_senders = args.server.metrics.num_workers();
// Work around some platforms not returning the number of cores.
let num_cores = available_parallelism();
if num_cores > 0 {
max_senders = max_senders.min(num_cores);
}
// If the user doesn't override the default 0, this is intended to then default
// to 1 for now as multiple senders is experimental.
+3 -3
View File
@@ -14,7 +14,7 @@ use crate::{
media, moderation, presence, pusher, registration_tokens, resolver, rooms, sending,
server_keys,
service::{self, Args, Map, Service},
sync, transaction_ids, uiaa, users,
sync, transactions, uiaa, users,
};
pub struct Services {
@@ -37,7 +37,7 @@ pub struct Services {
pub sending: Arc<sending::Service>,
pub server_keys: Arc<server_keys::Service>,
pub sync: Arc<sync::Service>,
pub transaction_ids: Arc<transaction_ids::Service>,
pub transactions: Arc<transactions::Service>,
pub uiaa: Arc<uiaa::Service>,
pub users: Arc<users::Service>,
pub moderation: Arc<moderation::Service>,
@@ -110,7 +110,7 @@ impl Services {
sending: build!(sending::Service),
server_keys: build!(server_keys::Service),
sync: build!(sync::Service),
transaction_ids: build!(transaction_ids::Service),
transactions: build!(transactions::Service),
uiaa: build!(uiaa::Service),
users: build!(users::Service),
moderation: build!(moderation::Service),
-54
View File
@@ -1,54 +0,0 @@
use std::sync::Arc;
use conduwuit::{Result, implement};
use database::{Handle, Map};
use ruma::{DeviceId, TransactionId, UserId};
pub struct Service {
db: Data,
}
struct Data {
userdevicetxnid_response: Arc<Map>,
}
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
db: Data {
userdevicetxnid_response: args.db["userdevicetxnid_response"].clone(),
},
}))
}
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
}
#[implement(Service)]
pub fn add_txnid(
&self,
user_id: &UserId,
device_id: Option<&DeviceId>,
txn_id: &TransactionId,
data: &[u8],
) {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default());
key.push(0xFF);
key.extend_from_slice(txn_id.as_bytes());
self.db.userdevicetxnid_response.insert(&key, data);
}
// If there's no entry, this is a new transaction
#[implement(Service)]
pub async fn existing_txnid(
&self,
user_id: &UserId,
device_id: Option<&DeviceId>,
txn_id: &TransactionId,
) -> Result<Handle<'_>> {
let key = (user_id, device_id, txn_id);
self.db.userdevicetxnid_response.qry(&key).await
}
+326
View File
@@ -0,0 +1,326 @@
use std::{
collections::HashMap,
fmt,
sync::{
Arc,
atomic::{AtomicU64, Ordering},
},
time::{Duration, SystemTime},
};
use async_trait::async_trait;
use conduwuit::{Error, Result, SyncRwLock, debug_warn, warn};
use database::{Handle, Map};
use ruma::{
DeviceId, OwnedServerName, OwnedTransactionId, TransactionId, UserId,
api::{
client::error::ErrorKind::LimitExceeded,
federation::transactions::send_transaction_message,
},
};
use tokio::sync::watch::{Receiver, Sender};
use crate::{Dep, config};
pub type TxnKey = (OwnedServerName, OwnedTransactionId);
pub type WrappedTransactionResponse =
Option<Result<send_transaction_message::v1::Response, TransactionError>>;
/// Errors that can occur during federation transaction processing.
#[derive(Debug, Clone)]
pub enum TransactionError {
/// Server is shutting down - the sender should retry the entire
/// transaction.
ShuttingDown,
}
impl fmt::Display for TransactionError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
| Self::ShuttingDown => write!(f, "Server is shutting down"),
}
}
}
impl std::error::Error for TransactionError {}
/// Minimum interval between cache cleanup runs.
/// Exists to prevent thrashing when the cache is full of things that can't be
/// cleared
const CLEANUP_INTERVAL_SECS: u64 = 30;
#[derive(Clone, Debug)]
pub struct CachedTxnResponse {
pub response: send_transaction_message::v1::Response,
pub created: SystemTime,
}
/// Internal state for a federation transaction.
/// Either actively being processed or completed and cached.
#[derive(Clone)]
enum TxnState {
/// Transaction is currently being processed.
Active(Receiver<WrappedTransactionResponse>),
/// Transaction completed and response is cached.
Cached(CachedTxnResponse),
}
/// Result of atomically checking or starting a federation transaction.
pub enum FederationTxnState {
/// Transaction already completed and cached
Cached(send_transaction_message::v1::Response),
/// Transaction is currently being processed by another request.
/// Wait on this receiver for the result.
Active(Receiver<WrappedTransactionResponse>),
/// This caller should process the transaction (first to request it).
Started {
receiver: Receiver<WrappedTransactionResponse>,
sender: Sender<WrappedTransactionResponse>,
},
}
pub struct Service {
services: Services,
db: Data,
federation_txn_state: Arc<SyncRwLock<HashMap<TxnKey, TxnState>>>,
last_cleanup: AtomicU64,
}
struct Services {
config: Dep<config::Service>,
}
struct Data {
userdevicetxnid_response: Arc<Map>,
}
#[async_trait]
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
services: Services {
config: args.depend::<config::Service>("config"),
},
db: Data {
userdevicetxnid_response: args.db["userdevicetxnid_response"].clone(),
},
federation_txn_state: Arc::new(SyncRwLock::new(HashMap::new())),
last_cleanup: AtomicU64::new(0),
}))
}
async fn clear_cache(&self) {
let mut state = self.federation_txn_state.write();
// Only clear cached entries, preserve active transactions
state.retain(|_, v| matches!(v, TxnState::Active(_)));
}
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
}
impl Service {
/// Returns the count of currently active (in-progress) transactions.
#[must_use]
pub fn txn_active_handle_count(&self) -> usize {
let state = self.federation_txn_state.read();
state
.values()
.filter(|v| matches!(v, TxnState::Active(_)))
.count()
}
pub fn add_client_txnid(
&self,
user_id: &UserId,
device_id: Option<&DeviceId>,
txn_id: &TransactionId,
data: &[u8],
) {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default());
key.push(0xFF);
key.extend_from_slice(txn_id.as_bytes());
self.db.userdevicetxnid_response.insert(&key, data);
}
pub async fn get_client_txn(
&self,
user_id: &UserId,
device_id: Option<&DeviceId>,
txn_id: &TransactionId,
) -> Result<Handle<'_>> {
let key = (user_id, device_id, txn_id);
self.db.userdevicetxnid_response.qry(&key).await
}
/// Atomically gets a cached response, joins an active transaction, or
/// starts a new one.
pub fn get_or_start_federation_txn(&self, key: TxnKey) -> Result<FederationTxnState> {
// Only one upgradable lock can be held at a time, and there aren't any
// read-only locks, so no point being upgradable
let mut state = self.federation_txn_state.write();
// Check existing state for this key
if let Some(txn_state) = state.get(&key) {
return Ok(match txn_state {
| TxnState::Cached(cached) => FederationTxnState::Cached(cached.response.clone()),
| TxnState::Active(receiver) => FederationTxnState::Active(receiver.clone()),
});
}
// Check if another transaction from this origin is already running
let has_active_from_origin = state
.iter()
.any(|(k, v)| k.0 == key.0 && matches!(v, TxnState::Active(_)));
if has_active_from_origin {
debug_warn!(
origin = ?key.0,
"Got concurrent transaction request from an origin with an active transaction"
);
return Err(Error::BadRequest(
LimitExceeded { retry_after: None },
"Still processing another transaction from this origin",
));
}
let max_active_txns = self.services.config.max_concurrent_inbound_transactions;
// Check if we're at capacity
if state.len() >= max_active_txns
&& let active_count = state
.values()
.filter(|v| matches!(v, TxnState::Active(_)))
.count() && active_count >= max_active_txns
{
warn!(
active = active_count,
max = max_active_txns,
"Server is overloaded, dropping incoming transaction"
);
return Err(Error::BadRequest(
LimitExceeded { retry_after: None },
"Server is overloaded, try again later",
));
}
// Start new transaction
let (sender, receiver) = tokio::sync::watch::channel(None);
state.insert(key, TxnState::Active(receiver.clone()));
Ok(FederationTxnState::Started { receiver, sender })
}
/// Finishes a transaction by transitioning it from active to cached state.
/// Additionally may trigger cleanup of old entries.
pub fn finish_federation_txn(
&self,
key: TxnKey,
sender: Sender<WrappedTransactionResponse>,
response: send_transaction_message::v1::Response,
) {
// Check if cleanup might be needed before acquiring the lock
let should_try_cleanup = self.should_try_cleanup();
let mut state = self.federation_txn_state.write();
// Explicitly set cached first so there is no gap where receivers get a closed
// channel
state.insert(
key,
TxnState::Cached(CachedTxnResponse {
response: response.clone(),
created: SystemTime::now(),
}),
);
if let Err(e) = sender.send(Some(Ok(response))) {
debug_warn!("Failed to send transaction response to waiting receivers: {e}");
}
// Explicitly close
drop(sender);
// This task is dangling, we can try clean caches now
if should_try_cleanup {
self.cleanup_entries_locked(&mut state);
}
}
pub fn remove_federation_txn(&self, key: &TxnKey) {
let mut state = self.federation_txn_state.write();
state.remove(key);
}
/// Checks if enough time has passed since the last cleanup to consider
/// running another. Updates the last cleanup time if returning true.
fn should_try_cleanup(&self) -> bool {
let now = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.expect("SystemTime before UNIX_EPOCH")
.as_secs();
let last = self.last_cleanup.load(Ordering::Relaxed);
if now.saturating_sub(last) >= CLEANUP_INTERVAL_SECS {
// CAS: only update if no one else has updated it since we read
self.last_cleanup
.compare_exchange(last, now, Ordering::Relaxed, Ordering::Relaxed)
.is_ok()
} else {
false
}
}
/// Cleans up cached entries based on age and count limits.
///
/// First removes all cached entries older than the configured max age.
/// Then, if the cache still exceeds the max entry count, removes the oldest
/// cached entries until the count is within limits.
///
/// Must be called with write lock held on the state map.
fn cleanup_entries_locked(&self, state: &mut HashMap<TxnKey, TxnState>) {
let max_age_secs = self.services.config.transaction_id_cache_max_age_secs;
let max_entries = self.services.config.transaction_id_cache_max_entries;
// First pass: remove all cached entries older than max age
let cutoff = SystemTime::now()
.checked_sub(Duration::from_secs(max_age_secs))
.unwrap_or(SystemTime::UNIX_EPOCH);
state.retain(|_, v| match v {
| TxnState::Active(_) => true, // Never remove active transactions
| TxnState::Cached(cached) => cached.created > cutoff,
});
// Count cached entries
let cached_count = state
.values()
.filter(|v| matches!(v, TxnState::Cached(_)))
.count();
// Second pass: if still over max entries, remove oldest cached entries
if cached_count > max_entries {
let excess = cached_count.saturating_sub(max_entries);
// Collect cached entries sorted by age (oldest first)
let mut cached_entries: Vec<_> = state
.iter()
.filter_map(|(k, v)| match v {
| TxnState::Cached(cached) => Some((k.clone(), cached.created)),
| TxnState::Active(_) => None,
})
.collect();
cached_entries.sort_by(|a, b| a.1.cmp(&b.1));
// Remove the oldest cached entries to get under the limit
for (key, _) in cached_entries.into_iter().take(excess) {
state.remove(&key);
}
}
}
}
+6
View File
@@ -184,6 +184,12 @@ impl Service {
password: Option<&str>,
origin: Option<&str>,
) -> Result<()> {
if !self.services.globals.user_is_local(user_id)
&& (password.is_some() || origin.is_some())
{
return Err!("Cannot create a nonlocal user with a set password or origin");
}
self.db
.userid_origin
.insert(user_id, origin.unwrap_or("password"));