mirror of
https://forgejo.ellis.link/continuwuation/continuwuity.git
synced 2026-05-26 20:49:55 +00:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| dfd89425a1 | |||
| c69e7c7d1b | |||
| bd404e808c | |||
| 0899985476 | |||
| b3cf649732 |
Generated
+479
-553
File diff suppressed because it is too large
Load Diff
+10
-15
@@ -68,7 +68,7 @@ default-features = false
|
||||
version = "0.1.3"
|
||||
|
||||
[workspace.dependencies.rand]
|
||||
version = "0.10.0"
|
||||
version = "0.8.5"
|
||||
|
||||
# Used for the http request / response body type for Ruma endpoints used with reqwest
|
||||
[workspace.dependencies.bytes]
|
||||
@@ -253,7 +253,7 @@ features = [
|
||||
version = "0.4.0"
|
||||
|
||||
[workspace.dependencies.libloading]
|
||||
version = "0.9.0"
|
||||
version = "0.8.6"
|
||||
|
||||
# Validating urls in config, was already a transitive dependency
|
||||
[workspace.dependencies.url]
|
||||
@@ -298,7 +298,7 @@ default-features = false
|
||||
features = ["env", "toml"]
|
||||
|
||||
[workspace.dependencies.hickory-resolver]
|
||||
version = "0.25.2"
|
||||
version = "0.25.1"
|
||||
default-features = false
|
||||
features = [
|
||||
"serde",
|
||||
@@ -307,14 +307,9 @@ features = [
|
||||
]
|
||||
|
||||
# Used for conduwuit::Error type
|
||||
[workspace.dependencies.snafu]
|
||||
version = "0.8"
|
||||
[workspace.dependencies.thiserror]
|
||||
version = "2.0.12"
|
||||
default-features = false
|
||||
features = ["std", "rust_1_81"]
|
||||
|
||||
# Used for macro name generation
|
||||
[workspace.dependencies.paste]
|
||||
version = "1.0"
|
||||
|
||||
# Used when hashing the state
|
||||
[workspace.dependencies.ring]
|
||||
@@ -348,7 +343,7 @@ version = "0.1.2"
|
||||
[workspace.dependencies.ruma]
|
||||
git = "https://forgejo.ellis.link/continuwuation/ruwuma"
|
||||
#branch = "conduwuit-changes"
|
||||
rev = "e087ff15888156942ca2ffe6097d1b4c3fd27628"
|
||||
rev = "3126cb5eea991ec40590e54d8c9d75637650641a"
|
||||
features = [
|
||||
"compat",
|
||||
"rand",
|
||||
@@ -430,7 +425,7 @@ features = ["http", "grpc-tonic", "trace", "logs", "metrics"]
|
||||
|
||||
# optional sentry metrics for crash/panic reporting
|
||||
[workspace.dependencies.sentry]
|
||||
version = "0.46.0"
|
||||
version = "0.45.0"
|
||||
default-features = false
|
||||
features = [
|
||||
"backtrace",
|
||||
@@ -446,9 +441,9 @@ features = [
|
||||
]
|
||||
|
||||
[workspace.dependencies.sentry-tracing]
|
||||
version = "0.46.0"
|
||||
version = "0.45.0"
|
||||
[workspace.dependencies.sentry-tower]
|
||||
version = "0.46.0"
|
||||
version = "0.45.0"
|
||||
|
||||
# jemalloc usage
|
||||
[workspace.dependencies.tikv-jemalloc-sys]
|
||||
@@ -477,7 +472,7 @@ features = ["use_std"]
|
||||
version = "0.5"
|
||||
|
||||
[workspace.dependencies.nix]
|
||||
version = "0.31.0"
|
||||
version = "0.30.1"
|
||||
default-features = false
|
||||
features = ["resource"]
|
||||
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
Fixed a startup crash in the sender service if we can't detect the number of CPU cores, even if the `sender_workers' config option is set correctly. Contributed by @katie.
|
||||
@@ -1325,7 +1325,7 @@
|
||||
# sender user's server name, inbound federation X-Matrix origin, and
|
||||
# outbound federation handler.
|
||||
#
|
||||
# You can set this to [".*"] to block all servers by default, and then
|
||||
# You can set this to ["*"] to block all servers by default, and then
|
||||
# use `allowed_remote_server_names` to allow only specific servers.
|
||||
#
|
||||
# example: ["badserver\\.tld$", "badphrase", "19dollarfortnitecards"]
|
||||
|
||||
+1
-1
@@ -52,7 +52,7 @@ ENV BINSTALL_VERSION=1.17.5
|
||||
# renovate: datasource=github-releases depName=psastras/sbom-rs
|
||||
ENV CARGO_SBOM_VERSION=0.9.1
|
||||
# renovate: datasource=crate depName=lddtree
|
||||
ENV LDDTREE_VERSION=0.5.0
|
||||
ENV LDDTREE_VERSION=0.4.0
|
||||
# renovate: datasource=crate depName=timelord-cli
|
||||
ENV TIMELORD_VERSION=3.0.1
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ ENV BINSTALL_VERSION=1.17.5
|
||||
# renovate: datasource=github-releases depName=psastras/sbom-rs
|
||||
ENV CARGO_SBOM_VERSION=0.9.1
|
||||
# renovate: datasource=crate depName=lddtree
|
||||
ENV LDDTREE_VERSION=0.5.0
|
||||
ENV LDDTREE_VERSION=0.4.0
|
||||
|
||||
# Install unpackaged tools
|
||||
RUN <<EOF
|
||||
|
||||
@@ -77,12 +77,7 @@ rec {
|
||||
craneLib.buildDepsOnly (
|
||||
(commonAttrs commonAttrsArgs)
|
||||
// {
|
||||
env = uwuenv.buildDepsOnlyEnv
|
||||
// (makeRocksDBEnv { inherit rocksdb; })
|
||||
// {
|
||||
# required since we started using unstable reqwest apparently ... otherwise the all-features build will fail
|
||||
RUSTFLAGS = "--cfg reqwest_unstable";
|
||||
};
|
||||
env = uwuenv.buildDepsOnlyEnv // (makeRocksDBEnv { inherit rocksdb; });
|
||||
inherit (features) cargoExtraArgs;
|
||||
}
|
||||
|
||||
@@ -107,13 +102,7 @@ rec {
|
||||
'';
|
||||
cargoArtifacts = deps;
|
||||
doCheck = true;
|
||||
env =
|
||||
uwuenv.buildPackageEnv
|
||||
// rocksdbEnv
|
||||
// {
|
||||
# required since we started using unstable reqwest apparently ... otherwise the all-features build will fail
|
||||
RUSTFLAGS = "--cfg reqwest_unstable";
|
||||
};
|
||||
env = uwuenv.buildPackageEnv // rocksdbEnv;
|
||||
passthru.env = uwuenv.buildPackageEnv // rocksdbEnv;
|
||||
meta.mainProgram = crateInfo.pname;
|
||||
inherit (features) cargoExtraArgs;
|
||||
|
||||
@@ -3,7 +3,7 @@ use std::fmt::Write;
|
||||
use axum::extract::State;
|
||||
use axum_client_ip::InsecureClientIp;
|
||||
use conduwuit::{
|
||||
Err, Event, Result, debug_info, err, error, info,
|
||||
Err, Error, Event, Result, debug_info, err, error, info,
|
||||
matrix::pdu::PduBuilder,
|
||||
utils::{self, ReadyExt, stream::BroadbandExt},
|
||||
warn,
|
||||
@@ -252,13 +252,6 @@ pub(crate) async fn register_route(
|
||||
}
|
||||
}
|
||||
|
||||
// Don't allow registration with user IDs that aren't local
|
||||
if !services.globals.user_is_local(&user_id) {
|
||||
return Err!(Request(InvalidUsername(
|
||||
"Username {body_username} is not local to this server"
|
||||
)));
|
||||
}
|
||||
|
||||
user_id
|
||||
},
|
||||
| Err(e) => {
|
||||
@@ -387,7 +380,7 @@ pub(crate) async fn register_route(
|
||||
)
|
||||
.await?;
|
||||
if !worked {
|
||||
return Err!(Uiaa(uiaainfo));
|
||||
return Err(Error::Uiaa(uiaainfo));
|
||||
}
|
||||
// Success!
|
||||
},
|
||||
@@ -401,7 +394,7 @@ pub(crate) async fn register_route(
|
||||
&uiaainfo,
|
||||
json,
|
||||
);
|
||||
return Err!(Uiaa(uiaainfo));
|
||||
return Err(Error::Uiaa(uiaainfo));
|
||||
},
|
||||
| _ => {
|
||||
return Err!(Request(NotJson("JSON body is not valid")));
|
||||
@@ -661,7 +654,7 @@ pub(crate) async fn change_password_route(
|
||||
.await?;
|
||||
|
||||
if !worked {
|
||||
return Err!(Uiaa(uiaainfo));
|
||||
return Err(Error::Uiaa(uiaainfo));
|
||||
}
|
||||
|
||||
// Success!
|
||||
@@ -673,7 +666,7 @@ pub(crate) async fn change_password_route(
|
||||
.uiaa
|
||||
.create(sender_user, body.sender_device(), &uiaainfo, json);
|
||||
|
||||
return Err!(Uiaa(uiaainfo));
|
||||
return Err(Error::Uiaa(uiaainfo));
|
||||
},
|
||||
| _ => {
|
||||
return Err!(Request(NotJson("JSON body is not valid")));
|
||||
@@ -791,7 +784,7 @@ pub(crate) async fn deactivate_route(
|
||||
.await?;
|
||||
|
||||
if !worked {
|
||||
return Err!(Uiaa(uiaainfo));
|
||||
return Err(Error::Uiaa(uiaainfo));
|
||||
}
|
||||
// Success!
|
||||
},
|
||||
@@ -802,7 +795,7 @@ pub(crate) async fn deactivate_route(
|
||||
.uiaa
|
||||
.create(sender_user, body.sender_device(), &uiaainfo, json);
|
||||
|
||||
return Err!(Uiaa(uiaainfo));
|
||||
return Err(Error::Uiaa(uiaainfo));
|
||||
},
|
||||
| _ => {
|
||||
return Err!(Request(NotJson("JSON body is not valid")));
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use axum::extract::State;
|
||||
use axum_client_ip::InsecureClientIp;
|
||||
use conduwuit::{Err, Result, debug, err, utils};
|
||||
use conduwuit::{Err, Error, Result, debug, err, utils};
|
||||
use futures::StreamExt;
|
||||
use ruma::{
|
||||
MilliSecondsSinceUnixEpoch, OwnedDeviceId,
|
||||
@@ -232,7 +232,7 @@ pub(crate) async fn delete_devices_route(
|
||||
.await?;
|
||||
|
||||
if !worked {
|
||||
return Err!(Uiaa(uiaainfo));
|
||||
return Err(Error::Uiaa(uiaainfo));
|
||||
}
|
||||
// Success!
|
||||
},
|
||||
@@ -243,10 +243,10 @@ pub(crate) async fn delete_devices_route(
|
||||
.uiaa
|
||||
.create(sender_user, sender_device, &uiaainfo, json);
|
||||
|
||||
return Err!(Uiaa(uiaainfo));
|
||||
return Err(Error::Uiaa(uiaainfo));
|
||||
},
|
||||
| _ => {
|
||||
return Err!(BadRequest(ErrorKind::NotJson, "Not json."));
|
||||
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ use std::{
|
||||
|
||||
use axum::extract::State;
|
||||
use conduwuit::{
|
||||
Err, Result, debug, debug_warn, err,
|
||||
Err, Error, Result, debug, debug_warn, err,
|
||||
result::NotFound,
|
||||
utils,
|
||||
utils::{IterStream, stream::WidebandExt},
|
||||
@@ -215,7 +215,7 @@ pub(crate) async fn upload_signing_keys_route(
|
||||
.await?;
|
||||
|
||||
if !worked {
|
||||
return Err!(Uiaa(uiaainfo));
|
||||
return Err(Error::Uiaa(uiaainfo));
|
||||
}
|
||||
// Success!
|
||||
},
|
||||
@@ -226,10 +226,10 @@ pub(crate) async fn upload_signing_keys_route(
|
||||
.uiaa
|
||||
.create(sender_user, sender_device, &uiaainfo, json);
|
||||
|
||||
return Err!(Uiaa(uiaainfo));
|
||||
return Err(Error::Uiaa(uiaainfo));
|
||||
},
|
||||
| _ => {
|
||||
return Err!(BadRequest(ErrorKind::NotJson, "Not json."));
|
||||
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -396,12 +396,12 @@ pub(crate) async fn get_key_changes_route(
|
||||
let from = body
|
||||
.from
|
||||
.parse()
|
||||
.map_err(|_| err!(BadRequest(ErrorKind::InvalidParam, "Invalid `from`.")))?;
|
||||
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `from`."))?;
|
||||
|
||||
let to = body
|
||||
.to
|
||||
.parse()
|
||||
.map_err(|_| err!(BadRequest(ErrorKind::InvalidParam, "Invalid `to`.")))?;
|
||||
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `to`."))?;
|
||||
|
||||
device_list_updates.extend(
|
||||
services
|
||||
|
||||
@@ -3,7 +3,7 @@ use std::time::Duration;
|
||||
use axum::extract::State;
|
||||
use axum_client_ip::InsecureClientIp;
|
||||
use conduwuit::{
|
||||
Err, Result, err, error,
|
||||
Err, Result, err,
|
||||
utils::{self, content_disposition::make_content_disposition, math::ruma_from_usize},
|
||||
};
|
||||
use conduwuit_service::{
|
||||
@@ -69,7 +69,7 @@ pub(crate) async fn create_content_route(
|
||||
.create(mxc, Some(user), Some(&content_disposition), content_type, &body.file)
|
||||
.await
|
||||
{
|
||||
error!("Failed to save uploaded media: {e}");
|
||||
err!("Failed to save uploaded media: {e}");
|
||||
return Err!(Request(Unknown("Failed to save uploaded media")));
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use axum::extract::State;
|
||||
use axum_client_ip::InsecureClientIp;
|
||||
use conduwuit::{
|
||||
Err, Result, at, debug_warn,
|
||||
Err, Error, Result, at, debug_warn,
|
||||
matrix::{
|
||||
event::{Event, Matches},
|
||||
pdu::PduCount,
|
||||
@@ -322,7 +322,7 @@ where
|
||||
|
||||
if server_ignored {
|
||||
// the sender's server is ignored, so ignore this event
|
||||
return Err!(BadRequest(
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::SenderIgnored { sender: None },
|
||||
"The sender's server is ignored by this server.",
|
||||
));
|
||||
@@ -331,7 +331,7 @@ where
|
||||
if user_ignored && !services.config.send_messages_from_ignored_users_to_client {
|
||||
// the recipient of this PDU has the sender ignored, and we're not
|
||||
// configured to send ignored messages to clients
|
||||
return Err!(BadRequest(
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::SenderIgnored { sender: Some(event.sender().to_owned()) },
|
||||
"You have ignored this sender.",
|
||||
));
|
||||
|
||||
+16
-16
@@ -1,5 +1,5 @@
|
||||
use axum::extract::State;
|
||||
use conduwuit::{Err, Result, err};
|
||||
use conduwuit::{Err, Error, Result, err};
|
||||
use conduwuit_service::Services;
|
||||
use ruma::{
|
||||
CanonicalJsonObject, CanonicalJsonValue,
|
||||
@@ -243,27 +243,27 @@ pub(crate) async fn set_pushrule_route(
|
||||
body.before.as_deref(),
|
||||
) {
|
||||
let err = match error {
|
||||
| InsertPushRuleError::ServerDefaultRuleId => err!(BadRequest(
|
||||
| InsertPushRuleError::ServerDefaultRuleId => Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Rule IDs starting with a dot are reserved for server-default rules.",
|
||||
)),
|
||||
| InsertPushRuleError::InvalidRuleId => err!(BadRequest(
|
||||
),
|
||||
| InsertPushRuleError::InvalidRuleId => Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Rule ID containing invalid characters.",
|
||||
)),
|
||||
| InsertPushRuleError::RelativeToServerDefaultRule => err!(BadRequest(
|
||||
),
|
||||
| InsertPushRuleError::RelativeToServerDefaultRule => Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Can't place a push rule relatively to a server-default rule.",
|
||||
)),
|
||||
| InsertPushRuleError::UnknownRuleId => err!(BadRequest(
|
||||
),
|
||||
| InsertPushRuleError::UnknownRuleId => Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"The before or after rule could not be found.",
|
||||
)),
|
||||
| InsertPushRuleError::BeforeHigherThanAfter => err!(BadRequest(
|
||||
),
|
||||
| InsertPushRuleError::BeforeHigherThanAfter => Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"The before rule has a higher priority than the after rule.",
|
||||
)),
|
||||
| _ => err!(BadRequest(ErrorKind::InvalidParam, "Invalid data.")),
|
||||
),
|
||||
| _ => Error::BadRequest(ErrorKind::InvalidParam, "Invalid data."),
|
||||
};
|
||||
|
||||
return Err(err);
|
||||
@@ -433,13 +433,13 @@ pub(crate) async fn delete_pushrule_route(
|
||||
.remove(body.kind.clone(), &body.rule_id)
|
||||
{
|
||||
let err = match error {
|
||||
| RemovePushRuleError::ServerDefault => err!(BadRequest(
|
||||
| RemovePushRuleError::ServerDefault => Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Cannot delete a server-default pushrule.",
|
||||
)),
|
||||
),
|
||||
| RemovePushRuleError::NotFound =>
|
||||
err!(BadRequest(ErrorKind::NotFound, "Push rule not found.")),
|
||||
| _ => err!(BadRequest(ErrorKind::InvalidParam, "Invalid data.")),
|
||||
Error::BadRequest(ErrorKind::NotFound, "Push rule not found."),
|
||||
| _ => Error::BadRequest(ErrorKind::InvalidParam, "Invalid data."),
|
||||
};
|
||||
|
||||
return Err(err);
|
||||
|
||||
@@ -4,6 +4,7 @@ use axum::extract::State;
|
||||
use axum_client_ip::InsecureClientIp;
|
||||
use conduwuit::{Err, Event, Result, debug_info, info, matrix::pdu::PduEvent, utils::ReadyExt};
|
||||
use conduwuit_service::Services;
|
||||
use rand::Rng;
|
||||
use ruma::{
|
||||
EventId, OwnedEventId, OwnedRoomId, OwnedUserId, RoomId, UserId,
|
||||
api::client::{
|
||||
@@ -243,7 +244,7 @@ fn build_report(report: Report) -> RoomMessageEventContent {
|
||||
/// random delay sending a response per spec suggestion regarding
|
||||
/// enumerating for potential events existing in our server.
|
||||
async fn delay_response() {
|
||||
let time_to_wait = rand::random_range(2..5);
|
||||
let time_to_wait = rand::thread_rng().gen_range(2..5);
|
||||
debug_info!(
|
||||
"Got successful /report request, waiting {time_to_wait} seconds before sending \
|
||||
successful response."
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::cmp::max;
|
||||
|
||||
use axum::extract::State;
|
||||
use conduwuit::{
|
||||
Err, Event, Result, RoomVersion, debug, err, info,
|
||||
Err, Error, Event, Result, RoomVersion, debug, err, info,
|
||||
matrix::{StateKey, pdu::PduBuilder},
|
||||
};
|
||||
use futures::{FutureExt, StreamExt};
|
||||
@@ -58,7 +58,7 @@ pub(crate) async fn upgrade_room_route(
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
if !services.server.supported_room_version(&body.new_version) {
|
||||
return Err!(BadRequest(
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::UnsupportedRoomVersion,
|
||||
"This server does not support that room version.",
|
||||
));
|
||||
@@ -170,7 +170,7 @@ pub(crate) async fn upgrade_room_route(
|
||||
"creator".into(),
|
||||
json!(&sender_user).try_into().map_err(|e| {
|
||||
info!("Error forming creation event: {e}");
|
||||
err!(BadRequest(ErrorKind::BadJson, "Error forming creation event"))
|
||||
Error::BadRequest(ErrorKind::BadJson, "Error forming creation event")
|
||||
})?,
|
||||
);
|
||||
},
|
||||
@@ -186,13 +186,13 @@ pub(crate) async fn upgrade_room_route(
|
||||
"room_version".into(),
|
||||
json!(&body.new_version)
|
||||
.try_into()
|
||||
.map_err(|_| err!(BadRequest(ErrorKind::BadJson, "Error forming creation event")))?,
|
||||
.map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Error forming creation event"))?,
|
||||
);
|
||||
create_event_content.insert(
|
||||
"predecessor".into(),
|
||||
json!(predecessor)
|
||||
.try_into()
|
||||
.map_err(|_| err!(BadRequest(ErrorKind::BadJson, "Error forming creation event")))?,
|
||||
.map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Error forming creation event"))?,
|
||||
);
|
||||
|
||||
// Validate creation event content
|
||||
@@ -203,7 +203,7 @@ pub(crate) async fn upgrade_room_route(
|
||||
)
|
||||
.is_err()
|
||||
{
|
||||
return Err!(BadRequest(ErrorKind::BadJson, "Error forming creation event"));
|
||||
return Err(Error::BadRequest(ErrorKind::BadJson, "Error forming creation event"));
|
||||
}
|
||||
|
||||
let create_event_id = services
|
||||
|
||||
@@ -3,7 +3,7 @@ use std::time::Duration;
|
||||
use axum::extract::State;
|
||||
use axum_client_ip::InsecureClientIp;
|
||||
use conduwuit::{
|
||||
Err, Result, debug, err, info,
|
||||
Err, Error, Result, debug, err, info,
|
||||
utils::{self, ReadyExt, hash},
|
||||
warn,
|
||||
};
|
||||
@@ -191,7 +191,7 @@ pub(crate) async fn handle_login(
|
||||
}
|
||||
|
||||
if services.users.is_locked(&user_id).await? {
|
||||
return Err!(BadRequest(ErrorKind::UserLocked, "This account has been locked."));
|
||||
return Err(Error::BadRequest(ErrorKind::UserLocked, "This account has been locked."));
|
||||
}
|
||||
|
||||
if services.users.is_login_disabled(&user_id).await {
|
||||
@@ -390,7 +390,7 @@ pub(crate) async fn login_token_route(
|
||||
.await?;
|
||||
|
||||
if !worked {
|
||||
return Err!(Uiaa(uiaainfo));
|
||||
return Err(Error::Uiaa(uiaainfo));
|
||||
}
|
||||
|
||||
// Success!
|
||||
@@ -402,7 +402,7 @@ pub(crate) async fn login_token_route(
|
||||
.uiaa
|
||||
.create(sender_user, sender_device, &uiaainfo, json);
|
||||
|
||||
return Err!(Uiaa(uiaainfo));
|
||||
return Err(Error::Uiaa(uiaainfo));
|
||||
},
|
||||
| _ => {
|
||||
return Err!(Request(NotJson("No JSON body was sent when required.")));
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
use axum::extract::State;
|
||||
use conduwuit::{Result, err};
|
||||
use conduwuit::{Error, Result};
|
||||
use conduwuit_service::sending::EduBuf;
|
||||
use futures::StreamExt;
|
||||
use ruma::{
|
||||
@@ -66,7 +66,7 @@ pub(crate) async fn send_event_to_device_route(
|
||||
|
||||
let event = event
|
||||
.deserialize_as()
|
||||
.map_err(|_| err!(BadRequest(ErrorKind::InvalidParam, "Event is invalid")))?;
|
||||
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid"))?;
|
||||
|
||||
match target_device_id_maybe {
|
||||
| DeviceIdOrAllDevices::DeviceId(target_device_id) => {
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
use axum::{Json, extract::State, response::IntoResponse};
|
||||
use conduwuit::{Err, Result};
|
||||
use ruma::api::client::discovery::{
|
||||
discover_homeserver::{self, HomeserverInfo, SlidingSyncProxyInfo},
|
||||
discover_support::{self, Contact},
|
||||
use conduwuit::{Error, Result};
|
||||
use ruma::api::client::{
|
||||
discovery::{
|
||||
discover_homeserver::{self, HomeserverInfo, SlidingSyncProxyInfo},
|
||||
discover_support::{self, Contact},
|
||||
},
|
||||
error::ErrorKind,
|
||||
};
|
||||
|
||||
use crate::Ruma;
|
||||
@@ -16,7 +19,7 @@ pub(crate) async fn well_known_client(
|
||||
) -> Result<discover_homeserver::Response> {
|
||||
let client_url = match services.config.well_known.client.as_ref() {
|
||||
| Some(url) => url.to_string(),
|
||||
| None => return Err!(BadRequest(ErrorKind::NotFound, "Not found.")),
|
||||
| None => return Err(Error::BadRequest(ErrorKind::NotFound, "Not found.")),
|
||||
};
|
||||
|
||||
Ok(discover_homeserver::Response {
|
||||
@@ -85,7 +88,7 @@ pub(crate) async fn well_known_support(
|
||||
|
||||
if contacts.is_empty() && support_page.is_none() {
|
||||
// No admin room, no configured contacts, and no support page
|
||||
return Err!(BadRequest(ErrorKind::NotFound, "Not found."));
|
||||
return Err(Error::BadRequest(ErrorKind::NotFound, "Not found."));
|
||||
}
|
||||
|
||||
Ok(discover_support::Response { contacts, support_page })
|
||||
@@ -102,7 +105,7 @@ pub(crate) async fn syncv3_client_server_json(
|
||||
| Some(url) => url.to_string(),
|
||||
| None => match services.config.well_known.server.as_ref() {
|
||||
| Some(url) => url.to_string(),
|
||||
| None => return Err!(BadRequest(ErrorKind::NotFound, "Not found.")),
|
||||
| None => return Err(Error::BadRequest(ErrorKind::NotFound, "Not found.")),
|
||||
},
|
||||
};
|
||||
|
||||
|
||||
+10
-10
@@ -4,7 +4,7 @@ use axum_extra::{
|
||||
headers::{Authorization, authorization::Bearer},
|
||||
typed_header::TypedHeaderRejectionReason,
|
||||
};
|
||||
use conduwuit::{Err, Result, debug_error, err, warn};
|
||||
use conduwuit::{Err, Error, Result, debug_error, err, warn};
|
||||
use futures::{
|
||||
TryFutureExt,
|
||||
future::{
|
||||
@@ -77,7 +77,7 @@ pub(super) async fn auth(
|
||||
// already
|
||||
},
|
||||
| Token::None | Token::Invalid => {
|
||||
return Err!(BadRequest(
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::MissingToken,
|
||||
"Missing or invalid access token.",
|
||||
));
|
||||
@@ -96,7 +96,7 @@ pub(super) async fn auth(
|
||||
// already
|
||||
},
|
||||
| Token::None | Token::Invalid => {
|
||||
return Err!(BadRequest(
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::MissingToken,
|
||||
"Missing or invalid access token.",
|
||||
));
|
||||
@@ -130,10 +130,10 @@ pub(super) async fn auth(
|
||||
appservice_info: None,
|
||||
})
|
||||
} else {
|
||||
Err!(BadRequest(ErrorKind::MissingToken, "Missing access token."))
|
||||
Err(Error::BadRequest(ErrorKind::MissingToken, "Missing access token."))
|
||||
}
|
||||
},
|
||||
| _ => Err!(BadRequest(ErrorKind::MissingToken, "Missing access token.")),
|
||||
| _ => Err(Error::BadRequest(ErrorKind::MissingToken, "Missing access token.")),
|
||||
},
|
||||
| (
|
||||
AuthScheme::AccessToken | AuthScheme::AccessTokenOptional | AuthScheme::None,
|
||||
@@ -149,7 +149,7 @@ pub(super) async fn auth(
|
||||
&ruma::api::client::session::logout::v3::Request::METADATA
|
||||
| &ruma::api::client::session::logout_all::v3::Request::METADATA
|
||||
) {
|
||||
return Err!(BadRequest(
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::UserLocked,
|
||||
"This account has been locked.",
|
||||
));
|
||||
@@ -174,11 +174,11 @@ pub(super) async fn auth(
|
||||
appservice_info: None,
|
||||
}),
|
||||
| (AuthScheme::ServerSignatures, Token::Appservice(_) | Token::User(_)) =>
|
||||
Err!(BadRequest(
|
||||
Err(Error::BadRequest(
|
||||
ErrorKind::Unauthorized,
|
||||
"Only server signatures should be used on this endpoint.",
|
||||
)),
|
||||
| (AuthScheme::AppserviceToken, Token::User(_)) => Err!(BadRequest(
|
||||
| (AuthScheme::AppserviceToken, Token::User(_)) => Err(Error::BadRequest(
|
||||
ErrorKind::Unauthorized,
|
||||
"Only appservice access tokens should be used on this endpoint.",
|
||||
)),
|
||||
@@ -196,13 +196,13 @@ pub(super) async fn auth(
|
||||
appservice_info: None,
|
||||
})
|
||||
} else {
|
||||
Err!(BadRequest(
|
||||
Err(Error::BadRequest(
|
||||
ErrorKind::UnknownToken { soft_logout: false },
|
||||
"Unknown access token.",
|
||||
))
|
||||
}
|
||||
},
|
||||
| (_, Token::Invalid) => Err!(BadRequest(
|
||||
| (_, Token::Invalid) => Err(Error::BadRequest(
|
||||
ErrorKind::UnknownToken { soft_logout: false },
|
||||
"Unknown access token.",
|
||||
)),
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
use std::{borrow::Borrow, iter::once};
|
||||
|
||||
use axum::extract::State;
|
||||
use conduwuit::{Err, Error, Result, err, info, utils::stream::ReadyExt};
|
||||
use conduwuit::{Err, Error, Result, info, utils::stream::ReadyExt};
|
||||
use futures::StreamExt;
|
||||
use ruma::{RoomId, api::federation::authorization::get_event_authorization};
|
||||
use ruma::{
|
||||
RoomId,
|
||||
api::{client::error::ErrorKind, federation::authorization::get_event_authorization},
|
||||
};
|
||||
|
||||
use super::AccessCheck;
|
||||
use crate::Ruma;
|
||||
@@ -44,7 +47,7 @@ pub(crate) async fn get_event_authorization_route(
|
||||
.timeline
|
||||
.get_pdu_json(&body.event_id)
|
||||
.await
|
||||
.map_err(|_| err!(BadRequest(ErrorKind::NotFound, "Event not found.")))?;
|
||||
.map_err(|_| Error::BadRequest(ErrorKind::NotFound, "Event not found."))?;
|
||||
|
||||
let room_id_str = event
|
||||
.get("room_id")
|
||||
|
||||
@@ -2,7 +2,7 @@ use axum::extract::State;
|
||||
use axum_client_ip::InsecureClientIp;
|
||||
use base64::{Engine as _, engine::general_purpose};
|
||||
use conduwuit::{
|
||||
Err, PduEvent, Result, err, error,
|
||||
Err, Error, PduEvent, Result, err, error,
|
||||
matrix::{Event, event::gen_event_id},
|
||||
utils::{self, hash::sha256},
|
||||
warn,
|
||||
@@ -33,7 +33,7 @@ pub(crate) async fn create_invite_route(
|
||||
.await?;
|
||||
|
||||
if !services.server.supported_room_version(&body.room_version) {
|
||||
return Err!(BadRequest(
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::IncompatibleRoomVersion { room_version: body.room_version.clone() },
|
||||
"Server does not support this room version.",
|
||||
));
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::borrow::ToOwned;
|
||||
|
||||
use axum::extract::State;
|
||||
use conduwuit::{Err, Result, debug, debug_info, info, matrix::pdu::PduBuilder, warn};
|
||||
use conduwuit::{Err, Error, Result, debug, debug_info, info, matrix::pdu::PduBuilder, warn};
|
||||
use conduwuit_service::Services;
|
||||
use futures::StreamExt;
|
||||
use ruma::{
|
||||
@@ -80,7 +80,7 @@ pub(crate) async fn create_join_event_template_route(
|
||||
|
||||
let room_version_id = services.rooms.state.get_room_version(&body.room_id).await?;
|
||||
if !body.ver.contains(&room_version_id) {
|
||||
return Err!(BadRequest(
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::IncompatibleRoomVersion { room_version: room_version_id },
|
||||
"Room version not supported.",
|
||||
));
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use RoomVersionId::*;
|
||||
use axum::extract::State;
|
||||
use conduwuit::{Err, Result, debug_warn, info, matrix::pdu::PduBuilder, warn};
|
||||
use conduwuit::{Err, Error, Result, debug_warn, info, matrix::pdu::PduBuilder, warn};
|
||||
use ruma::{
|
||||
RoomVersionId,
|
||||
api::{client::error::ErrorKind, federation::knock::create_knock_event_template},
|
||||
@@ -67,14 +67,14 @@ pub(crate) async fn create_knock_event_template_route(
|
||||
let room_version_id = services.rooms.state.get_room_version(&body.room_id).await?;
|
||||
|
||||
if matches!(room_version_id, V1 | V2 | V3 | V4 | V5 | V6) {
|
||||
return Err!(BadRequest(
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::IncompatibleRoomVersion { room_version: room_version_id },
|
||||
"Room version does not support knocking.",
|
||||
));
|
||||
}
|
||||
|
||||
if !body.ver.contains(&room_version_id) {
|
||||
return Err!(BadRequest(
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::IncompatibleRoomVersion { room_version: room_version_id },
|
||||
"Your homeserver does not support the features required to knock on this room.",
|
||||
));
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use axum::extract::State;
|
||||
use axum_client_ip::InsecureClientIp;
|
||||
use conduwuit::{Err, Result, err};
|
||||
use conduwuit::{Error, Result};
|
||||
use ruma::{
|
||||
api::{
|
||||
client::error::ErrorKind,
|
||||
@@ -25,7 +25,7 @@ pub(crate) async fn get_public_rooms_filtered_route(
|
||||
.config
|
||||
.allow_public_room_directory_over_federation
|
||||
{
|
||||
return Err!(BadRequest(ErrorKind::forbidden(), "Room directory is not public"));
|
||||
return Err(Error::BadRequest(ErrorKind::forbidden(), "Room directory is not public"));
|
||||
}
|
||||
|
||||
let response = crate::client::get_public_rooms_filtered_helper(
|
||||
@@ -38,10 +38,7 @@ pub(crate) async fn get_public_rooms_filtered_route(
|
||||
)
|
||||
.await
|
||||
.map_err(|_| {
|
||||
err!(BadRequest(
|
||||
ErrorKind::Unknown,
|
||||
"Failed to return this server's public room list."
|
||||
))
|
||||
Error::BadRequest(ErrorKind::Unknown, "Failed to return this server's public room list.")
|
||||
})?;
|
||||
|
||||
Ok(get_public_rooms_filtered::v1::Response {
|
||||
@@ -65,7 +62,7 @@ pub(crate) async fn get_public_rooms_route(
|
||||
.globals
|
||||
.allow_public_room_directory_over_federation()
|
||||
{
|
||||
return Err!(BadRequest(ErrorKind::forbidden(), "Room directory is not public"));
|
||||
return Err(Error::BadRequest(ErrorKind::forbidden(), "Room directory is not public"));
|
||||
}
|
||||
|
||||
let response = crate::client::get_public_rooms_filtered_helper(
|
||||
@@ -78,10 +75,7 @@ pub(crate) async fn get_public_rooms_route(
|
||||
)
|
||||
.await
|
||||
.map_err(|_| {
|
||||
err!(BadRequest(
|
||||
ErrorKind::Unknown,
|
||||
"Failed to return this server's public room list."
|
||||
))
|
||||
Error::BadRequest(ErrorKind::Unknown, "Failed to return this server's public room list.")
|
||||
})?;
|
||||
|
||||
Ok(get_public_rooms::v1::Response {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
use axum::extract::State;
|
||||
use conduwuit::{Err, Result, err};
|
||||
use conduwuit::{Error, Result, err};
|
||||
use futures::StreamExt;
|
||||
use get_profile_information::v1::ProfileField;
|
||||
use rand::seq::SliceRandom;
|
||||
@@ -40,7 +40,7 @@ pub(crate) async fn get_room_information_route(
|
||||
servers.sort_unstable();
|
||||
servers.dedup();
|
||||
|
||||
servers.shuffle(&mut rand::rng());
|
||||
servers.shuffle(&mut rand::thread_rng());
|
||||
|
||||
// insert our server as the very first choice if in list
|
||||
if let Some(server_index) = servers
|
||||
@@ -67,16 +67,17 @@ pub(crate) async fn get_profile_information_route(
|
||||
.config
|
||||
.allow_inbound_profile_lookup_federation_requests
|
||||
{
|
||||
return Err!(BadRequest(
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::forbidden(),
|
||||
"Profile lookup over federation is not allowed on this homeserver.",
|
||||
));
|
||||
}
|
||||
|
||||
if !services.globals.server_is_ours(body.user_id.server_name()) {
|
||||
return Err!(
|
||||
BadRequest(ErrorKind::InvalidParam, "User does not belong to this server.",)
|
||||
);
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"User does not belong to this server.",
|
||||
));
|
||||
}
|
||||
|
||||
let mut displayname = None;
|
||||
|
||||
@@ -114,7 +114,7 @@ pub(crate) async fn send_transaction_message_route(
|
||||
);
|
||||
for (id, result) in &results {
|
||||
if let Err(e) = result {
|
||||
if matches!(e, Error::BadRequest { kind: ErrorKind::NotFound, .. }) {
|
||||
if matches!(e, Error::BadRequest(ErrorKind::NotFound, _)) {
|
||||
warn!("Incoming PDU failed {id}: {e:?}");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use axum::extract::State;
|
||||
use conduwuit::{Err, Result};
|
||||
use conduwuit::{Error, Result};
|
||||
use futures::{FutureExt, StreamExt, TryFutureExt};
|
||||
use ruma::api::{
|
||||
client::error::ErrorKind,
|
||||
@@ -24,7 +24,7 @@ pub(crate) async fn get_devices_route(
|
||||
body: Ruma<get_devices::v1::Request>,
|
||||
) -> Result<get_devices::v1::Response> {
|
||||
if !services.globals.user_is_local(&body.user_id) {
|
||||
return Err!(BadRequest(
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Tried to access user from other server.",
|
||||
));
|
||||
@@ -86,9 +86,10 @@ pub(crate) async fn get_keys_route(
|
||||
.iter()
|
||||
.any(|(u, _)| !services.globals.user_is_local(u))
|
||||
{
|
||||
return Err!(
|
||||
BadRequest(ErrorKind::InvalidParam, "User does not belong to this server.",)
|
||||
);
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"User does not belong to this server.",
|
||||
));
|
||||
}
|
||||
|
||||
let result = get_keys_helper(
|
||||
@@ -120,7 +121,7 @@ pub(crate) async fn claim_keys_route(
|
||||
.iter()
|
||||
.any(|(u, _)| !services.globals.user_is_local(u))
|
||||
{
|
||||
return Err!(BadRequest(
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Tried to access user from other server.",
|
||||
));
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use axum::extract::State;
|
||||
use conduwuit::{Err, Result};
|
||||
use ruma::api::federation::discovery::discover_homeserver;
|
||||
use conduwuit::{Error, Result};
|
||||
use ruma::api::{client::error::ErrorKind, federation::discovery::discover_homeserver};
|
||||
|
||||
use crate::Ruma;
|
||||
|
||||
@@ -14,7 +14,7 @@ pub(crate) async fn well_known_server(
|
||||
Ok(discover_homeserver::Response {
|
||||
server: match services.server.config.well_known.server.as_ref() {
|
||||
| Some(server_name) => server_name.to_owned(),
|
||||
| None => return Err!(BadRequest(ErrorKind::NotFound, "Not found.")),
|
||||
| None => return Err(Error::BadRequest(ErrorKind::NotFound, "Not found.")),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
+1
-3
@@ -86,7 +86,6 @@ libloading.optional = true
|
||||
log.workspace = true
|
||||
num-traits.workspace = true
|
||||
rand.workspace = true
|
||||
rand_core = { version = "0.6.4", features = ["getrandom"] }
|
||||
regex.workspace = true
|
||||
reqwest.workspace = true
|
||||
ring.workspace = true
|
||||
@@ -98,8 +97,7 @@ serde-saphyr.workspace = true
|
||||
serde.workspace = true
|
||||
smallvec.workspace = true
|
||||
smallstr.workspace = true
|
||||
snafu.workspace = true
|
||||
paste.workspace = true
|
||||
thiserror.workspace = true
|
||||
tikv-jemallocator.optional = true
|
||||
tikv-jemallocator.workspace = true
|
||||
tikv-jemalloc-ctl.optional = true
|
||||
|
||||
@@ -1525,7 +1525,7 @@ pub struct Config {
|
||||
/// sender user's server name, inbound federation X-Matrix origin, and
|
||||
/// outbound federation handler.
|
||||
///
|
||||
/// You can set this to [".*"] to block all servers by default, and then
|
||||
/// You can set this to ["*"] to block all servers by default, and then
|
||||
/// use `allowed_remote_server_names` to allow only specific servers.
|
||||
///
|
||||
/// example: ["badserver\\.tld$", "badphrase", "19dollarfortnitecards"]
|
||||
|
||||
+30
-129
@@ -45,162 +45,63 @@ macro_rules! Err {
|
||||
macro_rules! err {
|
||||
(Request(Forbidden($level:ident!($($args:tt)+)))) => {{
|
||||
let mut buf = String::new();
|
||||
$crate::error::Error::Request {
|
||||
kind: $crate::ruma::api::client::error::ErrorKind::forbidden(),
|
||||
message: $crate::err_log!(buf, $level, $($args)+),
|
||||
code: $crate::http::StatusCode::BAD_REQUEST,
|
||||
backtrace: Some($crate::snafu::Backtrace::capture()),
|
||||
}
|
||||
$crate::error::Error::Request(
|
||||
$crate::ruma::api::client::error::ErrorKind::forbidden(),
|
||||
$crate::err_log!(buf, $level, $($args)+),
|
||||
$crate::http::StatusCode::BAD_REQUEST
|
||||
)
|
||||
}};
|
||||
|
||||
(Request(Forbidden($($args:tt)+))) => {
|
||||
{
|
||||
let message: std::borrow::Cow<'static, str> = $crate::format_maybe!($($args)+);
|
||||
$crate::error::Error::Request {
|
||||
kind: $crate::ruma::api::client::error::ErrorKind::forbidden(),
|
||||
message,
|
||||
code: $crate::http::StatusCode::BAD_REQUEST,
|
||||
backtrace: Some($crate::snafu::Backtrace::capture()),
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
(Request(NotFound($level:ident!($($args:tt)+)))) => {{
|
||||
let mut buf = String::new();
|
||||
$crate::error::Error::Request {
|
||||
kind: $crate::ruma::api::client::error::ErrorKind::NotFound,
|
||||
message: $crate::err_log!(buf, $level, $($args)+),
|
||||
code: $crate::http::StatusCode::BAD_REQUEST,
|
||||
backtrace: None,
|
||||
}
|
||||
}};
|
||||
|
||||
(Request(NotFound($($args:tt)+))) => {
|
||||
{
|
||||
let message: std::borrow::Cow<'static, str> = $crate::format_maybe!($($args)+);
|
||||
$crate::error::Error::Request {
|
||||
kind: $crate::ruma::api::client::error::ErrorKind::NotFound,
|
||||
message,
|
||||
code: $crate::http::StatusCode::BAD_REQUEST,
|
||||
backtrace: None,
|
||||
}
|
||||
}
|
||||
$crate::error::Error::Request(
|
||||
$crate::ruma::api::client::error::ErrorKind::forbidden(),
|
||||
$crate::format_maybe!($($args)+),
|
||||
$crate::http::StatusCode::BAD_REQUEST
|
||||
)
|
||||
};
|
||||
|
||||
(Request($variant:ident($level:ident!($($args:tt)+)))) => {{
|
||||
let mut buf = String::new();
|
||||
$crate::error::Error::Request {
|
||||
kind: $crate::ruma::api::client::error::ErrorKind::$variant,
|
||||
message: $crate::err_log!(buf, $level, $($args)+),
|
||||
code: $crate::http::StatusCode::BAD_REQUEST,
|
||||
backtrace: Some($crate::snafu::Backtrace::capture()),
|
||||
}
|
||||
$crate::error::Error::Request(
|
||||
$crate::ruma::api::client::error::ErrorKind::$variant,
|
||||
$crate::err_log!(buf, $level, $($args)+),
|
||||
$crate::http::StatusCode::BAD_REQUEST
|
||||
)
|
||||
}};
|
||||
|
||||
(Request($variant:ident($($args:tt)+))) => {
|
||||
{
|
||||
let message: std::borrow::Cow<'static, str> = $crate::format_maybe!($($args)+);
|
||||
$crate::error::Error::Request {
|
||||
kind: $crate::ruma::api::client::error::ErrorKind::$variant,
|
||||
message,
|
||||
code: $crate::http::StatusCode::BAD_REQUEST,
|
||||
backtrace: Some($crate::snafu::Backtrace::capture()),
|
||||
}
|
||||
}
|
||||
$crate::error::Error::Request(
|
||||
$crate::ruma::api::client::error::ErrorKind::$variant,
|
||||
$crate::format_maybe!($($args)+),
|
||||
$crate::http::StatusCode::BAD_REQUEST
|
||||
)
|
||||
};
|
||||
|
||||
(Config($item:literal, $($args:tt)+)) => {{
|
||||
let mut buf = String::new();
|
||||
$crate::error::ConfigSnafu {
|
||||
directive: $item,
|
||||
message: $crate::err_log!(buf, error, config = %$item, $($args)+),
|
||||
}.build()
|
||||
$crate::error::Error::Config($item, $crate::err_log!(buf, error, config = %$item, $($args)+))
|
||||
}};
|
||||
|
||||
(BadRequest(ErrorKind::NotFound, $($args:tt)+)) => {
|
||||
{
|
||||
let message: std::borrow::Cow<'static, str> = $crate::format_maybe!($($args)+);
|
||||
$crate::error::Error::Request {
|
||||
kind: $crate::ruma::api::client::error::ErrorKind::NotFound,
|
||||
message,
|
||||
code: $crate::http::StatusCode::BAD_REQUEST,
|
||||
backtrace: None,
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
(BadRequest($kind:expr, $($args:tt)+)) => {
|
||||
{
|
||||
let message: std::borrow::Cow<'static, str> = $crate::format_maybe!($($args)+);
|
||||
$crate::error::BadRequestSnafu {
|
||||
kind: $kind,
|
||||
message,
|
||||
}.build()
|
||||
}
|
||||
};
|
||||
|
||||
(FeatureDisabled($($args:tt)+)) => {
|
||||
{
|
||||
let feature: std::borrow::Cow<'static, str> = $crate::format_maybe!($($args)+);
|
||||
$crate::error::FeatureDisabledSnafu { feature }.build()
|
||||
}
|
||||
};
|
||||
|
||||
(Federation($server:expr, $error:expr $(,)?)) => {
|
||||
{
|
||||
$crate::error::FederationSnafu {
|
||||
server: $server,
|
||||
error: $error,
|
||||
}.build()
|
||||
}
|
||||
};
|
||||
|
||||
(InconsistentRoomState($message:expr, $room_id:expr $(,)?)) => {
|
||||
{
|
||||
$crate::error::InconsistentRoomStateSnafu {
|
||||
message: $message,
|
||||
room_id: $room_id,
|
||||
}.build()
|
||||
}
|
||||
};
|
||||
|
||||
(Uiaa($info:expr $(,)?)) => {
|
||||
{
|
||||
$crate::error::UiaaSnafu {
|
||||
info: $info,
|
||||
}.build()
|
||||
}
|
||||
};
|
||||
|
||||
($variant:ident($level:ident!($($args:tt)+))) => {{
|
||||
let mut buf = String::new();
|
||||
$crate::paste::paste! {
|
||||
$crate::error::[<$variant Snafu>] {
|
||||
message: $crate::err_log!(buf, $level, $($args)+),
|
||||
}.build()
|
||||
}
|
||||
$crate::error::Error::$variant($crate::err_log!(buf, $level, $($args)+))
|
||||
}};
|
||||
|
||||
($variant:ident($($args:ident),+)) => {
|
||||
$crate::error::Error::$variant($($args),+)
|
||||
};
|
||||
|
||||
($variant:ident($($args:tt)+)) => {
|
||||
$crate::paste::paste! {
|
||||
{
|
||||
let message: std::borrow::Cow<'static, str> = $crate::format_maybe!($($args)+);
|
||||
$crate::error::[<$variant Snafu>] { message }.build()
|
||||
}
|
||||
}
|
||||
$crate::error::Error::$variant($crate::format_maybe!($($args)+))
|
||||
};
|
||||
|
||||
($level:ident!($($args:tt)+)) => {{
|
||||
let mut buf = String::new();
|
||||
let message: std::borrow::Cow<'static, str> = $crate::err_log!(buf, $level, $($args)+);
|
||||
$crate::error::ErrSnafu { message }.build()
|
||||
$crate::error::Error::Err($crate::err_log!(buf, $level, $($args)+))
|
||||
}};
|
||||
|
||||
($($args:tt)+) => {
|
||||
{
|
||||
let message: std::borrow::Cow<'static, str> = $crate::format_maybe!($($args)+);
|
||||
$crate::error::ErrSnafu { message }.build()
|
||||
}
|
||||
$crate::error::Error::Err($crate::format_maybe!($($args)+))
|
||||
};
|
||||
}
|
||||
|
||||
@@ -233,7 +134,7 @@ macro_rules! err_log {
|
||||
};
|
||||
|
||||
($crate::error::visit)(&mut $out, LEVEL, &__CALLSITE, &mut valueset_all!(__CALLSITE.metadata().fields(), $($fields)+));
|
||||
std::borrow::Cow::<'static, str>::from($out)
|
||||
($out).into()
|
||||
}}
|
||||
}
|
||||
|
||||
|
||||
+139
-448
@@ -6,391 +6,151 @@ mod serde;
|
||||
|
||||
use std::{any::Any, borrow::Cow, convert::Infallible, sync::PoisonError};
|
||||
|
||||
use snafu::{IntoError, prelude::*};
|
||||
|
||||
pub use self::{err::visit, log::*};
|
||||
|
||||
#[derive(Debug, Snafu)]
|
||||
#[snafu(visibility(pub))]
|
||||
#[derive(thiserror::Error)]
|
||||
pub enum Error {
|
||||
#[snafu(display("PANIC!"))]
|
||||
PanicAny {
|
||||
panic: Box<dyn Any + Send>,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("PANIC! {message}"))]
|
||||
Panic {
|
||||
message: &'static str,
|
||||
panic: Box<dyn Any + Send + 'static>,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
#[error("PANIC!")]
|
||||
PanicAny(Box<dyn Any + Send>),
|
||||
#[error("PANIC! {0}")]
|
||||
Panic(&'static str, Box<dyn Any + Send + 'static>),
|
||||
|
||||
// std
|
||||
#[snafu(display("Format error: {source}"))]
|
||||
Fmt {
|
||||
source: std::fmt::Error,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("UTF-8 conversion error: {source}"))]
|
||||
FromUtf8 {
|
||||
source: std::string::FromUtf8Error,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("I/O error: {source}"))]
|
||||
Io {
|
||||
source: std::io::Error,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("Parse float error: {source}"))]
|
||||
ParseFloat {
|
||||
source: std::num::ParseFloatError,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("Parse int error: {source}"))]
|
||||
ParseInt {
|
||||
source: std::num::ParseIntError,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("Error: {source}"))]
|
||||
Std {
|
||||
source: Box<dyn std::error::Error + Send>,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("Thread access error: {source}"))]
|
||||
ThreadAccessError {
|
||||
source: std::thread::AccessError,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("Integer conversion error: {source}"))]
|
||||
TryFromInt {
|
||||
source: std::num::TryFromIntError,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("Slice conversion error: {source}"))]
|
||||
TryFromSlice {
|
||||
source: std::array::TryFromSliceError,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("UTF-8 error: {source}"))]
|
||||
Utf8 {
|
||||
source: std::str::Utf8Error,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
#[error(transparent)]
|
||||
Fmt(#[from] std::fmt::Error),
|
||||
#[error(transparent)]
|
||||
FromUtf8(#[from] std::string::FromUtf8Error),
|
||||
#[error("I/O error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
#[error(transparent)]
|
||||
ParseFloat(#[from] std::num::ParseFloatError),
|
||||
#[error(transparent)]
|
||||
ParseInt(#[from] std::num::ParseIntError),
|
||||
#[error(transparent)]
|
||||
Std(#[from] Box<dyn std::error::Error + Send>),
|
||||
#[error(transparent)]
|
||||
ThreadAccessError(#[from] std::thread::AccessError),
|
||||
#[error(transparent)]
|
||||
TryFromInt(#[from] std::num::TryFromIntError),
|
||||
#[error(transparent)]
|
||||
TryFromSlice(#[from] std::array::TryFromSliceError),
|
||||
#[error(transparent)]
|
||||
Utf8(#[from] std::str::Utf8Error),
|
||||
|
||||
// third-party
|
||||
#[snafu(display("Capacity error: {source}"))]
|
||||
CapacityError {
|
||||
source: arrayvec::CapacityError,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("Cargo.toml error: {source}"))]
|
||||
CargoToml {
|
||||
source: cargo_toml::Error,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("Clap error: {source}"))]
|
||||
Clap {
|
||||
source: clap::error::Error,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("Extension rejection: {source}"))]
|
||||
Extension {
|
||||
source: axum::extract::rejection::ExtensionRejection,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("Figment error: {source}"))]
|
||||
Figment {
|
||||
source: figment::error::Error,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("HTTP error: {source}"))]
|
||||
Http {
|
||||
source: http::Error,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("Invalid HTTP header value: {source}"))]
|
||||
HttpHeader {
|
||||
source: http::header::InvalidHeaderValue,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("Join error: {source}"))]
|
||||
JoinError {
|
||||
source: tokio::task::JoinError,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("JSON error: {source}"))]
|
||||
Json {
|
||||
source: serde_json::Error,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("JS parse int error: {source}"))]
|
||||
JsParseInt {
|
||||
source: ruma::JsParseIntError,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("JS try from int error: {source}"))]
|
||||
JsTryFromInt {
|
||||
source: ruma::JsTryFromIntError,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("Path rejection: {source}"))]
|
||||
Path {
|
||||
source: axum::extract::rejection::PathRejection,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("Mutex poisoned: {message}"))]
|
||||
Poison {
|
||||
message: Cow<'static, str>,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("Regex error: {source}"))]
|
||||
Regex {
|
||||
source: regex::Error,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("Request error: {source}"))]
|
||||
Reqwest {
|
||||
source: reqwest::Error,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("{message}"))]
|
||||
SerdeDe {
|
||||
message: Cow<'static, str>,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("{message}"))]
|
||||
SerdeSer {
|
||||
message: Cow<'static, str>,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("TOML deserialization error: {source}"))]
|
||||
TomlDe {
|
||||
source: toml::de::Error,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("TOML serialization error: {source}"))]
|
||||
TomlSer {
|
||||
source: toml::ser::Error,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("Tracing filter error: {source}"))]
|
||||
TracingFilter {
|
||||
source: tracing_subscriber::filter::ParseError,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("Tracing reload error: {source}"))]
|
||||
TracingReload {
|
||||
source: tracing_subscriber::reload::Error,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("Typed header rejection: {source}"))]
|
||||
TypedHeader {
|
||||
source: axum_extra::typed_header::TypedHeaderRejection,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("YAML deserialization error: {source}"))]
|
||||
YamlDe {
|
||||
source: serde_saphyr::Error,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("YAML serialization error: {source}"))]
|
||||
YamlSer {
|
||||
source: serde_saphyr::ser_error::Error,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
#[error(transparent)]
|
||||
CapacityError(#[from] arrayvec::CapacityError),
|
||||
#[error(transparent)]
|
||||
CargoToml(#[from] cargo_toml::Error),
|
||||
#[error(transparent)]
|
||||
Clap(#[from] clap::error::Error),
|
||||
#[error(transparent)]
|
||||
Extension(#[from] axum::extract::rejection::ExtensionRejection),
|
||||
#[error(transparent)]
|
||||
Figment(#[from] figment::error::Error),
|
||||
#[error(transparent)]
|
||||
Http(#[from] http::Error),
|
||||
#[error(transparent)]
|
||||
HttpHeader(#[from] http::header::InvalidHeaderValue),
|
||||
#[error("Join error: {0}")]
|
||||
JoinError(#[from] tokio::task::JoinError),
|
||||
#[error(transparent)]
|
||||
Json(#[from] serde_json::Error),
|
||||
#[error(transparent)]
|
||||
JsParseInt(#[from] ruma::JsParseIntError), // js_int re-export
|
||||
#[error(transparent)]
|
||||
JsTryFromInt(#[from] ruma::JsTryFromIntError), // js_int re-export
|
||||
#[error(transparent)]
|
||||
Path(#[from] axum::extract::rejection::PathRejection),
|
||||
#[error("Mutex poisoned: {0}")]
|
||||
Poison(Cow<'static, str>),
|
||||
#[error("Regex error: {0}")]
|
||||
Regex(#[from] regex::Error),
|
||||
#[error("Request error: {0}")]
|
||||
Reqwest(#[from] reqwest::Error),
|
||||
#[error("{0}")]
|
||||
SerdeDe(Cow<'static, str>),
|
||||
#[error("{0}")]
|
||||
SerdeSer(Cow<'static, str>),
|
||||
#[error(transparent)]
|
||||
TomlDe(#[from] toml::de::Error),
|
||||
#[error(transparent)]
|
||||
TomlSer(#[from] toml::ser::Error),
|
||||
#[error("Tracing filter error: {0}")]
|
||||
TracingFilter(#[from] tracing_subscriber::filter::ParseError),
|
||||
#[error("Tracing reload error: {0}")]
|
||||
TracingReload(#[from] tracing_subscriber::reload::Error),
|
||||
#[error(transparent)]
|
||||
TypedHeader(#[from] axum_extra::typed_header::TypedHeaderRejection),
|
||||
#[error(transparent)]
|
||||
YamlDe(#[from] serde_saphyr::Error),
|
||||
#[error(transparent)]
|
||||
YamlSer(#[from] serde_saphyr::ser_error::Error),
|
||||
|
||||
// ruma/conduwuit
|
||||
#[snafu(display("Arithmetic operation failed: {message}"))]
|
||||
Arithmetic {
|
||||
message: Cow<'static, str>,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("{kind}: {message}"))]
|
||||
BadRequest {
|
||||
kind: ruma::api::client::error::ErrorKind,
|
||||
message: Cow<'static, str>,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("{message}"))]
|
||||
BadServerResponse {
|
||||
message: Cow<'static, str>,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("Canonical JSON error: {source}"))]
|
||||
CanonicalJson {
|
||||
source: ruma::CanonicalJsonError,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display(
|
||||
"There was a problem with the '{directive}' directive in your configuration: {message}"
|
||||
))]
|
||||
Config {
|
||||
directive: &'static str,
|
||||
message: Cow<'static, str>,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("{message}"))]
|
||||
Conflict {
|
||||
message: Cow<'static, str>,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("Content disposition error: {source}"))]
|
||||
ContentDisposition {
|
||||
source: ruma::http_headers::ContentDispositionParseError,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("{message}"))]
|
||||
Database {
|
||||
message: Cow<'static, str>,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("Feature '{feature}' is not available on this server."))]
|
||||
FeatureDisabled {
|
||||
feature: Cow<'static, str>,
|
||||
},
|
||||
|
||||
#[snafu(display("Remote server {server} responded with: {error}"))]
|
||||
Federation {
|
||||
server: ruma::OwnedServerName,
|
||||
error: ruma::api::client::error::Error,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("{message} in {room_id}"))]
|
||||
InconsistentRoomState {
|
||||
message: &'static str,
|
||||
room_id: ruma::OwnedRoomId,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("HTTP conversion error: {source}"))]
|
||||
IntoHttp {
|
||||
source: ruma::api::error::IntoHttpError,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("{message}"))]
|
||||
Ldap {
|
||||
message: Cow<'static, str>,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("MXC URI error: {source}"))]
|
||||
Mxc {
|
||||
source: ruma::MxcUriError,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("Matrix ID parse error: {source}"))]
|
||||
Mxid {
|
||||
source: ruma::IdParseError,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("from {server}: {error}"))]
|
||||
Redaction {
|
||||
server: ruma::OwnedServerName,
|
||||
error: ruma::canonical_json::RedactionError,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("{kind}: {message}"))]
|
||||
Request {
|
||||
kind: ruma::api::client::error::ErrorKind,
|
||||
message: Cow<'static, str>,
|
||||
code: http::StatusCode,
|
||||
backtrace: Option<snafu::Backtrace>,
|
||||
},
|
||||
|
||||
#[snafu(display("Ruma error: {source}"))]
|
||||
Ruma {
|
||||
source: ruma::api::client::error::Error,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("Signature error: {source}"))]
|
||||
Signatures {
|
||||
source: ruma::signatures::Error,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("State resolution error: {source}"))]
|
||||
#[snafu(context(false))]
|
||||
StateRes {
|
||||
source: crate::state_res::Error,
|
||||
},
|
||||
|
||||
#[snafu(display("uiaa"))]
|
||||
Uiaa {
|
||||
info: ruma::api::client::uiaa::UiaaInfo,
|
||||
},
|
||||
#[error("Arithmetic operation failed: {0}")]
|
||||
Arithmetic(Cow<'static, str>),
|
||||
#[error("{0}: {1}")]
|
||||
BadRequest(ruma::api::client::error::ErrorKind, &'static str), //TODO: remove
|
||||
#[error("{0}")]
|
||||
BadServerResponse(Cow<'static, str>),
|
||||
#[error(transparent)]
|
||||
CanonicalJson(#[from] ruma::CanonicalJsonError),
|
||||
#[error("There was a problem with the '{0}' directive in your configuration: {1}")]
|
||||
Config(&'static str, Cow<'static, str>),
|
||||
#[error("{0}")]
|
||||
Conflict(Cow<'static, str>), // This is only needed for when a room alias already exists
|
||||
#[error(transparent)]
|
||||
ContentDisposition(#[from] ruma::http_headers::ContentDispositionParseError),
|
||||
#[error("{0}")]
|
||||
Database(Cow<'static, str>),
|
||||
#[error("Feature '{0}' is not available on this server.")]
|
||||
FeatureDisabled(Cow<'static, str>),
|
||||
#[error("Remote server {0} responded with: {1}")]
|
||||
Federation(ruma::OwnedServerName, ruma::api::client::error::Error),
|
||||
#[error("{0} in {1}")]
|
||||
InconsistentRoomState(&'static str, ruma::OwnedRoomId),
|
||||
#[error(transparent)]
|
||||
IntoHttp(#[from] ruma::api::error::IntoHttpError),
|
||||
#[error("{0}")]
|
||||
Ldap(Cow<'static, str>),
|
||||
#[error(transparent)]
|
||||
Mxc(#[from] ruma::MxcUriError),
|
||||
#[error(transparent)]
|
||||
Mxid(#[from] ruma::IdParseError),
|
||||
#[error("from {0}: {1}")]
|
||||
Redaction(ruma::OwnedServerName, ruma::canonical_json::RedactionError),
|
||||
#[error("{0}: {1}")]
|
||||
Request(ruma::api::client::error::ErrorKind, Cow<'static, str>, http::StatusCode),
|
||||
#[error(transparent)]
|
||||
Ruma(#[from] ruma::api::client::error::Error),
|
||||
#[error(transparent)]
|
||||
Signatures(#[from] ruma::signatures::Error),
|
||||
#[error(transparent)]
|
||||
StateRes(#[from] crate::state_res::Error),
|
||||
#[error("uiaa")]
|
||||
Uiaa(ruma::api::client::uiaa::UiaaInfo),
|
||||
|
||||
// unique / untyped
|
||||
#[snafu(display("{message}"))]
|
||||
Err {
|
||||
message: Cow<'static, str>,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
#[error("{0}")]
|
||||
Err(Cow<'static, str>),
|
||||
}
|
||||
|
||||
impl Error {
|
||||
#[inline]
|
||||
#[must_use]
|
||||
pub fn from_errno() -> Self { IoSnafu {}.into_error(std::io::Error::last_os_error()) }
|
||||
pub fn from_errno() -> Self { Self::Io(std::io::Error::last_os_error()) }
|
||||
|
||||
//#[deprecated]
|
||||
#[must_use]
|
||||
pub fn bad_database(message: &'static str) -> Self {
|
||||
let message: Cow<'static, str> = message.into();
|
||||
DatabaseSnafu { message }.build()
|
||||
crate::err!(Database(error!("{message}")))
|
||||
}
|
||||
|
||||
/// Sanitizes public-facing errors that can leak sensitive information.
|
||||
pub fn sanitized_message(&self) -> String {
|
||||
match self {
|
||||
| Self::Database { .. } => String::from("Database error occurred."),
|
||||
| Self::Io { .. } => String::from("I/O error occurred."),
|
||||
| Self::Database(..) => String::from("Database error occurred."),
|
||||
| Self::Io(..) => String::from("I/O error occurred."),
|
||||
| _ => self.message(),
|
||||
}
|
||||
}
|
||||
@@ -398,8 +158,8 @@ impl Error {
|
||||
/// Generate the error message string.
|
||||
pub fn message(&self) -> String {
|
||||
match self {
|
||||
| Self::Federation { server, error, .. } => format!("Answer from {server}: {error}"),
|
||||
| Self::Ruma { source, .. } => response::ruma_error_message(source),
|
||||
| Self::Federation(origin, error) => format!("Answer from {origin}: {error}"),
|
||||
| Self::Ruma(error) => response::ruma_error_message(error),
|
||||
| _ => format!("{self}"),
|
||||
}
|
||||
}
|
||||
@@ -410,10 +170,10 @@ impl Error {
|
||||
use ruma::api::client::error::ErrorKind::{FeatureDisabled, Unknown};
|
||||
|
||||
match self {
|
||||
| Self::Federation { error, .. } => response::ruma_error_kind(error).clone(),
|
||||
| Self::Ruma { source, .. } => response::ruma_error_kind(source).clone(),
|
||||
| Self::BadRequest { kind, .. } | Self::Request { kind, .. } => kind.clone(),
|
||||
| Self::FeatureDisabled { .. } => FeatureDisabled,
|
||||
| Self::Federation(_, error) | Self::Ruma(error) =>
|
||||
response::ruma_error_kind(error).clone(),
|
||||
| Self::BadRequest(kind, ..) | Self::Request(kind, ..) => kind.clone(),
|
||||
| Self::FeatureDisabled(..) => FeatureDisabled,
|
||||
| _ => Unknown,
|
||||
}
|
||||
}
|
||||
@@ -424,15 +184,13 @@ impl Error {
|
||||
use http::StatusCode;
|
||||
|
||||
match self {
|
||||
| Self::Federation { error, .. } => error.status_code,
|
||||
| Self::Ruma { source, .. } => source.status_code,
|
||||
| Self::Request { kind, code, .. } => response::status_code(kind, *code),
|
||||
| Self::BadRequest { kind, .. } => response::bad_request_code(kind),
|
||||
| Self::FeatureDisabled { .. } => response::bad_request_code(&self.kind()),
|
||||
| Self::Reqwest { source, .. } =>
|
||||
source.status().unwrap_or(StatusCode::INTERNAL_SERVER_ERROR),
|
||||
| Self::Conflict { .. } => StatusCode::CONFLICT,
|
||||
| Self::Io { source, .. } => response::io_error_code(source.kind()),
|
||||
| Self::Federation(_, error) | Self::Ruma(error) => error.status_code,
|
||||
| Self::Request(kind, _, code) => response::status_code(kind, *code),
|
||||
| Self::BadRequest(kind, ..) => response::bad_request_code(kind),
|
||||
| Self::FeatureDisabled(..) => response::bad_request_code(&self.kind()),
|
||||
| Self::Reqwest(error) => error.status().unwrap_or(StatusCode::INTERNAL_SERVER_ERROR),
|
||||
| Self::Conflict(_) => StatusCode::CONFLICT,
|
||||
| Self::Io(error) => response::io_error_code(error.kind()),
|
||||
| _ => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
}
|
||||
}
|
||||
@@ -445,46 +203,16 @@ impl Error {
|
||||
pub fn is_not_found(&self) -> bool { self.status_code() == http::StatusCode::NOT_FOUND }
|
||||
}
|
||||
|
||||
// Debug is already derived by Snafu
|
||||
|
||||
/// Macro to reduce boilerplate for From implementations using Snafu context
|
||||
macro_rules! impl_from_snafu {
|
||||
($source_ty:ty => $context:ident) => {
|
||||
impl From<$source_ty> for Error {
|
||||
fn from(source: $source_ty) -> Self { $context.into_error(source) }
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// Macro for From impls that format messages into ErrSnafu or other
|
||||
/// message-based contexts
|
||||
macro_rules! impl_from_message {
|
||||
($source_ty:ty => $context:ident, $msg:expr) => {
|
||||
impl From<$source_ty> for Error {
|
||||
fn from(source: $source_ty) -> Self {
|
||||
let message: Cow<'static, str> = format!($msg, source).into();
|
||||
$context { message }.build()
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// Macro for From impls with constant messages (no formatting)
|
||||
macro_rules! impl_from_const_message {
|
||||
($source_ty:ty => $context:ident, $msg:expr) => {
|
||||
impl From<$source_ty> for Error {
|
||||
fn from(_source: $source_ty) -> Self {
|
||||
let message: Cow<'static, str> = $msg.into();
|
||||
$context { message }.build()
|
||||
}
|
||||
}
|
||||
};
|
||||
impl std::fmt::Debug for Error {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.message())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<PoisonError<T>> for Error {
|
||||
#[cold]
|
||||
#[inline(never)]
|
||||
fn from(e: PoisonError<T>) -> Self { PoisonSnafu { message: e.to_string() }.build() }
|
||||
fn from(e: PoisonError<T>) -> Self { Self::Poison(e.to_string().into()) }
|
||||
}
|
||||
|
||||
#[allow(clippy::fallible_impl_from)]
|
||||
@@ -496,43 +224,6 @@ impl From<Infallible> for Error {
|
||||
}
|
||||
}
|
||||
|
||||
// Implementations using the macro
|
||||
impl_from_snafu!(std::io::Error => IoSnafu);
|
||||
impl_from_snafu!(std::string::FromUtf8Error => FromUtf8Snafu);
|
||||
impl_from_snafu!(regex::Error => RegexSnafu);
|
||||
impl_from_snafu!(ruma::http_headers::ContentDispositionParseError => ContentDispositionSnafu);
|
||||
impl_from_snafu!(ruma::api::error::IntoHttpError => IntoHttpSnafu);
|
||||
impl_from_snafu!(ruma::JsTryFromIntError => JsTryFromIntSnafu);
|
||||
impl_from_snafu!(ruma::CanonicalJsonError => CanonicalJsonSnafu);
|
||||
impl_from_snafu!(axum::extract::rejection::PathRejection => PathSnafu);
|
||||
impl_from_snafu!(clap::error::Error => ClapSnafu);
|
||||
impl_from_snafu!(ruma::MxcUriError => MxcSnafu);
|
||||
impl_from_snafu!(serde_saphyr::ser_error::Error => YamlSerSnafu);
|
||||
impl_from_snafu!(toml::de::Error => TomlDeSnafu);
|
||||
impl_from_snafu!(http::header::InvalidHeaderValue => HttpHeaderSnafu);
|
||||
impl_from_snafu!(serde_json::Error => JsonSnafu);
|
||||
|
||||
// Custom implementations using message formatting
|
||||
impl_from_const_message!(std::fmt::Error => ErrSnafu, "formatting error");
|
||||
impl_from_message!(std::str::Utf8Error => ErrSnafu, "UTF-8 error: {}");
|
||||
impl_from_message!(std::num::TryFromIntError => ArithmeticSnafu, "integer conversion error: {}");
|
||||
impl_from_message!(tracing_subscriber::reload::Error => ErrSnafu, "tracing reload error: {}");
|
||||
impl_from_message!(reqwest::Error => ErrSnafu, "HTTP client error: {}");
|
||||
impl_from_message!(ruma::signatures::Error => ErrSnafu, "Signature error: {}");
|
||||
impl_from_message!(ruma::IdParseError => ErrSnafu, "ID parse error: {}");
|
||||
impl_from_message!(std::num::ParseIntError => ErrSnafu, "Integer parse error: {}");
|
||||
impl_from_message!(std::array::TryFromSliceError => ErrSnafu, "Slice conversion error: {}");
|
||||
impl_from_message!(tokio::task::JoinError => ErrSnafu, "Task join error: {}");
|
||||
impl_from_message!(serde_saphyr::Error => ErrSnafu, "YAML error: {}");
|
||||
|
||||
// Generic implementation for CapacityError
|
||||
impl<T> From<arrayvec::CapacityError<T>> for Error {
|
||||
fn from(_source: arrayvec::CapacityError<T>) -> Self {
|
||||
let message: Cow<'static, str> = "capacity error: buffer is full".into();
|
||||
ErrSnafu { message }.build()
|
||||
}
|
||||
}
|
||||
|
||||
#[cold]
|
||||
#[inline(never)]
|
||||
pub fn infallible(_e: &Infallible) {
|
||||
|
||||
@@ -15,16 +15,13 @@ impl Error {
|
||||
|
||||
#[must_use]
|
||||
#[inline]
|
||||
pub fn from_panic(e: Box<dyn Any + Send>) -> Self {
|
||||
use super::PanicSnafu;
|
||||
PanicSnafu { message: debug::panic_str(&e), panic: e }.build()
|
||||
}
|
||||
pub fn from_panic(e: Box<dyn Any + Send>) -> Self { Self::Panic(debug::panic_str(&e), e) }
|
||||
|
||||
#[inline]
|
||||
pub fn into_panic(self) -> Box<dyn Any + Send + 'static> {
|
||||
match self {
|
||||
| Self::Panic { panic, .. } | Self::PanicAny { panic, .. } => panic,
|
||||
| Self::JoinError { source, .. } => source.into_panic(),
|
||||
| Self::Panic(_, e) | Self::PanicAny(e) => e,
|
||||
| Self::JoinError(e) => e.into_panic(),
|
||||
| _ => Box::new(self),
|
||||
}
|
||||
}
|
||||
@@ -40,8 +37,8 @@ impl Error {
|
||||
#[inline]
|
||||
pub fn is_panic(&self) -> bool {
|
||||
match &self {
|
||||
| Self::Panic { .. } | Self::PanicAny { .. } => true,
|
||||
| Self::JoinError { source, .. } => source.is_panic(),
|
||||
| Self::Panic(..) | Self::PanicAny(..) => true,
|
||||
| Self::JoinError(e) => e.is_panic(),
|
||||
| _ => false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -47,8 +47,8 @@ impl axum::response::IntoResponse for Error {
|
||||
impl From<Error> for UiaaResponse {
|
||||
#[inline]
|
||||
fn from(error: Error) -> Self {
|
||||
if let Error::Uiaa { info, .. } = error {
|
||||
return Self::AuthResponse(info);
|
||||
if let Error::Uiaa(uiaainfo) = error {
|
||||
return Self::AuthResponse(uiaainfo);
|
||||
}
|
||||
|
||||
let body = ErrorBody::Standard {
|
||||
|
||||
@@ -5,15 +5,9 @@ use serde::{de, ser};
|
||||
use crate::Error;
|
||||
|
||||
impl de::Error for Error {
|
||||
fn custom<T: Display + ToString>(msg: T) -> Self {
|
||||
let message: std::borrow::Cow<'static, str> = msg.to_string().into();
|
||||
super::SerdeDeSnafu { message }.build()
|
||||
}
|
||||
fn custom<T: Display + ToString>(msg: T) -> Self { Self::SerdeDe(msg.to_string().into()) }
|
||||
}
|
||||
|
||||
impl ser::Error for Error {
|
||||
fn custom<T: Display + ToString>(msg: T) -> Self {
|
||||
let message: std::borrow::Cow<'static, str> = msg.to_string().into();
|
||||
super::SerdeSerSnafu { message }.build()
|
||||
}
|
||||
fn custom<T: Display + ToString>(msg: T) -> Self { Self::SerdeSer(msg.to_string().into()) }
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use ruma::{RoomVersionId, canonical_json::redact_content_in_place};
|
||||
use serde_json::{Value as JsonValue, json, value::to_raw_value};
|
||||
|
||||
use crate::{Result, err, implement};
|
||||
use crate::{Error, Result, err, implement};
|
||||
|
||||
#[implement(super::Pdu)]
|
||||
pub fn redact(&mut self, room_version_id: &RoomVersionId, reason: JsonValue) -> Result {
|
||||
@@ -10,15 +10,8 @@ pub fn redact(&mut self, room_version_id: &RoomVersionId, reason: JsonValue) ->
|
||||
let mut content = serde_json::from_str(self.content.get())
|
||||
.map_err(|e| err!(Request(BadJson("Failed to deserialize content into type: {e}"))))?;
|
||||
|
||||
redact_content_in_place(&mut content, room_version_id, self.kind.to_string()).map_err(
|
||||
|error| {
|
||||
crate::error::RedactionSnafu {
|
||||
server: self.sender.server_name().to_owned(),
|
||||
error,
|
||||
}
|
||||
.build()
|
||||
},
|
||||
)?;
|
||||
redact_content_in_place(&mut content, room_version_id, self.kind.to_string())
|
||||
.map_err(|e| Error::Redaction(self.sender.server_name().to_owned(), e))?;
|
||||
|
||||
let reason = serde_json::to_value(reason).expect("Failed to preserialize reason");
|
||||
|
||||
|
||||
@@ -1,554 +0,0 @@
|
||||
#[cfg(conduwuit_bench)]
|
||||
extern crate test;
|
||||
|
||||
use std::{
|
||||
borrow::Borrow,
|
||||
collections::{HashMap, HashSet},
|
||||
sync::atomic::{AtomicU64, Ordering::SeqCst},
|
||||
};
|
||||
|
||||
use futures::{future, future::ready};
|
||||
use maplit::{btreemap, hashmap, hashset};
|
||||
use ruma::{
|
||||
EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, RoomId, RoomVersionId, Signatures, UserId,
|
||||
events::{
|
||||
StateEventType, TimelineEventType,
|
||||
room::{
|
||||
join_rules::{JoinRule, RoomJoinRulesEventContent},
|
||||
member::{MembershipState, RoomMemberEventContent},
|
||||
},
|
||||
},
|
||||
int, room_id, uint, user_id,
|
||||
};
|
||||
use serde_json::{
|
||||
json,
|
||||
value::{RawValue as RawJsonValue, to_raw_value as to_raw_json_value},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
matrix::{Event, Pdu, pdu::EventHash},
|
||||
state_res::{self as state_res, Error, Result, StateMap, error::NotFoundSnafu},
|
||||
};
|
||||
|
||||
static SERVER_TIMESTAMP: AtomicU64 = AtomicU64::new(0);
|
||||
|
||||
#[cfg(conduwuit_bench)]
|
||||
#[cfg_attr(conduwuit_bench, bench)]
|
||||
fn lexico_topo_sort(c: &mut test::Bencher) {
|
||||
let graph = hashmap! {
|
||||
event_id("l") => hashset![event_id("o")],
|
||||
event_id("m") => hashset![event_id("n"), event_id("o")],
|
||||
event_id("n") => hashset![event_id("o")],
|
||||
event_id("o") => hashset![], // "o" has zero outgoing edges but 4 incoming edges
|
||||
event_id("p") => hashset![event_id("o")],
|
||||
};
|
||||
|
||||
c.iter(|| {
|
||||
let _ = state_res::lexicographical_topological_sort(&graph, &|_| {
|
||||
future::ok((int!(0), MilliSecondsSinceUnixEpoch(uint!(0))))
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
#[cfg(conduwuit_bench)]
|
||||
#[cfg_attr(conduwuit_bench, bench)]
|
||||
fn resolution_shallow_auth_chain(c: &mut test::Bencher) {
|
||||
let mut store = TestStore(hashmap! {});
|
||||
|
||||
// build up the DAG
|
||||
let (state_at_bob, state_at_charlie, _) = store.set_up();
|
||||
|
||||
c.iter(|| async {
|
||||
let ev_map = store.0.clone();
|
||||
let state_sets = [&state_at_bob, &state_at_charlie];
|
||||
let fetch = |id: OwnedEventId| ready(ev_map.get(&id).map(ToOwned::to_owned));
|
||||
let exists = |id: OwnedEventId| ready(ev_map.get(&id).is_some());
|
||||
let auth_chain_sets: Vec<HashSet<_>> = state_sets
|
||||
.iter()
|
||||
.map(|map| {
|
||||
store
|
||||
.auth_event_ids(room_id(), map.values().cloned().collect())
|
||||
.unwrap()
|
||||
})
|
||||
.collect();
|
||||
|
||||
let _ = match state_res::resolve(
|
||||
&RoomVersionId::V6,
|
||||
state_sets.into_iter(),
|
||||
&auth_chain_sets,
|
||||
&fetch,
|
||||
&exists,
|
||||
)
|
||||
.await
|
||||
{
|
||||
| Ok(state) => state,
|
||||
| Err(e) => panic!("{e}"),
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
#[cfg(conduwuit_bench)]
|
||||
#[cfg_attr(conduwuit_bench, bench)]
|
||||
fn resolve_deeper_event_set(c: &mut test::Bencher) {
|
||||
let mut inner = INITIAL_EVENTS();
|
||||
let ban = BAN_STATE_SET();
|
||||
|
||||
inner.extend(ban);
|
||||
let store = TestStore(inner.clone());
|
||||
|
||||
let state_set_a = [
|
||||
inner.get(&event_id("CREATE")).unwrap(),
|
||||
inner.get(&event_id("IJR")).unwrap(),
|
||||
inner.get(&event_id("IMA")).unwrap(),
|
||||
inner.get(&event_id("IMB")).unwrap(),
|
||||
inner.get(&event_id("IMC")).unwrap(),
|
||||
inner.get(&event_id("MB")).unwrap(),
|
||||
inner.get(&event_id("PA")).unwrap(),
|
||||
]
|
||||
.iter()
|
||||
.map(|ev| {
|
||||
(
|
||||
(ev.event_type().clone().into(), ev.state_key().unwrap().into()),
|
||||
ev.event_id().to_owned(),
|
||||
)
|
||||
})
|
||||
.collect::<StateMap<_>>();
|
||||
|
||||
let state_set_b = [
|
||||
inner.get(&event_id("CREATE")).unwrap(),
|
||||
inner.get(&event_id("IJR")).unwrap(),
|
||||
inner.get(&event_id("IMA")).unwrap(),
|
||||
inner.get(&event_id("IMB")).unwrap(),
|
||||
inner.get(&event_id("IMC")).unwrap(),
|
||||
inner.get(&event_id("IME")).unwrap(),
|
||||
inner.get(&event_id("PA")).unwrap(),
|
||||
]
|
||||
.iter()
|
||||
.map(|ev| {
|
||||
(
|
||||
(ev.event_type().clone().into(), ev.state_key().unwrap().into()),
|
||||
ev.event_id().to_owned(),
|
||||
)
|
||||
})
|
||||
.collect::<StateMap<_>>();
|
||||
|
||||
c.iter(|| async {
|
||||
let state_sets = [&state_set_a, &state_set_b];
|
||||
let auth_chain_sets: Vec<HashSet<_>> = state_sets
|
||||
.iter()
|
||||
.map(|map| {
|
||||
store
|
||||
.auth_event_ids(room_id(), map.values().cloned().collect())
|
||||
.unwrap()
|
||||
})
|
||||
.collect();
|
||||
|
||||
let fetch = |id: OwnedEventId| ready(inner.get(&id).map(ToOwned::to_owned));
|
||||
let exists = |id: OwnedEventId| ready(inner.get(&id).is_some());
|
||||
let _ = match state_res::resolve(
|
||||
&RoomVersionId::V6,
|
||||
state_sets.into_iter(),
|
||||
&auth_chain_sets,
|
||||
&fetch,
|
||||
&exists,
|
||||
)
|
||||
.await
|
||||
{
|
||||
| Ok(state) => state,
|
||||
| Err(_) => panic!("resolution failed during benchmarking"),
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
//*/////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// IMPLEMENTATION DETAILS AHEAD
|
||||
//
|
||||
/////////////////////////////////////////////////////////////////////*/
|
||||
struct TestStore<E: Event>(HashMap<OwnedEventId, E>);
|
||||
|
||||
#[allow(unused)]
|
||||
impl<E: Event + Clone> TestStore<E> {
|
||||
fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result<E> {
|
||||
self.0.get(event_id).cloned().ok_or_else(|| {
|
||||
NotFoundSnafu {
|
||||
message: format!("{} not found", event_id),
|
||||
}
|
||||
.build()
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns the events that correspond to the `event_ids` sorted in the same
|
||||
/// order.
|
||||
fn get_events(&self, room_id: &RoomId, event_ids: &[OwnedEventId]) -> Result<Vec<E>> {
|
||||
let mut events = vec![];
|
||||
for id in event_ids {
|
||||
events.push(self.get_event(room_id, id)?);
|
||||
}
|
||||
Ok(events)
|
||||
}
|
||||
|
||||
/// Returns a Vec of the related auth events to the given `event`.
|
||||
fn auth_event_ids(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
event_ids: Vec<OwnedEventId>,
|
||||
) -> Result<HashSet<OwnedEventId>> {
|
||||
let mut result = HashSet::new();
|
||||
let mut stack = event_ids;
|
||||
|
||||
// DFS for auth event chain
|
||||
while !stack.is_empty() {
|
||||
let ev_id = stack.pop().unwrap();
|
||||
if result.contains(&ev_id) {
|
||||
continue;
|
||||
}
|
||||
|
||||
result.insert(ev_id.clone());
|
||||
|
||||
let event = self.get_event(room_id, ev_id.borrow())?;
|
||||
|
||||
stack.extend(event.auth_events().map(ToOwned::to_owned));
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Returns a vector representing the difference in auth chains of the given
|
||||
/// `events`.
|
||||
fn auth_chain_diff(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
event_ids: Vec<Vec<OwnedEventId>>,
|
||||
) -> Result<Vec<OwnedEventId>> {
|
||||
let mut auth_chain_sets = vec![];
|
||||
for ids in event_ids {
|
||||
// TODO state store `auth_event_ids` returns self in the event ids list
|
||||
// when an event returns `auth_event_ids` self is not contained
|
||||
let chain = self
|
||||
.auth_event_ids(room_id, ids)?
|
||||
.into_iter()
|
||||
.collect::<HashSet<_>>();
|
||||
auth_chain_sets.push(chain);
|
||||
}
|
||||
|
||||
if let Some(first) = auth_chain_sets.first().cloned() {
|
||||
let common = auth_chain_sets
|
||||
.iter()
|
||||
.skip(1)
|
||||
.fold(first, |a, b| a.intersection(b).cloned().collect::<HashSet<_>>());
|
||||
|
||||
Ok(auth_chain_sets
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.filter(|id| !common.contains(id))
|
||||
.collect())
|
||||
} else {
|
||||
Ok(vec![])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TestStore<Pdu> {
|
||||
#[allow(clippy::type_complexity)]
|
||||
fn set_up(
|
||||
&mut self,
|
||||
) -> (StateMap<OwnedEventId>, StateMap<OwnedEventId>, StateMap<OwnedEventId>) {
|
||||
let create_event = to_pdu_event::<&EventId>(
|
||||
"CREATE",
|
||||
alice(),
|
||||
TimelineEventType::RoomCreate,
|
||||
Some(""),
|
||||
to_raw_json_value(&json!({ "creator": alice() })).unwrap(),
|
||||
&[],
|
||||
&[],
|
||||
);
|
||||
let cre = create_event.event_id().to_owned();
|
||||
self.0.insert(cre.clone(), create_event.clone());
|
||||
|
||||
let alice_mem = to_pdu_event(
|
||||
"IMA",
|
||||
alice(),
|
||||
TimelineEventType::RoomMember,
|
||||
Some(alice().to_string().as_str()),
|
||||
member_content_join(),
|
||||
&[cre.clone()],
|
||||
&[cre.clone()],
|
||||
);
|
||||
self.0
|
||||
.insert(alice_mem.event_id().to_owned(), alice_mem.clone());
|
||||
|
||||
let join_rules = to_pdu_event(
|
||||
"IJR",
|
||||
alice(),
|
||||
TimelineEventType::RoomJoinRules,
|
||||
Some(""),
|
||||
to_raw_json_value(&RoomJoinRulesEventContent::new(JoinRule::Public)).unwrap(),
|
||||
&[cre.clone(), alice_mem.event_id().to_owned()],
|
||||
&[alice_mem.event_id().to_owned()],
|
||||
);
|
||||
self.0
|
||||
.insert(join_rules.event_id().to_owned(), join_rules.clone());
|
||||
|
||||
// Bob and Charlie join at the same time, so there is a fork
|
||||
// this will be represented in the state_sets when we resolve
|
||||
let bob_mem = to_pdu_event(
|
||||
"IMB",
|
||||
bob(),
|
||||
TimelineEventType::RoomMember,
|
||||
Some(bob().to_string().as_str()),
|
||||
member_content_join(),
|
||||
&[cre.clone(), join_rules.event_id().to_owned()],
|
||||
&[join_rules.event_id().to_owned()],
|
||||
);
|
||||
self.0
|
||||
.insert(bob_mem.event_id().to_owned(), bob_mem.clone());
|
||||
|
||||
let charlie_mem = to_pdu_event(
|
||||
"IMC",
|
||||
charlie(),
|
||||
TimelineEventType::RoomMember,
|
||||
Some(charlie().to_string().as_str()),
|
||||
member_content_join(),
|
||||
&[cre, join_rules.event_id().to_owned()],
|
||||
&[join_rules.event_id().to_owned()],
|
||||
);
|
||||
self.0
|
||||
.insert(charlie_mem.event_id().to_owned(), charlie_mem.clone());
|
||||
|
||||
let state_at_bob = [&create_event, &alice_mem, &join_rules, &bob_mem]
|
||||
.iter()
|
||||
.map(|ev| {
|
||||
(
|
||||
(ev.event_type().clone().into(), ev.state_key().unwrap().into()),
|
||||
ev.event_id().to_owned(),
|
||||
)
|
||||
})
|
||||
.collect::<StateMap<_>>();
|
||||
|
||||
let state_at_charlie = [&create_event, &alice_mem, &join_rules, &charlie_mem]
|
||||
.iter()
|
||||
.map(|ev| {
|
||||
(
|
||||
(ev.event_type().clone().into(), ev.state_key().unwrap().into()),
|
||||
ev.event_id().to_owned(),
|
||||
)
|
||||
})
|
||||
.collect::<StateMap<_>>();
|
||||
|
||||
let expected = [&create_event, &alice_mem, &join_rules, &bob_mem, &charlie_mem]
|
||||
.iter()
|
||||
.map(|ev| {
|
||||
(
|
||||
(ev.event_type().clone().into(), ev.state_key().unwrap().into()),
|
||||
ev.event_id().to_owned(),
|
||||
)
|
||||
})
|
||||
.collect::<StateMap<_>>();
|
||||
|
||||
(state_at_bob, state_at_charlie, expected)
|
||||
}
|
||||
}
|
||||
|
||||
fn event_id(id: &str) -> OwnedEventId {
|
||||
if id.contains('$') {
|
||||
return id.try_into().unwrap();
|
||||
}
|
||||
format!("${}:foo", id).try_into().unwrap()
|
||||
}
|
||||
|
||||
fn alice() -> &'static UserId { user_id!("@alice:foo") }
|
||||
|
||||
fn bob() -> &'static UserId { user_id!("@bob:foo") }
|
||||
|
||||
fn charlie() -> &'static UserId { user_id!("@charlie:foo") }
|
||||
|
||||
fn ella() -> &'static UserId { user_id!("@ella:foo") }
|
||||
|
||||
fn room_id() -> &'static RoomId { room_id!("!test:foo") }
|
||||
|
||||
fn member_content_ban() -> Box<RawJsonValue> {
|
||||
to_raw_json_value(&RoomMemberEventContent::new(MembershipState::Ban)).unwrap()
|
||||
}
|
||||
|
||||
fn member_content_join() -> Box<RawJsonValue> {
|
||||
to_raw_json_value(&RoomMemberEventContent::new(MembershipState::Join)).unwrap()
|
||||
}
|
||||
|
||||
fn to_pdu_event<S>(
|
||||
id: &str,
|
||||
sender: &UserId,
|
||||
ev_type: TimelineEventType,
|
||||
state_key: Option<&str>,
|
||||
content: Box<RawJsonValue>,
|
||||
auth_events: &[S],
|
||||
prev_events: &[S],
|
||||
) -> Pdu
|
||||
where
|
||||
S: AsRef<str>,
|
||||
{
|
||||
// We don't care if the addition happens in order just that it is atomic
|
||||
// (each event has its own value)
|
||||
let ts = SERVER_TIMESTAMP.fetch_add(1, SeqCst);
|
||||
let id = if id.contains('$') {
|
||||
id.to_owned()
|
||||
} else {
|
||||
format!("${}:foo", id)
|
||||
};
|
||||
let auth_events = auth_events
|
||||
.iter()
|
||||
.map(AsRef::as_ref)
|
||||
.map(event_id)
|
||||
.collect::<Vec<_>>();
|
||||
let prev_events = prev_events
|
||||
.iter()
|
||||
.map(AsRef::as_ref)
|
||||
.map(event_id)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
Pdu {
|
||||
event_id: id.try_into().unwrap(),
|
||||
room_id: Some(room_id().to_owned()),
|
||||
sender: sender.to_owned(),
|
||||
origin_server_ts: ts.try_into().unwrap(),
|
||||
state_key: state_key.map(Into::into),
|
||||
kind: ev_type,
|
||||
content,
|
||||
origin: None,
|
||||
redacts: None,
|
||||
unsigned: None,
|
||||
auth_events,
|
||||
prev_events,
|
||||
depth: uint!(0),
|
||||
hashes: EventHash { sha256: String::new() },
|
||||
signatures: None,
|
||||
}
|
||||
}
|
||||
|
||||
// all graphs start with these input events
|
||||
#[allow(non_snake_case)]
|
||||
fn INITIAL_EVENTS() -> HashMap<OwnedEventId, Pdu> {
|
||||
vec![
|
||||
to_pdu_event::<&EventId>(
|
||||
"CREATE",
|
||||
alice(),
|
||||
TimelineEventType::RoomCreate,
|
||||
Some(""),
|
||||
to_raw_json_value(&json!({ "creator": alice() })).unwrap(),
|
||||
&[],
|
||||
&[],
|
||||
),
|
||||
to_pdu_event(
|
||||
"IMA",
|
||||
alice(),
|
||||
TimelineEventType::RoomMember,
|
||||
Some(alice().as_str()),
|
||||
member_content_join(),
|
||||
&["CREATE"],
|
||||
&["CREATE"],
|
||||
),
|
||||
to_pdu_event(
|
||||
"IPOWER",
|
||||
alice(),
|
||||
TimelineEventType::RoomPowerLevels,
|
||||
Some(""),
|
||||
to_raw_json_value(&json!({ "users": { alice(): 100 } })).unwrap(),
|
||||
&["CREATE", "IMA"],
|
||||
&["IMA"],
|
||||
),
|
||||
to_pdu_event(
|
||||
"IJR",
|
||||
alice(),
|
||||
TimelineEventType::RoomJoinRules,
|
||||
Some(""),
|
||||
to_raw_json_value(&RoomJoinRulesEventContent::new(JoinRule::Public)).unwrap(),
|
||||
&["CREATE", "IMA", "IPOWER"],
|
||||
&["IPOWER"],
|
||||
),
|
||||
to_pdu_event(
|
||||
"IMB",
|
||||
bob(),
|
||||
TimelineEventType::RoomMember,
|
||||
Some(bob().to_string().as_str()),
|
||||
member_content_join(),
|
||||
&["CREATE", "IJR", "IPOWER"],
|
||||
&["IJR"],
|
||||
),
|
||||
to_pdu_event(
|
||||
"IMC",
|
||||
charlie(),
|
||||
TimelineEventType::RoomMember,
|
||||
Some(charlie().to_string().as_str()),
|
||||
member_content_join(),
|
||||
&["CREATE", "IJR", "IPOWER"],
|
||||
&["IMB"],
|
||||
),
|
||||
to_pdu_event::<&EventId>(
|
||||
"START",
|
||||
charlie(),
|
||||
TimelineEventType::RoomTopic,
|
||||
Some(""),
|
||||
to_raw_json_value(&json!({})).unwrap(),
|
||||
&[],
|
||||
&[],
|
||||
),
|
||||
to_pdu_event::<&EventId>(
|
||||
"END",
|
||||
charlie(),
|
||||
TimelineEventType::RoomTopic,
|
||||
Some(""),
|
||||
to_raw_json_value(&json!({})).unwrap(),
|
||||
&[],
|
||||
&[],
|
||||
),
|
||||
]
|
||||
.into_iter()
|
||||
.map(|ev| (ev.event_id().to_owned(), ev))
|
||||
.collect()
|
||||
}
|
||||
|
||||
// all graphs start with these input events
|
||||
#[allow(non_snake_case)]
|
||||
fn BAN_STATE_SET() -> HashMap<OwnedEventId, Pdu> {
|
||||
vec![
|
||||
to_pdu_event(
|
||||
"PA",
|
||||
alice(),
|
||||
TimelineEventType::RoomPowerLevels,
|
||||
Some(""),
|
||||
to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(),
|
||||
&["CREATE", "IMA", "IPOWER"], // auth_events
|
||||
&["START"], // prev_events
|
||||
),
|
||||
to_pdu_event(
|
||||
"PB",
|
||||
alice(),
|
||||
TimelineEventType::RoomPowerLevels,
|
||||
Some(""),
|
||||
to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(),
|
||||
&["CREATE", "IMA", "IPOWER"],
|
||||
&["END"],
|
||||
),
|
||||
to_pdu_event(
|
||||
"MB",
|
||||
alice(),
|
||||
TimelineEventType::RoomMember,
|
||||
Some(ella().as_str()),
|
||||
member_content_ban(),
|
||||
&["CREATE", "IMA", "PB"],
|
||||
&["PA"],
|
||||
),
|
||||
to_pdu_event(
|
||||
"IME",
|
||||
ella(),
|
||||
TimelineEventType::RoomMember,
|
||||
Some(ella().as_str()),
|
||||
member_content_join(),
|
||||
&["CREATE", "IJR", "PA"],
|
||||
&["MB"],
|
||||
),
|
||||
]
|
||||
.into_iter()
|
||||
.map(|ev| (ev.event_id().to_owned(), ev))
|
||||
.collect()
|
||||
}
|
||||
@@ -1,40 +1,42 @@
|
||||
use ruma::OwnedEventId;
|
||||
use serde_json::Error as JsonError;
|
||||
use snafu::{IntoError, prelude::*};
|
||||
use thiserror::Error;
|
||||
|
||||
/// Represents the various errors that arise when resolving state.
|
||||
#[derive(Debug, Snafu)]
|
||||
#[snafu(visibility(pub))]
|
||||
#[derive(Error, Debug)]
|
||||
#[non_exhaustive]
|
||||
pub enum Error {
|
||||
/// A deserialization error.
|
||||
#[snafu(display("JSON error: {source}"))]
|
||||
SerdeJson {
|
||||
source: JsonError,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
#[error(transparent)]
|
||||
SerdeJson(#[from] JsonError),
|
||||
|
||||
/// The given option or version is unsupported.
|
||||
#[snafu(display("Unsupported room version: {version}"))]
|
||||
Unsupported {
|
||||
version: String,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
#[error("Unsupported room version: {0}")]
|
||||
Unsupported(String),
|
||||
|
||||
/// The given event was not found.
|
||||
#[snafu(display("Not found error: {message}"))]
|
||||
NotFound {
|
||||
message: String,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
#[error("Event not found: {0}")]
|
||||
NotFound(String),
|
||||
|
||||
/// A required event this event depended on could not be fetched,
|
||||
/// either as it was missing, or because it was invalid
|
||||
#[error("Failed to fetch required {0} event: {1}")]
|
||||
DependencyFailed(OwnedEventId, String),
|
||||
|
||||
/// Invalid fields in the given PDU.
|
||||
#[snafu(display("Invalid PDU: {message}"))]
|
||||
InvalidPdu {
|
||||
message: String,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
}
|
||||
#[error("Invalid PDU: {0}")]
|
||||
InvalidPdu(String),
|
||||
|
||||
impl From<serde_json::Error> for Error {
|
||||
fn from(source: serde_json::Error) -> Self { SerdeJsonSnafu.into_error(source) }
|
||||
/// This event failed an authorization condition.
|
||||
#[error("Auth check failed: {0}")]
|
||||
AuthConditionFailed(String),
|
||||
|
||||
/// This event contained multiple auth events of the same type and state
|
||||
/// key.
|
||||
#[error("Duplicate auth events: {0}")]
|
||||
DuplicateAuthEvents(String),
|
||||
|
||||
/// This event contains unnecessary auth events.
|
||||
#[error("Unknown or unnecessary auth events present: {0}")]
|
||||
UnselectedAuthEvents(String),
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,238 @@
|
||||
//! Auth checks relevant to any event's `auth_events`.
|
||||
//!
|
||||
//! See: https://spec.matrix.org/v1.16/rooms/v12/#authorization-rules
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
use ruma::{
|
||||
EventId, OwnedEventId, RoomId, UserId,
|
||||
events::{
|
||||
StateEventType, TimelineEventType,
|
||||
room::member::{MembershipState, RoomMemberEventContent, ThirdPartyInvite},
|
||||
},
|
||||
};
|
||||
|
||||
use crate::{Event, EventTypeExt, Pdu, RoomVersion, matrix::StateKey, state_res::Error, warn};
|
||||
|
||||
/// For the given event `kind` what are the relevant auth events that are needed
|
||||
/// to authenticate this `content`.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// This function will return an error if the supplied `content` is not a JSON
|
||||
/// object.
|
||||
pub fn auth_types_for_event(
|
||||
room_version: &RoomVersion,
|
||||
event_type: &TimelineEventType,
|
||||
state_key: Option<&StateKey>,
|
||||
sender: &UserId,
|
||||
member_content: Option<RoomMemberEventContent>,
|
||||
) -> serde_json::Result<Vec<(StateEventType, StateKey)>> {
|
||||
if event_type == &TimelineEventType::RoomCreate {
|
||||
// Create events never have auth events
|
||||
return Ok(vec![]);
|
||||
}
|
||||
let mut auth_types = if room_version.room_ids_as_hashes {
|
||||
vec![
|
||||
StateEventType::RoomMember.with_state_key(sender.as_str()),
|
||||
StateEventType::RoomPowerLevels.with_state_key(""),
|
||||
]
|
||||
} else {
|
||||
// For room versions that do not use room IDs as hashes, include the
|
||||
// RoomCreate event as an auth event.
|
||||
vec![
|
||||
StateEventType::RoomMember.with_state_key(sender.as_str()),
|
||||
StateEventType::RoomPowerLevels.with_state_key(""),
|
||||
StateEventType::RoomCreate.with_state_key(""),
|
||||
]
|
||||
};
|
||||
|
||||
if event_type == &TimelineEventType::RoomMember {
|
||||
let member_content =
|
||||
member_content.expect("member_content must be provided for RoomMember events");
|
||||
|
||||
// Include the target's membership (if available)
|
||||
auth_types.push((
|
||||
StateEventType::RoomMember,
|
||||
state_key
|
||||
.expect("state_key must be provided for RoomMember events")
|
||||
.to_owned(),
|
||||
));
|
||||
|
||||
if matches!(
|
||||
member_content.membership,
|
||||
MembershipState::Join | MembershipState::Invite | MembershipState::Knock
|
||||
) {
|
||||
// Include the join rules
|
||||
auth_types.push(StateEventType::RoomJoinRules.with_state_key(""));
|
||||
}
|
||||
|
||||
if matches!(member_content.membership, MembershipState::Invite) {
|
||||
// If this is an invite, include the third party invite if it exists
|
||||
if let Some(ThirdPartyInvite { signed, .. }) = member_content.third_party_invite {
|
||||
auth_types
|
||||
.push(StateEventType::RoomThirdPartyInvite.with_state_key(signed.token));
|
||||
}
|
||||
}
|
||||
|
||||
if matches!(member_content.membership, MembershipState::Join)
|
||||
&& room_version.restricted_join_rules
|
||||
{
|
||||
// If this is a restricted join, include the authorizing user's membership
|
||||
if let Some(authorizing_user) = member_content.join_authorized_via_users_server {
|
||||
auth_types
|
||||
.push(StateEventType::RoomMember.with_state_key(authorizing_user.as_str()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(auth_types)
|
||||
}
|
||||
|
||||
/// Checks for duplicate auth events in the `auth_events` field of an event.
|
||||
/// Note: the caller should already have all of the auth events fetched.
|
||||
///
|
||||
/// If there are multiple auth events of the same type and state key, this
|
||||
/// returns an error. Otherwise, it returns a map of (type, state_key) to the
|
||||
/// corresponding auth event.
|
||||
pub async fn check_duplicate_auth_events<FE>(
|
||||
auth_events: &[OwnedEventId],
|
||||
fetch_event: FE,
|
||||
) -> Result<HashMap<(StateEventType, StateKey), Pdu>, Error>
|
||||
where
|
||||
FE: AsyncFn(&EventId) -> Result<Option<Pdu>, Error>,
|
||||
{
|
||||
let mut seen: HashMap<(StateEventType, StateKey), Pdu> = HashMap::new();
|
||||
|
||||
// Considering all of the event's auth events:
|
||||
for auth_event_id in auth_events {
|
||||
if let Ok(Some(auth_event)) = fetch_event(auth_event_id).await {
|
||||
let event_type = auth_event.kind();
|
||||
// If this is not a state event, reject it.
|
||||
let Some(state_key) = &auth_event.state_key() else {
|
||||
return Err(Error::InvalidPdu(format!(
|
||||
"Auth event {:?} is not a state event",
|
||||
auth_event_id
|
||||
)));
|
||||
};
|
||||
let type_key_pair: (StateEventType, StateKey) =
|
||||
event_type.clone().with_state_key(state_key.clone());
|
||||
|
||||
// If there are duplicate entries for a given type and state_key pair, reject.
|
||||
if seen.contains_key(&type_key_pair) {
|
||||
return Err(Error::DuplicateAuthEvents(format!(
|
||||
"({:?},\"{:?}\")",
|
||||
event_type, state_key
|
||||
)));
|
||||
}
|
||||
seen.insert(type_key_pair, auth_event);
|
||||
} else {
|
||||
return Err(Error::NotFound(auth_event_id.as_str().to_owned()));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(seen)
|
||||
}
|
||||
|
||||
// Checks that the event does not refer to any auth events that it does not need
|
||||
// to.
|
||||
pub fn check_unnecessary_auth_events(
|
||||
auth_events: &HashSet<(StateEventType, StateKey)>,
|
||||
expected: &Vec<(StateEventType, StateKey)>,
|
||||
) -> Result<(), Error> {
|
||||
// If there are entries whose type and state_key don't match those specified by
|
||||
// the auth events selection algorithm described in the server specification,
|
||||
// reject.
|
||||
let remaining = auth_events
|
||||
.iter()
|
||||
.filter(|key| !expected.contains(key))
|
||||
.collect::<HashSet<_>>();
|
||||
if !remaining.is_empty() {
|
||||
return Err(Error::UnselectedAuthEvents(format!("{:?}", remaining)));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Checks that all provided auth events were not rejected previously.
|
||||
//
|
||||
// TODO: this is currently a no-op and always returns Ok(()).
|
||||
pub fn check_all_auth_events_accepted(
|
||||
_auth_events: &HashMap<(StateEventType, StateKey), Pdu>,
|
||||
) -> Result<(), Error> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Checks that all auth events are from the same room as the event being
|
||||
// validated.
|
||||
pub fn check_auth_same_room(auth_events: &Vec<Pdu>, room_id: &RoomId) -> bool {
|
||||
for auth_event in auth_events {
|
||||
if let Some(auth_room_id) = &auth_event.room_id() {
|
||||
if auth_room_id.as_str() != room_id.as_str() {
|
||||
warn!(
|
||||
auth_event_id=%auth_event.event_id(),
|
||||
"Auth event room id {} does not match expected room id {}",
|
||||
auth_room_id,
|
||||
room_id
|
||||
);
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
warn!(auth_event_id=%auth_event.event_id(), "Auth event has no room_id");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
/// Performs all auth event checks for the given event.
|
||||
pub async fn check_auth_events<FE>(
|
||||
event: &Pdu,
|
||||
room_id: &RoomId,
|
||||
room_version: &RoomVersion,
|
||||
fetch_event: &FE,
|
||||
) -> Result<HashMap<(StateEventType, StateKey), Pdu>, Error>
|
||||
where
|
||||
FE: AsyncFn(&EventId) -> Result<Option<Pdu>, Error>,
|
||||
{
|
||||
// If there are duplicate entries for a given type and state_key pair, reject.
|
||||
let auth_events_map = check_duplicate_auth_events(&event.auth_events, fetch_event).await?;
|
||||
let auth_events_set: HashSet<(StateEventType, StateKey)> =
|
||||
auth_events_map.keys().cloned().collect();
|
||||
|
||||
// If there are entries whose type and state_key don’t match those specified by
|
||||
// the auth events selection algorithm described in the server specification,
|
||||
// reject.
|
||||
let member_event_content = match event.kind() {
|
||||
| TimelineEventType::RoomMember =>
|
||||
Some(event.get_content::<RoomMemberEventContent>().map_err(|e| {
|
||||
Error::InvalidPdu(format!("Failed to parse m.room.member content: {}", e))
|
||||
})?),
|
||||
| _ => None,
|
||||
};
|
||||
let expected_auth_events = auth_types_for_event(
|
||||
room_version,
|
||||
event.kind(),
|
||||
event.state_key.as_ref(),
|
||||
event.sender(),
|
||||
member_event_content,
|
||||
)?;
|
||||
if let Err(e) = check_unnecessary_auth_events(&auth_events_set, &expected_auth_events) {
|
||||
return Err(e);
|
||||
}
|
||||
|
||||
// If there are entries which were themselves rejected under the checks
|
||||
// performed on receipt of a PDU, reject.
|
||||
if let Err(e) = check_all_auth_events_accepted(&auth_events_map) {
|
||||
return Err(e);
|
||||
}
|
||||
|
||||
// If any event in auth_events has a room_id which does not match that of the
|
||||
// event being authorised, reject.
|
||||
let auth_event_refs: Vec<Pdu> = auth_events_map.values().cloned().collect();
|
||||
if !check_auth_same_room(&auth_event_refs, room_id) {
|
||||
return Err(Error::InvalidPdu(
|
||||
"One or more auth events are from a different room".to_owned(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(auth_events_map)
|
||||
}
|
||||
@@ -0,0 +1,113 @@
|
||||
//! Context for event authorisation checks
|
||||
|
||||
use ruma::{
|
||||
Int, OwnedUserId, UserId,
|
||||
events::{
|
||||
StateEventType,
|
||||
room::{create::RoomCreateEventContent, power_levels::RoomPowerLevelsEventContent},
|
||||
},
|
||||
};
|
||||
|
||||
use crate::{Event, EventTypeExt, Pdu, RoomVersion, matrix::StateKey, state_res::Error};
|
||||
|
||||
pub enum UserPower {
|
||||
/// Creator indicates this user should be granted a power level above all.
|
||||
Creator,
|
||||
/// Standard indicates power levels should be used to determine rank.
|
||||
Standard,
|
||||
}
|
||||
|
||||
impl PartialEq for UserPower {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
match (self, other) {
|
||||
| (UserPower::Creator, UserPower::Creator) => true,
|
||||
| (UserPower::Standard, UserPower::Standard) => true,
|
||||
| _ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the creators of the room.
|
||||
/// If this room only supports one creator, a vec of one will be returned.
|
||||
/// If multiple creators are supported, all will be returned, with the
|
||||
/// m.room.create sender first.
|
||||
pub async fn calculate_creators<FS>(
|
||||
room_version: &RoomVersion,
|
||||
fetch_state: FS,
|
||||
) -> Result<Vec<OwnedUserId>, Error>
|
||||
where
|
||||
FS: AsyncFn((StateEventType, StateKey)) -> Result<Option<Pdu>, Error>,
|
||||
{
|
||||
let create_event = fetch_state(StateEventType::RoomCreate.with_state_key(""))
|
||||
.await?
|
||||
.ok_or_else(|| Error::InvalidPdu("Room create event not found".to_owned()))?;
|
||||
let content = create_event
|
||||
.get_content::<RoomCreateEventContent>()
|
||||
.map_err(|e| {
|
||||
Error::InvalidPdu(format!("Room create event has invalid content: {}", e))
|
||||
})?;
|
||||
|
||||
if room_version.explicitly_privilege_room_creators {
|
||||
let mut creators = vec![create_event.sender().to_owned()];
|
||||
if let Some(additional) = content.additional_creators {
|
||||
for user_id in additional {
|
||||
if !creators.contains(&user_id) {
|
||||
creators.push(user_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(creators)
|
||||
} else if room_version.use_room_create_sender {
|
||||
Ok(vec![create_event.sender().to_owned()])
|
||||
} else {
|
||||
// Have to check the event content
|
||||
#[allow(deprecated)]
|
||||
if let Some(creator) = content.creator {
|
||||
Ok(vec![creator])
|
||||
} else {
|
||||
Err(Error::InvalidPdu("Room create event missing creator field".to_owned()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Rank fetches the creatorship and power level of the target user
|
||||
///
|
||||
/// Returns (UserPower, power_level, Option<RoomPowerLevelsEventContent>)
|
||||
/// If UserPower::Creator is returned, the power_level and
|
||||
/// RoomPowerLevelsEventContent will be meaningless and can be ignored.
|
||||
pub async fn get_rank<FS>(
|
||||
room_version: &RoomVersion,
|
||||
fetch_state: &FS,
|
||||
user_id: &UserId,
|
||||
) -> Result<(UserPower, Int, Option<RoomPowerLevelsEventContent>), Error>
|
||||
where
|
||||
FS: AsyncFn((StateEventType, StateKey)) -> Result<Option<Pdu>, Error>,
|
||||
{
|
||||
let creators = calculate_creators(room_version, &fetch_state).await?;
|
||||
if creators.contains(&user_id.to_owned()) && room_version.explicitly_privilege_room_creators {
|
||||
return Ok((UserPower::Creator, Int::MAX, None));
|
||||
}
|
||||
|
||||
let power_levels = fetch_state(StateEventType::RoomPowerLevels.with_state_key("")).await?;
|
||||
if let Some(power_levels) = power_levels {
|
||||
let power_levels = power_levels
|
||||
.get_content::<RoomPowerLevelsEventContent>()
|
||||
.map_err(|e| {
|
||||
Error::InvalidPdu(format!("m.room.power_levels event has invalid content: {}", e))
|
||||
})?;
|
||||
Ok((
|
||||
UserPower::Standard,
|
||||
*power_levels
|
||||
.users
|
||||
.get(user_id)
|
||||
.unwrap_or(&power_levels.users_default),
|
||||
Some(power_levels),
|
||||
))
|
||||
} else {
|
||||
// No power levels event, use defaults
|
||||
if creators[0] == user_id {
|
||||
return Ok((UserPower::Creator, Int::MAX, None));
|
||||
}
|
||||
Ok((UserPower::Standard, Int::from(0), None))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,97 @@
|
||||
//! Auth checks relevant to the `m.room.create` event specifically.
|
||||
//!
|
||||
//! See: https://spec.matrix.org/v1.16/rooms/v12/#authorization-rules
|
||||
|
||||
use ruma::{OwnedUserId, RoomVersionId, events::room::create::RoomCreateEventContent};
|
||||
use serde::Deserialize;
|
||||
use serde_json::from_str;
|
||||
|
||||
use crate::{Event, Pdu, RoomVersion, state_res::Error, trace};
|
||||
|
||||
// A raw representation of the create event content, for initial parsing.
|
||||
// This allows us to extract fields without fully validating the event first.
|
||||
#[derive(Deserialize)]
|
||||
struct RawCreateContent {
|
||||
creator: Option<String>,
|
||||
room_version: Option<String>,
|
||||
additional_creators: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
// Check whether an `m.room.create` event is valid.
|
||||
// This ensures that:
|
||||
//
|
||||
// 1. The event has no `prev_events`
|
||||
// 2. If the version disallows it, the event has no `room_id` present.
|
||||
// 3. If the room version is present and recognised, otherwise assume invalid.
|
||||
// 4. If the room version supports it, `additional_creators` is populated with
|
||||
// valid user IDs.
|
||||
// 5. If the room version supports it, `creator` is populated AND is a valid
|
||||
// user ID.
|
||||
// 6. Otherwise, this event is valid.
|
||||
//
|
||||
// The fully deserialized `RoomCreateEventContent` is returned for further calls
|
||||
// to other checks.
|
||||
pub fn check_room_create(event: &Pdu) -> Result<RoomCreateEventContent, Error> {
|
||||
// Check 1: The event has no `prev_events`
|
||||
if !event.prev_events.is_empty() {
|
||||
return Err(Error::InvalidPdu("m.room.create event has prev_events".to_owned()));
|
||||
}
|
||||
|
||||
let create_content = from_str::<RawCreateContent>(event.content().get())?;
|
||||
|
||||
// Note: Here we attempt to both load the raw room version string and validate
|
||||
// it, and then cast it to the room features. If either step fails, we return
|
||||
// an unsupported error. If the room version is missing, it defaults to "1",
|
||||
// which we also do not support.
|
||||
//
|
||||
// This performs check 3, which then allows us to perform check 2.
|
||||
let room_version = if let Some(raw_room_version) = create_content.room_version {
|
||||
trace!("Parsing and interpreting room version: {}", raw_room_version);
|
||||
let room_version_id = RoomVersionId::try_from(raw_room_version.as_str())
|
||||
.map_err(|_| Error::Unsupported(raw_room_version))?;
|
||||
RoomVersion::new(&room_version_id)
|
||||
.map_err(|_| Error::Unsupported(room_version_id.as_str().to_owned()))?
|
||||
} else {
|
||||
return Err(Error::Unsupported("1".to_owned()));
|
||||
};
|
||||
|
||||
// Check 2: If the version disallows it, the event has no `room_id` present.
|
||||
if room_version.room_ids_as_hashes && event.room_id.is_some() {
|
||||
return Err(Error::InvalidPdu(
|
||||
"m.room.create event has room_id but room version disallows it".to_owned(),
|
||||
));
|
||||
}
|
||||
|
||||
// Check 4: If the room version supports it, `additional_creators` is populated
|
||||
// with valid user IDs.
|
||||
if room_version.explicitly_privilege_room_creators {
|
||||
if let Some(additional_creators) = create_content.additional_creators {
|
||||
for creator in additional_creators {
|
||||
trace!("Validating additional creator user ID: {}", creator);
|
||||
if OwnedUserId::parse(&creator).is_err() {
|
||||
return Err(Error::InvalidPdu(format!(
|
||||
"Invalid user ID in additional_creators: {creator}"
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check 5: If the room version supports it, `creator` is populated AND is a
|
||||
// valid user ID.
|
||||
if !room_version.use_room_create_sender {
|
||||
if let Some(creator) = create_content.creator {
|
||||
trace!("Validating creator user ID: {}", creator);
|
||||
if OwnedUserId::parse(&creator).is_err() {
|
||||
return Err(Error::InvalidPdu(format!("Invalid user ID in creator: {creator}")));
|
||||
}
|
||||
} else {
|
||||
return Err(Error::InvalidPdu(
|
||||
"m.room.create event missing creator field".to_owned(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// Deserialise into the full create event for future checks.
|
||||
Ok(from_str::<RoomCreateEventContent>(event.content().get())?)
|
||||
}
|
||||
@@ -0,0 +1,650 @@
|
||||
use ruma::{
|
||||
EventId, OwnedUserId, RoomVersionId,
|
||||
events::{
|
||||
StateEventType, TimelineEventType,
|
||||
room::{create::RoomCreateEventContent, member::MembershipState},
|
||||
},
|
||||
int,
|
||||
serde::Raw,
|
||||
};
|
||||
use serde::{Deserialize, de::IgnoredAny};
|
||||
use serde_json::from_str as from_json_str;
|
||||
|
||||
use crate::{
|
||||
Event, EventTypeExt, Pdu, RoomVersion, debug, error,
|
||||
matrix::StateKey,
|
||||
state_res::{
|
||||
error::Error,
|
||||
event_auth::{
|
||||
auth_events::check_auth_events,
|
||||
context::{UserPower, calculate_creators, get_rank},
|
||||
create_event::check_room_create,
|
||||
member_event::check_member_event,
|
||||
power_levels::check_power_levels,
|
||||
},
|
||||
},
|
||||
trace, warn,
|
||||
};
|
||||
|
||||
// FIXME: field extracting could be bundled for `content`
|
||||
#[derive(Deserialize)]
|
||||
struct GetMembership {
|
||||
membership: MembershipState,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct RoomMemberContentFields {
|
||||
membership: Option<Raw<MembershipState>>,
|
||||
join_authorised_via_users_server: Option<Raw<OwnedUserId>>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct RoomCreateContentFields {
|
||||
room_version: Option<Raw<RoomVersionId>>,
|
||||
creator: Option<Raw<IgnoredAny>>,
|
||||
additional_creators: Option<Vec<Raw<OwnedUserId>>>,
|
||||
#[serde(rename = "m.federate", default = "ruma::serde::default_true")]
|
||||
federate: bool,
|
||||
}
|
||||
|
||||
/// Authenticate the incoming `event`.
|
||||
///
|
||||
/// The steps of authentication are:
|
||||
///
|
||||
/// * check that the event is being authenticated for the correct room
|
||||
/// * then there are checks for specific event types
|
||||
///
|
||||
/// The `fetch_state` closure should gather state from a state snapshot. We need
|
||||
/// to know if the event passes auth against some state not a recursive
|
||||
/// collection of auth_events fields.
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
fields(
|
||||
event_id = incoming_event.event_id().as_str(),
|
||||
event_type = ?incoming_event.event_type().to_string()
|
||||
)
|
||||
)]
|
||||
#[allow(clippy::suspicious_operation_groupings)]
|
||||
pub async fn auth_check<FE, FS>(
|
||||
room_version: &RoomVersion,
|
||||
incoming_event: &Pdu,
|
||||
fetch_event: &FE,
|
||||
fetch_state: &FS,
|
||||
create_event: Option<&Pdu>,
|
||||
) -> Result<bool, Error>
|
||||
where
|
||||
FE: AsyncFn(&EventId) -> Result<Option<Pdu>, Error>,
|
||||
FS: AsyncFn((StateEventType, StateKey)) -> Result<Option<Pdu>, Error>,
|
||||
{
|
||||
debug!("auth_check beginning");
|
||||
let sender = incoming_event.sender();
|
||||
|
||||
// Since v1, If type is m.room.create:
|
||||
if *incoming_event.event_type() == TimelineEventType::RoomCreate {
|
||||
debug!("start m.room.create check");
|
||||
if let Err(e) = check_room_create(incoming_event) {
|
||||
warn!("m.room.create event has been rejected: {}", e);
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
debug!("m.room.create event was allowed");
|
||||
return Ok(true);
|
||||
}
|
||||
let Some(create_event) = create_event else {
|
||||
error!("no create event provided for auth check");
|
||||
return Err(Error::InvalidPdu("missing create event".to_owned()));
|
||||
};
|
||||
|
||||
// TODO: we need to know if events have previously been rejected or soft failed
|
||||
// For now, we'll just assume the create_event is valid.
|
||||
let create_content = from_json_str::<RoomCreateEventContent>(create_event.content().get())
|
||||
.expect("provided create event must be valid");
|
||||
|
||||
// Since v12, If the event’s room_id is not an event ID for an accepted (not
|
||||
// rejected) m.room.create event, with the sigil ! instead of $, reject.
|
||||
if room_version.room_ids_as_hashes {
|
||||
let calculated_room_id = create_event.event_id().as_str().replace('$', "!");
|
||||
if let Some(claimed_room_id) = create_event.room_id() {
|
||||
if claimed_room_id.as_str() != calculated_room_id {
|
||||
warn!(
|
||||
expected = %calculated_room_id,
|
||||
received = %claimed_room_id,
|
||||
"event's room ID does not match the hash of the m.room.create event ID"
|
||||
);
|
||||
return Ok(false);
|
||||
}
|
||||
} else {
|
||||
warn!("event is missing a room ID");
|
||||
return Ok(false);
|
||||
}
|
||||
}
|
||||
|
||||
let room_id = incoming_event.room_id().expect("event must have a room ID");
|
||||
|
||||
let auth_map =
|
||||
match check_auth_events(incoming_event, room_id, &room_version, fetch_event).await {
|
||||
| Ok(map) => map,
|
||||
| Err(e) => {
|
||||
warn!("event's auth events are invalid: {}", e);
|
||||
return Ok(false);
|
||||
},
|
||||
};
|
||||
|
||||
// Considering the event's auth_events
|
||||
|
||||
// Since v1, If the content of the m.room.create event in the room state has the
|
||||
// property m.federate set to false, and the sender domain of the event does
|
||||
// not match the sender domain of the create event, reject.
|
||||
if !create_content.federate {
|
||||
if create_event.sender().server_name() != incoming_event.sender().server_name() {
|
||||
warn!(
|
||||
sender = %incoming_event.sender(),
|
||||
create_sender = %create_event.sender(),
|
||||
"room is not federated and event's sender domain does not match create event's sender domain"
|
||||
);
|
||||
return Ok(false);
|
||||
}
|
||||
}
|
||||
|
||||
// From v1 to v5, If type is m.room.aliases
|
||||
if room_version.special_case_aliases_auth
|
||||
&& *incoming_event.event_type() == TimelineEventType::RoomAliases
|
||||
{
|
||||
if let Some(state_key) = incoming_event.state_key() {
|
||||
// If sender's domain doesn't matches state_key, reject
|
||||
if state_key != sender.server_name().as_str() {
|
||||
warn!("state_key does not match sender");
|
||||
return Ok(false);
|
||||
}
|
||||
// Otherwise, allow
|
||||
return Ok(true);
|
||||
}
|
||||
// If event has no state_key, reject.
|
||||
warn!("m.room.alias event has no state key");
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
// From v1, If type is m.room.member
|
||||
if *incoming_event.event_type() == TimelineEventType::RoomMember {
|
||||
if let Err(e) =
|
||||
check_member_event(&room_version, incoming_event, fetch_event, fetch_state).await
|
||||
{
|
||||
warn!("m.room.member event has been rejected: {}", e);
|
||||
return Ok(false);
|
||||
}
|
||||
}
|
||||
|
||||
// From v1, If the sender's current membership state is not join, reject
|
||||
let sender_member_event =
|
||||
match auth_map.get(&StateEventType::RoomMember.with_state_key(sender.as_str())) {
|
||||
| Some(ev) => ev,
|
||||
| None => {
|
||||
warn!(
|
||||
%sender,
|
||||
"sender is not joined - no membership event found for sender in auth events"
|
||||
);
|
||||
return Ok(false);
|
||||
},
|
||||
};
|
||||
|
||||
let sender_membership_event_content: RoomMemberContentFields =
|
||||
from_json_str(sender_member_event.content().get())?;
|
||||
let Some(membership_state) = sender_membership_event_content.membership else {
|
||||
warn!(
|
||||
?sender_membership_event_content,
|
||||
"Sender membership event content missing membership field"
|
||||
);
|
||||
return Err(Error::InvalidPdu("Missing membership field".to_owned()));
|
||||
};
|
||||
let membership_state = membership_state.deserialize()?;
|
||||
|
||||
if membership_state != MembershipState::Join {
|
||||
warn!(
|
||||
%sender,
|
||||
?membership_state,
|
||||
"sender cannot send events without being joined to the room"
|
||||
);
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
// From v1, If type is m.room.third_party_invite
|
||||
let (rank, sender_pl, pl_evt) = get_rank(&room_version, fetch_state, sender).await?;
|
||||
|
||||
// Allow if and only if sender's current power level is greater than
|
||||
// or equal to the invite level
|
||||
if *incoming_event.event_type() == TimelineEventType::RoomThirdPartyInvite {
|
||||
if rank == UserPower::Creator {
|
||||
trace!("sender is room creator, allowing m.room.third_party_invite");
|
||||
return Ok(true);
|
||||
}
|
||||
let invite_level = match &pl_evt {
|
||||
| Some(power_levels) => power_levels.invite,
|
||||
| None => int!(0),
|
||||
};
|
||||
|
||||
if sender_pl < invite_level {
|
||||
warn!(
|
||||
%sender,
|
||||
has=%sender_pl,
|
||||
required=%invite_level,
|
||||
"sender cannot send invites in this room"
|
||||
);
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
debug!("m.room.third_party_invite event was allowed");
|
||||
return Ok(true);
|
||||
}
|
||||
|
||||
// Since v1, if the event type’s required power level is greater than the
|
||||
// sender’s power level, reject.
|
||||
let required_level = match &pl_evt {
|
||||
| Some(power_levels) => power_levels
|
||||
.events
|
||||
.get(incoming_event.kind())
|
||||
.unwrap_or_else(|| {
|
||||
if incoming_event.state_key.is_some() {
|
||||
&power_levels.state_default
|
||||
} else {
|
||||
&power_levels.events_default
|
||||
}
|
||||
}),
|
||||
| None => &int!(0),
|
||||
};
|
||||
if rank != UserPower::Creator && sender_pl < *required_level {
|
||||
warn!(
|
||||
%sender,
|
||||
has=%sender_pl,
|
||||
required=%required_level,
|
||||
"sender does not have enough power level to send this event"
|
||||
);
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
// Since v1, If the event has a state_key that starts with an @ and does not
|
||||
// match the sender, reject.
|
||||
if let Some(state_key) = incoming_event.state_key() {
|
||||
if state_key.starts_with('@') && state_key != sender.as_str() {
|
||||
warn!(
|
||||
%sender,
|
||||
%state_key,
|
||||
"event's state key starts with @ and does not match sender"
|
||||
);
|
||||
return Ok(false);
|
||||
}
|
||||
}
|
||||
|
||||
// Since v1, If type is m.room.power_levels
|
||||
if *incoming_event.event_type() == TimelineEventType::RoomPowerLevels {
|
||||
let creators = calculate_creators(&room_version, fetch_state).await?;
|
||||
if let Err(e) =
|
||||
check_power_levels(&room_version, incoming_event, pl_evt.as_ref(), creators).await
|
||||
{
|
||||
warn!(
|
||||
%sender,
|
||||
"m.room.power_levels event has been rejected: {}", e
|
||||
);
|
||||
return Ok(false);
|
||||
}
|
||||
}
|
||||
|
||||
// From v1 to v2: If type is m.room.redaction:
|
||||
// If the sender’s power level is greater than or equal to the redact level,
|
||||
// allow.
|
||||
// If the domain of the event_id of the event being redacted is the same as the
|
||||
// domain of the event_id of the m.room.redaction, allow.
|
||||
// Otherwise, reject.
|
||||
if room_version.extra_redaction_checks {
|
||||
// We'll panic here, since while we don't theoretically support the room
|
||||
// versions that require this, we don't want to incorrectly permit an event
|
||||
// that should be rejected in this theoretically impossible scenario.
|
||||
unreachable!(
|
||||
"continuwuity does not support room versions that require extra redaction checks"
|
||||
);
|
||||
}
|
||||
|
||||
debug!("allowing event passed all checks");
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use ruma::events::{
|
||||
StateEventType, TimelineEventType,
|
||||
room::{
|
||||
join_rules::{
|
||||
AllowRule, JoinRule, Restricted, RoomJoinRulesEventContent, RoomMembership,
|
||||
},
|
||||
member::{MembershipState, RoomMemberEventContent},
|
||||
},
|
||||
};
|
||||
use serde_json::value::to_raw_value as to_raw_json_value;
|
||||
|
||||
use crate::{
|
||||
matrix::{Event, EventTypeExt, Pdu as PduEvent},
|
||||
state_res::{
|
||||
RoomVersion, StateMap,
|
||||
event_auth::{
|
||||
iterative_auth_checks::valid_membership_change, valid_membership_change,
|
||||
},
|
||||
test_utils::{
|
||||
INITIAL_EVENTS, INITIAL_EVENTS_CREATE_ROOM, alice, charlie, ella, event_id,
|
||||
member_content_ban, member_content_join, room_id, to_pdu_event,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn test_ban_pass() {
|
||||
let _ = tracing::subscriber::set_default(
|
||||
tracing_subscriber::fmt().with_test_writer().finish(),
|
||||
);
|
||||
let events = INITIAL_EVENTS();
|
||||
|
||||
let auth_events = events
|
||||
.values()
|
||||
.map(|ev| (ev.event_type().with_state_key(ev.state_key().unwrap()), ev.clone()))
|
||||
.collect::<StateMap<_>>();
|
||||
|
||||
let requester = to_pdu_event(
|
||||
"HELLO",
|
||||
alice(),
|
||||
TimelineEventType::RoomMember,
|
||||
Some(charlie().as_str()),
|
||||
member_content_ban(),
|
||||
&[],
|
||||
&["IMC"],
|
||||
);
|
||||
|
||||
let fetch_state = |ty, key| auth_events.get(&(ty, key)).cloned();
|
||||
let target_user = charlie();
|
||||
let sender = alice();
|
||||
|
||||
assert!(
|
||||
valid_membership_change(
|
||||
&RoomVersion::V6,
|
||||
target_user,
|
||||
fetch_state(StateEventType::RoomMember, target_user.as_str().into()).as_ref(),
|
||||
sender,
|
||||
fetch_state(StateEventType::RoomMember, sender.as_str().into()).as_ref(),
|
||||
&requester,
|
||||
None::<&PduEvent>,
|
||||
fetch_state(StateEventType::RoomPowerLevels, "".into()).as_ref(),
|
||||
fetch_state(StateEventType::RoomJoinRules, "".into()).as_ref(),
|
||||
None,
|
||||
&MembershipState::Leave,
|
||||
&fetch_state(StateEventType::RoomCreate, "".into()).unwrap(),
|
||||
)
|
||||
.unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_join_non_creator() {
|
||||
let _ = tracing::subscriber::set_default(
|
||||
tracing_subscriber::fmt().with_test_writer().finish(),
|
||||
);
|
||||
let events = INITIAL_EVENTS_CREATE_ROOM();
|
||||
|
||||
let auth_events = events
|
||||
.values()
|
||||
.map(|ev| (ev.event_type().with_state_key(ev.state_key().unwrap()), ev.clone()))
|
||||
.collect::<StateMap<_>>();
|
||||
|
||||
let requester = to_pdu_event(
|
||||
"HELLO",
|
||||
charlie(),
|
||||
TimelineEventType::RoomMember,
|
||||
Some(charlie().as_str()),
|
||||
member_content_join(),
|
||||
&["CREATE"],
|
||||
&["CREATE"],
|
||||
);
|
||||
|
||||
let fetch_state = |ty, key| auth_events.get(&(ty, key)).cloned();
|
||||
let target_user = charlie();
|
||||
let sender = charlie();
|
||||
|
||||
assert!(
|
||||
!valid_membership_change(
|
||||
&RoomVersion::V6,
|
||||
target_user,
|
||||
fetch_state(StateEventType::RoomMember, target_user.as_str().into()).as_ref(),
|
||||
sender,
|
||||
fetch_state(StateEventType::RoomMember, sender.as_str().into()).as_ref(),
|
||||
&requester,
|
||||
None::<&PduEvent>,
|
||||
fetch_state(StateEventType::RoomPowerLevels, "".into()).as_ref(),
|
||||
fetch_state(StateEventType::RoomJoinRules, "".into()).as_ref(),
|
||||
None,
|
||||
&MembershipState::Leave,
|
||||
&fetch_state(StateEventType::RoomCreate, "".into()).unwrap(),
|
||||
)
|
||||
.unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_join_creator() {
|
||||
let _ = tracing::subscriber::set_default(
|
||||
tracing_subscriber::fmt().with_test_writer().finish(),
|
||||
);
|
||||
let events = INITIAL_EVENTS_CREATE_ROOM();
|
||||
|
||||
let auth_events = events
|
||||
.values()
|
||||
.map(|ev| (ev.event_type().with_state_key(ev.state_key().unwrap()), ev.clone()))
|
||||
.collect::<StateMap<_>>();
|
||||
|
||||
let requester = to_pdu_event(
|
||||
"HELLO",
|
||||
alice(),
|
||||
TimelineEventType::RoomMember,
|
||||
Some(alice().as_str()),
|
||||
member_content_join(),
|
||||
&["CREATE"],
|
||||
&["CREATE"],
|
||||
);
|
||||
|
||||
let fetch_state = |ty, key| auth_events.get(&(ty, key)).cloned();
|
||||
let target_user = alice();
|
||||
let sender = alice();
|
||||
|
||||
assert!(
|
||||
valid_membership_change(
|
||||
&RoomVersion::V6,
|
||||
target_user,
|
||||
fetch_state(StateEventType::RoomMember, target_user.as_str().into()).as_ref(),
|
||||
sender,
|
||||
fetch_state(StateEventType::RoomMember, sender.as_str().into()).as_ref(),
|
||||
&requester,
|
||||
None::<&PduEvent>,
|
||||
fetch_state(StateEventType::RoomPowerLevels, "".into()).as_ref(),
|
||||
fetch_state(StateEventType::RoomJoinRules, "".into()).as_ref(),
|
||||
None,
|
||||
&MembershipState::Leave,
|
||||
&fetch_state(StateEventType::RoomCreate, "".into()).unwrap(),
|
||||
)
|
||||
.unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ban_fail() {
|
||||
let _ = tracing::subscriber::set_default(
|
||||
tracing_subscriber::fmt().with_test_writer().finish(),
|
||||
);
|
||||
let events = INITIAL_EVENTS();
|
||||
|
||||
let auth_events = events
|
||||
.values()
|
||||
.map(|ev| (ev.event_type().with_state_key(ev.state_key().unwrap()), ev.clone()))
|
||||
.collect::<StateMap<_>>();
|
||||
|
||||
let requester = to_pdu_event(
|
||||
"HELLO",
|
||||
charlie(),
|
||||
TimelineEventType::RoomMember,
|
||||
Some(alice().as_str()),
|
||||
member_content_ban(),
|
||||
&[],
|
||||
&["IMC"],
|
||||
);
|
||||
|
||||
let fetch_state = |ty, key| auth_events.get(&(ty, key)).cloned();
|
||||
let target_user = alice();
|
||||
let sender = charlie();
|
||||
|
||||
assert!(
|
||||
!valid_membership_change(
|
||||
&RoomVersion::V6,
|
||||
target_user,
|
||||
fetch_state(StateEventType::RoomMember, target_user.as_str().into()).as_ref(),
|
||||
sender,
|
||||
fetch_state(StateEventType::RoomMember, sender.as_str().into()).as_ref(),
|
||||
&requester,
|
||||
None::<&PduEvent>,
|
||||
fetch_state(StateEventType::RoomPowerLevels, "".into()).as_ref(),
|
||||
fetch_state(StateEventType::RoomJoinRules, "".into()).as_ref(),
|
||||
None,
|
||||
&MembershipState::Leave,
|
||||
&fetch_state(StateEventType::RoomCreate, "".into()).unwrap(),
|
||||
)
|
||||
.unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_restricted_join_rule() {
|
||||
let _ = tracing::subscriber::set_default(
|
||||
tracing_subscriber::fmt().with_test_writer().finish(),
|
||||
);
|
||||
let mut events = INITIAL_EVENTS();
|
||||
*events.get_mut(&event_id("IJR")).unwrap() = to_pdu_event(
|
||||
"IJR",
|
||||
alice(),
|
||||
TimelineEventType::RoomJoinRules,
|
||||
Some(""),
|
||||
to_raw_json_value(&RoomJoinRulesEventContent::new(JoinRule::Restricted(
|
||||
Restricted::new(vec![AllowRule::RoomMembership(RoomMembership::new(
|
||||
room_id().to_owned(),
|
||||
))]),
|
||||
)))
|
||||
.unwrap(),
|
||||
&["CREATE", "IMA", "IPOWER"],
|
||||
&["IPOWER"],
|
||||
);
|
||||
|
||||
let mut member = RoomMemberEventContent::new(MembershipState::Join);
|
||||
member.join_authorized_via_users_server = Some(alice().to_owned());
|
||||
|
||||
let auth_events = events
|
||||
.values()
|
||||
.map(|ev| (ev.event_type().with_state_key(ev.state_key().unwrap()), ev.clone()))
|
||||
.collect::<StateMap<_>>();
|
||||
|
||||
let requester = to_pdu_event(
|
||||
"HELLO",
|
||||
ella(),
|
||||
TimelineEventType::RoomMember,
|
||||
Some(ella().as_str()),
|
||||
to_raw_json_value(&RoomMemberEventContent::new(MembershipState::Join)).unwrap(),
|
||||
&["CREATE", "IJR", "IPOWER", "new"],
|
||||
&["new"],
|
||||
);
|
||||
|
||||
let fetch_state = |ty, key| auth_events.get(&(ty, key)).cloned();
|
||||
let target_user = ella();
|
||||
let sender = ella();
|
||||
|
||||
assert!(
|
||||
valid_membership_change(
|
||||
&RoomVersion::V9,
|
||||
target_user,
|
||||
fetch_state(StateEventType::RoomMember, target_user.as_str().into()).as_ref(),
|
||||
sender,
|
||||
fetch_state(StateEventType::RoomMember, sender.as_str().into()).as_ref(),
|
||||
&requester,
|
||||
None::<&PduEvent>,
|
||||
fetch_state(StateEventType::RoomPowerLevels, "".into()).as_ref(),
|
||||
fetch_state(StateEventType::RoomJoinRules, "".into()).as_ref(),
|
||||
Some(alice()),
|
||||
&MembershipState::Join,
|
||||
&fetch_state(StateEventType::RoomCreate, "".into()).unwrap(),
|
||||
)
|
||||
.unwrap()
|
||||
);
|
||||
|
||||
assert!(
|
||||
!valid_membership_change(
|
||||
&RoomVersion::V9,
|
||||
target_user,
|
||||
fetch_state(StateEventType::RoomMember, target_user.as_str().into()).as_ref(),
|
||||
sender,
|
||||
fetch_state(StateEventType::RoomMember, sender.as_str().into()).as_ref(),
|
||||
&requester,
|
||||
None::<&PduEvent>,
|
||||
fetch_state(StateEventType::RoomPowerLevels, "".into()).as_ref(),
|
||||
fetch_state(StateEventType::RoomJoinRules, "".into()).as_ref(),
|
||||
Some(ella()),
|
||||
&MembershipState::Leave,
|
||||
&fetch_state(StateEventType::RoomCreate, "".into()).unwrap(),
|
||||
)
|
||||
.unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_knock() {
|
||||
let _ = tracing::subscriber::set_default(
|
||||
tracing_subscriber::fmt().with_test_writer().finish(),
|
||||
);
|
||||
let mut events = INITIAL_EVENTS();
|
||||
*events.get_mut(&event_id("IJR")).unwrap() = to_pdu_event(
|
||||
"IJR",
|
||||
alice(),
|
||||
TimelineEventType::RoomJoinRules,
|
||||
Some(""),
|
||||
to_raw_json_value(&RoomJoinRulesEventContent::new(JoinRule::Knock)).unwrap(),
|
||||
&["CREATE", "IMA", "IPOWER"],
|
||||
&["IPOWER"],
|
||||
);
|
||||
|
||||
let auth_events = events
|
||||
.values()
|
||||
.map(|ev| (ev.event_type().with_state_key(ev.state_key().unwrap()), ev.clone()))
|
||||
.collect::<StateMap<_>>();
|
||||
|
||||
let requester = to_pdu_event(
|
||||
"HELLO",
|
||||
ella(),
|
||||
TimelineEventType::RoomMember,
|
||||
Some(ella().as_str()),
|
||||
to_raw_json_value(&RoomMemberEventContent::new(MembershipState::Knock)).unwrap(),
|
||||
&[],
|
||||
&["IMC"],
|
||||
);
|
||||
|
||||
let fetch_state = |ty, key| auth_events.get(&(ty, key)).cloned();
|
||||
let target_user = ella();
|
||||
let sender = ella();
|
||||
|
||||
assert!(
|
||||
valid_membership_change(
|
||||
&RoomVersion::V7,
|
||||
target_user,
|
||||
fetch_state(StateEventType::RoomMember, target_user.as_str().into()).as_ref(),
|
||||
sender,
|
||||
fetch_state(StateEventType::RoomMember, sender.as_str().into()).as_ref(),
|
||||
&requester,
|
||||
None::<&PduEvent>,
|
||||
fetch_state(StateEventType::RoomPowerLevels, "".into()).as_ref(),
|
||||
fetch_state(StateEventType::RoomJoinRules, "".into()).as_ref(),
|
||||
None,
|
||||
&MembershipState::Leave,
|
||||
&fetch_state(StateEventType::RoomCreate, "".into()).unwrap(),
|
||||
)
|
||||
.unwrap()
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,422 @@
|
||||
//! Auth checks relevant to the `m.room.member` event specifically.
|
||||
//!
|
||||
//! See: https://spec.matrix.org/v1.16/rooms/v12/#authorization-rules
|
||||
|
||||
use ruma::{
|
||||
EventId, OwnedUserId, UserId,
|
||||
events::{
|
||||
StateEventType,
|
||||
room::{
|
||||
join_rules::{JoinRule, RoomJoinRulesEventContent},
|
||||
third_party_invite::{PublicKey, RoomThirdPartyInviteEventContent},
|
||||
},
|
||||
},
|
||||
serde::Base64,
|
||||
signatures::{PublicKeyMap, PublicKeySet, verify_json},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
Event, EventTypeExt, Pdu, RoomVersion,
|
||||
matrix::StateKey,
|
||||
state_res::{
|
||||
Error,
|
||||
event_auth::context::{UserPower, get_rank},
|
||||
},
|
||||
utils::to_canonical_object,
|
||||
};
|
||||
|
||||
#[derive(serde::Deserialize, Default)]
|
||||
struct PartialMembershipObject {
|
||||
membership: Option<String>,
|
||||
join_authorized_via_users_server: Option<OwnedUserId>,
|
||||
third_party_invite: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
/// Fetches the membership *content* of the target.
|
||||
/// If there is not one, an empty leave membership is returned.
|
||||
async fn fetch_membership<FS>(
|
||||
fetch_state: &FS,
|
||||
target: &UserId,
|
||||
) -> Result<PartialMembershipObject, Error>
|
||||
where
|
||||
FS: AsyncFn((StateEventType, StateKey)) -> Result<Option<Pdu>, Error>,
|
||||
{
|
||||
fetch_state(StateEventType::RoomMember.with_state_key(target.as_str()))
|
||||
.await
|
||||
.map(|pdu| {
|
||||
if let Some(ev) = pdu {
|
||||
ev.get_content::<PartialMembershipObject>().map_err(|e| {
|
||||
Error::InvalidPdu(format!("m.room.member event has invalid content: {}", e))
|
||||
})
|
||||
} else {
|
||||
Ok(PartialMembershipObject {
|
||||
membership: Some("leave".to_owned()),
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
})?
|
||||
}
|
||||
|
||||
async fn check_join_event<FE, FS>(
|
||||
room_version: &RoomVersion,
|
||||
event: &Pdu,
|
||||
membership: &PartialMembershipObject,
|
||||
target: &UserId,
|
||||
fetch_event: &FE,
|
||||
fetch_state: &FS,
|
||||
) -> Result<(), Error>
|
||||
where
|
||||
FE: AsyncFn(&EventId) -> Result<Option<Pdu>, Error>,
|
||||
FS: AsyncFn((StateEventType, StateKey)) -> Result<Option<Pdu>, Error>,
|
||||
{
|
||||
// 3.1: If the only previous event is an m.room.create and the state_key is the
|
||||
// sender of the m.room.create, allow.
|
||||
if event.prev_events.len() == 1 {
|
||||
let only_prev = fetch_event(&event.prev_events[0]).await?;
|
||||
if let Some(prev_event) = only_prev {
|
||||
let k = prev_event.event_type().with_state_key("");
|
||||
if k.0 == StateEventType::RoomCreate && k.1.as_str() == event.sender().as_str() {
|
||||
return Ok(());
|
||||
}
|
||||
} else {
|
||||
return Err(Error::DependencyFailed(
|
||||
event.prev_events[0].to_owned(),
|
||||
"Previous event not found when checking join event".to_owned(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// 3.2: If the sender does not match state_key, reject.
|
||||
if event.sender() != target {
|
||||
return Err(Error::AuthConditionFailed(
|
||||
"m.room.member join event sender does not match state_key".to_owned(),
|
||||
));
|
||||
}
|
||||
|
||||
let prev_membership = if let Some(ev) =
|
||||
fetch_state(StateEventType::RoomMember.with_state_key(target.as_str())).await?
|
||||
{
|
||||
Some(ev.get_content::<PartialMembershipObject>().map_err(|e| {
|
||||
Error::InvalidPdu(format!("Previous m.room.member event has invalid content: {}", e))
|
||||
})?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let join_rule_content =
|
||||
if let Some(jr) = fetch_state(StateEventType::RoomJoinRules.with_state_key("")).await? {
|
||||
jr.get_content::<RoomJoinRulesEventContent>().map_err(|e| {
|
||||
Error::InvalidPdu(format!("m.room.join_rules event has invalid content: {}", e))
|
||||
})?
|
||||
} else {
|
||||
// Default to invite if no join rules event is present.
|
||||
RoomJoinRulesEventContent { join_rule: JoinRule::Private }
|
||||
};
|
||||
|
||||
// 3.3: If the sender is banned, reject.
|
||||
let prev_member = if let Some(prev_content) = &prev_membership {
|
||||
if let Some(membership) = &prev_content.membership {
|
||||
if membership == "ban" {
|
||||
return Err(Error::AuthConditionFailed(
|
||||
"m.room.member join event sender is banned".to_owned(),
|
||||
));
|
||||
}
|
||||
membership
|
||||
} else {
|
||||
"leave"
|
||||
}
|
||||
} else {
|
||||
"leave"
|
||||
};
|
||||
|
||||
// 3.4: If the join_rule is invite or knock then allow if membership
|
||||
// state is invite or join.
|
||||
// 3.5: If the join_rule is restricted or knock_restricted:
|
||||
// 3.5.1: If membership state is join or invite, allow.
|
||||
match join_rule_content.join_rule {
|
||||
| JoinRule::Invite | JoinRule::Knock => {
|
||||
if prev_member == "invite" || prev_member == "join" {
|
||||
return Ok(());
|
||||
}
|
||||
Err(Error::AuthConditionFailed(
|
||||
"m.room.member join event not invited under invite/knock join rule".to_owned(),
|
||||
))
|
||||
},
|
||||
| JoinRule::Restricted(_) | JoinRule::KnockRestricted(_) => {
|
||||
// 3.5.2: If the join_authorised_via_users_server key in content is not a user
|
||||
// with sufficient permission to invite other users or is not a joined
|
||||
// member of the room, reject.
|
||||
if prev_member == "invite" || prev_member == "join" {
|
||||
return Ok(());
|
||||
}
|
||||
let join_authed_by = membership.join_authorized_via_users_server.as_ref();
|
||||
if let Some(user_id) = join_authed_by {
|
||||
let rank = get_rank(&room_version, fetch_state, user_id).await?;
|
||||
if rank.0 == UserPower::Standard {
|
||||
// This user is not a creator, check that they have
|
||||
// sufficient power level
|
||||
if rank.1 < rank.2.unwrap().invite {
|
||||
return Err(Error::InvalidPdu(
|
||||
"m.room.member join event join_authorised_via_users_server does not \
|
||||
have sufficient power level to invite"
|
||||
.to_owned(),
|
||||
));
|
||||
}
|
||||
}
|
||||
// Check that the user is a joined member of the room
|
||||
if let Some(state_event) =
|
||||
fetch_state(StateEventType::RoomMember.with_state_key(user_id.as_str()))
|
||||
.await?
|
||||
{
|
||||
let state_content = state_event
|
||||
.get_content::<PartialMembershipObject>()
|
||||
.map_err(|e| {
|
||||
Error::InvalidPdu(format!(
|
||||
"m.room.member event has invalid content: {}",
|
||||
e
|
||||
))
|
||||
})?;
|
||||
if let Some(state_membership) = &state_content.membership {
|
||||
if state_membership == "join" {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return Err(Error::AuthConditionFailed(
|
||||
"m.room.member join event missing join_authorised_via_users_server"
|
||||
.to_owned(),
|
||||
));
|
||||
}
|
||||
|
||||
// 3.5.3: Otherwise, allow
|
||||
return Ok(());
|
||||
},
|
||||
| JoinRule::Public => return Ok(()),
|
||||
| _ => Err(Error::AuthConditionFailed(format!(
|
||||
"unknown join rule: {:?}",
|
||||
join_rule_content.join_rule
|
||||
)))?,
|
||||
}
|
||||
}
|
||||
|
||||
/// Checks a third-party invite is valid.
|
||||
async fn check_third_party_invite(
|
||||
target_current_membership: PartialMembershipObject,
|
||||
raw_third_party_invite: &serde_json::Value,
|
||||
target: &UserId,
|
||||
event: &Pdu,
|
||||
fetch_state: impl AsyncFn((StateEventType, StateKey)) -> Result<Option<Pdu>, Error>,
|
||||
) -> Result<(), Error> {
|
||||
// 4.1.1: If target user is banned, reject.
|
||||
if target_current_membership
|
||||
.membership
|
||||
.is_some_and(|m| m == "ban")
|
||||
{
|
||||
return Err(Error::AuthConditionFailed("invite target is banned".to_owned()));
|
||||
}
|
||||
// 4.1.2: If content.third_party_invite does not have a signed property, reject.
|
||||
let signed = raw_third_party_invite.get("signed").ok_or_else(|| {
|
||||
Error::AuthConditionFailed(
|
||||
"invite event third_party_invite missing signed property".to_owned(),
|
||||
)
|
||||
})?;
|
||||
// 4.2.3: If signed does not have mxid and token properties, reject.
|
||||
let mxid = signed.get("mxid").and_then(|v| v.as_str()).ok_or_else(|| {
|
||||
Error::AuthConditionFailed(
|
||||
"invite event third_party_invite signed missing/invalid mxid property".to_owned(),
|
||||
)
|
||||
})?;
|
||||
let token = signed
|
||||
.get("token")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| {
|
||||
Error::AuthConditionFailed(
|
||||
"invite event third_party_invite signed missing token property".to_owned(),
|
||||
)
|
||||
})?;
|
||||
// 4.2.4: If mxid does not match state_key, reject.
|
||||
if mxid != target.as_str() {
|
||||
return Err(Error::AuthConditionFailed(
|
||||
"invite event third_party_invite signed mxid does not match state_key".to_owned(),
|
||||
));
|
||||
}
|
||||
// 4.2.5: If there is no m.room.third_party_invite event in the room
|
||||
// state matching the token, reject.
|
||||
let Some(third_party_invite_event) =
|
||||
fetch_state(StateEventType::RoomThirdPartyInvite.with_state_key(token)).await?
|
||||
else {
|
||||
return Err(Error::AuthConditionFailed(
|
||||
"invite event third_party_invite token has no matching m.room.third_party_invite"
|
||||
.to_owned(),
|
||||
));
|
||||
};
|
||||
// 4.2.6: If sender does not match sender of the m.room.third_party_invite,
|
||||
// reject.
|
||||
if third_party_invite_event.sender() != event.sender() {
|
||||
return Err(Error::AuthConditionFailed(
|
||||
"invite event sender does not match m.room.third_party_invite sender".to_owned(),
|
||||
));
|
||||
}
|
||||
// 4.2.7: If any signature in signed matches any public key in the
|
||||
// m.room.third_party_invite event, allow. The public keys are in
|
||||
// content of m.room.third_party_invite as:
|
||||
// 1. A single public key in the public_key property.
|
||||
// 2. A list of public keys in the public_keys property.
|
||||
let tpi_content = third_party_invite_event
|
||||
.get_content::<RoomThirdPartyInviteEventContent>()
|
||||
.or_else(|_| {
|
||||
Err(Error::InvalidPdu(
|
||||
"m.room.third_party_invite event has invalid content".to_owned(),
|
||||
))
|
||||
})?;
|
||||
let mut public_keys = tpi_content.public_keys.unwrap_or_default();
|
||||
public_keys.push(PublicKey {
|
||||
public_key: tpi_content.public_key,
|
||||
key_validity_url: None,
|
||||
});
|
||||
|
||||
let signatures = signed
|
||||
.get("signatures")
|
||||
.and_then(|v| v.as_object())
|
||||
.ok_or_else(|| {
|
||||
Error::InvalidPdu(
|
||||
"invite event third_party_invite signed missing/invalid signatures".to_owned(),
|
||||
)
|
||||
})?;
|
||||
let mut public_key_map = PublicKeyMap::new();
|
||||
for (server_name, sig_map) in signatures {
|
||||
let mut pk_set = PublicKeySet::new();
|
||||
if let Some(sig_map) = sig_map.as_object() {
|
||||
for (key_id, sig) in sig_map {
|
||||
let sig_b64 = Base64::parse(sig.as_str().ok_or(Error::InvalidPdu(
|
||||
"invite event third_party_invite signature is not a string".to_owned(),
|
||||
))?)
|
||||
.map_err(|_| {
|
||||
Error::InvalidPdu(
|
||||
"invite event third_party_invite signature is not valid Base64"
|
||||
.to_owned(),
|
||||
)
|
||||
})?;
|
||||
pk_set.insert(key_id.clone(), sig_b64);
|
||||
}
|
||||
}
|
||||
public_key_map.insert(server_name.clone(), pk_set);
|
||||
}
|
||||
verify_json(
|
||||
&public_key_map,
|
||||
to_canonical_object(signed).expect("signed was already validated"),
|
||||
)
|
||||
.map_err(|e| {
|
||||
Error::AuthConditionFailed(format!(
|
||||
"invite event third_party_invite signature verification failed: {e}"
|
||||
))
|
||||
})?;
|
||||
// If there was no error, there was a valid signature, so allow.
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn check_invite_event<FS>(
|
||||
room_version: &RoomVersion,
|
||||
event: &Pdu,
|
||||
membership: &PartialMembershipObject,
|
||||
target: &UserId,
|
||||
fetch_state: &FS,
|
||||
) -> Result<(), Error>
|
||||
where
|
||||
FS: AsyncFn((StateEventType, StateKey)) -> Result<Option<Pdu>, Error>,
|
||||
{
|
||||
let target_current_membership = fetch_membership(fetch_state, target).await?;
|
||||
|
||||
// 4.1: If content has a third_party_invite property:
|
||||
if let Some(raw_third_party_invite) = &membership.third_party_invite {
|
||||
return check_third_party_invite(
|
||||
target_current_membership,
|
||||
raw_third_party_invite,
|
||||
target,
|
||||
event,
|
||||
fetch_state,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
// 4.2: If the sender’s current membership state is not join, reject.
|
||||
let sender_membership = fetch_membership(fetch_state, event.sender()).await?;
|
||||
if sender_membership.membership.is_none_or(|m| m != "join") {
|
||||
return Err(Error::AuthConditionFailed("invite sender is not joined".to_owned()));
|
||||
}
|
||||
|
||||
// 4.3: If target user’s current membership state is join or ban, reject.
|
||||
if target_current_membership
|
||||
.membership
|
||||
.is_some_and(|m| m == "join" || m == "ban")
|
||||
{
|
||||
return Err(Error::AuthConditionFailed(
|
||||
"invite target is already joined or banned".to_owned(),
|
||||
));
|
||||
}
|
||||
|
||||
// 4.4: If the sender’s power level is greater than or equal to the invite
|
||||
// level, allow.
|
||||
let (rank, pl, pl_evt) = get_rank(&room_version, fetch_state, event.sender()).await?;
|
||||
if rank == UserPower::Creator || pl >= pl_evt.unwrap_or_default().invite {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// 4.5: Otherwise, reject.
|
||||
Err(Error::AuthConditionFailed(
|
||||
"invite sender does not have sufficient power level to invite".to_owned(),
|
||||
))
|
||||
}
|
||||
|
||||
pub async fn check_member_event<FE, FS>(
|
||||
room_version: &RoomVersion,
|
||||
event: &Pdu,
|
||||
fetch_event: FE,
|
||||
fetch_state: FS,
|
||||
) -> Result<(), Error>
|
||||
where
|
||||
FE: AsyncFn(&EventId) -> Result<Option<Pdu>, Error>,
|
||||
FS: AsyncFn((StateEventType, StateKey)) -> Result<Option<Pdu>, Error>,
|
||||
{
|
||||
// 1. If there is no state_key property, or no membership property in content,
|
||||
// reject.
|
||||
if event.state_key.is_none() {
|
||||
return Err(Error::InvalidPdu("m.room.member event missing state_key".to_owned()));
|
||||
}
|
||||
|
||||
let target = UserId::parse(event.state_key().unwrap())
|
||||
.map_err(|_| Error::InvalidPdu("m.room.member event has invalid state_key".to_owned()))?
|
||||
.to_owned();
|
||||
let content = event
|
||||
.get_content::<PartialMembershipObject>()
|
||||
.map_err(|e| {
|
||||
Error::InvalidPdu(format!("m.room.member event has invalid content: {}", e))
|
||||
})?;
|
||||
|
||||
if content.membership.is_none() {
|
||||
return Err(Error::InvalidPdu(
|
||||
"m.room.member event missing membership in content".to_owned(),
|
||||
));
|
||||
}
|
||||
|
||||
// 2: If content has a join_authorised_via_users_server key
|
||||
//
|
||||
// 2.1: If the event is not validly signed by the homeserver of the user ID
|
||||
// denoted by the key, reject.
|
||||
if let Some(_join_auth) = &content.join_authorized_via_users_server {
|
||||
// We need to check the signature here, but don't have the means to do so yet.
|
||||
todo!("Implement join_authorised_via_users_server check");
|
||||
}
|
||||
|
||||
match content.membership.as_deref().unwrap() {
|
||||
| "join" =>
|
||||
check_join_event(room_version, event, &content, &target, &fetch_event, &fetch_state)
|
||||
.await?,
|
||||
| "invite" =>
|
||||
check_invite_event(room_version, event, &content, &target, &fetch_state).await?,
|
||||
| _ => {
|
||||
todo!()
|
||||
},
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
@@ -0,0 +1,6 @@
|
||||
pub mod auth_events;
|
||||
mod context;
|
||||
pub mod create_event;
|
||||
pub mod iterative_auth_checks;
|
||||
pub mod member_event;
|
||||
mod power_levels;
|
||||
@@ -0,0 +1,157 @@
|
||||
use ruma::{OwnedUserId, events::room::power_levels::RoomPowerLevelsEventContent};
|
||||
|
||||
use crate::{
|
||||
Event, Pdu, RoomVersion,
|
||||
state_res::{Error, event_auth::context::UserPower},
|
||||
};
|
||||
|
||||
/// Verifies that a m.room.power_levels event is well-formed according to the
|
||||
/// Matrix specification.
|
||||
///
|
||||
/// Creators must contain the m.room.create sender and any additional creators.
|
||||
pub async fn check_power_levels(
|
||||
room_version: &RoomVersion,
|
||||
event: &Pdu,
|
||||
current_power_levels: Option<&RoomPowerLevelsEventContent>,
|
||||
creators: Vec<OwnedUserId>,
|
||||
) -> Result<(), Error> {
|
||||
let content = event
|
||||
.get_content::<RoomPowerLevelsEventContent>()
|
||||
.map_err(|e| {
|
||||
Error::InvalidPdu(format!("m.room.power_levels event has invalid content: {}", e))
|
||||
})?;
|
||||
|
||||
// If any of the properties users_default, events_default, state_default, ban,
|
||||
// redact, kick, or invite in content are present and not an integer, reject.
|
||||
//
|
||||
// If either of the properties events or notifications in content are present
|
||||
// and not an object with values that are integers, reject.
|
||||
//
|
||||
// NOTE: Deserialisation fails if this is not the case, so we don't need to
|
||||
// check these here.
|
||||
|
||||
// If the users property in content is not an object with keys that are valid
|
||||
// user IDs with values that are integers (or a string that is an integer),
|
||||
// reject.
|
||||
while let Some(user_id) = content.users.keys().next() {
|
||||
// NOTE: Deserialisation fails if the power level is not an integer, so we don't
|
||||
// need to check that here.
|
||||
|
||||
if let Err(e) = user_id.validate_historical() {
|
||||
return Err(Error::InvalidPdu(format!(
|
||||
"m.room.power_levels event has invalid user ID in users map: {}",
|
||||
e
|
||||
)));
|
||||
}
|
||||
// Since v12, If the users property in content contains the sender of the
|
||||
// m.room.create event or any of the additional_creators array (if present)
|
||||
// from the content of the m.room.create event, reject.
|
||||
if room_version.explicitly_privilege_room_creators && creators.contains(user_id) {
|
||||
return Err(Error::InvalidPdu(
|
||||
"m.room.power_levels event users map contains a room creator".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// If there is no previous m.room.power_levels event in the room, allow.
|
||||
if current_power_levels.is_none() {
|
||||
return Ok(());
|
||||
}
|
||||
let current_power_levels = current_power_levels.unwrap();
|
||||
|
||||
// For the properties users_default, events_default, state_default, ban, redact,
|
||||
// kick, invite check if they were added, changed or removed. For each found
|
||||
// alteration:
|
||||
// If the current value is higher than the sender’s current power level, reject.
|
||||
// If the new value is higher than the sender’s current power level, reject.
|
||||
let sender = event.sender();
|
||||
let rank = if room_version.explicitly_privilege_room_creators {
|
||||
if creators.contains(&sender.to_owned()) {
|
||||
UserPower::Creator
|
||||
} else {
|
||||
UserPower::Standard
|
||||
}
|
||||
} else {
|
||||
UserPower::Standard
|
||||
};
|
||||
let sender_pl = current_power_levels
|
||||
.users
|
||||
.get(sender)
|
||||
.unwrap_or(¤t_power_levels.users_default);
|
||||
|
||||
if rank != UserPower::Creator {
|
||||
let checks = [
|
||||
("users_default", current_power_levels.users_default, content.users_default),
|
||||
("events_default", current_power_levels.events_default, content.events_default),
|
||||
("state_default", current_power_levels.state_default, content.state_default),
|
||||
("ban", current_power_levels.ban, content.ban),
|
||||
("redact", current_power_levels.redact, content.redact),
|
||||
("kick", current_power_levels.kick, content.kick),
|
||||
("invite", current_power_levels.invite, content.invite),
|
||||
];
|
||||
|
||||
for (name, old_value, new_value) in checks.iter() {
|
||||
if old_value != new_value {
|
||||
if *old_value > *sender_pl {
|
||||
return Err(Error::AuthConditionFailed(format!(
|
||||
"sender cannot change level for {}",
|
||||
name
|
||||
)));
|
||||
}
|
||||
if *new_value > *sender_pl {
|
||||
return Err(Error::AuthConditionFailed(format!(
|
||||
"sender cannot raise level for {} to {}",
|
||||
name, new_value
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// For each entry being changed in, or removed from, the events
|
||||
// property:
|
||||
// If the current value is greater than the sender’s current power level,
|
||||
// reject.
|
||||
for (event_type, new_value) in content.events.iter() {
|
||||
let old_value = current_power_levels.events.get(event_type);
|
||||
if old_value != Some(new_value) {
|
||||
let old_pl = old_value.unwrap_or(¤t_power_levels.events_default);
|
||||
if *old_pl > *sender_pl {
|
||||
return Err(Error::AuthConditionFailed(format!(
|
||||
"sender cannot change event level for {}",
|
||||
event_type
|
||||
)));
|
||||
}
|
||||
if *new_value > *sender_pl {
|
||||
return Err(Error::AuthConditionFailed(format!(
|
||||
"sender cannot raise event level for {} to {}",
|
||||
event_type, new_value
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// For each entry being changed in, or removed from, the events or
|
||||
// notifications properties:
|
||||
// If the current value is greater than the sender’s current power
|
||||
// level, reject.
|
||||
// If the new value is greater than the sender’s current power level,
|
||||
// reject.
|
||||
// TODO after making ruwuma's notifications value a BTreeMap
|
||||
|
||||
// For each entry being added to, or changed in, the users property:
|
||||
// If the new value is greater than the sender’s current power level, reject.
|
||||
for (user_id, new_value) in content.users.iter() {
|
||||
let old_value = current_power_levels.users.get(user_id);
|
||||
if old_value != Some(new_value) {
|
||||
if *new_value > *sender_pl {
|
||||
return Err(Error::AuthConditionFailed(format!(
|
||||
"sender cannot raise user level for {} to {}",
|
||||
user_id, new_value
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
+149
-175
@@ -8,9 +8,6 @@ mod room_version;
|
||||
#[cfg(test)]
|
||||
mod test_utils;
|
||||
|
||||
#[cfg(test)]
|
||||
mod benches;
|
||||
|
||||
use std::{
|
||||
borrow::Borrow,
|
||||
cmp::{Ordering, Reverse},
|
||||
@@ -18,30 +15,31 @@ use std::{
|
||||
hash::{BuildHasher, Hash},
|
||||
};
|
||||
|
||||
use futures::{Future, FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt, future};
|
||||
use futures::{Future, FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt};
|
||||
use itertools::Itertools;
|
||||
use ruma::{
|
||||
EventId, Int, MilliSecondsSinceUnixEpoch, OwnedEventId, RoomVersionId,
|
||||
events::{
|
||||
StateEventType, TimelineEventType,
|
||||
room::member::{MembershipState, RoomMemberEventContent},
|
||||
},
|
||||
int,
|
||||
room::member::{MembershipState, RoomMemberEventContent}, StateEventType,
|
||||
TimelineEventType,
|
||||
}, int, EventId, Int, MilliSecondsSinceUnixEpoch,
|
||||
OwnedEventId,
|
||||
RoomVersionId,
|
||||
};
|
||||
use serde_json::from_str as from_json_str;
|
||||
|
||||
pub(crate) use self::error::{Error, InvalidPduSnafu, NotFoundSnafu};
|
||||
pub(crate) use self::error::Error;
|
||||
use self::power_levels::PowerLevelsContentFields;
|
||||
pub use self::{
|
||||
event_auth::{auth_check, auth_types_for_event},
|
||||
room_version::RoomVersion,
|
||||
};
|
||||
use super::{Event, StateKey};
|
||||
pub use self::{event_auth::iterative_auth_checks::auth_check, room_version::RoomVersion};
|
||||
use crate::utils::TryFutureExtExt;
|
||||
use crate::{
|
||||
debug, debug_error,
|
||||
state_res::room_version::StateResolutionVersion,
|
||||
debug, err, error as log_error, matrix::{Event, StateKey},
|
||||
state_res::{
|
||||
event_auth::auth_events::auth_types_for_event, room_version::StateResolutionVersion,
|
||||
},
|
||||
trace,
|
||||
utils::stream::{BroadbandExt, IterStream, ReadyExt, TryBroadbandExt},
|
||||
utils::stream::{BroadbandExt, IterStream, ReadyExt, TryBroadbandExt, WidebandExt},
|
||||
warn,
|
||||
Pdu,
|
||||
};
|
||||
|
||||
/// A mapping of event type and state_key to some value `T`, usually an
|
||||
@@ -75,23 +73,20 @@ type Result<T, E = Error> = crate::Result<T, E>;
|
||||
/// event is part of the same room.
|
||||
//#[tracing::instrument(level = "debug", skip(state_sets, auth_chain_sets,
|
||||
//#[tracing::instrument(level event_fetch))]
|
||||
pub async fn resolve<'a, Pdu, Sets, SetIter, Hasher, Fetch, FetchFut, Exists, ExistsFut>(
|
||||
pub async fn resolve<'a, Sets, SetIter, Hasher, FE, FR, Exists>(
|
||||
room_version: &RoomVersionId,
|
||||
state_sets: Sets,
|
||||
auth_chain_sets: &'a [HashSet<OwnedEventId, Hasher>],
|
||||
event_fetch: &Fetch,
|
||||
event_fetch: &FE,
|
||||
event_exists: &Exists,
|
||||
) -> Result<StateMap<OwnedEventId>>
|
||||
where
|
||||
Fetch: Fn(OwnedEventId) -> FetchFut + Sync,
|
||||
FetchFut: Future<Output = Option<Pdu>> + Send,
|
||||
Exists: Fn(OwnedEventId) -> ExistsFut + Sync,
|
||||
ExistsFut: Future<Output = bool> + Send,
|
||||
FE: Fn(&EventId) -> FR + Sync,
|
||||
FR: Future<Output = Result<Option<Pdu>, Error>> + Send,
|
||||
Exists: AsyncFn(OwnedEventId) -> bool + Sync,
|
||||
Sets: IntoIterator<IntoIter = SetIter> + Send,
|
||||
SetIter: Iterator<Item = &'a StateMap<OwnedEventId>> + Clone + Send,
|
||||
Hasher: BuildHasher + Send + Sync,
|
||||
Pdu: Event + Clone + Send + Sync,
|
||||
for<'b> &'b Pdu: Event + Send,
|
||||
{
|
||||
use RoomVersionId::*;
|
||||
let stateres_version = match room_version {
|
||||
@@ -118,10 +113,7 @@ where
|
||||
let csg = calculate_conflicted_subgraph(&conflicting, event_fetch)
|
||||
.await
|
||||
.ok_or_else(|| {
|
||||
InvalidPduSnafu {
|
||||
message: "Failed to calculate conflicted subgraph",
|
||||
}
|
||||
.build()
|
||||
Error::InvalidPdu("Failed to calculate conflicted subgraph".to_owned())
|
||||
})?;
|
||||
debug!(count = csg.len(), "conflicted subgraph");
|
||||
trace!(set = ?csg, "conflicted subgraph");
|
||||
@@ -152,11 +144,10 @@ where
|
||||
let control_events: Vec<_> = all_conflicted
|
||||
.iter()
|
||||
.stream()
|
||||
.broad_filter_map(async |id| {
|
||||
event_fetch(id.clone())
|
||||
.wide_filter_map(async |id| {
|
||||
is_power_event_id(id, &event_fetch)
|
||||
.await
|
||||
.filter(|event| is_power_event(&event))
|
||||
.map(|_| id.clone())
|
||||
.then_some(id.clone())
|
||||
})
|
||||
.collect()
|
||||
.await;
|
||||
@@ -173,7 +164,7 @@ where
|
||||
// Sequentially auth check each control event.
|
||||
let resolved_control = iterative_auth_check(
|
||||
&room_version,
|
||||
sorted_control_levels.iter().stream().map(AsRef::as_ref),
|
||||
sorted_control_levels.iter().stream().map(ToOwned::to_owned),
|
||||
initial_state,
|
||||
&event_fetch,
|
||||
)
|
||||
@@ -213,7 +204,7 @@ where
|
||||
|
||||
let mut resolved_state = iterative_auth_check(
|
||||
&room_version,
|
||||
sorted_left_events.iter().stream().map(AsRef::as_ref),
|
||||
sorted_left_events.iter().stream(),
|
||||
resolved_control, // The control events are added to the final resolved state
|
||||
&event_fetch,
|
||||
)
|
||||
@@ -277,14 +268,12 @@ where
|
||||
}
|
||||
|
||||
/// Calculate the conflicted subgraph
|
||||
async fn calculate_conflicted_subgraph<F, Fut, E>(
|
||||
async fn calculate_conflicted_subgraph<FE>(
|
||||
conflicted: &StateMap<Vec<OwnedEventId>>,
|
||||
fetch_event: &F,
|
||||
fetch_event: &FE,
|
||||
) -> Option<HashSet<OwnedEventId>>
|
||||
where
|
||||
F: Fn(OwnedEventId) -> Fut + Sync,
|
||||
Fut: Future<Output = Option<E>> + Send,
|
||||
E: Event + Send + Sync,
|
||||
FE: AsyncFn(OwnedEventId) -> Result<Option<Pdu>> + Sync,
|
||||
{
|
||||
let conflicted_events: HashSet<_> = conflicted.values().flatten().cloned().collect();
|
||||
let mut subgraph: HashSet<OwnedEventId> = HashSet::new();
|
||||
@@ -316,12 +305,19 @@ where
|
||||
continue;
|
||||
}
|
||||
trace!(event_id = event_id.as_str(), "fetching event for its auth events");
|
||||
let evt = fetch_event(event_id.clone()).await;
|
||||
let evt = fetch_event(event_id.clone())
|
||||
.await
|
||||
.inspect_err(|e| {
|
||||
log_error!(
|
||||
"error fetching event {} for conflicted state subgraph: {}",
|
||||
event_id,
|
||||
e
|
||||
)
|
||||
})
|
||||
.ok()
|
||||
.flatten();
|
||||
if evt.is_none() {
|
||||
tracing::error!(
|
||||
"could not fetch event {} to calculate conflicted subgraph",
|
||||
event_id
|
||||
);
|
||||
err!("could not fetch event {} to calculate conflicted subgraph", event_id);
|
||||
path.pop();
|
||||
continue;
|
||||
}
|
||||
@@ -366,15 +362,14 @@ where
|
||||
/// The power level is negative because a higher power level is equated to an
|
||||
/// earlier (further back in time) origin server timestamp.
|
||||
#[tracing::instrument(level = "debug", skip_all)]
|
||||
async fn reverse_topological_power_sort<E, F, Fut>(
|
||||
async fn reverse_topological_power_sort<FE, FR>(
|
||||
events_to_sort: Vec<OwnedEventId>,
|
||||
auth_diff: &HashSet<OwnedEventId>,
|
||||
fetch_event: &F,
|
||||
fetch_event: &FE,
|
||||
) -> Result<Vec<OwnedEventId>>
|
||||
where
|
||||
F: Fn(OwnedEventId) -> Fut + Sync,
|
||||
Fut: Future<Output = Option<E>> + Send,
|
||||
E: Event + Send + Sync,
|
||||
FE: Fn(&EventId) -> FR + Sync,
|
||||
FR: Future<Output = Result<Option<Pdu>, Error>> + Send,
|
||||
{
|
||||
debug!("reverse topological sort of power events");
|
||||
|
||||
@@ -409,11 +404,11 @@ where
|
||||
let fetcher = async |event_id: OwnedEventId| {
|
||||
let pl = *event_to_pl
|
||||
.get(&event_id)
|
||||
.ok_or_else(|| NotFoundSnafu { message: "" }.build())?;
|
||||
.ok_or_else(|| Error::NotFound(String::new()))?;
|
||||
|
||||
let ev = fetch_event(event_id)
|
||||
.await
|
||||
.ok_or_else(|| NotFoundSnafu { message: "" }.build())?;
|
||||
let ev = fetch_event(&event_id)
|
||||
.await?
|
||||
.ok_or_else(|| Error::NotFound(String::new()))?;
|
||||
|
||||
Ok((pl, ev.origin_server_ts()))
|
||||
};
|
||||
@@ -551,18 +546,14 @@ where
|
||||
/// Do NOT use this any where but topological sort, we find the power level for
|
||||
/// the eventId at the eventId's generation (we walk backwards to `EventId`s
|
||||
/// most recent previous power level event).
|
||||
async fn get_power_level_for_sender<E, F, Fut>(
|
||||
event_id: &EventId,
|
||||
fetch_event: &F,
|
||||
) -> serde_json::Result<Int>
|
||||
async fn get_power_level_for_sender<FE, FR>(event_id: &EventId, fetch_event: &FE) -> Result<Int>
|
||||
where
|
||||
F: Fn(OwnedEventId) -> Fut + Sync,
|
||||
Fut: Future<Output = Option<E>> + Send,
|
||||
E: Event + Send,
|
||||
FE: Fn(&EventId) -> FR + Sync,
|
||||
FR: Future<Output = Result<Option<Pdu>, Error>> + Send,
|
||||
{
|
||||
debug!("fetch event ({event_id}) senders power level");
|
||||
|
||||
let event = fetch_event(event_id.to_owned()).await;
|
||||
let event = fetch_event(event_id).await?;
|
||||
|
||||
let auth_events = event.as_ref().map(Event::auth_events);
|
||||
|
||||
@@ -570,7 +561,7 @@ where
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.stream()
|
||||
.broadn_filter_map(5, |aid| fetch_event(aid.to_owned()))
|
||||
.broad_filter_map(|aid| fetch_event(aid).unwrap_or_default())
|
||||
.ready_find(|aev| is_type_and_key(aev, &TimelineEventType::RoomPowerLevels, ""))
|
||||
.await;
|
||||
|
||||
@@ -601,30 +592,24 @@ where
|
||||
/// the the `fetch_event` closure and verify each event using the
|
||||
/// `event_auth::auth_check` function.
|
||||
#[tracing::instrument(level = "trace", skip_all)]
|
||||
async fn iterative_auth_check<'a, E, F, Fut, S>(
|
||||
async fn iterative_auth_check<FE, FR, S>(
|
||||
room_version: &RoomVersion,
|
||||
events_to_check: S,
|
||||
unconflicted_state: StateMap<OwnedEventId>,
|
||||
fetch_event: &F,
|
||||
fetch_event: &FE,
|
||||
) -> Result<StateMap<OwnedEventId>>
|
||||
where
|
||||
F: Fn(OwnedEventId) -> Fut + Sync,
|
||||
Fut: Future<Output = Option<E>> + Send,
|
||||
S: Stream<Item = &'a EventId> + Send + 'a,
|
||||
E: Event + Clone + Send + Sync,
|
||||
for<'b> &'b E: Event + Send,
|
||||
FE: Fn(&EventId) -> FR,
|
||||
FR: Future<Output = Result<Option<Pdu>, Error>> + Send + Sync,
|
||||
S: Stream<Item = OwnedEventId> + Send,
|
||||
{
|
||||
debug!("starting iterative auth check");
|
||||
|
||||
let events_to_check: Vec<_> = events_to_check
|
||||
.map(Result::Ok)
|
||||
.broad_and_then(async |event_id| {
|
||||
fetch_event(event_id.to_owned()).await.ok_or_else(|| {
|
||||
NotFoundSnafu {
|
||||
message: format!("Failed to find {event_id}"),
|
||||
}
|
||||
.build()
|
||||
})
|
||||
.map(Ok::<OwnedEventId, Error>)
|
||||
.broad_and_then(async |event_id| match fetch_event(&event_id).await {
|
||||
| Ok(Some(e)) => Ok(e),
|
||||
| _ => Err(Error::NotFound(format!("could not find {event_id}")))?,
|
||||
})
|
||||
.try_collect()
|
||||
.boxed()
|
||||
@@ -637,16 +622,20 @@ where
|
||||
|
||||
let auth_event_ids: HashSet<OwnedEventId> = events_to_check
|
||||
.iter()
|
||||
.flat_map(|event: &E| event.auth_events().map(ToOwned::to_owned))
|
||||
.flat_map(|event: &Pdu| event.auth_events().map(ToOwned::to_owned))
|
||||
.collect();
|
||||
|
||||
trace!(set = ?auth_event_ids, "auth event IDs to fetch");
|
||||
|
||||
let auth_events: HashMap<OwnedEventId, E> = auth_event_ids
|
||||
let auth_events: HashMap<OwnedEventId, Pdu> = auth_event_ids
|
||||
.into_iter()
|
||||
.stream()
|
||||
.broad_filter_map(fetch_event)
|
||||
.map(|auth_event| (auth_event.event_id().to_owned(), auth_event))
|
||||
.broad_filter_map(async |event_id| {
|
||||
fetch_event(&event_id)
|
||||
.await
|
||||
.map(|ev_opt| ev_opt.map(|ev| (event_id.clone(), ev)))
|
||||
.unwrap_or_default()
|
||||
})
|
||||
.collect()
|
||||
.boxed()
|
||||
.await;
|
||||
@@ -663,32 +652,25 @@ where
|
||||
trace!(event_id = event.event_id().as_str(), "checking event");
|
||||
let state_key = event
|
||||
.state_key()
|
||||
.ok_or_else(|| InvalidPduSnafu { message: "State event had no state key" }.build())?;
|
||||
.ok_or_else(|| Error::InvalidPdu("State event had no state key".to_owned()))?;
|
||||
|
||||
let member_event_content = match event.kind() {
|
||||
| TimelineEventType::RoomMember =>
|
||||
Some(event.get_content::<RoomMemberEventContent>().map_err(|e| {
|
||||
Error::InvalidPdu(format!("Failed to parse m.room.member content: {}", e))
|
||||
})?),
|
||||
| _ => None,
|
||||
};
|
||||
let auth_types = auth_types_for_event(
|
||||
event.event_type(),
|
||||
event.sender(),
|
||||
Some(state_key),
|
||||
event.content(),
|
||||
room_version,
|
||||
event.kind(),
|
||||
event.state_key().map(StateKey::from_str).as_ref(),
|
||||
event.sender(),
|
||||
member_event_content,
|
||||
)?;
|
||||
trace!(list = ?auth_types, event_id = event.event_id().as_str(), "auth types for event");
|
||||
|
||||
let mut auth_state = StateMap::new();
|
||||
if room_version.room_ids_as_hashes {
|
||||
trace!("room version uses hashed IDs, manually fetching create event");
|
||||
let create_event_id_raw = event.room_id_or_hash().as_str().replace('!', "$");
|
||||
let create_event_id = EventId::parse(&create_event_id_raw).map_err(|e| {
|
||||
InvalidPduSnafu {
|
||||
message: format!("Failed to parse create event ID from room ID/hash: {e}"),
|
||||
}
|
||||
.build()
|
||||
})?;
|
||||
let create_event = fetch_event(create_event_id.into()).await.ok_or_else(|| {
|
||||
NotFoundSnafu { message: "Failed to find create event" }.build()
|
||||
})?;
|
||||
auth_state.insert(create_event.event_type().with_state_key(""), create_event);
|
||||
}
|
||||
let mut auth_state = StateMap::with_capacity(event.auth_events.len());
|
||||
for aid in event.auth_events() {
|
||||
if let Some(ev) = auth_events.get(aid) {
|
||||
//TODO: synapse checks "rejected_reason" which is most likely related to
|
||||
@@ -697,7 +679,7 @@ where
|
||||
auth_state.insert(
|
||||
ev.event_type()
|
||||
.with_state_key(ev.state_key().ok_or_else(|| {
|
||||
InvalidPduSnafu { message: "State event had no state key" }.build()
|
||||
Error::InvalidPdu("State event had no state key".to_owned())
|
||||
})?),
|
||||
ev.clone(),
|
||||
);
|
||||
@@ -714,7 +696,13 @@ where
|
||||
if let Some(event) = auth_events.get(ev_id) {
|
||||
Some((key, event.clone()))
|
||||
} else {
|
||||
Some((key, fetch_event(ev_id.clone()).await?))
|
||||
match fetch_event(ev_id).await {
|
||||
| Ok(Some(event)) => Some((key, event)),
|
||||
| _ => {
|
||||
warn!(event_id = ev_id.as_str(), "unable to fetch auth event");
|
||||
None
|
||||
},
|
||||
}
|
||||
}
|
||||
})
|
||||
.ready_for_each(|(key, event)| {
|
||||
@@ -726,30 +714,16 @@ where
|
||||
|
||||
debug!(event_id = event.event_id().as_str(), "Running auth checks");
|
||||
|
||||
// The key for this is (eventType + a state_key of the signed token not sender)
|
||||
// so search for it
|
||||
let current_third_party = auth_state.iter().find_map(|(_, pdu)| {
|
||||
(*pdu.event_type() == TimelineEventType::RoomThirdPartyInvite).then_some(pdu)
|
||||
});
|
||||
|
||||
let fetch_state = |ty: &StateEventType, key: &str| {
|
||||
future::ready(
|
||||
auth_state
|
||||
.get(&ty.with_state_key(key))
|
||||
.map(ToOwned::to_owned),
|
||||
)
|
||||
let fetch_state = async |t: (StateEventType, StateKey)| {
|
||||
Ok(auth_state
|
||||
.get(&t.0.with_state_key(t.1.as_str()))
|
||||
.map(ToOwned::to_owned))
|
||||
};
|
||||
|
||||
let auth_result = auth_check(
|
||||
room_version,
|
||||
&event,
|
||||
current_third_party,
|
||||
fetch_state,
|
||||
&fetch_state(&StateEventType::RoomCreate, "")
|
||||
.await
|
||||
.expect("create event must exist"),
|
||||
)
|
||||
.await;
|
||||
let create_event = fetch_state((StateEventType::RoomCreate, StateKey::new())).await?;
|
||||
let auth_result =
|
||||
auth_check(room_version, &event, fetch_event, &fetch_state, create_event.as_ref())
|
||||
.await;
|
||||
|
||||
match auth_result {
|
||||
| Ok(true) => {
|
||||
@@ -769,7 +743,7 @@ where
|
||||
warn!("event {} failed the authentication check", event.event_id());
|
||||
},
|
||||
| Err(e) => {
|
||||
debug_error!("event {} failed the authentication check: {e}", event.event_id());
|
||||
log_error!("event {} failed the authentication check: {e}", event.event_id());
|
||||
return Err(e);
|
||||
},
|
||||
}
|
||||
@@ -788,15 +762,14 @@ where
|
||||
/// after the most recent are depth 0, the events before (with the first power
|
||||
/// level as a parent) will be marked as depth 1. depth 1 is "older" than depth
|
||||
/// 0.
|
||||
async fn mainline_sort<E, F, Fut>(
|
||||
async fn mainline_sort<FE, FR>(
|
||||
to_sort: &[OwnedEventId],
|
||||
resolved_power_level: Option<OwnedEventId>,
|
||||
fetch_event: &F,
|
||||
fetch_event: &FE,
|
||||
) -> Result<Vec<OwnedEventId>>
|
||||
where
|
||||
F: Fn(OwnedEventId) -> Fut + Sync,
|
||||
Fut: Future<Output = Option<E>> + Send,
|
||||
E: Event + Clone + Send + Sync,
|
||||
FE: Fn(&EventId) -> FR + Sync,
|
||||
FR: Future<Output = Result<Option<Pdu>, Error>> + Send,
|
||||
{
|
||||
debug!("mainline sort of events");
|
||||
|
||||
@@ -810,15 +783,15 @@ where
|
||||
while let Some(p) = pl {
|
||||
mainline.push(p.clone());
|
||||
|
||||
let event = fetch_event(p.clone())
|
||||
.await
|
||||
.ok_or_else(|| NotFoundSnafu { message: format!("Failed to find {p}") }.build())?;
|
||||
let event = fetch_event(&p)
|
||||
.await?
|
||||
.ok_or_else(|| Error::NotFound(format!("Failed to find {p}")))?;
|
||||
|
||||
pl = None;
|
||||
for aid in event.auth_events() {
|
||||
let ev = fetch_event(aid.to_owned()).await.ok_or_else(|| {
|
||||
NotFoundSnafu { message: format!("Failed to find {aid}") }.build()
|
||||
})?;
|
||||
let ev = fetch_event(aid)
|
||||
.await?
|
||||
.ok_or_else(|| Error::NotFound(format!("Failed to find {aid}")))?;
|
||||
|
||||
if is_type_and_key(&ev, &TimelineEventType::RoomPowerLevels, "") {
|
||||
pl = Some(aid.to_owned());
|
||||
@@ -838,7 +811,11 @@ where
|
||||
.iter()
|
||||
.stream()
|
||||
.broad_filter_map(async |ev_id| {
|
||||
fetch_event(ev_id.clone()).await.map(|event| (event, ev_id))
|
||||
fetch_event(ev_id)
|
||||
.await
|
||||
.ok()
|
||||
.flatten()
|
||||
.map(|event| (event, ev_id))
|
||||
})
|
||||
.broad_filter_map(|(event, ev_id)| {
|
||||
get_mainline_depth(Some(event.clone()), &mainline_map, fetch_event)
|
||||
@@ -860,15 +837,14 @@ where
|
||||
|
||||
/// Get the mainline depth from the `mainline_map` or finds a power_level event
|
||||
/// that has an associated mainline depth.
|
||||
async fn get_mainline_depth<E, F, Fut>(
|
||||
mut event: Option<E>,
|
||||
async fn get_mainline_depth<FE, FR>(
|
||||
mut event: Option<Pdu>,
|
||||
mainline_map: &HashMap<OwnedEventId, usize>,
|
||||
fetch_event: &F,
|
||||
fetch_event: &FE,
|
||||
) -> Result<usize>
|
||||
where
|
||||
F: Fn(OwnedEventId) -> Fut + Sync,
|
||||
Fut: Future<Output = Option<E>> + Send,
|
||||
E: Event + Send + Sync,
|
||||
FE: Fn(&EventId) -> FR + Sync,
|
||||
FR: Future<Output = Result<Option<Pdu>, Error>> + Send,
|
||||
{
|
||||
while let Some(sort_ev) = event {
|
||||
debug!(event_id = sort_ev.event_id().as_str(), "mainline");
|
||||
@@ -880,9 +856,9 @@ where
|
||||
|
||||
event = None;
|
||||
for aid in sort_ev.auth_events() {
|
||||
let aev = fetch_event(aid.to_owned()).await.ok_or_else(|| {
|
||||
NotFoundSnafu { message: format!("Failed to find {aid}") }.build()
|
||||
})?;
|
||||
let aev = fetch_event(aid)
|
||||
.await?
|
||||
.ok_or_else(|| Error::NotFound(format!("Failed to find {aid}")))?;
|
||||
|
||||
if is_type_and_key(&aev, &TimelineEventType::RoomPowerLevels, "") {
|
||||
event = Some(aev);
|
||||
@@ -894,20 +870,19 @@ where
|
||||
Ok(0)
|
||||
}
|
||||
|
||||
async fn add_event_and_auth_chain_to_graph<E, F, Fut>(
|
||||
async fn add_event_and_auth_chain_to_graph<FE, FR>(
|
||||
graph: &mut HashMap<OwnedEventId, HashSet<OwnedEventId>>,
|
||||
event_id: OwnedEventId,
|
||||
auth_diff: &HashSet<OwnedEventId>,
|
||||
fetch_event: &F,
|
||||
fetch_event: &FE,
|
||||
) where
|
||||
F: Fn(OwnedEventId) -> Fut + Sync,
|
||||
Fut: Future<Output = Option<E>> + Send,
|
||||
E: Event + Send + Sync,
|
||||
FE: Fn(&EventId) -> FR + Sync,
|
||||
FR: Future<Output = Result<Option<Pdu>, Error>> + Send,
|
||||
{
|
||||
let mut state = vec![event_id];
|
||||
while let Some(eid) = state.pop() {
|
||||
graph.entry(eid.clone()).or_default();
|
||||
let event = fetch_event(eid.clone()).await;
|
||||
let event = fetch_event(&eid).await.ok().flatten();
|
||||
let auth_events = event.as_ref().map(Event::auth_events).into_iter().flatten();
|
||||
|
||||
// Prefer the store to event as the store filters dedups the events
|
||||
@@ -926,15 +901,13 @@ async fn add_event_and_auth_chain_to_graph<E, F, Fut>(
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
async fn is_power_event_id<E, F, Fut>(event_id: &EventId, fetch: &F) -> bool
|
||||
async fn is_power_event_id<FE, FR>(event_id: &EventId, fetch: &FE) -> bool
|
||||
where
|
||||
F: Fn(OwnedEventId) -> Fut + Sync,
|
||||
Fut: Future<Output = Option<E>> + Send,
|
||||
E: Event + Send,
|
||||
FE: Fn(&EventId) -> FR + Sync,
|
||||
FR: Future<Output = Result<Option<Pdu>, Error>> + Send,
|
||||
{
|
||||
match fetch(event_id.to_owned()).await.as_ref() {
|
||||
| Some(state) => is_power_event(state),
|
||||
match fetch(event_id).await.as_ref() {
|
||||
| Ok(Some(state)) => is_power_event(state),
|
||||
| _ => false,
|
||||
}
|
||||
}
|
||||
@@ -991,26 +964,27 @@ where
|
||||
mod tests {
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
use itertools::Itertools;
|
||||
use maplit::{hashmap, hashset};
|
||||
use rand::seq::SliceRandom;
|
||||
use ruma::{
|
||||
MilliSecondsSinceUnixEpoch, OwnedEventId, RoomVersionId,
|
||||
events::{
|
||||
StateEventType, TimelineEventType,
|
||||
room::join_rules::{JoinRule, RoomJoinRulesEventContent},
|
||||
},
|
||||
int, uint,
|
||||
room::join_rules::{JoinRule, RoomJoinRulesEventContent}, StateEventType,
|
||||
TimelineEventType,
|
||||
}, int, uint,
|
||||
MilliSecondsSinceUnixEpoch,
|
||||
OwnedEventId, RoomVersionId,
|
||||
};
|
||||
use serde_json::{json, value::to_raw_value as to_raw_json_value};
|
||||
|
||||
use super::{
|
||||
StateMap, is_power_event,
|
||||
room_version::RoomVersion,
|
||||
is_power_event, room_version::RoomVersion,
|
||||
test_utils::{
|
||||
INITIAL_EVENTS, TestStore, alice, bob, charlie, do_check, ella, event_id,
|
||||
member_content_ban, member_content_join, room_id, to_init_pdu_event, to_pdu_event,
|
||||
zara,
|
||||
alice, bob, charlie, do_check, ella, event_id, member_content_ban, member_content_join,
|
||||
room_id, to_init_pdu_event, to_pdu_event, zara, TestStore,
|
||||
INITIAL_EVENTS,
|
||||
},
|
||||
StateMap,
|
||||
};
|
||||
use crate::{
|
||||
debug,
|
||||
@@ -1040,13 +1014,13 @@ mod tests {
|
||||
.map(|pdu| pdu.event_id.clone())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let fetcher = |id| ready(events.get(&id).cloned());
|
||||
let fetcher = |id| ready(Ok(events.get(id).cloned()));
|
||||
let sorted_power_events =
|
||||
super::reverse_topological_power_sort(power_events, &auth_chain, &fetcher)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let resolved_power = super::iterative_auth_check(
|
||||
let resolved_power = super::auth_check(
|
||||
&RoomVersion::V6,
|
||||
sorted_power_events.iter().map(AsRef::as_ref).stream(),
|
||||
HashMap::new(), // unconflicted events
|
||||
@@ -1058,7 +1032,7 @@ mod tests {
|
||||
// don't remove any events so we know it sorts them all correctly
|
||||
let mut events_to_sort = events.keys().cloned().collect::<Vec<_>>();
|
||||
|
||||
events_to_sort.shuffle(&mut rand::rng());
|
||||
events_to_sort.shuffle(&mut rand::thread_rng());
|
||||
|
||||
let power_level = resolved_power
|
||||
.get(&(StateEventType::RoomPowerLevels, "".into()))
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use ruma::RoomVersionId;
|
||||
|
||||
use super::{Result, error::UnsupportedSnafu};
|
||||
use super::{Error, Result};
|
||||
|
||||
#[derive(Debug)]
|
||||
#[allow(clippy::exhaustive_enums)]
|
||||
@@ -163,11 +163,7 @@ impl RoomVersion {
|
||||
| RoomVersionId::V10 => Self::V10,
|
||||
| RoomVersionId::V11 => Self::V11,
|
||||
| RoomVersionId::V12 => Self::V12,
|
||||
| ver =>
|
||||
return Err(UnsupportedSnafu {
|
||||
version: format!("found version `{ver}`"),
|
||||
}
|
||||
.build()),
|
||||
| ver => return Err(Error::Unsupported(format!("found version `{ver}`"))),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,7 +22,7 @@ use serde_json::{
|
||||
value::{RawValue as RawJsonValue, to_raw_value as to_raw_json_value},
|
||||
};
|
||||
|
||||
use super::{auth_types_for_event, error::NotFoundSnafu};
|
||||
use super::auth_types_for_event;
|
||||
use crate::{
|
||||
Result, RoomVersion, info,
|
||||
matrix::{Event, EventTypeExt, Pdu, StateMap, pdu::EventHash},
|
||||
@@ -232,7 +232,7 @@ impl<E: Event + Clone> TestStore<E> {
|
||||
self.0
|
||||
.get(event_id)
|
||||
.cloned()
|
||||
.ok_or_else(|| NotFoundSnafu { message: format!("{event_id} not found") }.build())
|
||||
.ok_or_else(|| super::Error::NotFound(format!("{event_id} not found")))
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
|
||||
@@ -14,11 +14,9 @@ pub mod utils;
|
||||
|
||||
pub use ::arrayvec;
|
||||
pub use ::http;
|
||||
pub use ::paste;
|
||||
pub use ::ruma;
|
||||
pub use ::smallstr;
|
||||
pub use ::smallvec;
|
||||
pub use ::snafu;
|
||||
pub use ::toml;
|
||||
pub use ::tracing;
|
||||
pub use config::Config;
|
||||
|
||||
@@ -28,7 +28,7 @@ fn init_argon() -> Argon2<'static> {
|
||||
}
|
||||
|
||||
pub(super) fn password(password: &str) -> Result<String> {
|
||||
let salt = SaltString::generate(rand_core::OsRng);
|
||||
let salt = SaltString::generate(rand::thread_rng());
|
||||
ARGON
|
||||
.get_or_init(init_argon)
|
||||
.hash_password(password.as_bytes(), &salt)
|
||||
|
||||
+10
-7
@@ -4,16 +4,16 @@ use std::{
|
||||
};
|
||||
|
||||
use arrayvec::ArrayString;
|
||||
use rand::{RngExt, seq::SliceRandom};
|
||||
use rand::{Rng, seq::SliceRandom, thread_rng};
|
||||
|
||||
pub fn shuffle<T>(vec: &mut [T]) {
|
||||
let mut rng = rand::rng();
|
||||
let mut rng = thread_rng();
|
||||
vec.shuffle(&mut rng);
|
||||
}
|
||||
|
||||
pub fn string(length: usize) -> String {
|
||||
rand::rng()
|
||||
.sample_iter(&rand::distr::Alphanumeric)
|
||||
thread_rng()
|
||||
.sample_iter(&rand::distributions::Alphanumeric)
|
||||
.take(length)
|
||||
.map(char::from)
|
||||
.collect()
|
||||
@@ -22,8 +22,8 @@ pub fn string(length: usize) -> String {
|
||||
#[inline]
|
||||
pub fn string_array<const LENGTH: usize>() -> ArrayString<LENGTH> {
|
||||
let mut ret = ArrayString::<LENGTH>::new();
|
||||
rand::rng()
|
||||
.sample_iter(&rand::distr::Alphanumeric)
|
||||
thread_rng()
|
||||
.sample_iter(&rand::distributions::Alphanumeric)
|
||||
.take(LENGTH)
|
||||
.map(char::from)
|
||||
.for_each(|c| ret.push(c));
|
||||
@@ -40,4 +40,7 @@ pub fn time_from_now_secs(range: Range<u64>) -> SystemTime {
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn secs(range: Range<u64>) -> Duration { Duration::from_secs(rand::random_range(range)) }
|
||||
pub fn secs(range: Range<u64>) -> Duration {
|
||||
let mut rng = thread_rng();
|
||||
Duration::from_secs(rng.gen_range(range))
|
||||
}
|
||||
|
||||
@@ -2,8 +2,6 @@
|
||||
|
||||
use std::{cell::Cell, fmt::Debug, path::PathBuf, sync::LazyLock};
|
||||
|
||||
use snafu::IntoError;
|
||||
|
||||
use crate::{Result, is_equal_to};
|
||||
|
||||
type Id = usize;
|
||||
@@ -144,9 +142,7 @@ pub fn getcpu() -> Result<usize> {
|
||||
|
||||
#[cfg(not(target_os = "linux"))]
|
||||
#[inline]
|
||||
pub fn getcpu() -> Result<usize> {
|
||||
Err(crate::error::IoSnafu.into_error(std::io::ErrorKind::Unsupported.into()))
|
||||
}
|
||||
pub fn getcpu() -> Result<usize> { Err(crate::Error::Io(std::io::ErrorKind::Unsupported.into())) }
|
||||
|
||||
fn query_cores_available() -> impl Iterator<Item = Id> {
|
||||
core_affinity::get_core_ids()
|
||||
|
||||
+7
-12
@@ -255,10 +255,7 @@ impl<'a, 'de: 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
|
||||
| "$serde_json::private::RawValue" => visitor.visit_map(self),
|
||||
| "Cbor" => visitor
|
||||
.visit_newtype_struct(&mut minicbor_serde::Deserializer::new(self.record_trail()))
|
||||
.map_err(|e| {
|
||||
let message: std::borrow::Cow<'static, str> = e.to_string().into();
|
||||
conduwuit_core::error::SerdeDeSnafu { message }.build()
|
||||
}),
|
||||
.map_err(|e| Self::Error::SerdeDe(e.to_string().into())),
|
||||
|
||||
| _ => visitor.visit_newtype_struct(self),
|
||||
}
|
||||
@@ -316,10 +313,9 @@ impl<'a, 'de: 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
|
||||
|
||||
let end = self.pos.saturating_add(BYTES).min(self.buf.len());
|
||||
let bytes: ArrayVec<u8, BYTES> = self.buf[self.pos..end].try_into()?;
|
||||
let bytes = bytes.into_inner().map_err(|_| {
|
||||
let message: std::borrow::Cow<'static, str> = "i64 buffer underflow".into();
|
||||
conduwuit_core::error::SerdeDeSnafu { message }.build()
|
||||
})?;
|
||||
let bytes = bytes
|
||||
.into_inner()
|
||||
.map_err(|_| Self::Error::SerdeDe("i64 buffer underflow".into()))?;
|
||||
|
||||
self.inc_pos(BYTES);
|
||||
visitor.visit_i64(i64::from_be_bytes(bytes))
|
||||
@@ -349,10 +345,9 @@ impl<'a, 'de: 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
|
||||
|
||||
let end = self.pos.saturating_add(BYTES).min(self.buf.len());
|
||||
let bytes: ArrayVec<u8, BYTES> = self.buf[self.pos..end].try_into()?;
|
||||
let bytes = bytes.into_inner().map_err(|_| {
|
||||
let message: std::borrow::Cow<'static, str> = "u64 buffer underflow".into();
|
||||
conduwuit_core::error::SerdeDeSnafu { message }.build()
|
||||
})?;
|
||||
let bytes = bytes
|
||||
.into_inner()
|
||||
.map_err(|_| Self::Error::SerdeDe("u64 buffer underflow".into()))?;
|
||||
|
||||
self.inc_pos(BYTES);
|
||||
visitor.visit_u64(u64::from_be_bytes(bytes))
|
||||
|
||||
+1
-4
@@ -199,10 +199,7 @@ impl<W: Write> ser::Serializer for &mut Serializer<'_, W> {
|
||||
|
||||
value
|
||||
.serialize(&mut Serializer::new(&mut Writer::new(&mut self.out)))
|
||||
.map_err(|e| {
|
||||
let message: std::borrow::Cow<'static, str> = e.to_string().into();
|
||||
conduwuit_core::error::SerdeSerSnafu { message }.build()
|
||||
})
|
||||
.map_err(|e| Self::Error::SerdeSer(e.to_string().into()))
|
||||
},
|
||||
| _ => unhandled!("Unrecognized serialization Newtype {name:?}"),
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use std::{borrow::Cow, sync::Arc};
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::{Router, response::IntoResponse};
|
||||
use conduwuit::Error;
|
||||
@@ -18,10 +18,5 @@ pub(crate) fn build(services: &Arc<Services>) -> (Router, Guard) {
|
||||
}
|
||||
|
||||
async fn not_found(_uri: Uri) -> impl IntoResponse {
|
||||
Error::Request {
|
||||
kind: ErrorKind::Unrecognized,
|
||||
message: Cow::Borrowed("Not Found"),
|
||||
code: StatusCode::NOT_FOUND,
|
||||
backtrace: None,
|
||||
}
|
||||
Error::Request(ErrorKind::Unrecognized, "Not Found".into(), StatusCode::NOT_FOUND)
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ use std::{sync::Arc, time::Duration};
|
||||
use async_trait::async_trait;
|
||||
use conduwuit::{Result, Server, debug, error, warn};
|
||||
use database::{Deserialized, Map};
|
||||
use rand::Rng;
|
||||
use ruma::events::{Mentions, room::message::RoomMessageEventContent};
|
||||
use serde::Deserialize;
|
||||
use tokio::{
|
||||
@@ -99,7 +100,8 @@ impl crate::Service for Service {
|
||||
}
|
||||
|
||||
let first_check_jitter = {
|
||||
let jitter_percent = rand::random_range(-50.0..=10.0);
|
||||
let mut rng = rand::thread_rng();
|
||||
let jitter_percent = rng.gen_range(-50.0..=10.0);
|
||||
self.interval.mul_f64(1.0 + jitter_percent / 100.0)
|
||||
};
|
||||
|
||||
|
||||
@@ -147,11 +147,11 @@ impl Service {
|
||||
// same appservice)
|
||||
if let Ok(existing) = self.find_from_token(®istration.as_token).await {
|
||||
if existing.registration.id != registration.id {
|
||||
return Err!(Request(InvalidParam(
|
||||
return Err(err!(Request(InvalidParam(
|
||||
"Cannot register appservice: Token is already used by appservice '{}'. \
|
||||
Please generate a different token.",
|
||||
existing.registration.id
|
||||
)));
|
||||
))));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -163,10 +163,10 @@ impl Service {
|
||||
.await
|
||||
.is_ok()
|
||||
{
|
||||
return Err!(Request(InvalidParam(
|
||||
return Err(err!(Request(InvalidParam(
|
||||
"Cannot register appservice: The provided token is already in use by a user \
|
||||
device. Please generate a different token for the appservice."
|
||||
)));
|
||||
))));
|
||||
}
|
||||
|
||||
self.db
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::{fmt::Debug, mem};
|
||||
|
||||
use bytes::Bytes;
|
||||
use conduwuit::{
|
||||
Err, Result, debug, debug::INFO_SPAN_LEVEL, debug_error, debug_warn, err,
|
||||
Err, Error, Result, debug, debug::INFO_SPAN_LEVEL, debug_error, debug_warn, err,
|
||||
error::inspect_debug_log, implement, trace,
|
||||
};
|
||||
use http::{HeaderValue, header::AUTHORIZATION};
|
||||
@@ -179,7 +179,10 @@ async fn into_http_response(
|
||||
|
||||
debug!("Got {status:?} for {method} {url}");
|
||||
if !status.is_success() {
|
||||
return Err!(Federation(dest.to_owned(), RumaError::from_http_response(http_response),));
|
||||
return Err(Error::Federation(
|
||||
dest.to_owned(),
|
||||
RumaError::from_http_response(http_response),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(http_response)
|
||||
|
||||
@@ -35,7 +35,7 @@ pub async fn fetch_remote_thumbnail(
|
||||
.fetch_thumbnail_authenticated(mxc, user, server, timeout_ms, dim)
|
||||
.await;
|
||||
|
||||
if let Err(Error::Request { kind: NotFound, .. }) = &result {
|
||||
if let Err(Error::Request(NotFound, ..)) = &result {
|
||||
return self
|
||||
.fetch_thumbnail_unauthenticated(mxc, user, server, timeout_ms, dim)
|
||||
.await;
|
||||
@@ -67,7 +67,7 @@ pub async fn fetch_remote_content(
|
||||
);
|
||||
});
|
||||
|
||||
if let Err(Error::Request { kind: Unrecognized, .. }) = &result {
|
||||
if let Err(Error::Request(Unrecognized, ..)) = &result {
|
||||
return self
|
||||
.fetch_content_unauthenticated(mxc, user, server, timeout_ms)
|
||||
.await;
|
||||
|
||||
@@ -112,14 +112,7 @@ where
|
||||
{
|
||||
let event_fetch = |event_id| self.event_fetch(event_id);
|
||||
let event_exists = |event_id| self.event_exists(event_id);
|
||||
Ok(
|
||||
state_res::resolve(
|
||||
room_version,
|
||||
state_sets,
|
||||
auth_chain_sets,
|
||||
&event_fetch,
|
||||
&event_exists,
|
||||
)
|
||||
.await?,
|
||||
)
|
||||
state_res::resolve(room_version, state_sets, auth_chain_sets, &event_fetch, &event_exists)
|
||||
.map_err(|e| err!(error!("State resolution failed: {e:?}")))
|
||||
.await
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ use std::{
|
||||
str::FromStr,
|
||||
};
|
||||
|
||||
use conduwuit::{Err, Error, Result};
|
||||
use conduwuit::{Error, Result};
|
||||
use ruma::{UInt, api::client::error::ErrorKind};
|
||||
|
||||
use crate::rooms::short::ShortRoomId;
|
||||
@@ -57,7 +57,7 @@ impl FromStr for PaginationToken {
|
||||
if let Some(token) = pag_tok() {
|
||||
Ok(token)
|
||||
} else {
|
||||
Err!(BadRequest(ErrorKind::InvalidParam, "invalid token"))
|
||||
Err(Error::BadRequest(ErrorKind::InvalidParam, "invalid token"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -75,7 +75,10 @@ pub async fn create_hash_and_sign_event(
|
||||
let content: RoomCreateEventContent = serde_json::from_str(content.get())?;
|
||||
Ok(content.room_version)
|
||||
} else {
|
||||
Err!(InconsistentRoomState("non-create event for room of unknown version", room_id))
|
||||
Err(Error::InconsistentRoomState(
|
||||
"non-create event for room of unknown version",
|
||||
room_id,
|
||||
))
|
||||
}
|
||||
}
|
||||
let PduBuilder {
|
||||
@@ -272,9 +275,7 @@ pub async fn create_hash_and_sign_event(
|
||||
.hash_and_sign_event(&mut pdu_json, &room_version_id)
|
||||
{
|
||||
return match e {
|
||||
| Error::Signatures { source, .. }
|
||||
if matches!(source, ruma::signatures::Error::PduSize) =>
|
||||
{
|
||||
| Error::Signatures(ruma::signatures::Error::PduSize) => {
|
||||
Err!(Request(TooLarge("Message/PDU is too long (exceeds 65535 bytes)")))
|
||||
},
|
||||
| _ => Err!(Request(Unknown(warn!("Signing event failed: {e}")))),
|
||||
|
||||
@@ -385,13 +385,11 @@ fn num_senders(args: &crate::Args<'_>) -> usize {
|
||||
const MIN_SENDERS: usize = 1;
|
||||
// Limit the number of senders to the number of workers threads or number of
|
||||
// cores, conservatively.
|
||||
let mut max_senders = args.server.metrics.num_workers();
|
||||
|
||||
// Work around some platforms not returning the number of cores.
|
||||
let num_cores = available_parallelism();
|
||||
if num_cores > 0 {
|
||||
max_senders = max_senders.min(num_cores);
|
||||
}
|
||||
let max_senders = args
|
||||
.server
|
||||
.metrics
|
||||
.num_workers()
|
||||
.min(available_parallelism());
|
||||
|
||||
// If the user doesn't override the default 0, this is intended to then default
|
||||
// to 1 for now as multiple senders is experimental.
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::{collections::BTreeMap, sync::Arc};
|
||||
|
||||
use conduwuit::{
|
||||
Err, Result, SyncRwLock, err, error, implement, utils,
|
||||
Err, Error, Result, SyncRwLock, err, error, implement, utils,
|
||||
utils::{hash, string::EMPTY},
|
||||
};
|
||||
use database::{Deserialized, Json, Map};
|
||||
@@ -117,7 +117,7 @@ pub async fn try_auth(
|
||||
} else if let Some(username) = user {
|
||||
username
|
||||
} else {
|
||||
return Err!(BadRequest(
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Unrecognized,
|
||||
"Identifier type not recognized.",
|
||||
));
|
||||
@@ -125,7 +125,7 @@ pub async fn try_auth(
|
||||
|
||||
#[cfg(not(feature = "element_hacks"))]
|
||||
let Some(UserIdentifier::UserIdOrLocalpart(username)) = identifier else {
|
||||
return Err!(BadRequest(
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Unrecognized,
|
||||
"Identifier type not recognized.",
|
||||
));
|
||||
@@ -135,7 +135,7 @@ pub async fn try_auth(
|
||||
username.clone(),
|
||||
self.services.globals.server_name(),
|
||||
)
|
||||
.map_err(|_| err!(BadRequest(ErrorKind::InvalidParam, "User ID is invalid.")))?;
|
||||
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "User ID is invalid."))?;
|
||||
|
||||
// Check if the access token being used matches the credentials used for UIAA
|
||||
if user_id.localpart() != user_id_from_username.localpart() {
|
||||
|
||||
@@ -184,12 +184,6 @@ impl Service {
|
||||
password: Option<&str>,
|
||||
origin: Option<&str>,
|
||||
) -> Result<()> {
|
||||
if !self.services.globals.user_is_local(user_id)
|
||||
&& (password.is_some() || origin.is_some())
|
||||
{
|
||||
return Err!("Cannot create a nonlocal user with a set password or origin");
|
||||
}
|
||||
|
||||
self.db
|
||||
.userid_origin
|
||||
.insert(user_id, origin.unwrap_or("password"));
|
||||
@@ -761,13 +755,13 @@ impl Service {
|
||||
.keys
|
||||
.into_values();
|
||||
|
||||
let self_signing_key_id = self_signing_key_ids.next().ok_or(err!(BadRequest(
|
||||
let self_signing_key_id = self_signing_key_ids.next().ok_or(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Self signing key contained no key.",
|
||||
)))?;
|
||||
))?;
|
||||
|
||||
if self_signing_key_ids.next().is_some() {
|
||||
return Err!(BadRequest(
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Self signing key contained more than one key.",
|
||||
));
|
||||
@@ -1439,13 +1433,13 @@ pub fn parse_master_key(
|
||||
|
||||
let master_key = master_key
|
||||
.deserialize()
|
||||
.map_err(|_| err!(BadRequest(ErrorKind::InvalidParam, "Invalid master key")))?;
|
||||
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid master key"))?;
|
||||
let mut master_key_ids = master_key.keys.values();
|
||||
let master_key_id = master_key_ids
|
||||
.next()
|
||||
.ok_or(err!(BadRequest(ErrorKind::InvalidParam, "Master key contained no key.")))?;
|
||||
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Master key contained no key."))?;
|
||||
if master_key_ids.next().is_some() {
|
||||
return Err!(BadRequest(
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Master key contained more than one key.",
|
||||
));
|
||||
|
||||
+1
-1
@@ -25,7 +25,7 @@ axum.workspace = true
|
||||
futures.workspace = true
|
||||
tracing.workspace = true
|
||||
rand.workspace = true
|
||||
snafu.workspace = true
|
||||
thiserror.workspace = true
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
+4
-12
@@ -8,7 +8,6 @@ use axum::{
|
||||
};
|
||||
use conduwuit_build_metadata::{GIT_REMOTE_COMMIT_URL, GIT_REMOTE_WEB_URL, version_tag};
|
||||
use conduwuit_service::state;
|
||||
use snafu::{IntoError, prelude::*};
|
||||
|
||||
pub fn build() -> Router<state::State> {
|
||||
Router::<state::State>::new()
|
||||
@@ -49,17 +48,10 @@ async fn logo_handler() -> impl IntoResponse {
|
||||
)
|
||||
}
|
||||
|
||||
#[derive(Debug, Snafu)]
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
enum WebError {
|
||||
#[snafu(display("Failed to render template: {source}"))]
|
||||
Render {
|
||||
source: askama::Error,
|
||||
backtrace: snafu::Backtrace,
|
||||
},
|
||||
}
|
||||
|
||||
impl From<askama::Error> for WebError {
|
||||
fn from(source: askama::Error) -> Self { RenderSnafu.into_error(source) }
|
||||
#[error("Failed to render template: {0}")]
|
||||
Render(#[from] askama::Error),
|
||||
}
|
||||
|
||||
impl IntoResponse for WebError {
|
||||
@@ -74,7 +66,7 @@ impl IntoResponse for WebError {
|
||||
let nonce = rand::random::<u64>().to_string();
|
||||
|
||||
let status = match &self {
|
||||
| Self::Render { .. } => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
| Self::Render(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
};
|
||||
let tmpl = Error { nonce: &nonce, err: self };
|
||||
if let Ok(body) = tmpl.render() {
|
||||
|
||||
@@ -12,8 +12,8 @@ Server Error
|
||||
</h1>
|
||||
|
||||
{%- match err -%}
|
||||
{% when WebError::Render { source, .. } -%}
|
||||
<pre>{{ source }}</pre>
|
||||
{% when WebError::Render(err) -%}
|
||||
<pre>{{ err }}</pre>
|
||||
{% else -%} <p>An error occurred</p>
|
||||
{%- endmatch -%}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user