Compare commits

..

23 Commits

Author SHA1 Message Date
June Clementine Strawberry 8658a4c2d0 misc nix CI fixes that might speed it up a bit
Signed-off-by: June Clementine Strawberry <strawberry@puppygock.gay>
Signed-off-by: strawberry <strawberry@puppygock.gay>
2025-01-29 17:04:49 -05:00
Jason Volk eb7d893c86 fix malloc_conf feature-awareness
Signed-off-by: Jason Volk <jason@zemos.net>
2025-01-29 06:37:30 +00:00
Jason Volk 936161d89e reduce bottommost compression underrides
Signed-off-by: Jason Volk <jason@zemos.net>
2025-01-29 03:09:13 +00:00
Jason Volk 329925c661 additional info level span adjustments
Signed-off-by: Jason Volk <jason@zemos.net>
2025-01-29 03:09:13 +00:00
Jason Volk af399fd517 flatten state accessor iterations
Signed-off-by: Jason Volk <jason@zemos.net>
2025-01-29 01:18:08 +00:00
Jason Volk ad0b0af955 combine state_accessor data into mod
Signed-off-by: Jason Volk <jason@zemos.net>
2025-01-29 01:18:08 +00:00
Jason Volk 2c5af902a3 support executing configurable admin commands via SIGUSR2
Signed-off-by: Jason Volk <jason@zemos.net>
2025-01-29 01:18:08 +00:00
Jason Volk 2f449ba47d support reloading config via SIGUSR1
Signed-off-by: Jason Volk <jason@zemos.net>
2025-01-29 01:18:08 +00:00
Jason Volk a567e314e9 simplify shutdown signal handlers
Signed-off-by: Jason Volk <jason@zemos.net>
2025-01-29 01:18:08 +00:00
Jason Volk ed3cd99781 abstract the config reload checks
Signed-off-by: Jason Volk <jason@zemos.net>
2025-01-29 01:18:08 +00:00
Jason Volk 99fe88c21e use smallvec for the edu sending event buffer
Signed-off-by: Jason Volk <jason@zemos.net>
2025-01-29 01:18:08 +00:00
Jason Volk ffd0fd4242 pipeline pdu fetch for federation sending destination
Signed-off-by: Jason Volk <jason@zemos.net>
2025-01-29 01:18:08 +00:00
Jason Volk b2a565b0b4 propagate better error from server.check_running() 2025-01-29 01:18:08 +00:00
Jason Volk c516a8df3e fanout edu processing
Signed-off-by: Jason Volk <jason@zemos.net>
2025-01-29 01:18:08 +00:00
Jason Volk 94d786ac12 process rooms and edus concurrently
Signed-off-by: Jason Volk <jason@zemos.net>
2025-01-29 01:18:08 +00:00
Jason Volk 677316631a pipeline prologue of handle_incoming_pdu
simplify room_version/first_pdu_in_room argument passing

Signed-off-by: Jason Volk <jason@zemos.net>
2025-01-29 01:18:08 +00:00
Jason Volk 2b730a30ad add broad_flat_map
Signed-off-by: Jason Volk <jason@zemos.net>
2025-01-29 01:18:08 +00:00
Jason Volk 98f9570547 add option to disable rocksdb checksums
reference runtime state for default option initialization

Signed-off-by: Jason Volk <jason@zemos.net>
2025-01-28 18:38:47 +00:00
Jason Volk 13335042b7 enable the paranoid-checks options in debug mode
Signed-off-by: Jason Volk <jason@zemos.net>
2025-01-28 18:36:00 +00:00
Jason Volk 6db8df5e23 skip redundant acl check when sender is origin
Signed-off-by: Jason Volk <jason@zemos.net>
2025-01-28 18:36:00 +00:00
Jason Volk d0b4a619af furnish batch interface with trait
Signed-off-by: Jason Volk <jason@zemos.net>
2025-01-28 18:36:00 +00:00
Jason Volk 4a2d0d35bc split federation request from sending service
Signed-off-by: Jason Volk <jason@zemos.net>
2025-01-28 18:36:00 +00:00
Jason Volk 3e0ff2dc84 simplify references to server_name
Signed-off-by: Jason Volk <jason@zemos.net>
2025-01-28 18:36:00 +00:00
90 changed files with 1512 additions and 1206 deletions
+2 -2
View File
@@ -88,8 +88,8 @@ jobs:
ssh -q website "echo test" || ssh -q website "echo test"
echo "Creating commit rev directory on web server"
ssh -q website "rm -rf /var/www/girlboss.ceo/~strawberry/conduwuit/ci-bins/${WEBSERVER_DIR_NAME}/" || ssh -q website "rm -rf /var/www/girlboss.ceo/~strawberry/conduwuit/ci-bins/${WEBSERVER_DIR_NAME}/"
ssh -q website "mkdir -v /var/www/girlboss.ceo/~strawberry/conduwuit/ci-bins/${WEBSERVER_DIR_NAME}/" || ssh -q website "mkdir -v /var/www/girlboss.ceo/~strawberry/conduwuit/ci-bins/${WEBSERVER_DIR_NAME}/"
ssh -q website "rm -rf /var/www/girlboss.ceo/~strawberry/conduwuit/ci-bins/${WEBSERVER_DIR_NAME}/" || ssh -q website "rm -rf /var/www/girlboss.ceo/~strawberry/conduwuit/ci-bins/${WEBSERVER_DIR_NAME}/" || true
ssh -q website "mkdir -v /var/www/girlboss.ceo/~strawberry/conduwuit/ci-bins/${WEBSERVER_DIR_NAME}/" || ssh -q website "mkdir -v /var/www/girlboss.ceo/~strawberry/conduwuit/ci-bins/${WEBSERVER_DIR_NAME}/" || true
echo "SSH_WEBSITE=1" >> "$GITHUB_ENV"
Generated
+2
View File
@@ -685,6 +685,7 @@ dependencies = [
"http-body-util",
"hyper",
"ipaddress",
"itertools 0.13.0",
"log",
"rand",
"reqwest",
@@ -844,6 +845,7 @@ dependencies = [
"serde_json",
"serde_yaml",
"sha2",
"smallvec",
"termimad",
"tokio",
"tracing",
+19
View File
@@ -897,6 +897,13 @@
#
#rocksdb_paranoid_file_checks = false
# Enables or disables checksum verification in rocksdb at runtime.
# Checksums are usually hardware accelerated with low overhead; they are
# enabled in rocksdb by default. Older or slower platforms may see gains
# from disabling.
#
#rocksdb_checksums = true
# Database repair mode (for RocksDB SST corruption).
#
# Use this option when the server reports corruption while running or
@@ -1355,6 +1362,13 @@
#
#admin_execute_errors_ignore = false
# List of admin commands to execute on SIGUSR2.
#
# Similar to admin_execute, but these commands are executed when the
# server receives SIGUSR2 on supporting platforms.
#
#admin_signal_execute = []
# Controls the max log level for admin command log captures (logs
# generated from running admin commands). Defaults to "info" on release
# builds, else "debug" on debug builds.
@@ -1517,6 +1531,11 @@
#
#listening = true
# Enables configuration reload when the server receives SIGUSR1 on
# supporting platforms.
#
#config_reload_signal = true
[global.tls]
# Path to a valid TLS certificate file.
+3 -39
View File
@@ -15,8 +15,8 @@
let
main' = main.override {
#profile = "test";
profile = "release-debuginfo";
profile = "test";
#profile = "release-debuginfo";
all_features = true;
disable_release_max_log_level = true;
disable_features = [
@@ -55,42 +55,7 @@ let
#--exit-on-first-error=yes \
#--error-exitcode=1 \
# valgrind only works in non-static ocntexts
start = if !stdenv.hostPlatform.isStatic then writeShellScriptBin "start" ''
set -euxo pipefail
${lib.getExe openssl} genrsa -out private_key.key 2048
${lib.getExe openssl} req \
-new \
-sha256 \
-key private_key.key \
-subj "/C=US/ST=CA/O=MyOrg, Inc./CN=$SERVER_NAME" \
-out signing_request.csr
cp ${./v3.ext} v3.ext
echo "DNS.1 = $SERVER_NAME" >> v3.ext
echo "IP.1 = $(${lib.getExe gawk} 'END{print $1}' /etc/hosts)" \
>> v3.ext
${lib.getExe openssl} x509 \
-req \
-extfile v3.ext \
-in signing_request.csr \
-CA /complement/ca/ca.crt \
-CAkey /complement/ca/ca.key \
-CAcreateserial \
-out certificate.crt \
-days 1 \
-sha256
${lib.getExe' coreutils "env"} \
CONDUWUIT_SERVER_NAME="$SERVER_NAME" \
TMPDIR="/" \
${lib.getExe' valgrind "valgrind"} \
--leak-check=no \
--undef-value-errors=no \
--exit-on-first-error=yes \
--error-exitcode=1 \
${lib.getExe main'}
'' else writeShellScriptBin "start" ''
start = writeShellScriptBin "start" ''
set -euxo pipefail
${lib.getExe openssl} genrsa -out private_key.key 2048
@@ -135,7 +100,6 @@ dockerTools.buildImage {
coreutils
main'
start
valgrind
];
};
+11 -14
View File
@@ -9,7 +9,7 @@ use conduwuit::{
debug_error, err, info, trace, utils, utils::string::EMPTY, warn, Error, PduEvent, PduId,
RawPduId, Result,
};
use futures::{FutureExt, StreamExt};
use futures::{FutureExt, StreamExt, TryStreamExt};
use ruma::{
api::{client::error::ErrorKind, federation::event::get_room_state},
events::room::message::RoomMessageEventContent,
@@ -327,11 +327,10 @@ pub(super) async fn get_room_state(
.services
.rooms
.state_accessor
.room_state_full(&room_id)
.await?
.values()
.map(PduEvent::to_state_event)
.collect();
.room_state_full_pdus(&room_id)
.map_ok(PduEvent::into_state_event)
.try_collect()
.await?;
if room_state.is_empty() {
return Ok(RoomMessageEventContent::text_plain(
@@ -554,7 +553,7 @@ pub(super) async fn first_pdu_in_room(
.services
.rooms
.state_cache
.server_in_room(&self.services.server.config.server_name, &room_id)
.server_in_room(&self.services.server.name, &room_id)
.await
{
return Ok(RoomMessageEventContent::text_plain(
@@ -583,7 +582,7 @@ pub(super) async fn latest_pdu_in_room(
.services
.rooms
.state_cache
.server_in_room(&self.services.server.config.server_name, &room_id)
.server_in_room(&self.services.server.name, &room_id)
.await
{
return Ok(RoomMessageEventContent::text_plain(
@@ -613,7 +612,7 @@ pub(super) async fn force_set_room_state_from_server(
.services
.rooms
.state_cache
.server_in_room(&self.services.server.config.server_name, &room_id)
.server_in_room(&self.services.server.name, &room_id)
.await
{
return Ok(RoomMessageEventContent::text_plain(
@@ -756,8 +755,7 @@ pub(super) async fn get_signing_keys(
notary: Option<Box<ServerName>>,
query: bool,
) -> Result<RoomMessageEventContent> {
let server_name =
server_name.unwrap_or_else(|| self.services.server.config.server_name.clone().into());
let server_name = server_name.unwrap_or_else(|| self.services.server.name.clone().into());
if let Some(notary) = notary {
let signing_keys = self
@@ -793,8 +791,7 @@ pub(super) async fn get_verify_keys(
&self,
server_name: Option<Box<ServerName>>,
) -> Result<RoomMessageEventContent> {
let server_name =
server_name.unwrap_or_else(|| self.services.server.config.server_name.clone().into());
let server_name = server_name.unwrap_or_else(|| self.services.server.name.clone().into());
let keys = self
.services
@@ -824,7 +821,7 @@ pub(super) async fn resolve_true_destination(
));
}
if server_name == self.services.server.config.server_name {
if server_name == self.services.server.name {
return Ok(RoomMessageEventContent::text_plain(
"Not allowed to send federation requests to ourselves. Please use `get-pdu` for \
fetching local PDUs.",
+1 -1
View File
@@ -92,7 +92,7 @@ pub(super) async fn remote_user_in_rooms(
&self,
user_id: Box<UserId>,
) -> Result<RoomMessageEventContent> {
if user_id.server_name() == self.services.server.config.server_name {
if user_id.server_name() == self.services.server.name {
return Ok(RoomMessageEventContent::text_plain(
"User belongs to our server, please use `list-joined-rooms` user admin command \
instead.",
+2 -7
View File
@@ -1,6 +1,6 @@
use std::{fmt::Write, path::PathBuf, sync::Arc};
use conduwuit::{info, utils::time, warn, Config, Err, Result};
use conduwuit::{info, utils::time, warn, Err, Result};
use ruma::events::room::message::RoomMessageEventContent;
use crate::admin_command;
@@ -33,12 +33,7 @@ pub(super) async fn reload_config(
path: Option<PathBuf>,
) -> Result<RoomMessageEventContent> {
let path = path.as_deref().into_iter();
let config = Config::load(path).and_then(|raw| Config::new(&raw))?;
if config.server_name != self.services.server.config.server_name {
return Err!("You can't change the server name.");
}
let _old = self.services.server.config.update(config)?;
self.services.config.reload(path)?;
Ok(RoomMessageEventContent::text_plain("Successfully reconfigured."))
}
+1
View File
@@ -50,6 +50,7 @@ http.workspace = true
http-body-util.workspace = true
hyper.workspace = true
ipaddress.workspace = true
itertools.workspace = true
log.workspace = true
rand.workspace = true
reqwest.workspace = true
+18 -10
View File
@@ -1,6 +1,6 @@
use axum::extract::State;
use conduwuit::{
at, deref_at, err, ref_at,
at, err, ref_at,
utils::{
future::TryExtExt,
stream::{BroadbandExt, ReadyExt, TryIgnore, WidebandExt},
@@ -10,10 +10,10 @@ use conduwuit::{
};
use futures::{
future::{join, join3, try_join3, OptionFuture},
FutureExt, StreamExt, TryFutureExt,
FutureExt, StreamExt, TryFutureExt, TryStreamExt,
};
use ruma::{api::client::context::get_context, events::StateEventType, OwnedEventId, UserId};
use service::rooms::{lazy_loading, lazy_loading::Options};
use service::rooms::{lazy_loading, lazy_loading::Options, short::ShortStateKey};
use crate::{
client::message::{event_filter, ignored_filter, lazy_loading_witness, visibility_filter},
@@ -132,21 +132,29 @@ pub(crate) async fn get_context_route(
.state_accessor
.pdu_shortstatehash(state_at)
.or_else(|_| services.rooms.state.get_room_shortstatehash(room_id))
.and_then(|shortstatehash| services.rooms.state_accessor.state_full_ids(shortstatehash))
.map_ok(|shortstatehash| {
services
.rooms
.state_accessor
.state_full_ids(shortstatehash)
.map(Ok)
})
.map_err(|e| err!(Database("State not found: {e}")))
.try_flatten_stream()
.try_collect()
.boxed();
let (lazy_loading_witnessed, state_ids) = join(lazy_loading_witnessed, state_ids).await;
let state_ids = state_ids?;
let state_ids: Vec<(ShortStateKey, OwnedEventId)> = state_ids?;
let shortstatekeys = state_ids.iter().map(at!(0)).stream();
let shorteventids = state_ids.iter().map(ref_at!(1)).stream();
let lazy_loading_witnessed = lazy_loading_witnessed.unwrap_or_default();
let shortstatekeys = state_ids.iter().stream().map(deref_at!(0));
let state: Vec<_> = services
.rooms
.short
.multi_get_statekey_from_short(shortstatekeys)
.zip(state_ids.iter().stream().map(at!(1)))
.zip(shorteventids)
.ready_filter_map(|item| Some((item.0.ok()?, item.1)))
.ready_filter_map(|((event_type, state_key), event_id)| {
if filter.lazy_load_options.is_enabled()
@@ -162,9 +170,9 @@ pub(crate) async fn get_context_route(
Some(event_id)
})
.broad_filter_map(|event_id: &OwnedEventId| {
services.rooms.timeline.get_pdu(event_id).ok()
services.rooms.timeline.get_pdu(event_id.as_ref()).ok()
})
.map(|pdu| pdu.to_state_event())
.map(PduEvent::into_state_event)
.collect()
.await;
+8 -10
View File
@@ -8,14 +8,14 @@ use std::{
use axum::extract::State;
use axum_client_ip::InsecureClientIp;
use conduwuit::{
debug, debug_info, debug_warn, err, info,
at, debug, debug_info, debug_warn, err, info,
pdu::{gen_event_id_canonical_json, PduBuilder},
result::FlatOk,
trace,
utils::{self, shuffle, IterStream, ReadyExt},
warn, Err, PduEvent, Result,
};
use futures::{join, FutureExt, StreamExt};
use futures::{join, FutureExt, StreamExt, TryFutureExt};
use ruma::{
api::{
client::{
@@ -765,11 +765,12 @@ pub(crate) async fn get_member_events_route(
.rooms
.state_accessor
.room_state_full(&body.room_id)
.await?
.iter()
.filter(|(key, _)| key.0 == StateEventType::RoomMember)
.map(|(_, pdu)| pdu.to_member_event())
.collect(),
.ready_filter_map(Result::ok)
.ready_filter(|((ty, _), _)| *ty == StateEventType::RoomMember)
.map(at!(1))
.map(PduEvent::into_member_event)
.collect()
.await,
})
}
@@ -1707,9 +1708,6 @@ pub async fn leave_room(
room_id: &RoomId,
reason: Option<String>,
) -> Result<()> {
//use conduwuit::utils::stream::OptionStream;
use futures::TryFutureExt;
// Ask a remote server if we don't have this room and are not knocking on it
if !services
.rooms
+3 -3
View File
@@ -6,9 +6,9 @@ use conduwuit::{
stream::{BroadbandExt, TryIgnore, WidebandExt},
IterStream, ReadyExt,
},
Event, PduCount, Result,
Event, PduCount, PduEvent, Result,
};
use futures::{future::OptionFuture, pin_mut, FutureExt, StreamExt};
use futures::{future::OptionFuture, pin_mut, FutureExt, StreamExt, TryFutureExt};
use ruma::{
api::{
client::{filter::RoomEventFilter, message::get_message_events},
@@ -220,8 +220,8 @@ async fn get_member_event(
.rooms
.state_accessor
.room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str())
.map_ok(PduEvent::into_state_event)
.await
.map(|member_event| member_event.to_state_event())
.ok()
}
+1 -1
View File
@@ -37,7 +37,7 @@ pub(crate) async fn create_openid_token_route(
Ok(account::request_openid_token::v3::Response {
access_token,
token_type: TokenType::Bearer,
matrix_server_name: services.server.config.server_name.clone(),
matrix_server_name: services.server.name.clone(),
expires_in: Duration::from_secs(expires_in),
})
}
+1 -1
View File
@@ -50,7 +50,7 @@ pub(crate) async fn report_room_route(
if !services
.rooms
.state_cache
.server_in_room(&services.server.config.server_name, &body.room_id)
.server_in_room(&services.server.name, &body.room_id)
.await
{
return Err!(Request(NotFound(
+1 -1
View File
@@ -71,7 +71,7 @@ pub(crate) async fn create_room_route(
let room_id: OwnedRoomId = if let Some(custom_room_id) = &body.room_id {
custom_room_id_check(&services, custom_room_id)?
} else {
RoomId::new(&services.server.config.server_name)
RoomId::new(&services.server.name)
};
// check if room ID doesn't already exist instead of erroring on auth check
+4 -5
View File
@@ -2,7 +2,7 @@ use axum::extract::State;
use conduwuit::{
at,
utils::{stream::TryTools, BoolExt},
Err, Result,
Err, PduEvent, Result,
};
use futures::TryStreamExt;
use ruma::api::client::room::initial_sync::v3::{PaginationChunk, Request, Response};
@@ -39,10 +39,9 @@ pub(crate) async fn room_initial_sync_route(
.rooms
.state_accessor
.room_state_full_pdus(room_id)
.await?
.into_iter()
.map(|pdu| pdu.to_state_event())
.collect();
.map_ok(PduEvent::into_state_event)
.try_collect()
.await?;
let messages = PaginationChunk {
start: events.last().map(at!(0)).as_ref().map(ToString::to_string),
+6 -6
View File
@@ -7,7 +7,7 @@ use conduwuit::{
utils::{stream::ReadyExt, IterStream},
Err, PduEvent, Result,
};
use futures::{future::OptionFuture, FutureExt, StreamExt, TryFutureExt};
use futures::{future::OptionFuture, FutureExt, StreamExt, TryFutureExt, TryStreamExt};
use ruma::{
api::client::search::search_events::{
self,
@@ -181,15 +181,15 @@ async fn category_room_events(
}
async fn procure_room_state(services: &Services, room_id: &RoomId) -> Result<RoomState> {
let state_map = services
let state = services
.rooms
.state_accessor
.room_state_full(room_id)
.room_state_full_pdus(room_id)
.map_ok(PduEvent::into_state_event)
.try_collect()
.await?;
let state_events = state_map.values().map(PduEvent::to_state_event).collect();
Ok(state_events)
Ok(state)
}
async fn check_room_visible(
+6 -6
View File
@@ -1,5 +1,6 @@
use axum::extract::State;
use conduwuit::{err, pdu::PduBuilder, utils::BoolExt, Err, PduEvent, Result};
use futures::TryStreamExt;
use ruma::{
api::client::state::{get_state_events, get_state_events_for_key, send_state_event},
events::{
@@ -82,11 +83,10 @@ pub(crate) async fn get_state_events_route(
room_state: services
.rooms
.state_accessor
.room_state_full(&body.room_id)
.await?
.values()
.map(PduEvent::to_state_event)
.collect(),
.room_state_full_pdus(&body.room_id)
.map_ok(PduEvent::into_state_event)
.try_collect()
.await?,
})
}
@@ -133,7 +133,7 @@ pub(crate) async fn get_state_events_for_key_route(
Ok(get_state_events_for_key::v3::Response {
content: event_format.or(|| event.get_content_as_value()),
event: event_format.then(|| event.to_state_event_value()),
event: event_format.then(|| event.into_state_event_value()),
})
}
+27 -21
View File
@@ -28,7 +28,7 @@ use conduwuit_service::{
};
use futures::{
future::{join, join3, join4, join5, try_join, try_join4, OptionFuture},
FutureExt, StreamExt, TryFutureExt,
FutureExt, StreamExt, TryFutureExt, TryStreamExt,
};
use ruma::{
api::client::{
@@ -503,16 +503,20 @@ async fn handle_left_room(
let mut left_state_events = Vec::new();
let since_shortstatehash = services
.rooms
.user
.get_token_shortstatehash(room_id, since)
.await;
let since_shortstatehash = services.rooms.user.get_token_shortstatehash(room_id, since);
let since_state_ids = match since_shortstatehash {
| Ok(s) => services.rooms.state_accessor.state_full_ids(s).await?,
| Err(_) => HashMap::new(),
};
let since_state_ids: HashMap<_, OwnedEventId> = since_shortstatehash
.map_ok(|since_shortstatehash| {
services
.rooms
.state_accessor
.state_full_ids(since_shortstatehash)
.map(Ok)
})
.try_flatten_stream()
.try_collect()
.await
.unwrap_or_default();
let Ok(left_event_id): Result<OwnedEventId> = services
.rooms
@@ -534,11 +538,12 @@ async fn handle_left_room(
return Ok(None);
};
let mut left_state_ids = services
let mut left_state_ids: HashMap<_, _> = services
.rooms
.state_accessor
.state_full_ids(left_shortstatehash)
.await?;
.collect()
.await;
let leave_shortstatekey = services
.rooms
@@ -960,19 +965,18 @@ async fn calculate_state_initial(
current_shortstatehash: ShortStateHash,
witness: Option<&Witness>,
) -> Result<StateChanges> {
let state_events = services
let (shortstatekeys, event_ids): (Vec<_>, Vec<_>) = services
.rooms
.state_accessor
.state_full_ids(current_shortstatehash)
.await?;
let shortstatekeys = state_events.keys().copied().stream();
.unzip()
.await;
let state_events = services
.rooms
.short
.multi_get_statekey_from_short(shortstatekeys)
.zip(state_events.values().cloned().stream())
.multi_get_statekey_from_short(shortstatekeys.into_iter().stream())
.zip(event_ids.into_iter().stream())
.ready_filter_map(|item| Some((item.0.ok()?, item.1)))
.ready_filter_map(|((event_type, state_key), event_id)| {
let lazy_load_enabled = filter.room.state.lazy_load_options.is_enabled()
@@ -1036,17 +1040,19 @@ async fn calculate_state_incremental(
let current_state_ids = services
.rooms
.state_accessor
.state_full_ids(current_shortstatehash);
.state_full_ids(current_shortstatehash)
.collect();
let since_state_ids = services
.rooms
.state_accessor
.state_full_ids(since_shortstatehash);
.state_full_ids(since_shortstatehash)
.collect();
let (current_state_ids, since_state_ids): (
HashMap<_, OwnedEventId>,
HashMap<_, OwnedEventId>,
) = try_join(current_state_ids, since_state_ids).await?;
) = join(current_state_ids, since_state_ids).await;
current_state_ids
.iter()
+5 -3
View File
@@ -241,13 +241,15 @@ pub(crate) async fn sync_events_v4_route(
.rooms
.state_accessor
.state_full_ids(current_shortstatehash)
.await?;
.collect()
.await;
let since_state_ids = services
let since_state_ids: HashMap<_, _> = services
.rooms
.state_accessor
.state_full_ids(since_shortstatehash)
.await?;
.collect()
.await;
for (key, id) in current_state_ids {
if since_state_ids.get(&key) != Some(&id) {
+5 -3
View File
@@ -748,13 +748,15 @@ async fn collect_e2ee<'a>(
.rooms
.state_accessor
.state_full_ids(current_shortstatehash)
.await?;
.collect()
.await;
let since_state_ids = services
let since_state_ids: HashMap<_, _> = services
.rooms
.state_accessor
.state_full_ids(since_shortstatehash)
.await?;
.collect()
.await;
for (key, id) in current_state_ids {
if since_state_ids.get(&key) != Some(&id) {
+16 -12
View File
@@ -10,6 +10,7 @@ use ruma::{
},
to_device::DeviceIdOrAllDevices,
};
use service::sending::EduBuf;
use crate::Ruma;
@@ -42,18 +43,21 @@ pub(crate) async fn send_event_to_device_route(
messages.insert(target_user_id.clone(), map);
let count = services.globals.next_count()?;
services.sending.send_edu_server(
target_user_id.server_name(),
serde_json::to_vec(&federation::transactions::edu::Edu::DirectToDevice(
DirectDeviceContent {
sender: sender_user.clone(),
ev_type: body.event_type.clone(),
message_id: count.to_string().into(),
messages,
},
))
.expect("DirectToDevice EDU can be serialized"),
)?;
let mut buf = EduBuf::new();
serde_json::to_writer(
&mut buf,
&federation::transactions::edu::Edu::DirectToDevice(DirectDeviceContent {
sender: sender_user.clone(),
ev_type: body.event_type.clone(),
message_id: count.to_string().into(),
messages,
}),
)
.expect("DirectToDevice EDU can be serialized");
services
.sending
.send_edu_server(target_user_id.server_name(), buf)?;
continue;
}
+1 -1
View File
@@ -38,7 +38,7 @@ pub(crate) async fn turn_server_route(
let user = body.sender_user.unwrap_or_else(|| {
UserId::parse_with_server_name(
utils::random_string(RANDOM_USER_ID_LENGTH).to_lowercase(),
&services.server.config.server_name,
&services.server.name,
)
.unwrap()
});
+2 -2
View File
@@ -13,7 +13,7 @@ use crate::{Error, Result, Ruma};
/// # `POST /_matrix/federation/v1/publicRooms`
///
/// Lists the public rooms on this server.
#[tracing::instrument(skip_all, fields(%client), name = "publicrooms")]
#[tracing::instrument(name = "publicrooms", level = "debug", skip_all, fields(%client))]
pub(crate) async fn get_public_rooms_filtered_route(
State(services): State<crate::State>,
InsecureClientIp(client): InsecureClientIp,
@@ -51,7 +51,7 @@ pub(crate) async fn get_public_rooms_filtered_route(
/// # `GET /_matrix/federation/v1/publicRooms`
///
/// Lists the public rooms on this server.
#[tracing::instrument(skip_all, fields(%client), "publicrooms")]
#[tracing::instrument(name = "publicrooms", level = "debug", skip_all, fields(%client))]
pub(crate) async fn get_public_rooms_route(
State(services): State<crate::State>,
InsecureClientIp(client): InsecureClientIp,
+363 -261
View File
@@ -3,17 +3,27 @@ use std::{collections::BTreeMap, net::IpAddr, time::Instant};
use axum::extract::State;
use axum_client_ip::InsecureClientIp;
use conduwuit::{
debug, debug_warn, err, error, result::LogErr, trace, utils::ReadyExt, warn, Err, Error,
Result,
debug,
debug::INFO_SPAN_LEVEL,
debug_warn, err, error,
result::LogErr,
trace,
utils::{
stream::{automatic_width, BroadbandExt, TryBroadbandExt},
IterStream, ReadyExt,
},
warn, Err, Error, Result,
};
use futures::{FutureExt, StreamExt};
use futures::{FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt};
use itertools::Itertools;
use ruma::{
api::{
client::error::ErrorKind,
federation::transactions::{
edu::{
DeviceListUpdateContent, DirectDeviceContent, Edu, PresenceContent,
ReceiptContent, SigningKeyUpdateContent, TypingContent,
PresenceUpdate, ReceiptContent, ReceiptData, ReceiptMap, SigningKeyUpdateContent,
TypingContent,
},
send_transaction_message,
},
@@ -21,27 +31,28 @@ use ruma::{
events::receipt::{ReceiptEvent, ReceiptEventContent, ReceiptType},
serde::Raw,
to_device::DeviceIdOrAllDevices,
OwnedEventId, ServerName,
CanonicalJsonObject, OwnedEventId, OwnedRoomId, OwnedUserId, RoomId, ServerName, UserId,
};
use serde_json::value::RawValue as RawJsonValue;
use service::{
sending::{EDU_LIMIT, PDU_LIMIT},
Services,
};
use utils::millis_since_unix_epoch;
use crate::{
utils::{self},
Ruma,
};
type ResolvedMap = BTreeMap<OwnedEventId, Result<()>>;
type ResolvedMap = BTreeMap<OwnedEventId, Result>;
type Pdu = (OwnedRoomId, OwnedEventId, CanonicalJsonObject);
/// # `PUT /_matrix/federation/v1/send/{txnId}`
///
/// Push EDUs and PDUs to this server.
#[tracing::instrument(
name = "send",
level = "debug",
name = "txn",
level = INFO_SPAN_LEVEL,
skip_all,
fields(
%client,
@@ -73,91 +84,41 @@ pub(crate) async fn send_transaction_message_route(
let txn_start_time = Instant::now();
trace!(
pdus = ?body.pdus.len(),
edus = ?body.edus.len(),
pdus = body.pdus.len(),
edus = body.edus.len(),
elapsed = ?txn_start_time.elapsed(),
id = ?body.transaction_id,
origin =?body.origin(),
"Starting txn",
);
let resolved_map =
handle_pdus(&services, &client, &body.pdus, body.origin(), &txn_start_time)
.boxed()
.await?;
let pdus = body
.pdus
.iter()
.stream()
.broad_then(|pdu| services.rooms.event_handler.parse_incoming_pdu(pdu))
.inspect_err(|e| debug_warn!("Could not parse PDU: {e}"))
.ready_filter_map(Result::ok);
handle_edus(&services, &client, &body.edus, body.origin())
.boxed()
.await;
let edus = body
.edus
.iter()
.map(|edu| edu.json().get())
.map(serde_json::from_str)
.filter_map(Result::ok)
.stream();
let results = handle(&services, &client, body.origin(), txn_start_time, pdus, edus).await?;
debug!(
pdus = ?body.pdus.len(),
edus = ?body.edus.len(),
pdus = body.pdus.len(),
edus = body.edus.len(),
elapsed = ?txn_start_time.elapsed(),
id = ?body.transaction_id,
origin =?body.origin(),
"Finished txn",
);
Ok(send_transaction_message::v1::Response {
pdus: resolved_map
.into_iter()
.map(|(e, r)| (e, r.map_err(error::sanitized_message)))
.collect(),
})
}
async fn handle_pdus(
services: &Services,
_client: &IpAddr,
pdus: &[Box<RawJsonValue>],
origin: &ServerName,
txn_start_time: &Instant,
) -> Result<ResolvedMap> {
let mut parsed_pdus = Vec::with_capacity(pdus.len());
for pdu in pdus {
parsed_pdus.push(match services.rooms.event_handler.parse_incoming_pdu(pdu).await {
| Ok(t) => t,
| Err(e) => {
debug_warn!("Could not parse PDU: {e}");
continue;
},
});
// We do not add the event_id field to the pdu here because of signature
// and hashes checks
}
let mut resolved_map = BTreeMap::new();
for (event_id, value, room_id) in parsed_pdus {
services.server.check_running()?;
let pdu_start_time = Instant::now();
let mutex_lock = services
.rooms
.event_handler
.mutex_federation
.lock(&room_id)
.await;
let result = services
.rooms
.event_handler
.handle_incoming_pdu(origin, &room_id, &event_id, value, true)
.boxed()
.await
.map(|_| ());
drop(mutex_lock);
debug!(
pdu_elapsed = ?pdu_start_time.elapsed(),
txn_elapsed = ?txn_start_time.elapsed(),
"Finished PDU {event_id}",
);
resolved_map.insert(event_id, result);
}
for (id, result) in &resolved_map {
for (id, result) in &results {
if let Err(e) = result {
if matches!(e, Error::BadRequest(ErrorKind::NotFound, _)) {
warn!("Incoming PDU failed {id}: {e:?}");
@@ -165,39 +126,117 @@ async fn handle_pdus(
}
}
Ok(resolved_map)
Ok(send_transaction_message::v1::Response {
pdus: results
.into_iter()
.map(|(e, r)| (e, r.map_err(error::sanitized_message)))
.collect(),
})
}
async fn handle_edus(
async fn handle(
services: &Services,
client: &IpAddr,
edus: &[Raw<Edu>],
origin: &ServerName,
) {
for edu in edus
.iter()
.filter_map(|edu| serde_json::from_str::<Edu>(edu.json().get()).ok())
{
match edu {
| Edu::Presence(presence) => {
handle_edu_presence(services, client, origin, presence).await;
},
| Edu::Receipt(receipt) =>
handle_edu_receipt(services, client, origin, receipt).await,
| Edu::Typing(typing) => handle_edu_typing(services, client, origin, typing).await,
| Edu::DeviceListUpdate(content) => {
handle_edu_device_list_update(services, client, origin, content).await;
},
| Edu::DirectToDevice(content) => {
handle_edu_direct_to_device(services, client, origin, content).await;
},
| Edu::SigningKeyUpdate(content) => {
handle_edu_signing_key_update(services, client, origin, content).await;
},
| Edu::_Custom(ref _custom) => {
debug_warn!(?edus, "received custom/unknown EDU");
},
}
started: Instant,
pdus: impl Stream<Item = Pdu> + Send,
edus: impl Stream<Item = Edu> + Send,
) -> Result<ResolvedMap> {
// group pdus by room
let pdus = pdus
.collect()
.map(|mut pdus: Vec<_>| {
pdus.sort_by(|(room_a, ..), (room_b, ..)| room_a.cmp(room_b));
pdus.into_iter()
.into_grouping_map_by(|(room_id, ..)| room_id.clone())
.collect()
})
.await;
// we can evaluate rooms concurrently
let results: ResolvedMap = pdus
.into_iter()
.try_stream()
.broad_and_then(|(room_id, pdus): (_, Vec<_>)| {
handle_room(services, client, origin, started, room_id, pdus.into_iter())
.map_ok(Vec::into_iter)
.map_ok(IterStream::try_stream)
})
.try_flatten()
.try_collect()
.boxed()
.await?;
// evaluate edus after pdus, at least for now.
edus.for_each_concurrent(automatic_width(), |edu| handle_edu(services, client, origin, edu))
.boxed()
.await;
Ok(results)
}
async fn handle_room(
services: &Services,
_client: &IpAddr,
origin: &ServerName,
txn_start_time: Instant,
room_id: OwnedRoomId,
pdus: impl Iterator<Item = Pdu> + Send,
) -> Result<Vec<(OwnedEventId, Result)>> {
let _room_lock = services
.rooms
.event_handler
.mutex_federation
.lock(&room_id)
.await;
let room_id = &room_id;
pdus.try_stream()
.and_then(|(_, event_id, value)| async move {
services.server.check_running()?;
let pdu_start_time = Instant::now();
let result = services
.rooms
.event_handler
.handle_incoming_pdu(origin, room_id, &event_id, value, true)
.await
.map(|_| ());
debug!(
pdu_elapsed = ?pdu_start_time.elapsed(),
txn_elapsed = ?txn_start_time.elapsed(),
"Finished PDU {event_id}",
);
Ok((event_id, result))
})
.try_collect()
.await
}
async fn handle_edu(services: &Services, client: &IpAddr, origin: &ServerName, edu: Edu) {
match edu {
| Edu::Presence(presence) if services.server.config.allow_incoming_presence =>
handle_edu_presence(services, client, origin, presence).await,
| Edu::Receipt(receipt) if services.server.config.allow_incoming_read_receipts =>
handle_edu_receipt(services, client, origin, receipt).await,
| Edu::Typing(typing) if services.server.config.allow_incoming_typing =>
handle_edu_typing(services, client, origin, typing).await,
| Edu::DeviceListUpdate(content) =>
handle_edu_device_list_update(services, client, origin, content).await,
| Edu::DirectToDevice(content) =>
handle_edu_direct_to_device(services, client, origin, content).await,
| Edu::SigningKeyUpdate(content) =>
handle_edu_signing_key_update(services, client, origin, content).await,
| Edu::_Custom(ref _custom) => debug_warn!(?edu, "received custom/unknown EDU"),
| _ => trace!(?edu, "skipped"),
}
}
@@ -207,32 +246,41 @@ async fn handle_edu_presence(
origin: &ServerName,
presence: PresenceContent,
) {
if !services.globals.allow_incoming_presence() {
presence
.push
.into_iter()
.stream()
.for_each_concurrent(automatic_width(), |update| {
handle_edu_presence_update(services, origin, update)
})
.await;
}
async fn handle_edu_presence_update(
services: &Services,
origin: &ServerName,
update: PresenceUpdate,
) {
if update.user_id.server_name() != origin {
debug_warn!(
%update.user_id, %origin,
"received presence EDU for user not belonging to origin"
);
return;
}
for update in presence.push {
if update.user_id.server_name() != origin {
debug_warn!(
%update.user_id, %origin,
"received presence EDU for user not belonging to origin"
);
continue;
}
services
.presence
.set_presence(
&update.user_id,
&update.presence,
Some(update.currently_active),
Some(update.last_active_ago),
update.status_msg.clone(),
)
.await
.log_err()
.ok();
}
services
.presence
.set_presence(
&update.user_id,
&update.presence,
Some(update.currently_active),
Some(update.last_active_ago),
update.status_msg.clone(),
)
.await
.log_err()
.ok();
}
async fn handle_edu_receipt(
@@ -241,66 +289,94 @@ async fn handle_edu_receipt(
origin: &ServerName,
receipt: ReceiptContent,
) {
if !services.globals.allow_incoming_read_receipts() {
receipt
.receipts
.into_iter()
.stream()
.for_each_concurrent(automatic_width(), |(room_id, room_updates)| {
handle_edu_receipt_room(services, origin, room_id, room_updates)
})
.await;
}
async fn handle_edu_receipt_room(
services: &Services,
origin: &ServerName,
room_id: OwnedRoomId,
room_updates: ReceiptMap,
) {
if services
.rooms
.event_handler
.acl_check(origin, &room_id)
.await
.is_err()
{
debug_warn!(
%origin, %room_id,
"received read receipt EDU from ACL'd server"
);
return;
}
for (room_id, room_updates) in receipt.receipts {
if services
.rooms
.event_handler
.acl_check(origin, &room_id)
.await
.is_err()
{
debug_warn!(
%origin, %room_id,
"received read receipt EDU from ACL'd server"
);
continue;
}
let room_id = &room_id;
room_updates
.read
.into_iter()
.stream()
.for_each_concurrent(automatic_width(), |(user_id, user_updates)| async move {
handle_edu_receipt_room_user(services, origin, room_id, &user_id, user_updates).await;
})
.await;
}
for (user_id, user_updates) in room_updates.read {
if user_id.server_name() != origin {
debug_warn!(
%user_id, %origin,
"received read receipt EDU for user not belonging to origin"
);
continue;
}
if services
.rooms
.state_cache
.room_members(&room_id)
.ready_any(|member| member.server_name() == user_id.server_name())
.await
{
for event_id in &user_updates.event_ids {
let user_receipts =
BTreeMap::from([(user_id.clone(), user_updates.data.clone())]);
let receipts = BTreeMap::from([(ReceiptType::Read, user_receipts)]);
let receipt_content = BTreeMap::from([(event_id.to_owned(), receipts)]);
let event = ReceiptEvent {
content: ReceiptEventContent(receipt_content),
room_id: room_id.clone(),
};
services
.rooms
.read_receipt
.readreceipt_update(&user_id, &room_id, &event)
.await;
}
} else {
debug_warn!(
%user_id, %room_id, %origin,
"received read receipt EDU from server who does not have a member in the room",
);
continue;
}
}
async fn handle_edu_receipt_room_user(
services: &Services,
origin: &ServerName,
room_id: &RoomId,
user_id: &UserId,
user_updates: ReceiptData,
) {
if user_id.server_name() != origin {
debug_warn!(
%user_id, %origin,
"received read receipt EDU for user not belonging to origin"
);
return;
}
if !services
.rooms
.state_cache
.server_in_room(origin, room_id)
.await
{
debug_warn!(
%user_id, %room_id, %origin,
"received read receipt EDU from server who does not have a member in the room",
);
return;
}
let data = &user_updates.data;
user_updates
.event_ids
.into_iter()
.stream()
.for_each_concurrent(automatic_width(), |event_id| async move {
let user_data = [(user_id.to_owned(), data.clone())];
let receipts = [(ReceiptType::Read, BTreeMap::from(user_data))];
let content = [(event_id.clone(), BTreeMap::from(receipts))];
services
.rooms
.read_receipt
.readreceipt_update(user_id, room_id, &ReceiptEvent {
content: ReceiptEventContent(content.into()),
room_id: room_id.to_owned(),
})
.await;
})
.await;
}
async fn handle_edu_typing(
@@ -309,10 +385,6 @@ async fn handle_edu_typing(
origin: &ServerName,
typing: TypingContent,
) {
if !services.server.config.allow_incoming_typing {
return;
}
if typing.user_id.server_name() != origin {
debug_warn!(
%typing.user_id, %origin,
@@ -335,41 +407,38 @@ async fn handle_edu_typing(
return;
}
if services
if !services
.rooms
.state_cache
.is_joined(&typing.user_id, &typing.room_id)
.await
{
if typing.typing {
let timeout = utils::millis_since_unix_epoch().saturating_add(
services
.server
.config
.typing_federation_timeout_s
.saturating_mul(1000),
);
services
.rooms
.typing
.typing_add(&typing.user_id, &typing.room_id, timeout)
.await
.log_err()
.ok();
} else {
services
.rooms
.typing
.typing_remove(&typing.user_id, &typing.room_id)
.await
.log_err()
.ok();
}
} else {
debug_warn!(
%typing.user_id, %typing.room_id, %origin,
"received typing EDU for user not in room"
);
return;
}
if typing.typing {
let secs = services.server.config.typing_federation_timeout_s;
let timeout = millis_since_unix_epoch().saturating_add(secs.saturating_mul(1000));
services
.rooms
.typing
.typing_add(&typing.user_id, &typing.room_id, timeout)
.await
.log_err()
.ok();
} else {
services
.rooms
.typing
.typing_remove(&typing.user_id, &typing.room_id)
.await
.log_err()
.ok();
}
}
@@ -398,7 +467,12 @@ async fn handle_edu_direct_to_device(
origin: &ServerName,
content: DirectDeviceContent,
) {
let DirectDeviceContent { sender, ev_type, message_id, messages } = content;
let DirectDeviceContent {
ref sender,
ref ev_type,
ref message_id,
messages,
} = content;
if sender.server_name() != origin {
debug_warn!(
@@ -411,60 +485,88 @@ async fn handle_edu_direct_to_device(
// Check if this is a new transaction id
if services
.transaction_ids
.existing_txnid(&sender, None, &message_id)
.existing_txnid(sender, None, message_id)
.await
.is_ok()
{
return;
}
for (target_user_id, map) in &messages {
for (target_device_id_maybe, event) in map {
let Ok(event) = event.deserialize_as().map_err(|e| {
err!(Request(InvalidParam(error!("To-Device event is invalid: {e}"))))
}) else {
continue;
};
let ev_type = ev_type.to_string();
match target_device_id_maybe {
| DeviceIdOrAllDevices::DeviceId(target_device_id) => {
services
.users
.add_to_device_event(
&sender,
target_user_id,
target_device_id,
&ev_type,
event,
)
.await;
},
| DeviceIdOrAllDevices::AllDevices => {
let (sender, ev_type, event) = (&sender, &ev_type, &event);
services
.users
.all_device_ids(target_user_id)
.for_each(|target_device_id| {
services.users.add_to_device_event(
sender,
target_user_id,
target_device_id,
ev_type,
event.clone(),
)
})
.await;
},
}
}
}
// process messages concurrently for different users
let ev_type = ev_type.to_string();
messages
.into_iter()
.stream()
.for_each_concurrent(automatic_width(), |(target_user_id, map)| {
handle_edu_direct_to_device_user(services, target_user_id, sender, &ev_type, map)
})
.await;
// Save transaction id with empty data
services
.transaction_ids
.add_txnid(&sender, None, &message_id, &[]);
.add_txnid(sender, None, message_id, &[]);
}
async fn handle_edu_direct_to_device_user<Event: Send + Sync>(
services: &Services,
target_user_id: OwnedUserId,
sender: &UserId,
ev_type: &str,
map: BTreeMap<DeviceIdOrAllDevices, Raw<Event>>,
) {
for (target_device_id_maybe, event) in map {
let Ok(event) = event
.deserialize_as()
.map_err(|e| err!(Request(InvalidParam(error!("To-Device event is invalid: {e}")))))
else {
continue;
};
handle_edu_direct_to_device_event(
services,
&target_user_id,
sender,
target_device_id_maybe,
ev_type,
event,
)
.await;
}
}
async fn handle_edu_direct_to_device_event(
services: &Services,
target_user_id: &UserId,
sender: &UserId,
target_device_id_maybe: DeviceIdOrAllDevices,
ev_type: &str,
event: serde_json::Value,
) {
match target_device_id_maybe {
| DeviceIdOrAllDevices::DeviceId(ref target_device_id) => {
services
.users
.add_to_device_event(sender, target_user_id, target_device_id, ev_type, event)
.await;
},
| DeviceIdOrAllDevices::AllDevices => {
services
.users
.all_device_ids(target_user_id)
.for_each(|target_device_id| {
services.users.add_to_device_event(
sender,
target_user_id,
target_device_id,
ev_type,
event.clone(),
)
})
.await;
},
}
}
async fn handle_edu_signing_key_update(
+8 -6
View File
@@ -1,10 +1,10 @@
#![allow(deprecated)]
use std::{borrow::Borrow, collections::HashMap};
use std::borrow::Borrow;
use axum::extract::State;
use conduwuit::{
err,
at, err,
pdu::gen_event_id_canonical_json,
utils::stream::{IterStream, TryBroadbandExt},
warn, Err, Result,
@@ -211,14 +211,16 @@ async fn create_join_event(
drop(mutex_lock);
let state_ids: HashMap<_, OwnedEventId> = services
let state_ids: Vec<OwnedEventId> = services
.rooms
.state_accessor
.state_full_ids(shortstatehash)
.await?;
.map(at!(1))
.collect()
.await;
let state = state_ids
.values()
.iter()
.try_stream()
.broad_and_then(|event_id| services.rooms.timeline.get_pdu_json(event_id))
.broad_and_then(|pdu| {
@@ -231,7 +233,7 @@ async fn create_join_event(
.boxed()
.await?;
let starting_events = state_ids.values().map(Borrow::borrow);
let starting_events = state_ids.iter().map(Borrow::borrow);
let auth_chain = services
.rooms
.auth_chain
+4 -6
View File
@@ -1,7 +1,7 @@
use std::{borrow::Borrow, iter::once};
use axum::extract::State;
use conduwuit::{err, result::LogErr, utils::IterStream, Result};
use conduwuit::{at, err, utils::IterStream, Result};
use futures::{FutureExt, StreamExt, TryStreamExt};
use ruma::{api::federation::event::get_room_state, OwnedEventId};
@@ -35,11 +35,9 @@ pub(crate) async fn get_room_state_route(
.rooms
.state_accessor
.state_full_ids(shortstatehash)
.await
.log_err()
.map_err(|_| err!(Request(NotFound("PDU state IDs not found."))))?
.into_values()
.collect();
.map(at!(1))
.collect()
.await;
let pdus = state_ids
.iter()
+4 -5
View File
@@ -1,7 +1,7 @@
use std::{borrow::Borrow, iter::once};
use axum::extract::State;
use conduwuit::{err, Result};
use conduwuit::{at, err, Result};
use futures::StreamExt;
use ruma::{api::federation::event::get_room_state_ids, OwnedEventId};
@@ -36,10 +36,9 @@ pub(crate) async fn get_room_state_ids_route(
.rooms
.state_accessor
.state_full_ids(shortstatehash)
.await
.map_err(|_| err!(Request(NotFound("State ids not found"))))?
.into_values()
.collect();
.map(at!(1))
.collect()
.await;
let auth_chain_ids = services
.rooms
+19 -12
View File
@@ -8,6 +8,7 @@ use std::{
};
use arrayvec::ArrayVec;
use const_str::concat_bytes;
use tikv_jemalloc_ctl as mallctl;
use tikv_jemalloc_sys as ffi;
use tikv_jemallocator as jemalloc;
@@ -20,18 +21,24 @@ use crate::{
#[cfg(feature = "jemalloc_conf")]
#[unsafe(no_mangle)]
pub static malloc_conf: &[u8] = b"\
metadata_thp:always\
,percpu_arena:percpu\
,background_thread:true\
,max_background_threads:-1\
,lg_extent_max_active_fit:4\
,oversize_threshold:16777216\
,tcache_max:2097152\
,dirty_decay_ms:16000\
,muzzy_decay_ms:144000\
,prof_active:false\
\0";
pub static malloc_conf: &[u8] = concat_bytes!(
"lg_extent_max_active_fit:4",
",oversize_threshold:16777216",
",tcache_max:2097152",
",dirty_decay_ms:16000",
",muzzy_decay_ms:144000",
",percpu_arena:percpu",
",metadata_thp:always",
",background_thread:true",
",max_background_threads:-1",
MALLOC_CONF_PROF,
0
);
#[cfg(all(feature = "jemalloc_conf", feature = "jemalloc_prof"))]
const MALLOC_CONF_PROF: &str = ",prof_active:false";
#[cfg(all(feature = "jemalloc_conf", not(feature = "jemalloc_prof")))]
const MALLOC_CONF_PROF: &str = "";
#[global_allocator]
static JEMALLOC: jemalloc::Jemalloc = jemalloc::Jemalloc;
+17 -1
View File
@@ -6,8 +6,24 @@ use figment::Figment;
use super::DEPRECATED_KEYS;
use crate::{debug, debug_info, debug_warn, error, warn, Config, Err, Result, Server};
/// Performs check() with additional checks specific to reloading old config
/// with new config.
pub fn reload(old: &Config, new: &Config) -> Result {
check(new)?;
if new.server_name != old.server_name {
return Err!(Config(
"server_name",
"You can't change the server's name from {:?}.",
old.server_name
));
}
Ok(())
}
#[allow(clippy::cognitive_complexity)]
pub fn check(config: &Config) -> Result<()> {
pub fn check(config: &Config) -> Result {
if cfg!(debug_assertions) {
warn!("Note: conduwuit was built without optimisations (i.e. debug build)");
}
+25
View File
@@ -1049,6 +1049,15 @@ pub struct Config {
#[serde(default)]
pub rocksdb_paranoid_file_checks: bool,
/// Enables or disables checksum verification in rocksdb at runtime.
/// Checksums are usually hardware accelerated with low overhead; they are
/// enabled in rocksdb by default. Older or slower platforms may see gains
/// from disabling.
///
/// default: true
#[serde(default = "true_fn")]
pub rocksdb_checksums: bool,
/// Database repair mode (for RocksDB SST corruption).
///
/// Use this option when the server reports corruption while running or
@@ -1545,6 +1554,15 @@ pub struct Config {
#[serde(default)]
pub admin_execute_errors_ignore: bool,
/// List of admin commands to execute on SIGUSR2.
///
/// Similar to admin_execute, but these commands are executed when the
/// server receives SIGUSR2 on supporting platforms.
///
/// default: []
#[serde(default)]
pub admin_signal_execute: Vec<String>,
/// Controls the max log level for admin command log captures (logs
/// generated from running admin commands). Defaults to "info" on release
/// builds, else "debug" on debug builds.
@@ -1733,6 +1751,13 @@ pub struct Config {
#[serde(default = "true_fn")]
pub listening: bool,
/// Enables configuration reload when the server receives SIGUSR1 on
/// supporting platforms.
///
/// default: true
#[serde(default = "true_fn")]
pub config_reload_signal: bool,
#[serde(flatten)]
#[allow(clippy::zero_sized_map_values)]
// this is a catchall, the map shouldn't be zero at runtime
+7
View File
@@ -4,6 +4,7 @@ use std::{any::Any, panic};
// Export debug proc_macros
pub use conduwuit_macros::recursion_depth;
use tracing::Level;
// Export all of the ancillary tools from here as well.
pub use crate::{result::DebugInspect, utils::debug::*};
@@ -51,6 +52,12 @@ macro_rules! debug_info {
}
}
pub const INFO_SPAN_LEVEL: Level = if cfg!(debug_assertions) {
Level::INFO
} else {
Level::DEBUG
};
pub fn set_panic_trap() {
let next = panic::take_hook();
panic::set_hook(Box::new(move |info| {
+1
View File
@@ -106,6 +106,7 @@ pub(super) fn io_error_code(kind: std::io::ErrorKind) -> StatusCode {
| ErrorKind::TimedOut => StatusCode::GATEWAY_TIMEOUT,
| ErrorKind::FileTooLarge => StatusCode::PAYLOAD_TOO_LARGE,
| ErrorKind::StorageFull => StatusCode::INSUFFICIENT_STORAGE,
| ErrorKind::Interrupted => StatusCode::SERVICE_UNAVAILABLE,
| _ => StatusCode::INTERNAL_SERVER_ERROR,
}
}
+6 -6
View File
@@ -116,7 +116,7 @@ pub fn to_message_like_event(&self) -> Raw<AnyMessageLikeEvent> {
#[must_use]
#[implement(super::Pdu)]
pub fn to_state_event_value(&self) -> JsonValue {
pub fn into_state_event_value(self) -> JsonValue {
let mut json = json!({
"content": self.content,
"type": self.kind,
@@ -127,7 +127,7 @@ pub fn to_state_event_value(&self) -> JsonValue {
"state_key": self.state_key,
});
if let Some(unsigned) = &self.unsigned {
if let Some(unsigned) = self.unsigned {
json["unsigned"] = json!(unsigned);
}
@@ -136,8 +136,8 @@ pub fn to_state_event_value(&self) -> JsonValue {
#[must_use]
#[implement(super::Pdu)]
pub fn to_state_event(&self) -> Raw<AnyStateEvent> {
serde_json::from_value(self.to_state_event_value()).expect("Raw::from_value always works")
pub fn into_state_event(self) -> Raw<AnyStateEvent> {
serde_json::from_value(self.into_state_event_value()).expect("Raw::from_value always works")
}
#[must_use]
@@ -188,7 +188,7 @@ pub fn to_stripped_spacechild_state_event(&self) -> Raw<HierarchySpaceChildEvent
#[must_use]
#[implement(super::Pdu)]
pub fn to_member_event(&self) -> Raw<StateEvent<RoomMemberEventContent>> {
pub fn into_member_event(self) -> Raw<StateEvent<RoomMemberEventContent>> {
let mut json = json!({
"content": self.content,
"type": self.kind,
@@ -200,7 +200,7 @@ pub fn to_member_event(&self) -> Raw<StateEvent<RoomMemberEventContent>> {
"state_key": self.state_key,
});
if let Some(unsigned) = &self.unsigned {
if let Some(unsigned) = self.unsigned {
json["unsigned"] = json!(unsigned);
}
+12 -3
View File
@@ -6,12 +6,17 @@ use std::{
time::SystemTime,
};
use ruma::OwnedServerName;
use tokio::{runtime, sync::broadcast};
use crate::{config, config::Config, err, log::Log, metrics::Metrics, Err, Result};
use crate::{config, config::Config, log::Log, metrics::Metrics, Err, Result};
/// Server runtime state; public portion
pub struct Server {
/// Configured name of server. This is the same as the one in the config
/// but developers can (and should) reference this string instead.
pub name: OwnedServerName,
/// Server-wide configuration instance
pub config: config::Manager,
@@ -46,6 +51,7 @@ impl Server {
#[must_use]
pub fn new(config: Config, runtime: Option<runtime::Handle>, log: Log) -> Self {
Self {
name: config.server_name.clone(),
config: config::Manager::new(config),
started: SystemTime::now(),
stopping: AtomicBool::new(false),
@@ -106,7 +112,7 @@ impl Server {
}
#[inline]
pub async fn until_shutdown(self: Arc<Self>) {
pub async fn until_shutdown(self: &Arc<Self>) {
while self.running() {
self.signal.subscribe().recv().await.ok();
}
@@ -121,9 +127,12 @@ impl Server {
#[inline]
pub fn check_running(&self) -> Result {
use std::{io, io::ErrorKind::Interrupted};
self.running()
.then_some(())
.ok_or_else(|| err!(debug_warn!("Server is shutting down.")))
.ok_or_else(|| io::Error::new(Interrupted, "Server shutting down"))
.map_err(Into::into)
}
#[inline]
+28
View File
@@ -35,6 +35,13 @@ where
Fut: Future<Output = Option<U>> + Send,
U: Send;
fn broadn_flat_map<F, Fut, U, N>(self, n: N, f: F) -> impl Stream<Item = U> + Send
where
N: Into<Option<usize>>,
F: Fn(Item) -> Fut + Send,
Fut: Stream<Item = U> + Send + Unpin,
U: Send;
fn broadn_then<F, Fut, U, N>(self, n: N, f: F) -> impl Stream<Item = U> + Send
where
N: Into<Option<usize>>,
@@ -70,6 +77,16 @@ where
self.broadn_filter_map(None, f)
}
#[inline]
fn broad_flat_map<F, Fut, U>(self, f: F) -> impl Stream<Item = U> + Send
where
F: Fn(Item) -> Fut + Send,
Fut: Stream<Item = U> + Send + Unpin,
U: Send,
{
self.broadn_flat_map(None, f)
}
#[inline]
fn broad_then<F, Fut, U>(self, f: F) -> impl Stream<Item = U> + Send
where
@@ -122,6 +139,17 @@ where
.ready_filter_map(identity)
}
#[inline]
fn broadn_flat_map<F, Fut, U, N>(self, n: N, f: F) -> impl Stream<Item = U> + Send
where
N: Into<Option<usize>>,
F: Fn(Item) -> Fut + Send,
Fut: Stream<Item = U> + Send + Unpin,
U: Send,
{
self.flat_map_unordered(n.into().unwrap_or_else(automatic_width), f)
}
#[inline]
fn broadn_then<F, Fut, U, N>(self, n: N, f: F) -> impl Stream<Item = U> + Send
where
+1
View File
@@ -32,6 +32,7 @@ use crate::{
pub struct Engine {
pub(super) read_only: bool,
pub(super) secondary: bool,
pub(crate) checksums: bool,
corks: AtomicU32,
pub(crate) db: Db,
pub(crate) pool: Arc<Pool>,
+7
View File
@@ -72,6 +72,13 @@ fn descriptor_cf_options(
opts.set_options_from_string("{{arena_block_size=2097152;}}")
.map_err(map_err)?;
#[cfg(debug_assertions)]
opts.set_options_from_string(
"{{paranoid_checks=true;paranoid_file_checks=true;force_consistency_checks=true;\
verify_sst_unique_id_in_manifest=true;}}",
)
.map_err(map_err)?;
Ok(opts)
}
+6 -6
View File
@@ -83,7 +83,7 @@ pub(crate) static RANDOM: Descriptor = Descriptor {
write_size: 1024 * 1024 * 32,
cache_shards: 128,
compression_level: -3,
bottommost_level: Some(4),
bottommost_level: Some(-1),
compressed_index: true,
..BASE
};
@@ -94,8 +94,8 @@ pub(crate) static SEQUENTIAL: Descriptor = Descriptor {
level_size: 1024 * 1024 * 32,
file_size: 1024 * 1024 * 2,
cache_shards: 128,
compression_level: -1,
bottommost_level: Some(6),
compression_level: -2,
bottommost_level: Some(-1),
compression_shape: [0, 0, 1, 1, 1, 1, 1],
compressed_index: false,
..BASE
@@ -111,7 +111,7 @@ pub(crate) static RANDOM_SMALL: Descriptor = Descriptor {
block_size: 512,
cache_shards: 64,
compression_level: -4,
bottommost_level: Some(1),
bottommost_level: Some(-1),
compression_shape: [0, 0, 0, 0, 0, 1, 1],
compressed_index: false,
..RANDOM
@@ -126,8 +126,8 @@ pub(crate) static SEQUENTIAL_SMALL: Descriptor = Descriptor {
block_size: 512,
cache_shards: 64,
block_index_hashing: Some(false),
compression_level: -2,
bottommost_level: Some(4),
compression_level: -4,
bottommost_level: Some(-2),
compression_shape: [0, 0, 0, 0, 1, 1, 1],
compressed_index: false,
..SEQUENTIAL
+1
View File
@@ -58,6 +58,7 @@ pub(crate) async fn open(ctx: Arc<Context>, desc: &[Descriptor]) -> Result<Arc<S
Ok(Arc::new(Self {
read_only: config.rocksdb_read_only,
secondary: config.rocksdb_secondary,
checksums: config.rocksdb_checksums,
corks: AtomicU32::new(0),
pool: ctx.pool.clone(),
db,
+6 -3
View File
@@ -9,6 +9,8 @@ mod keys_from;
mod keys_prefix;
mod open;
mod options;
mod qry;
mod qry_batch;
mod remove;
mod rev_keys;
mod rev_keys_from;
@@ -37,6 +39,7 @@ pub(crate) use self::options::{
cache_iter_options_default, cache_read_options_default, iter_options_default,
read_options_default, write_options_default,
};
pub use self::{get_batch::Get, qry_batch::Qry};
use crate::{watchers::Watchers, Engine};
pub struct Map {
@@ -56,9 +59,9 @@ impl Map {
db: db.clone(),
cf: open::open(db, name),
watchers: Watchers::default(),
write_options: write_options_default(),
read_options: read_options_default(),
cache_read_options: cache_read_options_default(),
write_options: write_options_default(db),
read_options: read_options_default(db),
cache_read_options: cache_read_options_default(db),
}))
}
+1 -51
View File
@@ -1,65 +1,15 @@
use std::{convert::AsRef, fmt::Debug, io::Write, sync::Arc};
use std::{convert::AsRef, fmt::Debug, sync::Arc};
use arrayvec::ArrayVec;
use conduwuit::{err, implement, utils::result::MapExpect, Err, Result};
use futures::{future::ready, Future, FutureExt, TryFutureExt};
use rocksdb::{DBPinnableSlice, ReadOptions};
use serde::Serialize;
use tokio::task;
use crate::{
keyval::KeyBuf,
ser,
util::{is_incomplete, map_err, or_else},
Handle,
};
/// Fetch a value from the database into cache, returning a reference-handle
/// asynchronously. The key is serialized into an allocated buffer to perform
/// the query.
#[implement(super::Map)]
#[inline]
pub fn qry<K>(self: &Arc<Self>, key: &K) -> impl Future<Output = Result<Handle<'_>>> + Send
where
K: Serialize + ?Sized + Debug,
{
let mut buf = KeyBuf::new();
self.bqry(key, &mut buf)
}
/// Fetch a value from the database into cache, returning a reference-handle
/// asynchronously. The key is serialized into a fixed-sized buffer to perform
/// the query. The maximum size is supplied as const generic parameter.
#[implement(super::Map)]
#[inline]
pub fn aqry<const MAX: usize, K>(
self: &Arc<Self>,
key: &K,
) -> impl Future<Output = Result<Handle<'_>>> + Send
where
K: Serialize + ?Sized + Debug,
{
let mut buf = ArrayVec::<u8, MAX>::new();
self.bqry(key, &mut buf)
}
/// Fetch a value from the database into cache, returning a reference-handle
/// asynchronously. The key is serialized into a user-supplied Writer.
#[implement(super::Map)]
#[tracing::instrument(skip(self, buf), level = "trace")]
pub fn bqry<K, B>(
self: &Arc<Self>,
key: &K,
buf: &mut B,
) -> impl Future<Output = Result<Handle<'_>>> + Send
where
K: Serialize + ?Sized + Debug,
B: Write + AsRef<[u8]>,
{
let key = ser::serialize(buf, key).expect("failed to serialize query key");
self.get(key)
}
/// Fetch a value from the database into cache, returning a reference-handle
/// asynchronously. The key is referenced directly to perform the query.
#[implement(super::Map)]
+18 -27
View File
@@ -1,4 +1,4 @@
use std::{convert::AsRef, fmt::Debug, sync::Arc};
use std::{convert::AsRef, sync::Arc};
use conduwuit::{
implement,
@@ -10,43 +10,34 @@ use conduwuit::{
};
use futures::{Stream, StreamExt, TryStreamExt};
use rocksdb::{DBPinnableSlice, ReadOptions};
use serde::Serialize;
use super::get::{cached_handle_from, handle_from};
use crate::{keyval::KeyBuf, ser, Handle};
use crate::Handle;
#[implement(super::Map)]
#[tracing::instrument(skip(self, keys), level = "trace")]
pub fn qry_batch<'a, S, K>(
self: &'a Arc<Self>,
keys: S,
) -> impl Stream<Item = Result<Handle<'_>>> + Send + 'a
pub trait Get<'a, K, S>
where
Self: Sized,
S: Stream<Item = K> + Send + 'a,
K: Serialize + Debug + 'a,
K: AsRef<[u8]> + Send + Sync + 'a,
{
use crate::pool::Get;
fn get(self, map: &'a Arc<super::Map>) -> impl Stream<Item = Result<Handle<'_>>> + Send + 'a;
}
keys.ready_chunks(automatic_amplification())
.widen_then(automatic_width(), |chunk| {
let keys = chunk
.iter()
.map(ser::serialize_to::<KeyBuf, _>)
.map(|result| result.expect("failed to serialize query key"))
.map(Into::into)
.collect();
self.db
.pool
.execute_get(Get { map: self.clone(), key: keys, res: None })
})
.map_ok(|results| results.into_iter().stream())
.try_flatten()
impl<'a, K, S> Get<'a, K, S> for S
where
Self: Sized,
S: Stream<Item = K> + Send + 'a,
K: AsRef<[u8]> + Send + Sync + 'a,
{
#[inline]
fn get(self, map: &'a Arc<super::Map>) -> impl Stream<Item = Result<Handle<'_>>> + Send + 'a {
map.get_batch(self)
}
}
#[implement(super::Map)]
#[tracing::instrument(skip(self, keys), level = "trace")]
pub fn get_batch<'a, S, K>(
pub(crate) fn get_batch<'a, S, K>(
self: &'a Arc<Self>,
keys: S,
) -> impl Stream<Item = Result<Handle<'_>>> + Send + 'a
+1 -1
View File
@@ -22,7 +22,7 @@ where
pub fn raw_keys(self: &Arc<Self>) -> impl Stream<Item = Result<Key<'_>>> + Send {
use crate::pool::Seek;
let opts = super::iter_options_default();
let opts = super::iter_options_default(&self.db);
let state = stream::State::new(self, opts);
if is_cached(self) {
let state = state.init_fwd(None);
+1 -1
View File
@@ -53,7 +53,7 @@ where
{
use crate::pool::Seek;
let opts = super::iter_options_default();
let opts = super::iter_options_default(&self.db);
let state = stream::State::new(self, opts);
if is_cached(self, from) {
return stream::Keys::<'_>::from(state.init_fwd(from.as_ref().into())).boxed();
+29 -21
View File
@@ -1,35 +1,43 @@
use std::sync::Arc;
use rocksdb::{ReadOptions, ReadTier, WriteOptions};
#[inline]
pub(crate) fn iter_options_default() -> ReadOptions {
let mut options = read_options_default();
options.set_background_purge_on_iterator_cleanup(true);
//options.set_pin_data(true);
options
}
use crate::Engine;
#[inline]
pub(crate) fn cache_iter_options_default() -> ReadOptions {
let mut options = cache_read_options_default();
options.set_background_purge_on_iterator_cleanup(true);
//options.set_pin_data(true);
options
}
#[inline]
pub(crate) fn cache_read_options_default() -> ReadOptions {
let mut options = read_options_default();
pub(crate) fn cache_iter_options_default(db: &Arc<Engine>) -> ReadOptions {
let mut options = iter_options_default(db);
options.set_read_tier(ReadTier::BlockCache);
options.fill_cache(false);
options
}
#[inline]
pub(crate) fn read_options_default() -> ReadOptions {
let mut options = ReadOptions::default();
options.set_total_order_seek(true);
pub(crate) fn iter_options_default(db: &Arc<Engine>) -> ReadOptions {
let mut options = read_options_default(db);
options.set_background_purge_on_iterator_cleanup(true);
options
}
#[inline]
pub(crate) fn write_options_default() -> WriteOptions { WriteOptions::default() }
pub(crate) fn cache_read_options_default(db: &Arc<Engine>) -> ReadOptions {
let mut options = read_options_default(db);
options.set_read_tier(ReadTier::BlockCache);
options.fill_cache(false);
options
}
#[inline]
pub(crate) fn read_options_default(db: &Arc<Engine>) -> ReadOptions {
let mut options = ReadOptions::default();
options.set_total_order_seek(true);
if !db.checksums {
options.set_verify_checksums(false);
}
options
}
#[inline]
pub(crate) fn write_options_default(_db: &Arc<Engine>) -> WriteOptions { WriteOptions::default() }
+54
View File
@@ -0,0 +1,54 @@
use std::{convert::AsRef, fmt::Debug, io::Write, sync::Arc};
use arrayvec::ArrayVec;
use conduwuit::{implement, Result};
use futures::Future;
use serde::Serialize;
use crate::{keyval::KeyBuf, ser, Handle};
/// Fetch a value from the database into cache, returning a reference-handle
/// asynchronously. The key is serialized into an allocated buffer to perform
/// the query.
#[implement(super::Map)]
#[inline]
pub fn qry<K>(self: &Arc<Self>, key: &K) -> impl Future<Output = Result<Handle<'_>>> + Send
where
K: Serialize + ?Sized + Debug,
{
let mut buf = KeyBuf::new();
self.bqry(key, &mut buf)
}
/// Fetch a value from the database into cache, returning a reference-handle
/// asynchronously. The key is serialized into a fixed-sized buffer to perform
/// the query. The maximum size is supplied as const generic parameter.
#[implement(super::Map)]
#[inline]
pub fn aqry<const MAX: usize, K>(
self: &Arc<Self>,
key: &K,
) -> impl Future<Output = Result<Handle<'_>>> + Send
where
K: Serialize + ?Sized + Debug,
{
let mut buf = ArrayVec::<u8, MAX>::new();
self.bqry(key, &mut buf)
}
/// Fetch a value from the database into cache, returning a reference-handle
/// asynchronously. The key is serialized into a user-supplied Writer.
#[implement(super::Map)]
#[tracing::instrument(skip(self, buf), level = "trace")]
pub fn bqry<K, B>(
self: &Arc<Self>,
key: &K,
buf: &mut B,
) -> impl Future<Output = Result<Handle<'_>>> + Send
where
K: Serialize + ?Sized + Debug,
B: Write + AsRef<[u8]>,
{
let key = ser::serialize(buf, key).expect("failed to serialize query key");
self.get(key)
}
+63
View File
@@ -0,0 +1,63 @@
use std::{fmt::Debug, sync::Arc};
use conduwuit::{
implement,
utils::{
stream::{automatic_amplification, automatic_width, WidebandExt},
IterStream,
},
Result,
};
use futures::{Stream, StreamExt, TryStreamExt};
use serde::Serialize;
use crate::{keyval::KeyBuf, ser, Handle};
pub trait Qry<'a, K, S>
where
S: Stream<Item = K> + Send + 'a,
K: Serialize + Debug,
{
fn qry(self, map: &'a Arc<super::Map>) -> impl Stream<Item = Result<Handle<'_>>> + Send + 'a;
}
impl<'a, K, S> Qry<'a, K, S> for S
where
Self: 'a,
S: Stream<Item = K> + Send + 'a,
K: Serialize + Debug + 'a,
{
#[inline]
fn qry(self, map: &'a Arc<super::Map>) -> impl Stream<Item = Result<Handle<'_>>> + Send + 'a {
map.qry_batch(self)
}
}
#[implement(super::Map)]
#[tracing::instrument(skip(self, keys), level = "trace")]
pub(crate) fn qry_batch<'a, S, K>(
self: &'a Arc<Self>,
keys: S,
) -> impl Stream<Item = Result<Handle<'_>>> + Send + 'a
where
S: Stream<Item = K> + Send + 'a,
K: Serialize + Debug + 'a,
{
use crate::pool::Get;
keys.ready_chunks(automatic_amplification())
.widen_then(automatic_width(), |chunk| {
let keys = chunk
.iter()
.map(ser::serialize_to::<KeyBuf, _>)
.map(|result| result.expect("failed to serialize query key"))
.map(Into::into)
.collect();
self.db
.pool
.execute_get(Get { map: self.clone(), key: keys, res: None })
})
.map_ok(|results| results.into_iter().stream())
.try_flatten()
}
+1 -1
View File
@@ -22,7 +22,7 @@ where
pub fn rev_raw_keys(self: &Arc<Self>) -> impl Stream<Item = Result<Key<'_>>> + Send {
use crate::pool::Seek;
let opts = super::iter_options_default();
let opts = super::iter_options_default(&self.db);
let state = stream::State::new(self, opts);
if is_cached(self) {
let state = state.init_rev(None);
+1 -1
View File
@@ -61,7 +61,7 @@ where
{
use crate::pool::Seek;
let opts = super::iter_options_default();
let opts = super::iter_options_default(&self.db);
let state = stream::State::new(self, opts);
if is_cached(self, from) {
return stream::KeysRev::<'_>::from(state.init_rev(from.as_ref().into())).boxed();
+2 -2
View File
@@ -31,7 +31,7 @@ where
pub fn rev_raw_stream(self: &Arc<Self>) -> impl Stream<Item = Result<KeyVal<'_>>> + Send {
use crate::pool::Seek;
let opts = super::iter_options_default();
let opts = super::iter_options_default(&self.db);
let state = stream::State::new(self, opts);
if is_cached(self) {
let state = state.init_rev(None);
@@ -66,7 +66,7 @@ pub fn rev_raw_stream(self: &Arc<Self>) -> impl Stream<Item = Result<KeyVal<'_>>
fields(%map),
)]
pub(super) fn is_cached(map: &Arc<super::Map>) -> bool {
let opts = super::cache_iter_options_default();
let opts = super::cache_iter_options_default(&map.db);
let state = stream::State::new(map, opts).init_rev(None);
!state.is_incomplete()
+2 -2
View File
@@ -80,7 +80,7 @@ where
{
use crate::pool::Seek;
let opts = super::iter_options_default();
let opts = super::iter_options_default(&self.db);
let state = stream::State::new(self, opts);
if is_cached(self, from) {
let state = state.init_rev(from.as_ref().into());
@@ -118,7 +118,7 @@ pub(super) fn is_cached<P>(map: &Arc<super::Map>, from: &P) -> bool
where
P: AsRef<[u8]> + ?Sized,
{
let cache_opts = super::cache_iter_options_default();
let cache_opts = super::cache_iter_options_default(&map.db);
let cache_status = stream::State::new(map, cache_opts)
.init_rev(from.as_ref().into())
.status();
+2 -2
View File
@@ -30,7 +30,7 @@ where
pub fn raw_stream(self: &Arc<Self>) -> impl Stream<Item = Result<KeyVal<'_>>> + Send {
use crate::pool::Seek;
let opts = super::iter_options_default();
let opts = super::iter_options_default(&self.db);
let state = stream::State::new(self, opts);
if is_cached(self) {
let state = state.init_fwd(None);
@@ -65,7 +65,7 @@ pub fn raw_stream(self: &Arc<Self>) -> impl Stream<Item = Result<KeyVal<'_>>> +
fields(%map),
)]
pub(super) fn is_cached(map: &Arc<super::Map>) -> bool {
let opts = super::cache_iter_options_default();
let opts = super::cache_iter_options_default(&map.db);
let state = stream::State::new(map, opts).init_fwd(None);
!state.is_incomplete()
+2 -2
View File
@@ -77,7 +77,7 @@ where
{
use crate::pool::Seek;
let opts = super::iter_options_default();
let opts = super::iter_options_default(&self.db);
let state = stream::State::new(self, opts);
if is_cached(self, from) {
let state = state.init_fwd(from.as_ref().into());
@@ -115,7 +115,7 @@ pub(super) fn is_cached<P>(map: &Arc<super::Map>, from: &P) -> bool
where
P: AsRef<[u8]> + ?Sized,
{
let opts = super::cache_iter_options_default();
let opts = super::cache_iter_options_default(&map.db);
let state = stream::State::new(map, opts).init_fwd(from.as_ref().into());
!state.is_incomplete()
+1 -1
View File
@@ -30,7 +30,7 @@ pub use self::{
deserialized::Deserialized,
handle::Handle,
keyval::{serialize_key, serialize_val, KeyVal, Slice},
map::{compact, Map},
map::{compact, Get, Map, Qry},
ser::{serialize, serialize_to, serialize_to_vec, Cbor, Interfix, Json, Separator, SEP},
};
pub(crate) use self::{
+3 -3
View File
@@ -46,14 +46,14 @@ impl Server {
.and_then(|raw| crate::clap::update(raw, args))
.and_then(|raw| Config::new(&raw))?;
#[cfg(feature = "sentry_telemetry")]
let sentry_guard = crate::sentry::init(&config);
let (tracing_reload_handle, tracing_flame_guard, capture) =
crate::logging::init(&config)?;
config.check()?;
#[cfg(feature = "sentry_telemetry")]
let sentry_guard = crate::sentry::init(&config);
#[cfg(unix)]
sys::maximize_fd_limit()
.expect("Unable to increase maximum soft and hard file descriptor limit");
+4
View File
@@ -16,6 +16,8 @@ pub(super) async fn signal(server: Arc<Server>) {
let mut quit = unix::signal(SignalKind::quit()).expect("SIGQUIT handler");
let mut term = unix::signal(SignalKind::terminate()).expect("SIGTERM handler");
let mut usr1 = unix::signal(SignalKind::user_defined1()).expect("SIGUSR1 handler");
let mut usr2 = unix::signal(SignalKind::user_defined2()).expect("SIGUSR2 handler");
loop {
trace!("Installed signal handlers");
let sig: &'static str;
@@ -23,6 +25,8 @@ pub(super) async fn signal(server: Arc<Server>) {
_ = signal::ctrl_c() => { sig = "SIGINT"; },
_ = quit.recv() => { sig = "SIGQUIT"; },
_ = term.recv() => { sig = "SIGTERM"; },
_ = usr1.recv() => { sig = "SIGUSR1"; },
_ = usr2.recv() => { sig = "SIGUSR2"; },
}
warn!("Received {sig}");
+7 -20
View File
@@ -9,6 +9,7 @@ use std::{
use axum_server::Handle as ServerHandle;
use conduwuit::{debug, debug_error, debug_info, error, info, Error, Result, Server};
use futures::FutureExt;
use service::Services;
use tokio::{
sync::broadcast::{self, Sender},
@@ -109,28 +110,14 @@ pub(crate) async fn stop(services: Arc<Services>) -> Result<()> {
#[tracing::instrument(skip_all)]
async fn signal(server: Arc<Server>, tx: Sender<()>, handle: axum_server::Handle) {
loop {
let sig: &'static str = server
.signal
.subscribe()
.recv()
.await
.expect("channel error");
if !server.running() {
handle_shutdown(&server, &tx, &handle, sig).await;
break;
}
}
server
.clone()
.until_shutdown()
.then(move |()| handle_shutdown(server, tx, handle))
.await;
}
async fn handle_shutdown(
server: &Arc<Server>,
tx: &Sender<()>,
handle: &axum_server::Handle,
sig: &str,
) {
debug!("Received signal {sig}");
async fn handle_shutdown(server: Arc<Server>, tx: Sender<()>, handle: axum_server::Handle) {
if let Err(e) = tx.send(()) {
error!("failed sending shutdown transaction to channel: {e}");
}
+1
View File
@@ -74,6 +74,7 @@ serde_json.workspace = true
serde.workspace = true
serde_yaml.workspace = true
sha2.workspace = true
smallvec.workspace = true
termimad.workspace = true
termimad.optional = true
tokio.workspace = true
@@ -2,6 +2,8 @@ use conduwuit::{debug, debug_info, error, implement, info, Err, Result};
use ruma::events::room::message::RoomMessageEventContent;
use tokio::time::{sleep, Duration};
pub(super) const SIGNAL: &str = "SIGUSR2";
/// Possibly spawn the terminal console at startup if configured.
#[implement(super::Service)]
pub(super) async fn console_auto_start(&self) {
@@ -22,7 +24,7 @@ pub(super) async fn console_auto_stop(&self) {
/// Execute admin commands after startup
#[implement(super::Service)]
pub(super) async fn startup_execute(&self) -> Result<()> {
pub(super) async fn startup_execute(&self) -> Result {
// List of comamnds to execute
let commands = &self.services.server.config.admin_execute;
@@ -36,7 +38,7 @@ pub(super) async fn startup_execute(&self) -> Result<()> {
sleep(Duration::from_millis(500)).await;
for (i, command) in commands.iter().enumerate() {
if let Err(e) = self.startup_execute_command(i, command.clone()).await {
if let Err(e) = self.execute_command(i, command.clone()).await {
if !errors {
return Err(e);
}
@@ -59,16 +61,38 @@ pub(super) async fn startup_execute(&self) -> Result<()> {
Ok(())
}
/// Execute one admin command after startup
/// Execute admin commands after signal
#[implement(super::Service)]
async fn startup_execute_command(&self, i: usize, command: String) -> Result<()> {
debug!("Startup command #{i}: executing {command:?}");
pub(super) async fn signal_execute(&self) -> Result {
// List of comamnds to execute
let commands = self.services.server.config.admin_signal_execute.clone();
// When true, errors are ignored and execution continues.
let ignore_errors = self.services.server.config.admin_execute_errors_ignore;
for (i, command) in commands.iter().enumerate() {
if let Err(e) = self.execute_command(i, command.clone()).await {
if !ignore_errors {
return Err(e);
}
}
tokio::task::yield_now().await;
}
Ok(())
}
/// Execute one admin command after startup or signal
#[implement(super::Service)]
async fn execute_command(&self, i: usize, command: String) -> Result {
debug!("Execute command #{i}: executing {command:?}");
match self.command_in_place(command, None).await {
| Ok(Some(output)) => Self::startup_command_output(i, &output),
| Err(output) => Self::startup_command_error(i, &output),
| Ok(Some(output)) => Self::execute_command_output(i, &output),
| Err(output) => Self::execute_command_error(i, &output),
| Ok(None) => {
info!("Startup command #{i} completed (no output).");
info!("Execute command #{i} completed (no output).");
Ok(())
},
}
@@ -76,28 +100,28 @@ async fn startup_execute_command(&self, i: usize, command: String) -> Result<()>
#[cfg(feature = "console")]
#[implement(super::Service)]
fn startup_command_output(i: usize, content: &RoomMessageEventContent) -> Result<()> {
debug_info!("Startup command #{i} completed:");
fn execute_command_output(i: usize, content: &RoomMessageEventContent) -> Result {
debug_info!("Execute command #{i} completed:");
super::console::print(content.body());
Ok(())
}
#[cfg(feature = "console")]
#[implement(super::Service)]
fn startup_command_error(i: usize, content: &RoomMessageEventContent) -> Result<()> {
fn execute_command_error(i: usize, content: &RoomMessageEventContent) -> Result {
super::console::print_err(content.body());
Err!(debug_error!("Startup command #{i} failed."))
Err!(debug_error!("Execute command #{i} failed."))
}
#[cfg(not(feature = "console"))]
#[implement(super::Service)]
fn startup_command_output(i: usize, content: &RoomMessageEventContent) -> Result<()> {
info!("Startup command #{i} completed:\n{:#}", content.body());
fn execute_command_output(i: usize, content: &RoomMessageEventContent) -> Result {
info!("Execute command #{i} completed:\n{:#}", content.body());
Ok(())
}
#[cfg(not(feature = "console"))]
#[implement(super::Service)]
fn startup_command_error(i: usize, content: &RoomMessageEventContent) -> Result<()> {
Err!(error!("Startup command #{i} failed:\n{:#}", content.body()))
fn execute_command_error(i: usize, content: &RoomMessageEventContent) -> Result {
Err!(error!("Execute command #{i} failed:\n{:#}", content.body()))
}
+6 -2
View File
@@ -1,7 +1,7 @@
pub mod console;
mod create;
mod execute;
mod grant;
mod startup;
use std::{
future::Future,
@@ -183,7 +183,11 @@ impl Service {
.map(|complete| complete(command))
}
async fn handle_signal(&self, #[allow(unused_variables)] sig: &'static str) {
async fn handle_signal(&self, sig: &'static str) {
if sig == execute::SIGNAL {
self.signal_execute().await.ok();
}
#[cfg(feature = "console")]
self.console.handle_signal(sig).await;
}
+62
View File
@@ -0,0 +1,62 @@
use std::{iter, ops::Deref, path::Path, sync::Arc};
use async_trait::async_trait;
use conduwuit::{
config::{check, Config},
error, implement, Result, Server,
};
pub struct Service {
server: Arc<Server>,
}
const SIGNAL: &str = "SIGUSR1";
#[async_trait]
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self { server: args.server.clone() }))
}
async fn worker(self: Arc<Self>) -> Result {
while self.server.running() {
if self.server.signal.subscribe().recv().await == Ok(SIGNAL) {
if let Err(e) = self.handle_reload() {
error!("Failed to reload config: {e}");
}
}
}
Ok(())
}
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
}
impl Deref for Service {
type Target = Arc<Config>;
#[inline]
fn deref(&self) -> &Self::Target { &self.server.config }
}
#[implement(Service)]
fn handle_reload(&self) -> Result {
if self.server.config.config_reload_signal {
self.reload(iter::empty())?;
}
Ok(())
}
#[implement(Service)]
pub fn reload<'a, I>(&self, paths: I) -> Result<Arc<Config>>
where
I: Iterator<Item = &'a Path>,
{
let old = self.server.config.clone();
let new = Config::load(paths).and_then(|raw| Config::new(&raw))?;
check::reload(&old, &new)?;
self.server.config.update(new)
}
@@ -1,9 +1,9 @@
use std::mem;
use std::{fmt::Debug, mem};
use bytes::Bytes;
use conduwuit::{
debug, debug_error, debug_warn, err, error::inspect_debug_log, implement, trace,
utils::string::EMPTY, Err, Error, Result,
debug, debug::INFO_SPAN_LEVEL, debug_error, debug_warn, err, error::inspect_debug_log,
implement, trace, utils::string::EMPTY, Err, Error, Result,
};
use http::{header::AUTHORIZATION, HeaderValue};
use ipaddress::IPAddress;
@@ -20,82 +20,110 @@ use ruma::{
use crate::resolver::actual::ActualDest;
impl super::Service {
#[tracing::instrument(
level = "debug"
/// Sends a request to a federation server
#[implement(super::Service)]
#[tracing::instrument(skip_all, name = "request", level = "debug")]
pub async fn execute<T>(&self, dest: &ServerName, request: T) -> Result<T::IncomingResponse>
where
T: OutgoingRequest + Debug + Send,
{
let client = &self.services.client.federation;
self.execute_on(client, dest, request).await
}
/// Like execute() but with a very large timeout
#[implement(super::Service)]
#[tracing::instrument(skip_all, name = "synapse", level = "debug")]
pub async fn execute_synapse<T>(
&self,
dest: &ServerName,
request: T,
) -> Result<T::IncomingResponse>
where
T: OutgoingRequest + Debug + Send,
{
let client = &self.services.client.synapse;
self.execute_on(client, dest, request).await
}
#[implement(super::Service)]
#[tracing::instrument(
name = "fed",
level = INFO_SPAN_LEVEL,
skip(self, client, request),
)]
pub async fn send<T>(
&self,
client: &Client,
dest: &ServerName,
request: T,
) -> Result<T::IncomingResponse>
where
T: OutgoingRequest + Send,
pub async fn execute_on<T>(
&self,
client: &Client,
dest: &ServerName,
request: T,
) -> Result<T::IncomingResponse>
where
T: OutgoingRequest + Send,
{
if !self.services.server.config.allow_federation {
return Err!(Config("allow_federation", "Federation is disabled."));
}
if self
.services
.server
.config
.forbidden_remote_server_names
.contains(dest)
{
if !self.server.config.allow_federation {
return Err!(Config("allow_federation", "Federation is disabled."));
}
if self
.server
.config
.forbidden_remote_server_names
.contains(dest)
{
return Err!(Request(Forbidden(debug_warn!(
"Federation with {dest} is not allowed."
))));
}
let actual = self.services.resolver.get_actual_dest(dest).await?;
let request = into_http_request::<T>(&actual, request)?;
let request = self.prepare(dest, request)?;
self.execute::<T>(dest, &actual, request, client).await
return Err!(Request(Forbidden(debug_warn!("Federation with {dest} is not allowed."))));
}
async fn execute<T>(
&self,
dest: &ServerName,
actual: &ActualDest,
request: Request,
client: &Client,
) -> Result<T::IncomingResponse>
where
T: OutgoingRequest + Send,
{
let url = request.url().clone();
let method = request.method().clone();
let actual = self.services.resolver.get_actual_dest(dest).await?;
let request = into_http_request::<T>(&actual, request)?;
let request = self.prepare(dest, request)?;
self.perform::<T>(dest, &actual, request, client).await
}
debug!(?method, ?url, "Sending request");
match client.execute(request).await {
| Ok(response) => handle_response::<T>(dest, actual, &method, &url, response).await,
| Err(error) =>
Err(handle_error(actual, &method, &url, error).expect_err("always returns error")),
#[implement(super::Service)]
async fn perform<T>(
&self,
dest: &ServerName,
actual: &ActualDest,
request: Request,
client: &Client,
) -> Result<T::IncomingResponse>
where
T: OutgoingRequest + Send,
{
let url = request.url().clone();
let method = request.method().clone();
debug!(?method, ?url, "Sending request");
match client.execute(request).await {
| Ok(response) => handle_response::<T>(dest, actual, &method, &url, response).await,
| Err(error) =>
Err(handle_error(actual, &method, &url, error).expect_err("always returns error")),
}
}
#[implement(super::Service)]
fn prepare(&self, dest: &ServerName, mut request: http::Request<Vec<u8>>) -> Result<Request> {
self.sign_request(&mut request, dest);
let request = Request::try_from(request)?;
self.validate_url(request.url())?;
self.services.server.check_running()?;
Ok(request)
}
#[implement(super::Service)]
fn validate_url(&self, url: &Url) -> Result<()> {
if let Some(url_host) = url.host_str() {
if let Ok(ip) = IPAddress::parse(url_host) {
trace!("Checking request URL IP {ip:?}");
self.services.resolver.validate_ip(&ip)?;
}
}
fn prepare(&self, dest: &ServerName, mut request: http::Request<Vec<u8>>) -> Result<Request> {
self.sign_request(&mut request, dest);
let request = Request::try_from(request)?;
self.validate_url(request.url())?;
self.server.check_running()?;
Ok(request)
}
fn validate_url(&self, url: &Url) -> Result<()> {
if let Some(url_host) = url.host_str() {
if let Ok(ip) = IPAddress::parse(url_host) {
trace!("Checking request URL IP {ip:?}");
self.services.resolver.validate_ip(&ip)?;
}
}
Ok(())
}
Ok(())
}
async fn handle_response<T>(
@@ -195,7 +223,7 @@ fn sign_request(&self, http_request: &mut http::Request<Vec<u8>>, dest: &ServerN
type Value = CanonicalJsonValue;
type Object = CanonicalJsonObject;
let origin = self.services.globals.server_name();
let origin = &self.services.server.name;
let body = http_request.body();
let uri = http_request
.uri()
+33
View File
@@ -0,0 +1,33 @@
mod execute;
use std::sync::Arc;
use conduwuit::{Result, Server};
use crate::{client, resolver, server_keys, Dep};
pub struct Service {
services: Services,
}
struct Services {
server: Arc<Server>,
client: Dep<client::Service>,
resolver: Dep<resolver::Service>,
server_keys: Dep<server_keys::Service>,
}
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
services: Services {
server: args.server.clone(),
client: args.depend::<client::Service>("client"),
resolver: args.depend::<resolver::Service>("resolver"),
server_keys: args.depend::<server_keys::Service>("server_keys"),
},
}))
}
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
}
+4 -4
View File
@@ -61,11 +61,11 @@ impl crate::Service for Service {
db,
server: args.server.clone(),
bad_event_ratelimiter: Arc::new(RwLock::new(HashMap::new())),
admin_alias: OwnedRoomAliasId::try_from(format!("#admins:{}", &config.server_name))
admin_alias: OwnedRoomAliasId::try_from(format!("#admins:{}", &args.server.name))
.expect("#admins:server_name is valid alias name"),
server_user: UserId::parse_with_server_name(
String::from("conduit"),
&config.server_name,
&args.server.name,
)
.expect("@conduit:server_name is valid"),
turn_secret,
@@ -107,7 +107,7 @@ impl Service {
pub fn current_count(&self) -> Result<u64> { Ok(self.db.current_count()) }
#[inline]
pub fn server_name(&self) -> &ServerName { self.server.config.server_name.as_ref() }
pub fn server_name(&self) -> &ServerName { self.server.name.as_ref() }
pub fn allow_registration(&self) -> bool { self.server.config.allow_registration }
@@ -207,7 +207,7 @@ impl Service {
#[inline]
pub fn server_is_ours(&self, server_name: &ServerName) -> bool {
server_name == self.server.config.server_name
server_name == self.server_name()
}
#[inline]
+4 -6
View File
@@ -218,8 +218,6 @@ async fn migrate(services: &Services) -> Result<()> {
}
async fn db_lt_12(services: &Services) -> Result<()> {
let config = &services.server.config;
for username in &services
.users
.list_local_users()
@@ -227,7 +225,8 @@ async fn db_lt_12(services: &Services) -> Result<()> {
.collect::<Vec<_>>()
.await
{
let user = match UserId::parse_with_server_name(username.as_str(), &config.server_name) {
let user = match UserId::parse_with_server_name(username.as_str(), &services.server.name)
{
| Ok(u) => u,
| Err(e) => {
warn!("Invalid username {username}: {e}");
@@ -297,8 +296,6 @@ async fn db_lt_12(services: &Services) -> Result<()> {
}
async fn db_lt_13(services: &Services) -> Result<()> {
let config = &services.server.config;
for username in &services
.users
.list_local_users()
@@ -306,7 +303,8 @@ async fn db_lt_13(services: &Services) -> Result<()> {
.collect::<Vec<_>>()
.await
{
let user = match UserId::parse_with_server_name(username.as_str(), &config.server_name) {
let user = match UserId::parse_with_server_name(username.as_str(), &services.server.name)
{
| Ok(u) => u,
| Err(e) => {
warn!("Invalid username {username}: {e}");
+2
View File
@@ -9,7 +9,9 @@ pub mod account_data;
pub mod admin;
pub mod appservice;
pub mod client;
pub mod config;
pub mod emergency;
pub mod federation;
pub mod globals;
pub mod key_backups;
pub mod media;
+1 -2
View File
@@ -401,8 +401,7 @@ impl super::Service {
}
fn validate_dest(&self, dest: &ServerName) -> Result<()> {
let config = &self.services.server.config;
if dest == config.server_name && !config.federation_loopback {
if dest == self.services.server.name && !self.services.server.config.federation_loopback {
return Err!("Won't send federation request to ourselves");
}
+1 -1
View File
@@ -150,7 +150,7 @@ impl Service {
let servers_contains_ours = || {
servers
.as_ref()
.is_some_and(|servers| servers.contains(&self.services.server.config.server_name))
.is_some_and(|servers| servers.contains(&self.services.server.name))
};
if !server_is_ours && !servers_contains_ours() {
@@ -10,10 +10,11 @@ use conduwuit::{
};
use futures::TryFutureExt;
use ruma::{
api::federation::event::get_event, CanonicalJsonValue, OwnedEventId, RoomId, RoomVersionId,
ServerName,
api::federation::event::get_event, CanonicalJsonValue, OwnedEventId, RoomId, ServerName,
};
use super::get_room_version_id;
/// Find the event and auth it. Once the event is validated (steps 1 - 8)
/// it is appended to the outliers Tree.
///
@@ -30,7 +31,6 @@ pub(super) async fn fetch_and_handle_outliers<'a>(
events: &'a [OwnedEventId],
create_event: &'a PduEvent,
room_id: &'a RoomId,
room_version_id: &'a RoomVersionId,
) -> Vec<(Arc<PduEvent>, Option<BTreeMap<String, CanonicalJsonValue>>)> {
let back_off = |id| match self
.services
@@ -113,8 +113,13 @@ pub(super) async fn fetch_and_handle_outliers<'a>(
{
| Ok(res) => {
debug!("Got {next_id} over federation");
let Ok(room_version_id) = get_room_version_id(create_event) else {
back_off((*next_id).to_owned());
continue;
};
let Ok((calculated_event_id, value)) =
pdu::gen_event_id_canonical_json(&res.pdu, room_version_id)
pdu::gen_event_id_canonical_json(&res.pdu, &room_version_id)
else {
back_off((*next_id).to_owned());
continue;
+4 -13
View File
@@ -8,8 +8,7 @@ use futures::{future, FutureExt};
use ruma::{
int,
state_res::{self},
uint, CanonicalJsonValue, MilliSecondsSinceUnixEpoch, OwnedEventId, RoomId, RoomVersionId,
ServerName,
uint, CanonicalJsonValue, MilliSecondsSinceUnixEpoch, OwnedEventId, RoomId, ServerName, UInt,
};
use super::check_room_id;
@@ -26,7 +25,7 @@ pub(super) async fn fetch_prev(
origin: &ServerName,
create_event: &PduEvent,
room_id: &RoomId,
room_version_id: &RoomVersionId,
first_ts_in_room: UInt,
initial_set: Vec<OwnedEventId>,
) -> Result<(
Vec<OwnedEventId>,
@@ -36,21 +35,13 @@ pub(super) async fn fetch_prev(
let mut eventid_info = HashMap::new();
let mut todo_outlier_stack: VecDeque<OwnedEventId> = initial_set.into();
let first_pdu_in_room = self.services.timeline.first_pdu_in_room(room_id).await?;
let mut amount = 0;
while let Some(prev_event_id) = todo_outlier_stack.pop_front() {
self.services.server.check_running()?;
if let Some((pdu, mut json_opt)) = self
.fetch_and_handle_outliers(
origin,
&[prev_event_id.clone()],
create_event,
room_id,
room_version_id,
)
.fetch_and_handle_outliers(origin, &[prev_event_id.clone()], create_event, room_id)
.boxed()
.await
.pop()
@@ -74,7 +65,7 @@ pub(super) async fn fetch_prev(
}
if let Some(json) = json_opt {
if pdu.origin_server_ts > first_pdu_in_room.origin_server_ts {
if pdu.origin_server_ts > first_ts_in_room {
amount = amount.saturating_add(1);
for prev_prev in &pdu.prev_events {
if !graph.contains_key(prev_prev) {
@@ -4,7 +4,7 @@ use conduwuit::{debug, debug_warn, implement, Err, Error, PduEvent, Result};
use futures::FutureExt;
use ruma::{
api::federation::event::get_room_state_ids, events::StateEventType, EventId, OwnedEventId,
RoomId, RoomVersionId, ServerName,
RoomId, ServerName,
};
use crate::rooms::short::ShortStateKey;
@@ -23,7 +23,6 @@ pub(super) async fn fetch_state(
origin: &ServerName,
create_event: &PduEvent,
room_id: &RoomId,
room_version_id: &RoomVersionId,
event_id: &EventId,
) -> Result<Option<HashMap<u64, OwnedEventId>>> {
let res = self
@@ -38,7 +37,7 @@ pub(super) async fn fetch_state(
debug!("Fetching state events");
let state_vec = self
.fetch_and_handle_outliers(origin, &res.pdu_ids, create_event, room_id, room_version_id)
.fetch_and_handle_outliers(origin, &res.pdu_ids, create_event, room_id)
.boxed()
.await;
@@ -1,14 +1,15 @@
use std::{
collections::{hash_map, BTreeMap},
sync::Arc,
time::Instant,
};
use conduwuit::{debug, err, implement, warn, Err, Result};
use futures::{FutureExt, TryFutureExt};
use conduwuit::{debug, debug::INFO_SPAN_LEVEL, err, implement, warn, Err, Result};
use futures::{
future::{try_join5, OptionFuture},
FutureExt,
};
use ruma::{events::StateEventType, CanonicalJsonValue, EventId, RoomId, ServerName, UserId};
use super::{check_room_id, get_room_version_id};
use crate::rooms::timeline::RawPduId;
/// When receiving an event one needs to:
@@ -41,7 +42,7 @@ use crate::rooms::timeline::RawPduId;
#[implement(super::Service)]
#[tracing::instrument(
name = "pdu",
level = "debug",
level = INFO_SPAN_LEVEL,
skip_all,
fields(%room_id, %event_id),
)]
@@ -59,19 +60,13 @@ pub async fn handle_incoming_pdu<'a>(
}
// 1.1 Check the server is in the room
if !self.services.metadata.exists(room_id).await {
return Err!(Request(NotFound("Room is unknown to this server")));
}
let meta_exists = self.services.metadata.exists(room_id).map(Ok);
// 1.2 Check if the room is disabled
if self.services.metadata.is_disabled(room_id).await {
return Err!(Request(Forbidden(
"Federation of this room is currently disabled on this server."
)));
}
let is_disabled = self.services.metadata.is_disabled(room_id).map(Ok);
// 1.3.1 Check room ACL on origin field/server
self.acl_check(origin, room_id).await?;
let origin_acl_check = self.acl_check(origin, room_id);
// 1.3.2 Check room ACL on sender's server name
let sender: &UserId = value
@@ -79,34 +74,53 @@ pub async fn handle_incoming_pdu<'a>(
.try_into()
.map_err(|e| err!(Request(InvalidParam("PDU does not have a valid sender key: {e}"))))?;
self.acl_check(sender.server_name(), room_id).await?;
let sender_acl_check: OptionFuture<_> = sender
.server_name()
.ne(origin)
.then(|| self.acl_check(sender.server_name(), room_id))
.into();
// Fetch create event
let create_event = self
.services
.state_accessor
.room_state_get(room_id, &StateEventType::RoomCreate, "")
.map_ok(Arc::new)
.await?;
let create_event =
self.services
.state_accessor
.room_state_get(room_id, &StateEventType::RoomCreate, "");
// Procure the room version
let room_version_id = get_room_version_id(&create_event)?;
let (meta_exists, is_disabled, (), (), create_event) = try_join5(
meta_exists,
is_disabled,
origin_acl_check,
sender_acl_check.map(|o| o.unwrap_or(Ok(()))),
create_event,
)
.await?;
let first_pdu_in_room = self.services.timeline.first_pdu_in_room(room_id).await?;
if !meta_exists {
return Err!(Request(NotFound("Room is unknown to this server")));
}
if is_disabled {
return Err!(Request(Forbidden("Federation of this room is disabled by this server.")));
}
let (incoming_pdu, val) = self
.handle_outlier_pdu(origin, &create_event, event_id, room_id, value, false)
.boxed()
.await?;
check_room_id(room_id, &incoming_pdu)?;
// 8. if not timeline event: stop
if !is_timeline_event {
return Ok(None);
}
// Skip old events
if incoming_pdu.origin_server_ts < first_pdu_in_room.origin_server_ts {
let first_ts_in_room = self
.services
.timeline
.first_pdu_in_room(room_id)
.await?
.origin_server_ts;
if incoming_pdu.origin_server_ts < first_ts_in_room {
return Ok(None);
}
@@ -117,7 +131,7 @@ pub async fn handle_incoming_pdu<'a>(
origin,
&create_event,
room_id,
&room_version_id,
first_ts_in_room,
incoming_pdu.prev_events.clone(),
)
.await?;
@@ -132,7 +146,7 @@ pub async fn handle_incoming_pdu<'a>(
room_id,
&mut eventid_info,
&create_event,
&first_pdu_in_room,
first_ts_in_room,
&prev_id,
)
.await
@@ -84,7 +84,6 @@ pub(super) async fn handle_outlier_pdu<'a>(
&incoming_pdu.auth_events,
create_event,
room_id,
&room_version_id,
))
.await;
}
@@ -5,16 +5,17 @@ use std::{
};
use conduwuit::{
debug, implement, utils::continue_exponential_backoff_secs, Err, PduEvent, Result,
debug, debug::INFO_SPAN_LEVEL, implement, utils::continue_exponential_backoff_secs, Err,
PduEvent, Result,
};
use ruma::{CanonicalJsonValue, EventId, OwnedEventId, RoomId, ServerName};
use ruma::{CanonicalJsonValue, EventId, OwnedEventId, RoomId, ServerName, UInt};
#[implement(super::Service)]
#[allow(clippy::type_complexity)]
#[allow(clippy::too_many_arguments)]
#[tracing::instrument(
name = "prev",
level = "debug",
level = INFO_SPAN_LEVEL,
skip_all,
fields(%prev_id),
)]
@@ -27,8 +28,8 @@ pub(super) async fn handle_prev_pdu<'a>(
OwnedEventId,
(Arc<PduEvent>, BTreeMap<String, CanonicalJsonValue>),
>,
create_event: &Arc<PduEvent>,
first_pdu_in_room: &PduEvent,
create_event: &PduEvent,
first_ts_in_room: UInt,
prev_id: &EventId,
) -> Result {
// Check for disabled again because it might have changed
@@ -62,7 +63,7 @@ pub(super) async fn handle_prev_pdu<'a>(
if let Some((pdu, json)) = eventid_info.remove(prev_id) {
// Skip old events
if pdu.origin_server_ts < first_pdu_in_room.origin_server_ts {
if pdu.origin_server_ts < first_ts_in_room {
return Ok(());
}
@@ -2,11 +2,10 @@ use conduwuit::{err, implement, pdu::gen_event_id_canonical_json, result::FlatOk
use ruma::{CanonicalJsonObject, CanonicalJsonValue, OwnedEventId, OwnedRoomId};
use serde_json::value::RawValue as RawJsonValue;
type Parsed = (OwnedRoomId, OwnedEventId, CanonicalJsonObject);
#[implement(super::Service)]
pub async fn parse_incoming_pdu(
&self,
pdu: &RawJsonValue,
) -> Result<(OwnedEventId, CanonicalJsonObject, OwnedRoomId)> {
pub async fn parse_incoming_pdu(&self, pdu: &RawJsonValue) -> Result<Parsed> {
let value = serde_json::from_str::<CanonicalJsonObject>(pdu.get()).map_err(|e| {
err!(BadServerResponse(debug_warn!("Error parsing incoming event {e:?}")))
})?;
@@ -28,5 +27,5 @@ pub async fn parse_incoming_pdu(
err!(Request(InvalidParam("Could not convert event to canonical json: {e}")))
})?;
Ok((event_id, value, room_id))
Ok((room_id, event_id, value))
}
@@ -33,11 +33,12 @@ pub async fn resolve_state(
.await
.map_err(|e| err!(Database(error!("No state for {room_id:?}: {e:?}"))))?;
let current_state_ids = self
let current_state_ids: HashMap<_, _> = self
.services
.state_accessor
.state_full_ids(current_sstatehash)
.await?;
.collect()
.await;
let fork_states = [current_state_ids, incoming_state];
let auth_chain_sets: Vec<HashSet<OwnedEventId>> = fork_states
@@ -31,15 +31,12 @@ pub(super) async fn state_at_incoming_degree_one(
return Ok(None);
};
let Ok(mut state) = self
let mut state: HashMap<_, _> = self
.services
.state_accessor
.state_full_ids(prev_event_sstatehash)
.await
.log_err()
else {
return Ok(None);
};
.collect()
.await;
debug!("Using cached state");
let prev_pdu = self
@@ -103,14 +100,12 @@ pub(super) async fn state_at_incoming_resolved(
let mut fork_states = Vec::with_capacity(extremity_sstatehashes.len());
let mut auth_chain_sets = Vec::with_capacity(extremity_sstatehashes.len());
for (sstatehash, prev_event) in extremity_sstatehashes {
let Ok(mut leaf_state) = self
let mut leaf_state: HashMap<_, _> = self
.services
.state_accessor
.state_full_ids(sstatehash)
.await
else {
continue;
};
.collect()
.await;
if let Some(state_key) = &prev_event.state_key {
let shortstatekey = self
@@ -63,7 +63,7 @@ pub(super) async fn upgrade_outlier_to_timeline_pdu(
if state_at_incoming_event.is_none() {
state_at_incoming_event = self
.fetch_state(origin, create_event, room_id, &room_version_id, &incoming_pdu.event_id)
.fetch_state(origin, create_event, room_id, &incoming_pdu.event_id)
.await?;
}
+6 -4
View File
@@ -7,7 +7,7 @@ use conduwuit::{
utils::{stream::TryIgnore, IterStream, ReadyExt},
Result,
};
use database::{Database, Deserialized, Handle, Interfix, Map};
use database::{Database, Deserialized, Handle, Interfix, Map, Qry};
use futures::{pin_mut, Stream, StreamExt};
use ruma::{api::client::filter::LazyLoadOptions, DeviceId, OwnedUserId, RoomId, UserId};
@@ -115,9 +115,11 @@ where
let make_key =
|sender: &'a UserId| -> Key<'a> { (ctx.user_id, ctx.device_id, ctx.room_id, sender) };
self.db
.lazyloadedids
.qry_batch(senders.clone().stream().map(make_key))
senders
.clone()
.stream()
.map(make_key)
.qry(&self.db.lazyloadedids)
.map(into_status)
.zip(senders.stream())
.map(move |(status, sender)| {
+9 -10
View File
@@ -2,7 +2,7 @@ use std::{borrow::Borrow, fmt::Debug, mem::size_of_val, sync::Arc};
pub use conduwuit::pdu::{ShortEventId, ShortId, ShortRoomId};
use conduwuit::{err, implement, utils, utils::IterStream, Result};
use database::{Deserialized, Map};
use database::{Deserialized, Get, Map, Qry};
use futures::{Stream, StreamExt};
use ruma::{events::StateEventType, EventId, RoomId};
use serde::Deserialize;
@@ -67,9 +67,10 @@ pub fn multi_get_or_create_shorteventid<'a, I>(
where
I: Iterator<Item = &'a EventId> + Clone + Debug + Send + 'a,
{
self.db
.eventid_shorteventid
.get_batch(event_ids.clone().stream())
event_ids
.clone()
.stream()
.get(&self.db.eventid_shorteventid)
.zip(event_ids.into_iter().stream())
.map(|(result, event_id)| match result {
| Ok(ref short) => utils::u64_from_u8(short),
@@ -171,9 +172,8 @@ where
Id: for<'de> Deserialize<'de> + Sized + ToOwned + 'a,
<Id as ToOwned>::Owned: Borrow<EventId>,
{
self.db
.shorteventid_eventid
.qry_batch(shorteventid)
shorteventid
.qry(&self.db.shorteventid_eventid)
.map(Deserialized::deserialized)
}
@@ -204,9 +204,8 @@ pub fn multi_get_statekey_from_short<'a, S>(
where
S: Stream<Item = ShortStateKey> + Send + 'a,
{
self.db
.shortstatekey_statekey
.qry_batch(shortstatekey)
shortstatekey
.qry(&self.db.shortstatekey_statekey)
.map(Deserialized::deserialized)
}
+3 -3
View File
@@ -268,7 +268,7 @@ impl Service {
}
/// Gets the summary of a space using solely federation
#[tracing::instrument(skip(self))]
#[tracing::instrument(level = "debug", skip(self))]
async fn get_summary_and_children_federation(
&self,
current_room: &OwnedRoomId,
@@ -624,8 +624,8 @@ impl Service {
.services
.state_accessor
.state_full_ids(current_shortstatehash)
.await
.map_err(|e| err!(Database("State in space not found: {e}")))?;
.collect()
.await;
let mut children_pdus = Vec::with_capacity(state.len());
for (key, id) in state {
-253
View File
@@ -1,253 +0,0 @@
use std::{borrow::Borrow, collections::HashMap, sync::Arc};
use conduwuit::{
at, err,
utils::stream::{BroadbandExt, IterStream, ReadyExt},
PduEvent, Result,
};
use database::{Deserialized, Map};
use futures::{FutureExt, StreamExt, TryFutureExt};
use ruma::{events::StateEventType, EventId, OwnedEventId, RoomId};
use serde::Deserialize;
use crate::{
rooms,
rooms::{
short::{ShortEventId, ShortStateHash, ShortStateKey},
state_compressor::parse_compressed_state_event,
},
Dep,
};
pub(super) struct Data {
shorteventid_shortstatehash: Arc<Map>,
services: Services,
}
struct Services {
short: Dep<rooms::short::Service>,
state: Dep<rooms::state::Service>,
state_compressor: Dep<rooms::state_compressor::Service>,
timeline: Dep<rooms::timeline::Service>,
}
impl Data {
pub(super) fn new(args: &crate::Args<'_>) -> Self {
let db = &args.db;
Self {
shorteventid_shortstatehash: db["shorteventid_shortstatehash"].clone(),
services: Services {
short: args.depend::<rooms::short::Service>("rooms::short"),
state: args.depend::<rooms::state::Service>("rooms::state"),
state_compressor: args
.depend::<rooms::state_compressor::Service>("rooms::state_compressor"),
timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"),
},
}
}
pub(super) async fn state_full(
&self,
shortstatehash: ShortStateHash,
) -> Result<HashMap<(StateEventType, String), PduEvent>> {
let state = self
.state_full_pdus(shortstatehash)
.await?
.into_iter()
.filter_map(|pdu| Some(((pdu.kind.to_string().into(), pdu.state_key.clone()?), pdu)))
.collect();
Ok(state)
}
pub(super) async fn state_full_pdus(
&self,
shortstatehash: ShortStateHash,
) -> Result<Vec<PduEvent>> {
let short_ids = self.state_full_shortids(shortstatehash).await?;
let full_pdus = self
.services
.short
.multi_get_eventid_from_short(short_ids.into_iter().map(at!(1)).stream())
.ready_filter_map(Result::ok)
.broad_filter_map(|event_id: OwnedEventId| async move {
self.services.timeline.get_pdu(&event_id).await.ok()
})
.collect()
.await;
Ok(full_pdus)
}
pub(super) async fn state_full_ids<Id>(
&self,
shortstatehash: ShortStateHash,
) -> Result<HashMap<ShortStateKey, Id>>
where
Id: for<'de> Deserialize<'de> + Send + Sized + ToOwned,
<Id as ToOwned>::Owned: Borrow<EventId>,
{
let short_ids = self.state_full_shortids(shortstatehash).await?;
let full_ids = self
.services
.short
.multi_get_eventid_from_short(short_ids.iter().map(at!(1)).stream())
.zip(short_ids.iter().stream().map(at!(0)))
.ready_filter_map(|(event_id, shortstatekey)| Some((shortstatekey, event_id.ok()?)))
.collect()
.boxed()
.await;
Ok(full_ids)
}
pub(super) async fn state_full_shortids(
&self,
shortstatehash: ShortStateHash,
) -> Result<Vec<(ShortStateKey, ShortEventId)>> {
let shortids = self
.services
.state_compressor
.load_shortstatehash_info(shortstatehash)
.await
.map_err(|e| err!(Database("Missing state IDs: {e}")))?
.pop()
.expect("there is always one layer")
.full_state
.iter()
.copied()
.map(parse_compressed_state_event)
.collect();
Ok(shortids)
}
/// Returns a single EventId from `room_id` with key
/// (`event_type`,`state_key`).
pub(super) async fn state_get_id<Id>(
&self,
shortstatehash: ShortStateHash,
event_type: &StateEventType,
state_key: &str,
) -> Result<Id>
where
Id: for<'de> Deserialize<'de> + Sized + ToOwned,
<Id as ToOwned>::Owned: Borrow<EventId>,
{
let shortstatekey = self
.services
.short
.get_shortstatekey(event_type, state_key)
.await?;
let full_state = self
.services
.state_compressor
.load_shortstatehash_info(shortstatehash)
.await
.map_err(|e| err!(Database(error!(?event_type, ?state_key, "Missing state: {e:?}"))))?
.pop()
.expect("there is always one layer")
.full_state;
let compressed = full_state
.iter()
.find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes()))
.ok_or(err!(Database("No shortstatekey in compressed state")))?;
let (_, shorteventid) = parse_compressed_state_event(*compressed);
self.services
.short
.get_eventid_from_short(shorteventid)
.await
}
/// Returns a single PDU from `room_id` with key (`event_type`,`state_key`).
pub(super) async fn state_get(
&self,
shortstatehash: ShortStateHash,
event_type: &StateEventType,
state_key: &str,
) -> Result<PduEvent> {
self.state_get_id(shortstatehash, event_type, state_key)
.and_then(|event_id: OwnedEventId| async move {
self.services.timeline.get_pdu(&event_id).await
})
.await
}
/// Returns the state hash for this pdu.
pub(super) async fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<ShortStateHash> {
const BUFSIZE: usize = size_of::<ShortEventId>();
self.services
.short
.get_shorteventid(event_id)
.and_then(|shorteventid| {
self.shorteventid_shortstatehash
.aqry::<BUFSIZE, _>(&shorteventid)
})
.await
.deserialized()
}
/// Returns the full room state.
pub(super) async fn room_state_full(
&self,
room_id: &RoomId,
) -> Result<HashMap<(StateEventType, String), PduEvent>> {
self.services
.state
.get_room_shortstatehash(room_id)
.and_then(|shortstatehash| self.state_full(shortstatehash))
.map_err(|e| err!(Database("Missing state for {room_id:?}: {e:?}")))
.await
}
/// Returns the full room state's pdus.
#[allow(unused_qualifications)] // async traits
pub(super) async fn room_state_full_pdus(&self, room_id: &RoomId) -> Result<Vec<PduEvent>> {
self.services
.state
.get_room_shortstatehash(room_id)
.and_then(|shortstatehash| self.state_full_pdus(shortstatehash))
.map_err(|e| err!(Database("Missing state pdus for {room_id:?}: {e:?}")))
.await
}
/// Returns a single EventId from `room_id` with key
/// (`event_type`,`state_key`).
pub(super) async fn room_state_get_id<Id>(
&self,
room_id: &RoomId,
event_type: &StateEventType,
state_key: &str,
) -> Result<Id>
where
Id: for<'de> Deserialize<'de> + Sized + ToOwned,
<Id as ToOwned>::Owned: Borrow<EventId>,
{
self.services
.state
.get_room_shortstatehash(room_id)
.and_then(|shortstatehash| self.state_get_id(shortstatehash, event_type, state_key))
.await
}
/// Returns a single PDU from `room_id` with key (`event_type`,`state_key`).
pub(super) async fn room_state_get(
&self,
room_id: &RoomId,
event_type: &StateEventType,
state_key: &str,
) -> Result<PduEvent> {
self.services
.state
.get_room_shortstatehash(room_id)
.and_then(|shortstatehash| self.state_get(shortstatehash, event_type, state_key))
.await
}
}
+180 -50
View File
@@ -1,23 +1,22 @@
mod data;
use std::{
borrow::Borrow,
collections::HashMap,
fmt::Write,
sync::{Arc, Mutex as StdMutex, Mutex},
};
use conduwuit::{
err, error,
at, err, error,
pdu::PduBuilder,
utils,
utils::{
math::{usize_from_f64, Expected},
ReadyExt,
stream::BroadbandExt,
IterStream, ReadyExt,
},
Err, Error, PduEvent, Result,
};
use futures::StreamExt;
use database::{Deserialized, Map};
use futures::{FutureExt, Stream, StreamExt, TryFutureExt};
use lru_cache::LruCache;
use ruma::{
events::{
@@ -38,33 +37,40 @@ use ruma::{
},
room::RoomType,
space::SpaceRoomJoinRule,
EventEncryptionAlgorithm, EventId, JsOption, OwnedRoomAliasId, OwnedRoomId, OwnedServerName,
OwnedUserId, RoomId, ServerName, UserId,
EventEncryptionAlgorithm, EventId, JsOption, OwnedEventId, OwnedRoomAliasId, OwnedRoomId,
OwnedServerName, OwnedUserId, RoomId, ServerName, UserId,
};
use serde::Deserialize;
use self::data::Data;
use crate::{
rooms,
rooms::{
short::{ShortEventId, ShortStateHash, ShortStateKey},
state::RoomMutexGuard,
state_compressor::parse_compressed_state_event,
},
Dep,
};
pub struct Service {
services: Services,
db: Data,
pub server_visibility_cache: Mutex<LruCache<(OwnedServerName, ShortStateHash), bool>>,
pub user_visibility_cache: Mutex<LruCache<(OwnedUserId, ShortStateHash), bool>>,
services: Services,
db: Data,
}
struct Services {
short: Dep<rooms::short::Service>,
state: Dep<rooms::state::Service>,
state_compressor: Dep<rooms::state_compressor::Service>,
state_cache: Dep<rooms::state_cache::Service>,
timeline: Dep<rooms::timeline::Service>,
}
struct Data {
shorteventid_shortstatehash: Arc<Map>,
}
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
let config = &args.server.config;
@@ -74,17 +80,23 @@ impl crate::Service for Service {
f64::from(config.user_visibility_cache_capacity) * config.cache_capacity_modifier;
Ok(Arc::new(Self {
services: Services {
state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"),
timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"),
},
db: Data::new(&args),
server_visibility_cache: StdMutex::new(LruCache::new(usize_from_f64(
server_visibility_cache_capacity,
)?)),
user_visibility_cache: StdMutex::new(LruCache::new(usize_from_f64(
user_visibility_cache_capacity,
)?)),
services: Services {
state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"),
timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"),
short: args.depend::<rooms::short::Service>("rooms::short"),
state: args.depend::<rooms::state::Service>("rooms::state"),
state_compressor: args
.depend::<rooms::state_compressor::Service>("rooms::state_compressor"),
},
db: Data {
shorteventid_shortstatehash: args.db["shorteventid_shortstatehash"].clone(),
},
}))
}
@@ -130,33 +142,74 @@ impl crate::Service for Service {
}
impl Service {
pub fn state_full(
&self,
shortstatehash: ShortStateHash,
) -> impl Stream<Item = ((StateEventType, String), PduEvent)> + Send + '_ {
self.state_full_pdus(shortstatehash)
.ready_filter_map(|pdu| {
Some(((pdu.kind.to_string().into(), pdu.state_key.clone()?), pdu))
})
}
pub fn state_full_pdus(
&self,
shortstatehash: ShortStateHash,
) -> impl Stream<Item = PduEvent> + Send + '_ {
let short_ids = self
.state_full_shortids(shortstatehash)
.map(|result| result.expect("missing shortstatehash"))
.map(Vec::into_iter)
.map(|iter| iter.map(at!(1)))
.map(IterStream::stream)
.flatten_stream()
.boxed();
self.services
.short
.multi_get_eventid_from_short(short_ids)
.ready_filter_map(Result::ok)
.broad_filter_map(move |event_id: OwnedEventId| async move {
self.services.timeline.get_pdu(&event_id).await.ok()
})
}
/// Builds a StateMap by iterating over all keys that start
/// with state_hash, this gives the full state for the given state_hash.
#[tracing::instrument(skip(self), level = "debug")]
pub async fn state_full_ids<Id>(
&self,
pub fn state_full_ids<'a, Id>(
&'a self,
shortstatehash: ShortStateHash,
) -> Result<HashMap<ShortStateKey, Id>>
) -> impl Stream<Item = (ShortStateKey, Id)> + Send + 'a
where
Id: for<'de> Deserialize<'de> + Send + Sized + ToOwned,
Id: for<'de> Deserialize<'de> + Send + Sized + ToOwned + 'a,
<Id as ToOwned>::Owned: Borrow<EventId>,
{
self.db.state_full_ids::<Id>(shortstatehash).await
}
let shortids = self
.state_full_shortids(shortstatehash)
.map(|result| result.expect("missing shortstatehash"))
.map(|vec| vec.into_iter().unzip())
.boxed()
.shared();
#[inline]
pub async fn state_full_shortids(
&self,
shortstatehash: ShortStateHash,
) -> Result<Vec<(ShortStateKey, ShortEventId)>> {
self.db.state_full_shortids(shortstatehash).await
}
let shortstatekeys = shortids
.clone()
.map(at!(0))
.map(Vec::into_iter)
.map(IterStream::stream)
.flatten_stream();
pub async fn state_full(
&self,
shortstatehash: ShortStateHash,
) -> Result<HashMap<(StateEventType, String), PduEvent>> {
self.db.state_full(shortstatehash).await
let shorteventids = shortids
.map(at!(1))
.map(Vec::into_iter)
.map(IterStream::stream)
.flatten_stream();
self.services
.short
.multi_get_eventid_from_short(shorteventids)
.zip(shortstatekeys)
.ready_filter_map(|(event_id, shortstatekey)| Some((shortstatekey, event_id.ok()?)))
}
/// Returns a single EventId from `room_id` with key (`event_type`,
@@ -172,22 +225,69 @@ impl Service {
Id: for<'de> Deserialize<'de> + Sized + ToOwned,
<Id as ToOwned>::Owned: Borrow<EventId>,
{
self.db
.state_get_id(shortstatehash, event_type, state_key)
let shortstatekey = self
.services
.short
.get_shortstatekey(event_type, state_key)
.await?;
let full_state = self
.services
.state_compressor
.load_shortstatehash_info(shortstatehash)
.await
.map_err(|e| err!(Database(error!(?event_type, ?state_key, "Missing state: {e:?}"))))?
.pop()
.expect("there is always one layer")
.full_state;
let compressed = full_state
.iter()
.find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes()))
.ok_or(err!(Database("No shortstatekey in compressed state")))?;
let (_, shorteventid) = parse_compressed_state_event(*compressed);
self.services
.short
.get_eventid_from_short(shorteventid)
.await
}
#[inline]
pub async fn state_full_shortids(
&self,
shortstatehash: ShortStateHash,
) -> Result<Vec<(ShortStateKey, ShortEventId)>> {
let shortids = self
.services
.state_compressor
.load_shortstatehash_info(shortstatehash)
.await
.map_err(|e| err!(Database("Missing state IDs: {e}")))?
.pop()
.expect("there is always one layer")
.full_state
.iter()
.copied()
.map(parse_compressed_state_event)
.collect();
Ok(shortids)
}
/// Returns a single PDU from `room_id` with key (`event_type`,
/// `state_key`).
#[inline]
pub async fn state_get(
&self,
shortstatehash: ShortStateHash,
event_type: &StateEventType,
state_key: &str,
) -> Result<PduEvent> {
self.db
.state_get(shortstatehash, event_type, state_key)
self.state_get_id(shortstatehash, event_type, state_key)
.and_then(|event_id: OwnedEventId| async move {
self.services.timeline.get_pdu(&event_id).await
})
.await
}
@@ -375,22 +475,46 @@ impl Service {
/// Returns the state hash for this pdu.
pub async fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<ShortStateHash> {
self.db.pdu_shortstatehash(event_id).await
const BUFSIZE: usize = size_of::<ShortEventId>();
self.services
.short
.get_shorteventid(event_id)
.and_then(|shorteventid| {
self.db
.shorteventid_shortstatehash
.aqry::<BUFSIZE, _>(&shorteventid)
})
.await
.deserialized()
}
/// Returns the full room state.
#[tracing::instrument(skip(self), level = "debug")]
pub async fn room_state_full(
&self,
room_id: &RoomId,
) -> Result<HashMap<(StateEventType, String), PduEvent>> {
self.db.room_state_full(room_id).await
pub fn room_state_full<'a>(
&'a self,
room_id: &'a RoomId,
) -> impl Stream<Item = Result<((StateEventType, String), PduEvent)>> + Send + 'a {
self.services
.state
.get_room_shortstatehash(room_id)
.map_ok(|shortstatehash| self.state_full(shortstatehash).map(Ok))
.map_err(move |e| err!(Database("Missing state for {room_id:?}: {e:?}")))
.try_flatten_stream()
}
/// Returns the full room state pdus
#[tracing::instrument(skip(self), level = "debug")]
pub async fn room_state_full_pdus(&self, room_id: &RoomId) -> Result<Vec<PduEvent>> {
self.db.room_state_full_pdus(room_id).await
pub fn room_state_full_pdus<'a>(
&'a self,
room_id: &'a RoomId,
) -> impl Stream<Item = Result<PduEvent>> + Send + 'a {
self.services
.state
.get_room_shortstatehash(room_id)
.map_ok(|shortstatehash| self.state_full_pdus(shortstatehash).map(Ok))
.map_err(move |e| err!(Database("Missing state for {room_id:?}: {e:?}")))
.try_flatten_stream()
}
/// Returns a single EventId from `room_id` with key (`event_type`,
@@ -406,8 +530,10 @@ impl Service {
Id: for<'de> Deserialize<'de> + Sized + ToOwned,
<Id as ToOwned>::Owned: Borrow<EventId>,
{
self.db
.room_state_get_id(room_id, event_type, state_key)
self.services
.state
.get_room_shortstatehash(room_id)
.and_then(|shortstatehash| self.state_get_id(shortstatehash, event_type, state_key))
.await
}
@@ -420,7 +546,11 @@ impl Service {
event_type: &StateEventType,
state_key: &str,
) -> Result<PduEvent> {
self.db.room_state_get(room_id, event_type, state_key).await
self.services
.state
.get_room_shortstatehash(room_id)
.and_then(|shortstatehash| self.state_get(shortstatehash, event_type, state_key))
.await
}
/// Returns a single PDU from `room_id` with key (`event_type`,`state_key`).
+1 -1
View File
@@ -1166,7 +1166,7 @@ impl Service {
#[tracing::instrument(skip(self, pdu), level = "debug")]
pub async fn backfill_pdu(&self, origin: &ServerName, pdu: Box<RawJsonValue>) -> Result<()> {
let (event_id, value, room_id) =
let (room_id, event_id, value) =
self.services.event_handler.parse_incoming_pdu(&pdu).await?;
// Lock so we cannot backfill the same pdu twice at the same time
+7 -6
View File
@@ -13,7 +13,7 @@ use ruma::{
};
use tokio::sync::{broadcast, RwLock};
use crate::{globals, sending, users, Dep};
use crate::{globals, sending, sending::EduBuf, users, Dep};
pub struct Service {
server: Arc<Server>,
@@ -228,12 +228,13 @@ impl Service {
return Ok(());
}
let edu = Edu::Typing(TypingContent::new(room_id.to_owned(), user_id.to_owned(), typing));
let content = TypingContent::new(room_id.to_owned(), user_id.to_owned(), typing);
let edu = Edu::Typing(content);
self.services
.sending
.send_edu_room(room_id, serde_json::to_vec(&edu).expect("Serialized Edu::Typing"))
.await?;
let mut buf = EduBuf::new();
serde_json::to_writer(&mut buf, &edu).expect("Serialized Edu::Typing");
self.services.sending.send_edu_room(room_id, buf).await?;
Ok(())
}
+3 -3
View File
@@ -202,7 +202,7 @@ fn parse_servercurrentevent(key: &[u8], value: &[u8]) -> Result<(Destination, Se
if value.is_empty() {
SendingEvent::Pdu(event.into())
} else {
SendingEvent::Edu(value.to_vec())
SendingEvent::Edu(value.into())
},
)
} else if key.starts_with(b"$") {
@@ -230,7 +230,7 @@ fn parse_servercurrentevent(key: &[u8], value: &[u8]) -> Result<(Destination, Se
SendingEvent::Pdu(event.into())
} else {
// I'm pretty sure this should never be called
SendingEvent::Edu(value.to_vec())
SendingEvent::Edu(value.into())
},
)
} else {
@@ -252,7 +252,7 @@ fn parse_servercurrentevent(key: &[u8], value: &[u8]) -> Result<(Destination, Se
if value.is_empty() {
SendingEvent::Pdu(event.into())
} else {
SendingEvent::Edu(value.to_vec())
SendingEvent::Edu(value.into())
},
)
})
+24 -19
View File
@@ -1,7 +1,6 @@
mod appservice;
mod data;
mod dest;
mod send;
mod sender;
use std::{
@@ -22,6 +21,7 @@ use ruma::{
api::{appservice::Registration, OutgoingRequest},
RoomId, ServerName, UserId,
};
use smallvec::SmallVec;
use tokio::task::JoinSet;
use self::data::Data;
@@ -30,8 +30,8 @@ pub use self::{
sender::{EDU_LIMIT, PDU_LIMIT},
};
use crate::{
account_data, client, globals, presence, pusher, resolver, rooms, rooms::timeline::RawPduId,
server_keys, users, Dep,
account_data, client, federation, globals, presence, pusher, rooms,
rooms::timeline::RawPduId, users, Dep,
};
pub struct Service {
@@ -44,7 +44,6 @@ pub struct Service {
struct Services {
client: Dep<client::Service>,
globals: Dep<globals::Service>,
resolver: Dep<resolver::Service>,
state: Dep<rooms::state::Service>,
state_cache: Dep<rooms::state_cache::Service>,
user: Dep<rooms::user::Service>,
@@ -55,7 +54,7 @@ struct Services {
account_data: Dep<account_data::Service>,
appservice: Dep<crate::appservice::Service>,
pusher: Dep<pusher::Service>,
server_keys: Dep<server_keys::Service>,
federation: Dep<federation::Service>,
}
#[derive(Clone, Debug, PartialEq, Eq)]
@@ -69,10 +68,16 @@ struct Msg {
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub enum SendingEvent {
Pdu(RawPduId), // pduid
Edu(Vec<u8>), // pdu json
Edu(EduBuf), // edu json
Flush, // none
}
pub type EduBuf = SmallVec<[u8; EDU_BUF_CAP]>;
pub type EduVec = SmallVec<[EduBuf; EDU_VEC_CAP]>;
const EDU_BUF_CAP: usize = 128;
const EDU_VEC_CAP: usize = 1;
#[async_trait]
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
@@ -83,7 +88,6 @@ impl crate::Service for Service {
services: Services {
client: args.depend::<client::Service>("client"),
globals: args.depend::<globals::Service>("globals"),
resolver: args.depend::<resolver::Service>("resolver"),
state: args.depend::<rooms::state::Service>("rooms::state"),
state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"),
user: args.depend::<rooms::user::Service>("rooms::user"),
@@ -94,7 +98,7 @@ impl crate::Service for Service {
account_data: args.depend::<account_data::Service>("account_data"),
appservice: args.depend::<crate::appservice::Service>("appservice"),
pusher: args.depend::<pusher::Service>("pusher"),
server_keys: args.depend::<server_keys::Service>("server_keys"),
federation: args.depend::<federation::Service>("federation"),
},
channels: (0..num_senders).map(|_| loole::unbounded()).collect(),
}))
@@ -180,7 +184,6 @@ impl Service {
where
S: Stream<Item = &'a ServerName> + Send + 'a,
{
let _cork = self.db.db.cork();
let requests = servers
.map(|server| {
(Destination::Federation(server.into()), SendingEvent::Pdu(pdu_id.to_owned()))
@@ -188,6 +191,7 @@ impl Service {
.collect::<Vec<_>>()
.await;
let _cork = self.db.db.cork();
let keys = self.db.queue_requests(requests.iter().map(|(o, e)| (e, o)));
for ((dest, event), queue_id) in requests.into_iter().zip(keys) {
@@ -198,7 +202,7 @@ impl Service {
}
#[tracing::instrument(skip(self, server, serialized), level = "debug")]
pub fn send_edu_server(&self, server: &ServerName, serialized: Vec<u8>) -> Result<()> {
pub fn send_edu_server(&self, server: &ServerName, serialized: EduBuf) -> Result {
let dest = Destination::Federation(server.to_owned());
let event = SendingEvent::Edu(serialized);
let _cork = self.db.db.cork();
@@ -211,7 +215,7 @@ impl Service {
}
#[tracing::instrument(skip(self, room_id, serialized), level = "debug")]
pub async fn send_edu_room(&self, room_id: &RoomId, serialized: Vec<u8>) -> Result<()> {
pub async fn send_edu_room(&self, room_id: &RoomId, serialized: EduBuf) -> Result {
let servers = self
.services
.state_cache
@@ -222,11 +226,10 @@ impl Service {
}
#[tracing::instrument(skip(self, servers, serialized), level = "debug")]
pub async fn send_edu_servers<'a, S>(&self, servers: S, serialized: Vec<u8>) -> Result<()>
pub async fn send_edu_servers<'a, S>(&self, servers: S, serialized: EduBuf) -> Result
where
S: Stream<Item = &'a ServerName> + Send + 'a,
{
let _cork = self.db.db.cork();
let requests = servers
.map(|server| {
(
@@ -237,6 +240,7 @@ impl Service {
.collect::<Vec<_>>()
.await;
let _cork = self.db.db.cork();
let keys = self.db.queue_requests(requests.iter().map(|(o, e)| (e, o)));
for ((dest, event), queue_id) in requests.into_iter().zip(keys) {
@@ -277,7 +281,7 @@ impl Service {
}
/// Sends a request to a federation server
#[tracing::instrument(skip_all, name = "request", level = "debug")]
#[inline]
pub async fn send_federation_request<T>(
&self,
dest: &ServerName,
@@ -286,12 +290,11 @@ impl Service {
where
T: OutgoingRequest + Debug + Send,
{
let client = &self.services.client.federation;
self.send(client, dest, request).await
self.services.federation.execute(dest, request).await
}
/// Like send_federation_request() but with a very large timeout
#[tracing::instrument(skip_all, name = "synapse", level = "debug")]
#[inline]
pub async fn send_synapse_request<T>(
&self,
dest: &ServerName,
@@ -300,8 +303,10 @@ impl Service {
where
T: OutgoingRequest + Debug + Send,
{
let client = &self.services.client.synapse;
self.send(client, dest, request).await
self.services
.federation
.execute_synapse(dest, request)
.await
}
/// Sends a request to an appservice
+87 -84
View File
@@ -8,12 +8,12 @@ use std::{
time::{Duration, Instant},
};
use base64::{engine::general_purpose, Engine as _};
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _};
use conduwuit::{
debug, err, error,
result::LogErr,
trace,
utils::{calculate_hash, continue_exponential_backoff_secs, ReadyExt},
utils::{calculate_hash, continue_exponential_backoff_secs, stream::IterStream, ReadyExt},
warn, Error, Result,
};
use futures::{
@@ -38,12 +38,16 @@ use ruma::{
push_rules::PushRulesEvent, receipt::ReceiptType, AnySyncEphemeralRoomEvent,
GlobalAccountDataEventType,
},
push, uint, CanonicalJsonObject, MilliSecondsSinceUnixEpoch, OwnedRoomId, OwnedServerName,
push,
serde::Raw,
uint, CanonicalJsonObject, MilliSecondsSinceUnixEpoch, OwnedRoomId, OwnedServerName,
OwnedUserId, RoomId, RoomVersionId, ServerName, UInt,
};
use serde_json::value::{to_raw_value, RawValue as RawJsonValue};
use super::{appservice, data::QueueItem, Destination, Msg, SendingEvent, Service};
use super::{
appservice, data::QueueItem, Destination, EduBuf, EduVec, Msg, SendingEvent, Service,
};
#[derive(Debug)]
enum TransactionStatus {
@@ -311,7 +315,12 @@ impl Service {
if let Destination::Federation(server_name) = dest {
if let Ok((select_edus, last_count)) = self.select_edus(server_name).await {
debug_assert!(select_edus.len() <= EDU_LIMIT, "exceeded edus limit");
events.extend(select_edus.into_iter().map(SendingEvent::Edu));
let select_edus = select_edus
.into_iter()
.map(Into::into)
.map(SendingEvent::Edu);
events.extend(select_edus);
self.db.set_latest_educount(server_name, last_count);
}
}
@@ -355,7 +364,7 @@ impl Service {
level = "debug",
skip_all,
)]
async fn select_edus(&self, server_name: &ServerName) -> Result<(Vec<Vec<u8>>, u64)> {
async fn select_edus(&self, server_name: &ServerName) -> Result<(EduVec, u64)> {
// selection window
let since = self.db.get_latest_educount(server_name).await;
let since_upper = self.services.globals.current_count()?;
@@ -403,8 +412,8 @@ impl Service {
since: (u64, u64),
max_edu_count: &AtomicU64,
events_len: &AtomicUsize,
) -> Vec<Vec<u8>> {
let mut events = Vec::new();
) -> EduVec {
let mut events = EduVec::new();
let server_rooms = self.services.state_cache.server_rooms(server_name);
pin_mut!(server_rooms);
@@ -439,10 +448,11 @@ impl Service {
keys: None,
});
let edu = serde_json::to_vec(&edu)
let mut buf = EduBuf::new();
serde_json::to_writer(&mut buf, &edu)
.expect("failed to serialize device list update to JSON");
events.push(edu);
events.push(buf);
if events_len.fetch_add(1, Ordering::Relaxed) >= SELECT_EDU_LIMIT - 1 {
return events;
}
@@ -463,7 +473,7 @@ impl Service {
server_name: &ServerName,
since: (u64, u64),
max_edu_count: &AtomicU64,
) -> Option<Vec<u8>> {
) -> Option<EduBuf> {
let server_rooms = self.services.state_cache.server_rooms(server_name);
pin_mut!(server_rooms);
@@ -485,10 +495,11 @@ impl Service {
let receipt_content = Edu::Receipt(ReceiptContent { receipts });
let receipt_content = serde_json::to_vec(&receipt_content)
let mut buf = EduBuf::new();
serde_json::to_writer(&mut buf, &receipt_content)
.expect("Failed to serialize Receipt EDU to JSON vec");
Some(receipt_content)
Some(buf)
}
/// Look for read receipts in this room
@@ -567,7 +578,7 @@ impl Service {
server_name: &ServerName,
since: (u64, u64),
max_edu_count: &AtomicU64,
) -> Option<Vec<u8>> {
) -> Option<EduBuf> {
let presence_since = self.services.presence.presence_since(since.0);
pin_mut!(presence_since);
@@ -626,14 +637,15 @@ impl Service {
push: presence_updates.into_values().collect(),
});
let presence_content = serde_json::to_vec(&presence_content)
let mut buf = EduBuf::new();
serde_json::to_writer(&mut buf, &presence_content)
.expect("failed to serialize Presence EDU to JSON");
Some(presence_content)
Some(buf)
}
fn send_events(&self, dest: Destination, events: Vec<SendingEvent>) -> SendingFuture<'_> {
//debug_assert!(!events.is_empty(), "sending empty transaction");
debug_assert!(!events.is_empty(), "sending empty transaction");
match dest {
| Destination::Federation(server) =>
self.send_events_dest_federation(server, events).boxed(),
@@ -698,7 +710,7 @@ impl Service {
| SendingEvent::Flush => None,
}));
let txn_id = &*general_purpose::URL_SAFE_NO_PAD.encode(txn_hash);
let txn_id = &*URL_SAFE_NO_PAD.encode(txn_hash);
//debug_assert!(pdu_jsons.len() + edu_jsons.len() > 0, "sending empty
// transaction");
@@ -796,81 +808,72 @@ impl Service {
Ok(Destination::Push(user_id, pushkey))
}
#[tracing::instrument(
name = "fed",
level = "debug",
skip(self, events),
fields(
events = %events.len(),
),
)]
async fn send_events_dest_federation(
&self,
server: OwnedServerName,
events: Vec<SendingEvent>,
) -> SendingResult {
let mut pdu_jsons = Vec::with_capacity(
events
.iter()
.filter(|event| matches!(event, SendingEvent::Pdu(_)))
.count(),
);
let mut edu_jsons = Vec::with_capacity(
events
.iter()
.filter(|event| matches!(event, SendingEvent::Edu(_)))
.count(),
);
let pdus: Vec<_> = events
.iter()
.filter_map(|pdu| match pdu {
| SendingEvent::Pdu(pdu) => Some(pdu),
| _ => None,
})
.stream()
.then(|pdu_id| self.services.timeline.get_pdu_json_from_id(pdu_id))
.ready_filter_map(Result::ok)
.then(|pdu| self.convert_to_outgoing_federation_event(pdu))
.collect()
.await;
for event in &events {
match event {
// TODO: check room version and remove event_id if needed
| SendingEvent::Pdu(pdu_id) => {
if let Ok(pdu) = self.services.timeline.get_pdu_json_from_id(pdu_id).await {
pdu_jsons.push(self.convert_to_outgoing_federation_event(pdu).await);
}
},
| SendingEvent::Edu(edu) =>
if let Ok(raw) = serde_json::from_slice(edu) {
edu_jsons.push(raw);
},
| SendingEvent::Flush => {}, // flush only; no new content
let edus: Vec<Raw<Edu>> = events
.iter()
.filter_map(|edu| match edu {
| SendingEvent::Edu(edu) => Some(edu.as_ref()),
| _ => None,
})
.map(serde_json::from_slice)
.filter_map(Result::ok)
.collect();
if pdus.is_empty() && edus.is_empty() {
return Ok(Destination::Federation(server));
}
let preimage = pdus
.iter()
.map(|raw| raw.get().as_bytes())
.chain(edus.iter().map(|raw| raw.json().get().as_bytes()));
let txn_hash = calculate_hash(preimage);
let txn_id = &*URL_SAFE_NO_PAD.encode(txn_hash);
let request = send_transaction_message::v1::Request {
transaction_id: txn_id.into(),
origin: self.server.name.clone(),
origin_server_ts: MilliSecondsSinceUnixEpoch::now(),
pdus,
edus,
};
let result = self
.services
.federation
.execute_on(&self.services.client.sender, &server, request)
.await;
for (event_id, result) in result.iter().flat_map(|resp| resp.pdus.iter()) {
if let Err(e) = result {
warn!(
%txn_id, %server,
"error sending PDU {event_id} to remote server: {e:?}"
);
}
}
//debug_assert!(pdu_jsons.len() + edu_jsons.len() > 0, "sending empty
// transaction");
let txn_hash = calculate_hash(events.iter().filter_map(|e| match e {
| SendingEvent::Edu(b) => Some(&**b),
| SendingEvent::Pdu(b) => Some(b.as_ref()),
| SendingEvent::Flush => None,
}));
let txn_id = &*general_purpose::URL_SAFE_NO_PAD.encode(txn_hash);
let request = send_transaction_message::v1::Request {
origin: self.server.config.server_name.clone(),
pdus: pdu_jsons,
edus: edu_jsons,
origin_server_ts: MilliSecondsSinceUnixEpoch::now(),
transaction_id: txn_id.into(),
};
let client = &self.services.client.sender;
self.send(client, &server, request)
.await
.inspect(|response| {
response
.pdus
.iter()
.filter(|(_, res)| res.is_err())
.for_each(
|(pdu_id, res)| warn!(%txn_id, %server, "error sending PDU {pdu_id} to remote server: {res:?}"),
);
})
.map_err(|e| (Destination::Federation(server.clone()), e))
.map(|_| Destination::Federation(server))
match result {
| Err(error) => Err((Destination::Federation(server), error)),
| Ok(_) => Ok(Destination::Federation(server)),
}
}
/// This does not return a full `Pdu` it is only to satisfy ruma's types.
+5 -1
View File
@@ -10,7 +10,7 @@ use database::Database;
use tokio::sync::Mutex;
use crate::{
account_data, admin, appservice, client, emergency, globals, key_backups,
account_data, admin, appservice, client, config, emergency, federation, globals, key_backups,
manager::Manager,
media, presence, pusher, resolver, rooms, sending, server_keys, service,
service::{Args, Map, Service},
@@ -21,6 +21,7 @@ pub struct Services {
pub account_data: Arc<account_data::Service>,
pub admin: Arc<admin::Service>,
pub appservice: Arc<appservice::Service>,
pub config: Arc<config::Service>,
pub client: Arc<client::Service>,
pub emergency: Arc<emergency::Service>,
pub globals: Arc<globals::Service>,
@@ -30,6 +31,7 @@ pub struct Services {
pub pusher: Arc<pusher::Service>,
pub resolver: Arc<resolver::Service>,
pub rooms: rooms::Service,
pub federation: Arc<federation::Service>,
pub sending: Arc<sending::Service>,
pub server_keys: Arc<server_keys::Service>,
pub sync: Arc<sync::Service>,
@@ -67,6 +69,7 @@ impl Services {
appservice: build!(appservice::Service),
resolver: build!(resolver::Service),
client: build!(client::Service),
config: build!(config::Service),
emergency: build!(emergency::Service),
globals: build!(globals::Service),
key_backups: build!(key_backups::Service),
@@ -95,6 +98,7 @@ impl Services {
typing: build!(rooms::typing::Service),
user: build!(rooms::user::Service),
},
federation: build!(federation::Service),
sending: build!(sending::Service),
server_keys: build!(server_keys::Service),
sync: build!(sync::Service),
+2 -2
View File
@@ -97,8 +97,8 @@ pub async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result {
);
// Server shutdown
let server_shutdown = self.services.server.clone().until_shutdown().boxed();
futures.push(server_shutdown);
futures.push(self.services.server.until_shutdown().boxed());
if !self.services.server.running() {
return Ok(());
}