Compare commits

..

1 Commits

Author SHA1 Message Date
Ginger 77ae79396f feat: Use OptimisticTransactionDB for the database engine 2025-12-16 09:49:13 -05:00
45 changed files with 312 additions and 1467 deletions
Generated
+11 -11
View File
@@ -4063,7 +4063,7 @@ checksum = "88f8660c1ff60292143c98d08fc6e2f654d722db50410e3f3797d40baaf9d8f3"
[[package]] [[package]]
name = "ruma" name = "ruma"
version = "0.10.1" version = "0.10.1"
source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=27abe0dcd33fd4056efc94bab3582646b31b6ce9#27abe0dcd33fd4056efc94bab3582646b31b6ce9" source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=50b2a91b2ab8f9830eea80b9911e11234e0eac66#50b2a91b2ab8f9830eea80b9911e11234e0eac66"
dependencies = [ dependencies = [
"assign", "assign",
"js_int", "js_int",
@@ -4083,7 +4083,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma-appservice-api" name = "ruma-appservice-api"
version = "0.10.0" version = "0.10.0"
source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=27abe0dcd33fd4056efc94bab3582646b31b6ce9#27abe0dcd33fd4056efc94bab3582646b31b6ce9" source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=50b2a91b2ab8f9830eea80b9911e11234e0eac66#50b2a91b2ab8f9830eea80b9911e11234e0eac66"
dependencies = [ dependencies = [
"js_int", "js_int",
"ruma-common", "ruma-common",
@@ -4095,7 +4095,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma-client-api" name = "ruma-client-api"
version = "0.18.0" version = "0.18.0"
source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=27abe0dcd33fd4056efc94bab3582646b31b6ce9#27abe0dcd33fd4056efc94bab3582646b31b6ce9" source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=50b2a91b2ab8f9830eea80b9911e11234e0eac66#50b2a91b2ab8f9830eea80b9911e11234e0eac66"
dependencies = [ dependencies = [
"as_variant", "as_variant",
"assign", "assign",
@@ -4118,7 +4118,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma-common" name = "ruma-common"
version = "0.13.0" version = "0.13.0"
source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=27abe0dcd33fd4056efc94bab3582646b31b6ce9#27abe0dcd33fd4056efc94bab3582646b31b6ce9" source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=50b2a91b2ab8f9830eea80b9911e11234e0eac66#50b2a91b2ab8f9830eea80b9911e11234e0eac66"
dependencies = [ dependencies = [
"as_variant", "as_variant",
"base64 0.22.1", "base64 0.22.1",
@@ -4150,7 +4150,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma-events" name = "ruma-events"
version = "0.28.1" version = "0.28.1"
source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=27abe0dcd33fd4056efc94bab3582646b31b6ce9#27abe0dcd33fd4056efc94bab3582646b31b6ce9" source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=50b2a91b2ab8f9830eea80b9911e11234e0eac66#50b2a91b2ab8f9830eea80b9911e11234e0eac66"
dependencies = [ dependencies = [
"as_variant", "as_variant",
"indexmap", "indexmap",
@@ -4175,7 +4175,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma-federation-api" name = "ruma-federation-api"
version = "0.9.0" version = "0.9.0"
source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=27abe0dcd33fd4056efc94bab3582646b31b6ce9#27abe0dcd33fd4056efc94bab3582646b31b6ce9" source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=50b2a91b2ab8f9830eea80b9911e11234e0eac66#50b2a91b2ab8f9830eea80b9911e11234e0eac66"
dependencies = [ dependencies = [
"bytes", "bytes",
"headers", "headers",
@@ -4197,7 +4197,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma-identifiers-validation" name = "ruma-identifiers-validation"
version = "0.9.5" version = "0.9.5"
source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=27abe0dcd33fd4056efc94bab3582646b31b6ce9#27abe0dcd33fd4056efc94bab3582646b31b6ce9" source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=50b2a91b2ab8f9830eea80b9911e11234e0eac66#50b2a91b2ab8f9830eea80b9911e11234e0eac66"
dependencies = [ dependencies = [
"js_int", "js_int",
"thiserror 2.0.17", "thiserror 2.0.17",
@@ -4206,7 +4206,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma-identity-service-api" name = "ruma-identity-service-api"
version = "0.9.0" version = "0.9.0"
source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=27abe0dcd33fd4056efc94bab3582646b31b6ce9#27abe0dcd33fd4056efc94bab3582646b31b6ce9" source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=50b2a91b2ab8f9830eea80b9911e11234e0eac66#50b2a91b2ab8f9830eea80b9911e11234e0eac66"
dependencies = [ dependencies = [
"js_int", "js_int",
"ruma-common", "ruma-common",
@@ -4216,7 +4216,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma-macros" name = "ruma-macros"
version = "0.13.0" version = "0.13.0"
source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=27abe0dcd33fd4056efc94bab3582646b31b6ce9#27abe0dcd33fd4056efc94bab3582646b31b6ce9" source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=50b2a91b2ab8f9830eea80b9911e11234e0eac66#50b2a91b2ab8f9830eea80b9911e11234e0eac66"
dependencies = [ dependencies = [
"cfg-if", "cfg-if",
"proc-macro-crate", "proc-macro-crate",
@@ -4231,7 +4231,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma-push-gateway-api" name = "ruma-push-gateway-api"
version = "0.9.0" version = "0.9.0"
source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=27abe0dcd33fd4056efc94bab3582646b31b6ce9#27abe0dcd33fd4056efc94bab3582646b31b6ce9" source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=50b2a91b2ab8f9830eea80b9911e11234e0eac66#50b2a91b2ab8f9830eea80b9911e11234e0eac66"
dependencies = [ dependencies = [
"js_int", "js_int",
"ruma-common", "ruma-common",
@@ -4243,7 +4243,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma-signatures" name = "ruma-signatures"
version = "0.15.0" version = "0.15.0"
source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=27abe0dcd33fd4056efc94bab3582646b31b6ce9#27abe0dcd33fd4056efc94bab3582646b31b6ce9" source = "git+https://forgejo.ellis.link/continuwuation/ruwuma?rev=50b2a91b2ab8f9830eea80b9911e11234e0eac66#50b2a91b2ab8f9830eea80b9911e11234e0eac66"
dependencies = [ dependencies = [
"base64 0.22.1", "base64 0.22.1",
"ed25519-dalek", "ed25519-dalek",
+1 -1
View File
@@ -351,7 +351,7 @@ version = "0.1.2"
# Used for matrix spec type definitions and helpers # Used for matrix spec type definitions and helpers
[workspace.dependencies.ruma] [workspace.dependencies.ruma]
git = "https://forgejo.ellis.link/continuwuation/ruwuma" git = "https://forgejo.ellis.link/continuwuation/ruwuma"
rev = "27abe0dcd33fd4056efc94bab3582646b31b6ce9" rev = "50b2a91b2ab8f9830eea80b9911e11234e0eac66"
features = [ features = [
"compat", "compat",
"rand", "rand",
+2 -5
View File
@@ -586,13 +586,10 @@
#allow_unstable_room_versions = true #allow_unstable_room_versions = true
# Default room version continuwuity will create rooms with. # Default room version continuwuity will create rooms with.
# Note that this has to be a string since the room version is a string
# rather than an integer. Forgetting the quotes will make the server fail
# to start!
# #
# Per spec, room version "11" is the default. # Per spec, room version 11 is the default.
# #
#default_room_version = "11" #default_room_version = 11
# Enable OpenTelemetry OTLP tracing export. This replaces the deprecated # Enable OpenTelemetry OTLP tracing export. This replaces the deprecated
# Jaeger exporter. Traces will be sent via OTLP to a collector (such as # Jaeger exporter. Traces will be sent via OTLP to a collector (such as
+92 -118
View File
@@ -6,69 +6,6 @@
pkgs, pkgs,
... ...
}: }:
let
baseTestScript =
pkgs.writers.writePython3Bin "do_test" { libraries = [ pkgs.python3Packages.matrix-nio ]; }
''
import asyncio
import nio
async def main() -> None:
# Connect to continuwuity
client = nio.AsyncClient("http://continuwuity:6167", "alice")
# Register as user alice
response = await client.register("alice", "my-secret-password")
# Log in as user alice
response = await client.login("my-secret-password")
# Create a new room
response = await client.room_create(federate=False)
print("Matrix room create response:", response)
assert isinstance(response, nio.RoomCreateResponse)
room_id = response.room_id
# Join the room
response = await client.join(room_id)
print("Matrix join response:", response)
assert isinstance(response, nio.JoinResponse)
# Send a message to the room
response = await client.room_send(
room_id=room_id,
message_type="m.room.message",
content={
"msgtype": "m.text",
"body": "Hello continuwuity!"
}
)
print("Matrix room send response:", response)
assert isinstance(response, nio.RoomSendResponse)
# Sync responses
response = await client.sync(timeout=30000)
print("Matrix sync response:", response)
assert isinstance(response, nio.SyncResponse)
# Check the message was received by continuwuity
last_message = response.rooms.join[room_id].timeline.events[-1].body
assert last_message == "Hello continuwuity!"
# Leave the room
response = await client.room_leave(room_id)
print("Matrix room leave response:", response)
assert isinstance(response, nio.RoomLeaveResponse)
# Close the client
await client.close()
if __name__ == "__main__":
asyncio.run(main())
'';
in
{ {
# run some nixos tests as checks # run some nixos tests as checks
checks = lib.pipe self'.packages [ checks = lib.pipe self'.packages [
@@ -81,69 +18,106 @@
# this test was initially yoinked from # this test was initially yoinked from
# #
# https://github.com/NixOS/nixpkgs/blob/960ce26339661b1b69c6f12b9063ca51b688615f/nixos/tests/matrix/continuwuity.nix # https://github.com/NixOS/nixpkgs/blob/960ce26339661b1b69c6f12b9063ca51b688615f/nixos/tests/matrix/continuwuity.nix
(builtins.concatMap ( (builtins.map (name: {
name: name = "test-${name}";
builtins.map value = pkgs.testers.runNixOSTest {
( inherit name;
{ config, suffix }:
{
name = "test-${name}-${suffix}";
value = pkgs.testers.runNixOSTest {
inherit name;
nodes = { nodes = {
continuwuity = { continuwuity = {
services.matrix-continuwuity = { services.matrix-continuwuity = {
enable = true; enable = true;
package = self'.packages.${name}; package = self'.packages.${name};
settings = config; settings.global = {
extraEnvironment.RUST_BACKTRACE = "yes";
};
networking.firewall.allowedTCPPorts = [ 6167 ];
};
client.environment.systemPackages = [ baseTestScript ];
};
testScript = ''
start_all()
with subtest("start continuwuity"):
continuwuity.wait_for_unit("continuwuity.service")
continuwuity.wait_for_open_port(6167)
with subtest("ensure messages can be exchanged"):
client.succeed("${lib.getExe baseTestScript} >&2")
'';
};
}
)
[
{
suffix = "base";
config = {
global = {
server_name = name; server_name = name;
address = [ "0.0.0.0" ]; address = [ "0.0.0.0" ];
allow_registration = true; allow_registration = true;
yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse = true; yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse = true;
}; };
extraEnvironment.RUST_BACKTRACE = "yes";
}; };
} networking.firewall.allowedTCPPorts = [ 6167 ];
{ };
suffix = "with-room-version"; client =
config = { { pkgs, ... }:
global = { {
server_name = name; environment.systemPackages = [
address = [ "0.0.0.0" ]; (pkgs.writers.writePython3Bin "do_test" { libraries = [ pkgs.python3Packages.matrix-nio ]; } ''
allow_registration = true; import asyncio
yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse = true; import nio
default_room_version = "12";
};
async def main() -> None:
# Connect to continuwuity
client = nio.AsyncClient("http://continuwuity:6167", "alice")
# Register as user alice
response = await client.register("alice", "my-secret-password")
# Log in as user alice
response = await client.login("my-secret-password")
# Create a new room
response = await client.room_create(federate=False)
print("Matrix room create response:", response)
assert isinstance(response, nio.RoomCreateResponse)
room_id = response.room_id
# Join the room
response = await client.join(room_id)
print("Matrix join response:", response)
assert isinstance(response, nio.JoinResponse)
# Send a message to the room
response = await client.room_send(
room_id=room_id,
message_type="m.room.message",
content={
"msgtype": "m.text",
"body": "Hello continuwuity!"
}
)
print("Matrix room send response:", response)
assert isinstance(response, nio.RoomSendResponse)
# Sync responses
response = await client.sync(timeout=30000)
print("Matrix sync response:", response)
assert isinstance(response, nio.SyncResponse)
# Check the message was received by continuwuity
last_message = response.rooms.join[room_id].timeline.events[-1].body
assert last_message == "Hello continuwuity!"
# Leave the room
response = await client.room_leave(room_id)
print("Matrix room leave response:", response)
assert isinstance(response, nio.RoomLeaveResponse)
# Close the client
await client.close()
if __name__ == "__main__":
asyncio.run(main())
'')
];
}; };
} };
]
)) testScript = ''
start_all()
with subtest("start continuwuity"):
continuwuity.wait_for_unit("continuwuity.service")
continuwuity.wait_for_open_port(6167)
with subtest("ensure messages can be exchanged"):
client.succeed("do_test >&2")
'';
};
}))
builtins.listToAttrs builtins.listToAttrs
]; ];
}; };
+2 -2
View File
@@ -31,7 +31,7 @@ pub(super) async fn last(&self, room_id: OwnedRoomOrAliasId) -> Result {
.services .services
.rooms .rooms
.timeline .timeline
.last_timeline_count(&room_id) .last_timeline_count(None, &room_id)
.await?; .await?;
self.write_str(&format!("{result:#?}")).await self.write_str(&format!("{result:#?}")).await
@@ -52,7 +52,7 @@ pub(super) async fn pdus(
.services .services
.rooms .rooms
.timeline .timeline
.pdus_rev(&room_id, from) .pdus_rev(None, &room_id, from)
.try_take(limit.unwrap_or(3)) .try_take(limit.unwrap_or(3))
.try_collect() .try_collect()
.await?; .await?;
+3 -24
View File
@@ -30,31 +30,10 @@ pub(super) async fn show_config(&self) -> Result {
#[admin_command] #[admin_command]
pub(super) async fn reload_config(&self, path: Option<PathBuf>) -> Result { pub(super) async fn reload_config(&self, path: Option<PathBuf>) -> Result {
// The path argument is only what's optionally passed via the admin command, let path = path.as_deref().into_iter();
// so we need to merge it with the existing paths if any were given at startup. self.services.config.reload(path)?;
let mut paths = Vec::new();
// Add previously saved paths to the argument list self.write_str("Successfully reconfigured.").await
self.services
.config
.config_paths
.clone()
.unwrap_or_default()
.iter()
.for_each(|p| paths.push(p.to_owned()));
// If a path is given, and it's not already in the list,
// add it last, so that it overrides earlier files
if let Some(p) = path {
if !paths.contains(&p) {
paths.push(p);
}
}
self.services.config.reload(&paths)?;
self.write_str(&format!("Successfully reconfigured from paths: {paths:?}"))
.await
} }
#[admin_command] #[admin_command]
+4 -30
View File
@@ -59,7 +59,7 @@ pub(crate) async fn get_context_route(
.rooms .rooms
.timeline .timeline
.get_pdu(event_id) .get_pdu(event_id)
.map_err(|_| err!(Request(NotFound("Event not found.")))); .map_err(|_| err!(Request(NotFound("Base event not found."))));
let visible = services let visible = services
.rooms .rooms
@@ -70,7 +70,7 @@ pub(crate) async fn get_context_route(
let (base_id, base_pdu, visible) = try_join3(base_id, base_pdu, visible).await?; let (base_id, base_pdu, visible) = try_join3(base_id, base_pdu, visible).await?;
if base_pdu.room_id_or_hash() != *room_id || base_pdu.event_id != *event_id { if base_pdu.room_id_or_hash() != *room_id || base_pdu.event_id != *event_id {
return Err!(Request(NotFound("Event not found."))); return Err!(Request(NotFound("Base event not found.")));
} }
if !visible { if !visible {
@@ -82,25 +82,11 @@ pub(crate) async fn get_context_route(
let base_event = ignored_filter(&services, (base_count, base_pdu), sender_user); let base_event = ignored_filter(&services, (base_count, base_pdu), sender_user);
// PDUs are used to get seen user IDs and then returned in response.
let events_before = services let events_before = services
.rooms .rooms
.timeline .timeline
.pdus_rev(room_id, Some(base_count)) .pdus_rev(Some(sender_user), room_id, Some(base_count))
.ignore_err() .ignore_err()
.then(async |mut pdu| {
pdu.1.set_unsigned(Some(sender_user));
if let Err(e) = services
.rooms
.pdu_metadata
.add_bundled_aggregations_to_pdu(sender_user, &mut pdu.1)
.await
{
debug_warn!("Failed to add bundled aggregations: {e}");
}
pdu
})
.ready_filter_map(|item| event_filter(item, filter)) .ready_filter_map(|item| event_filter(item, filter))
.wide_filter_map(|item| ignored_filter(&services, item, sender_user)) .wide_filter_map(|item| ignored_filter(&services, item, sender_user))
.wide_filter_map(|item| visibility_filter(&services, item, sender_user)) .wide_filter_map(|item| visibility_filter(&services, item, sender_user))
@@ -110,20 +96,8 @@ pub(crate) async fn get_context_route(
let events_after = services let events_after = services
.rooms .rooms
.timeline .timeline
.pdus(room_id, Some(base_count)) .pdus(Some(sender_user), room_id, Some(base_count))
.ignore_err() .ignore_err()
.then(async |mut pdu| {
pdu.1.set_unsigned(Some(sender_user));
if let Err(e) = services
.rooms
.pdu_metadata
.add_bundled_aggregations_to_pdu(sender_user, &mut pdu.1)
.await
{
debug_warn!("Failed to add bundled aggregations: {e}");
}
pdu
})
.ready_filter_map(|item| event_filter(item, filter)) .ready_filter_map(|item| event_filter(item, filter))
.wide_filter_map(|item| ignored_filter(&services, item, sender_user)) .wide_filter_map(|item| ignored_filter(&services, item, sender_user))
.wide_filter_map(|item| visibility_filter(&services, item, sender_user)) .wide_filter_map(|item| visibility_filter(&services, item, sender_user))
+3 -15
View File
@@ -1,7 +1,7 @@
use axum::extract::State; use axum::extract::State;
use axum_client_ip::InsecureClientIp; use axum_client_ip::InsecureClientIp;
use conduwuit::{ use conduwuit::{
Err, Result, at, debug_warn, Err, Result, at,
matrix::{ matrix::{
event::{Event, Matches}, event::{Event, Matches},
pdu::PduCount, pdu::PduCount,
@@ -122,14 +122,14 @@ pub(crate) async fn get_message_events_route(
| Direction::Forward => services | Direction::Forward => services
.rooms .rooms
.timeline .timeline
.pdus(room_id, Some(from)) .pdus(Some(sender_user), room_id, Some(from))
.ignore_err() .ignore_err()
.boxed(), .boxed(),
| Direction::Backward => services | Direction::Backward => services
.rooms .rooms
.timeline .timeline
.pdus_rev(room_id, Some(from)) .pdus_rev(Some(sender_user), room_id, Some(from))
.ignore_err() .ignore_err()
.boxed(), .boxed(),
}; };
@@ -140,18 +140,6 @@ pub(crate) async fn get_message_events_route(
.wide_filter_map(|item| ignored_filter(&services, item, sender_user)) .wide_filter_map(|item| ignored_filter(&services, item, sender_user))
.wide_filter_map(|item| visibility_filter(&services, item, sender_user)) .wide_filter_map(|item| visibility_filter(&services, item, sender_user))
.take(limit) .take(limit)
.then(async |mut pdu| {
pdu.1.set_unsigned(Some(sender_user));
if let Err(e) = services
.rooms
.pdu_metadata
.add_bundled_aggregations_to_pdu(sender_user, &mut pdu.1)
.await
{
debug_warn!("Failed to add bundled aggregations: {e}");
}
pdu
})
.collect() .collect()
.await; .await;
+38 -26
View File
@@ -1,6 +1,6 @@
use axum::extract::State; use axum::extract::State;
use conduwuit::{ use conduwuit::{
Err, Result, at, debug_warn, Result, at,
matrix::{Event, event::RelationTypeEqual, pdu::PduCount}, matrix::{Event, event::RelationTypeEqual, pdu::PduCount},
utils::{IterStream, ReadyExt, result::FlatOk, stream::WidebandExt}, utils::{IterStream, ReadyExt, result::FlatOk, stream::WidebandExt},
}; };
@@ -109,16 +109,6 @@ async fn paginate_relations_with_filter(
recurse: bool, recurse: bool,
dir: Direction, dir: Direction,
) -> Result<get_relating_events::v1::Response> { ) -> Result<get_relating_events::v1::Response> {
if !services
.rooms
.state_accessor
.user_can_see_event(sender_user, room_id, target)
.await
{
debug_warn!(req_evt = ?target, ?room_id, "Event relations requested by {sender_user} but is not allowed to see it, returning 404");
return Err!(Request(NotFound("Event not found.")));
}
let start: PduCount = from let start: PduCount = from
.map(str::parse) .map(str::parse)
.transpose()? .transpose()?
@@ -139,6 +129,11 @@ async fn paginate_relations_with_filter(
// Spec (v1.10) recommends depth of at least 3 // Spec (v1.10) recommends depth of at least 3
let depth: u8 = if recurse { 3 } else { 1 }; let depth: u8 = if recurse { 3 } else { 1 };
// Check if this is a thread request
let is_thread = filter_rel_type
.as_ref()
.is_some_and(|rel| *rel == RelationType::Thread);
let events: Vec<_> = services let events: Vec<_> = services
.rooms .rooms
.pdu_metadata .pdu_metadata
@@ -157,24 +152,40 @@ async fn paginate_relations_with_filter(
}) })
.stream() .stream()
.ready_take_while(|(count, _)| Some(*count) != to) .ready_take_while(|(count, _)| Some(*count) != to)
.take(limit)
.wide_filter_map(|item| visibility_filter(services, sender_user, item)) .wide_filter_map(|item| visibility_filter(services, sender_user, item))
.then(async |mut pdu| { .take(limit)
if let Err(e) = services
.rooms
.pdu_metadata
.add_bundled_aggregations_to_pdu(sender_user, &mut pdu.1)
.await
{
debug_warn!("Failed to add bundled aggregations to relation: {e}");
}
pdu
})
.collect() .collect()
.await; .await;
// For threads, check if we should include the root event
let mut root_event = None;
if is_thread && dir == Direction::Backward {
// Check if we've reached the beginning of the thread
// (fewer events than requested means we've exhausted the thread)
if events.len() < limit {
// Try to get the thread root event
if let Ok(root_pdu) = services.rooms.timeline.get_pdu(target).await {
// Check visibility
if services
.rooms
.state_accessor
.user_can_see_event(sender_user, room_id, target)
.await
{
// Store the root event to add to the response
root_event = Some(root_pdu);
}
}
}
}
// Determine if there are more events to fetch // Determine if there are more events to fetch
let has_more = events.len() >= limit; let has_more = if root_event.is_some() {
false // We've included the root, no more events
} else {
// Check if we got a full page of results (might be more)
events.len() >= limit
};
let next_batch = if has_more { let next_batch = if has_more {
match dir { match dir {
@@ -186,10 +197,11 @@ async fn paginate_relations_with_filter(
None None
}; };
let chunk: Vec<_> = events // Build the response chunk with thread root if needed
let chunk: Vec<_> = root_event
.into_iter() .into_iter()
.map(at!(1))
.map(Event::into_format) .map(Event::into_format)
.chain(events.into_iter().map(at!(1)).map(Event::into_format))
.collect(); .collect();
Ok(get_relating_events::v1::Response { Ok(get_relating_events::v1::Response {
+2 -11
View File
@@ -1,5 +1,5 @@
use axum::extract::State; use axum::extract::State;
use conduwuit::{Err, Event, Result, debug_warn, err}; use conduwuit::{Err, Event, Result, err};
use futures::{FutureExt, TryFutureExt, future::try_join}; use futures::{FutureExt, TryFutureExt, future::try_join};
use ruma::api::client::room::get_room_event; use ruma::api::client::room::get_room_event;
@@ -33,16 +33,7 @@ pub(crate) async fn get_room_event_route(
return Err!(Request(Forbidden("You don't have permission to view this event."))); return Err!(Request(Forbidden("You don't have permission to view this event.")));
} }
if let Err(e) = services event.add_age().ok();
.rooms
.pdu_metadata
.add_bundled_aggregations_to_pdu(body.sender_user(), &mut event)
.await
{
debug_warn!("Failed to add bundled aggregations to event: {e}");
}
event.set_unsigned(body.sender_user.as_deref());
Ok(get_room_event::v3::Response { event: event.into_format() }) Ok(get_room_event::v3::Response { event: event.into_format() })
} }
+2 -18
View File
@@ -1,6 +1,6 @@
use axum::extract::State; use axum::extract::State;
use conduwuit::{ use conduwuit::{
Err, Event, Result, at, debug_warn, Err, Event, Result, at,
utils::{BoolExt, stream::TryTools}, utils::{BoolExt, stream::TryTools},
}; };
use futures::{FutureExt, TryStreamExt, future::try_join4}; use futures::{FutureExt, TryStreamExt, future::try_join4};
@@ -40,28 +40,12 @@ pub(crate) async fn room_initial_sync_route(
.map_ok(Event::into_format) .map_ok(Event::into_format)
.try_collect::<Vec<_>>(); .try_collect::<Vec<_>>();
// Events are returned in body
let limit = LIMIT_MAX; let limit = LIMIT_MAX;
let events = services let events = services
.rooms .rooms
.timeline .timeline
.pdus_rev(room_id, None) .pdus_rev(None, room_id, None)
.try_take(limit) .try_take(limit)
.and_then(async |mut pdu| {
pdu.1.set_unsigned(body.sender_user.as_deref());
if let Some(sender_user) = body.sender_user.as_deref() {
if let Err(e) = services
.rooms
.pdu_metadata
.add_bundled_aggregations_to_pdu(sender_user, &mut pdu.1)
.await
{
debug_warn!("Failed to add bundled aggregations: {e}");
}
}
Ok(pdu)
})
.try_collect::<Vec<_>>(); .try_collect::<Vec<_>>();
let (membership, visibility, state, events) = let (membership, visibility, state, events) =
+3 -19
View File
@@ -2,7 +2,7 @@ use std::collections::BTreeMap;
use axum::extract::State; use axum::extract::State;
use conduwuit::{ use conduwuit::{
Err, Result, at, debug_warn, is_true, Err, Result, at, is_true,
matrix::Event, matrix::Event,
result::FlatOk, result::FlatOk,
utils::{IterStream, stream::ReadyExt}, utils::{IterStream, stream::ReadyExt},
@@ -50,7 +50,7 @@ pub(crate) async fn search_events_route(
Ok(Response { Ok(Response {
search_categories: ResultCategories { search_categories: ResultCategories {
room_events: Box::pin(room_events_result) room_events: room_events_result
.await .await
.unwrap_or_else(|| Ok(ResultRoomEvents::default()))?, .unwrap_or_else(|| Ok(ResultRoomEvents::default()))?,
}, },
@@ -110,12 +110,7 @@ async fn category_room_events(
limit, limit,
}; };
let (count, results) = services let (count, results) = services.rooms.search.search_pdus(&query).await.ok()?;
.rooms
.search
.search_pdus(&query, sender_user)
.await
.ok()?;
results results
.collect::<Vec<_>>() .collect::<Vec<_>>()
@@ -149,17 +144,6 @@ async fn category_room_events(
.map(at!(2)) .map(at!(2))
.flatten() .flatten()
.stream() .stream()
.then(|mut pdu| async {
if let Err(e) = services
.rooms
.pdu_metadata
.add_bundled_aggregations_to_pdu(sender_user, &mut pdu)
.await
{
debug_warn!("Failed to add bundled aggregations to search result: {e}");
}
pdu
})
.map(Event::into_format) .map(Event::into_format)
.map(|result| SearchResult { .map(|result| SearchResult {
rank: None, rank: None,
+1 -1
View File
@@ -158,7 +158,7 @@ pub(crate) async fn get_state_events_for_key_route(
"content": event.content(), "content": event.content(),
"event_id": event.event_id(), "event_id": event.event_id(),
"origin_server_ts": event.origin_server_ts(), "origin_server_ts": event.origin_server_ts(),
"room_id": event.room_id_or_hash(), "room_id": event.room_id(),
"sender": event.sender(), "sender": event.sender(),
"state_key": event.state_key(), "state_key": event.state_key(),
"type": event.kind(), "type": event.kind(),
+12 -34
View File
@@ -4,7 +4,7 @@ mod v5;
use std::collections::VecDeque; use std::collections::VecDeque;
use conduwuit::{ use conduwuit::{
Event, PduCount, Result, debug_warn, err, Event, PduCount, Result, err,
matrix::pdu::PduEvent, matrix::pdu::PduEvent,
ref_at, trace, ref_at, trace,
utils::stream::{BroadbandExt, ReadyExt, TryIgnore}, utils::stream::{BroadbandExt, ReadyExt, TryIgnore},
@@ -53,7 +53,7 @@ async fn load_timeline(
let last_timeline_count = services let last_timeline_count = services
.rooms .rooms
.timeline .timeline
.last_timeline_count(room_id) .last_timeline_count(Some(sender_user), room_id)
.await .await
.map_err(|err| { .map_err(|err| {
err!(Database(warn!("Failed to fetch end of room timeline: {}", err))) err!(Database(warn!("Failed to fetch end of room timeline: {}", err)))
@@ -71,24 +71,13 @@ async fn load_timeline(
services services
.rooms .rooms
.timeline .timeline
.pdus_rev(room_id, ending_count.map(|count| count.saturating_add(1))) .pdus_rev(
Some(sender_user),
room_id,
ending_count.map(|count| count.saturating_add(1)),
)
.ignore_err() .ignore_err()
.ready_take_while(move |&(pducount, _)| pducount > starting_count) .ready_take_while(move |&(pducount, _)| pducount > starting_count)
.map(move |mut pdu| {
pdu.1.set_unsigned(Some(sender_user));
pdu
})
.then(async move |mut pdu| {
if let Err(e) = services
.rooms
.pdu_metadata
.add_bundled_aggregations_to_pdu(sender_user, &mut pdu.1)
.await
{
debug_warn!("Failed to add bundled aggregations: {e}");
}
pdu
})
.boxed() .boxed()
}, },
| None => { | None => {
@@ -97,23 +86,12 @@ async fn load_timeline(
services services
.rooms .rooms
.timeline .timeline
.pdus_rev(room_id, ending_count.map(|count| count.saturating_add(1))) .pdus_rev(
Some(sender_user),
room_id,
ending_count.map(|count| count.saturating_add(1)),
)
.ignore_err() .ignore_err()
.map(move |mut pdu| {
pdu.1.set_unsigned(Some(sender_user));
pdu
})
.then(async move |mut pdu| {
if let Err(e) = services
.rooms
.pdu_metadata
.add_bundled_aggregations_to_pdu(sender_user, &mut pdu.1)
.await
{
debug_warn!("Failed to add bundled aggregations: {e}");
}
pdu
})
.boxed() .boxed()
}, },
}; };
+1 -1
View File
@@ -127,7 +127,7 @@ pub(super) async fn build_state_incremental<'a>(
let last_pdu_of_last_sync = services let last_pdu_of_last_sync = services
.rooms .rooms
.timeline .timeline
.pdus_rev(room_id, Some(last_sync_end_count.saturating_add(1))) .pdus_rev(Some(sender_user), room_id, Some(last_sync_end_count.saturating_add(1)))
.boxed() .boxed()
.next() .next()
.await .await
+1 -12
View File
@@ -1,6 +1,6 @@
use axum::extract::State; use axum::extract::State;
use conduwuit::{ use conduwuit::{
Result, at, debug_warn, Result, at,
matrix::{ matrix::{
Event, Event,
pdu::{PduCount, PduEvent}, pdu::{PduCount, PduEvent},
@@ -45,17 +45,6 @@ pub(crate) async fn get_threads_route(
.await .await
.then_some((count, pdu)) .then_some((count, pdu))
}) })
.then(|(count, mut pdu)| async move {
if let Err(e) = services
.rooms
.pdu_metadata
.add_bundled_aggregations_to_pdu(body.sender_user(), &mut pdu)
.await
{
debug_warn!("Failed to add bundled aggregations to thread: {e}");
}
(count, pdu)
})
.collect() .collect()
.await; .await;
+1 -11
View File
@@ -3,7 +3,6 @@ use std::cmp;
use axum::extract::State; use axum::extract::State;
use conduwuit::{ use conduwuit::{
Event, PduCount, Result, Event, PduCount, Result,
result::LogErr,
utils::{IterStream, ReadyExt, stream::TryTools}, utils::{IterStream, ReadyExt, stream::TryTools},
}; };
use futures::{FutureExt, StreamExt, TryStreamExt}; use futures::{FutureExt, StreamExt, TryStreamExt};
@@ -63,7 +62,7 @@ pub(crate) async fn get_backfill_route(
pdus: services pdus: services
.rooms .rooms
.timeline .timeline
.pdus_rev(&body.room_id, Some(from.saturating_add(1))) .pdus_rev(None, &body.room_id, Some(from.saturating_add(1)))
.try_take(limit) .try_take(limit)
.try_filter_map(|(_, pdu)| async move { .try_filter_map(|(_, pdu)| async move {
Ok(services Ok(services
@@ -73,15 +72,6 @@ pub(crate) async fn get_backfill_route(
.await .await
.then_some(pdu)) .then_some(pdu))
}) })
.and_then(async |mut pdu| {
// Strip the transaction ID, as that is private
pdu.remove_transaction_id().log_err().ok();
// Add age, as this is specified
pdu.add_age().log_err().ok();
// It's not clear if we should strip or add any more data, leave as is.
// In particular: Redaction?
Ok(pdu)
})
.try_filter_map(|pdu| async move { .try_filter_map(|pdu| async move {
Ok(services Ok(services
.rooms .rooms
-41
View File
@@ -10,7 +10,6 @@ use conduwuit::{
use ruma::{ use ruma::{
CanonicalJsonValue, OwnedUserId, UserId, CanonicalJsonValue, OwnedUserId, UserId,
api::{client::error::ErrorKind, federation::membership::create_invite}, api::{client::error::ErrorKind, federation::membership::create_invite},
events::room::member::{MembershipState, RoomMemberEventContent},
serde::JsonObject, serde::JsonObject,
}; };
@@ -61,46 +60,6 @@ pub(crate) async fn create_invite_route(
let mut signed_event = utils::to_canonical_object(&body.event) let mut signed_event = utils::to_canonical_object(&body.event)
.map_err(|_| err!(Request(InvalidParam("Invite event is invalid."))))?; .map_err(|_| err!(Request(InvalidParam("Invite event is invalid."))))?;
// Ensure this is a membership event
if signed_event
.get("type")
.expect("event must have a type")
.as_str()
.expect("type must be a string")
!= "m.room.member"
{
return Err!(Request(BadJson(
"Not allowed to send non-membership event to invite endpoint."
)));
}
let content: RoomMemberEventContent = serde_json::from_value(
signed_event
.get("content")
.ok_or_else(|| err!(Request(BadJson("Event missing content property"))))?
.clone()
.into(),
)
.map_err(|e| err!(Request(BadJson(warn!("Event content is empty or invalid: {e}")))))?;
// Ensure this is an invite membership event
if content.membership != MembershipState::Invite {
return Err!(Request(BadJson(
"Not allowed to send a non-invite membership event to invite endpoint."
)));
}
// Ensure the sending user isn't a lying bozo
let sender_server = signed_event
.get("sender")
.try_into()
.map(UserId::server_name)
.map_err(|e| err!(Request(InvalidParam("Invalid sender property: {e}"))))?;
if sender_server != body.origin() {
return Err!(Request(Forbidden("Sender's server does not match the origin server.",)));
}
// Ensure the target user belongs to this server
let recipient_user: OwnedUserId = signed_event let recipient_user: OwnedUserId = signed_event
.get("state_key") .get("state_key")
.try_into() .try_into()
+11 -22
View File
@@ -6,7 +6,7 @@ pub mod proxy;
use std::{ use std::{
collections::{BTreeMap, BTreeSet}, collections::{BTreeMap, BTreeSet},
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
path::PathBuf, path::{Path, PathBuf},
}; };
use conduwuit_macros::config_example_generator; use conduwuit_macros::config_example_generator;
@@ -53,13 +53,9 @@ use crate::{Result, err, error::Error, utils::sys};
### For more information, see: ### For more information, see:
### https://continuwuity.org/configuration.html ### https://continuwuity.org/configuration.html
"#, "#,
ignore = "config_paths catchall well_known tls blurhashing allow_invalid_tls_certificates_yes_i_know_what_the_fuck_i_am_doing_with_this_and_i_know_this_is_insecure" ignore = "catchall well_known tls blurhashing allow_invalid_tls_certificates_yes_i_know_what_the_fuck_i_am_doing_with_this_and_i_know_this_is_insecure"
)] )]
pub struct Config { pub struct Config {
// Paths to config file(s). Not supposed to be set manually in the config file,
// only updated dynamically from the --config option given at runtime.
pub config_paths: Option<Vec<PathBuf>>,
/// The server_name is the pretty name of this server. It is used as a /// The server_name is the pretty name of this server. It is used as a
/// suffix for user and room IDs/aliases. /// suffix for user and room IDs/aliases.
/// ///
@@ -707,13 +703,10 @@ pub struct Config {
pub allow_unstable_room_versions: bool, pub allow_unstable_room_versions: bool,
/// Default room version continuwuity will create rooms with. /// Default room version continuwuity will create rooms with.
/// Note that this has to be a string since the room version is a string
/// rather than an integer. Forgetting the quotes will make the server fail
/// to start!
/// ///
/// Per spec, room version "11" is the default. /// Per spec, room version 11 is the default.
/// ///
/// default: "11" /// default: 11
#[serde(default = "default_default_room_version")] #[serde(default = "default_default_room_version")]
pub default_room_version: RoomVersionId, pub default_room_version: RoomVersionId,
@@ -1231,12 +1224,6 @@ pub struct Config {
#[serde(default)] #[serde(default)]
pub rocksdb_repair: bool, pub rocksdb_repair: bool,
#[serde(default)]
pub rocksdb_read_only: bool,
#[serde(default)]
pub rocksdb_secondary: bool,
/// Enables idle CPU priority for compaction thread. This is not enabled by /// Enables idle CPU priority for compaction thread. This is not enabled by
/// default to prevent compaction from falling too far behind on busy /// default to prevent compaction from falling too far behind on busy
/// systems. /// systems.
@@ -2230,24 +2217,26 @@ const DEPRECATED_KEYS: &[&str; 9] = &[
impl Config { impl Config {
/// Pre-initialize config /// Pre-initialize config
pub fn load(paths: &[PathBuf]) -> Result<Figment> { pub fn load<'a, I>(paths: I) -> Result<Figment>
where
I: Iterator<Item = &'a Path>,
{
let envs = [ let envs = [
Env::var("CONDUIT_CONFIG"), Env::var("CONDUIT_CONFIG"),
Env::var("CONDUWUIT_CONFIG"), Env::var("CONDUWUIT_CONFIG"),
Env::var("CONTINUWUITY_CONFIG"), Env::var("CONTINUWUITY_CONFIG"),
]; ];
let mut config = envs
let config = envs
.into_iter() .into_iter()
.flatten() .flatten()
.map(Toml::file) .map(Toml::file)
.chain(paths.iter().cloned().map(Toml::file)) .chain(paths.map(Toml::file))
.fold(Figment::new(), |config, file| config.merge(file.nested())) .fold(Figment::new(), |config, file| config.merge(file.nested()))
.merge(Env::prefixed("CONDUIT_").global().split("__")) .merge(Env::prefixed("CONDUIT_").global().split("__"))
.merge(Env::prefixed("CONDUWUIT_").global().split("__")) .merge(Env::prefixed("CONDUWUIT_").global().split("__"))
.merge(Env::prefixed("CONTINUWUITY_").global().split("__")); .merge(Env::prefixed("CONTINUWUITY_").global().split("__"));
config = config.join(("config_paths", paths));
Ok(config) Ok(config)
} }
+2 -2
View File
@@ -56,7 +56,7 @@ impl<'a, E: Event> From<Ref<'a, E>> for Raw<AnyTimelineEvent> {
"content": content, "content": content,
"event_id": event.event_id(), "event_id": event.event_id(),
"origin_server_ts": event.origin_server_ts(), "origin_server_ts": event.origin_server_ts(),
"room_id": event.room_id_or_hash(), "room_id": event.room_id(),
"sender": event.sender(), "sender": event.sender(),
"type": event.kind(), "type": event.kind(),
}); });
@@ -117,7 +117,7 @@ impl<'a, E: Event> From<Ref<'a, E>> for Raw<AnyStateEvent> {
"content": event.content(), "content": event.content(),
"event_id": event.event_id(), "event_id": event.event_id(),
"origin_server_ts": event.origin_server_ts(), "origin_server_ts": event.origin_server_ts(),
"room_id": event.room_id_or_hash(), "room_id": event.room_id(),
"sender": event.sender(), "sender": event.sender(),
"state_key": event.state_key(), "state_key": event.state_key(),
"type": event.kind(), "type": event.kind(),
+2 -15
View File
@@ -1,23 +1,10 @@
use std::{borrow::Borrow, collections::BTreeMap}; use std::collections::BTreeMap;
use ruma::MilliSecondsSinceUnixEpoch; use ruma::MilliSecondsSinceUnixEpoch;
use serde_json::value::{RawValue as RawJsonValue, Value as JsonValue, to_raw_value}; use serde_json::value::{RawValue as RawJsonValue, Value as JsonValue, to_raw_value};
use super::Pdu; use super::Pdu;
use crate::{Result, err, implement, result::LogErr}; use crate::{Result, err, implement};
/// Set the `unsigned` field of the PDU using only information in the PDU.
/// Some unsigned data is already set within the database (eg. prev events,
/// threads). Once this is done, other data must be calculated from the database
/// (eg. relations) This is for server-to-client events.
/// Backfill handles this itself.
#[implement(Pdu)]
pub fn set_unsigned(&mut self, user_id: Option<&ruma::UserId>) {
if Some(self.sender.borrow()) != user_id {
self.remove_transaction_id().log_err().ok();
}
self.add_age().log_err().ok();
}
#[implement(Pdu)] #[implement(Pdu)]
pub fn remove_transaction_id(&mut self) -> Result { pub fn remove_transaction_id(&mut self) -> Result {
+2 -12
View File
@@ -19,7 +19,7 @@ use std::{
use conduwuit::{Err, Result, debug, info, warn}; use conduwuit::{Err, Result, debug, info, warn};
use rocksdb::{ use rocksdb::{
AsColumnFamilyRef, BoundColumnFamily, DBCommon, DBWithThreadMode, MultiThreaded, AsColumnFamilyRef, BoundColumnFamily, DBCommon, MultiThreaded, OptimisticTransactionDB,
WaitForCompactOptions, WaitForCompactOptions,
}; };
@@ -33,13 +33,11 @@ pub struct Engine {
pub(crate) db: Db, pub(crate) db: Db,
pub(crate) pool: Arc<Pool>, pub(crate) pool: Arc<Pool>,
pub(crate) ctx: Arc<Context>, pub(crate) ctx: Arc<Context>,
pub(super) read_only: bool,
pub(super) secondary: bool,
pub(crate) checksums: bool, pub(crate) checksums: bool,
corks: AtomicU32, corks: AtomicU32,
} }
pub(crate) type Db = DBWithThreadMode<MultiThreaded>; pub(crate) type Db = OptimisticTransactionDB<MultiThreaded>;
impl Engine { impl Engine {
#[tracing::instrument( #[tracing::instrument(
@@ -129,14 +127,6 @@ impl Engine {
sequence sequence
} }
#[inline]
#[must_use]
pub fn is_read_only(&self) -> bool { self.secondary || self.read_only }
#[inline]
#[must_use]
pub fn is_secondary(&self) -> bool { self.secondary }
} }
impl Drop for Engine { impl Drop for Engine {
+1 -2
View File
@@ -12,9 +12,8 @@ pub fn backup(&self) -> Result {
let mut engine = self.backup_engine()?; let mut engine = self.backup_engine()?;
let config = &self.ctx.server.config; let config = &self.ctx.server.config;
if config.database_backups_to_keep > 0 { if config.database_backups_to_keep > 0 {
let flush = !self.is_read_only();
engine engine
.create_new_backup_flush(&self.db, flush) .create_new_backup_flush(&self.db, true)
.map_err(map_err)?; .map_err(map_err)?;
let engine_info = engine.get_backup_info(); let engine_info = engine.get_backup_info();
+11 -6
View File
@@ -1,7 +1,7 @@
use std::fmt::Write; use std::fmt::Write;
use conduwuit::{Result, implement}; use conduwuit::{Result, implement};
use rocksdb::perf::get_memory_usage_stats; use rocksdb::perf::MemoryUsageBuilder;
use super::Engine; use super::Engine;
use crate::or_else; use crate::or_else;
@@ -9,16 +9,21 @@ use crate::or_else;
#[implement(Engine)] #[implement(Engine)]
pub fn memory_usage(&self) -> Result<String> { pub fn memory_usage(&self) -> Result<String> {
let mut res = String::new(); let mut res = String::new();
let stats = get_memory_usage_stats(Some(&[&self.db]), Some(&[&*self.ctx.row_cache.lock()]))
.or_else(or_else)?; let mut builder = MemoryUsageBuilder::new().or_else(or_else)?;
builder.add_db(&self.db);
builder.add_cache(&self.ctx.row_cache.lock());
let usage = builder.build().or_else(or_else)?;
let mibs = |input| f64::from(u32::try_from(input / 1024).unwrap_or(0)) / 1024.0; let mibs = |input| f64::from(u32::try_from(input / 1024).unwrap_or(0)) / 1024.0;
writeln!( writeln!(
res, res,
"Memory buffers: {:.2} MiB\nPending write: {:.2} MiB\nTable readers: {:.2} MiB\nRow \ "Memory buffers: {:.2} MiB\nPending write: {:.2} MiB\nTable readers: {:.2} MiB\nRow \
cache: {:.2} MiB", cache: {:.2} MiB",
mibs(stats.mem_table_total), mibs(usage.approximate_mem_table_total()),
mibs(stats.mem_table_unflushed), mibs(usage.approximate_mem_table_unflushed()),
mibs(stats.mem_table_readers_total), mibs(usage.approximate_mem_table_readers_total()),
mibs(u64::try_from(self.ctx.row_cache.lock().get_usage())?), mibs(u64::try_from(self.ctx.row_cache.lock().get_usage())?),
)?; )?;
+1 -10
View File
@@ -35,14 +35,7 @@ pub(crate) async fn open(ctx: Arc<Context>, desc: &[Descriptor]) -> Result<Arc<S
} }
debug!("Opening database..."); debug!("Opening database...");
let db = if config.rocksdb_read_only { let db = Db::open_cf_descriptors(&db_opts, path, cfds).or_else(or_else)?;
Db::open_cf_descriptors_read_only(&db_opts, path, cfds, false)
} else if config.rocksdb_secondary {
Db::open_cf_descriptors_as_secondary(&db_opts, path, path, cfds)
} else {
Db::open_cf_descriptors(&db_opts, path, cfds)
}
.or_else(or_else)?;
info!( info!(
columns = num_cfds, columns = num_cfds,
@@ -55,8 +48,6 @@ pub(crate) async fn open(ctx: Arc<Context>, desc: &[Descriptor]) -> Result<Arc<S
db, db,
pool: ctx.pool.clone(), pool: ctx.pool.clone(),
ctx: ctx.clone(), ctx: ctx.clone(),
read_only: config.rocksdb_read_only,
secondary: config.rocksdb_secondary,
checksums: config.rocksdb_checksums, checksums: config.rocksdb_checksums,
corks: AtomicU32::new(0), corks: AtomicU32::new(0),
})) }))
+1 -1
View File
@@ -219,7 +219,7 @@ where
K: AsRef<[u8]> + Sized + Debug + 'a, K: AsRef<[u8]> + Sized + Debug + 'a,
V: AsRef<[u8]> + Sized + 'a, V: AsRef<[u8]> + Sized + 'a,
{ {
let mut batch = WriteBatchWithTransaction::<false>::default(); let mut batch = WriteBatchWithTransaction::<true>::default();
for (key, val) in iter { for (key, val) in iter {
batch.put_cf(&self.cf(), key.as_ref(), val.as_ref()); batch.put_cf(&self.cf(), key.as_ref(), val.as_ref());
} }
-8
View File
@@ -77,14 +77,6 @@ impl Database {
#[inline] #[inline]
pub fn keys(&self) -> impl Iterator<Item = &MapsKey> + Send + '_ { self.maps.keys() } pub fn keys(&self) -> impl Iterator<Item = &MapsKey> + Send + '_ { self.maps.keys() }
#[inline]
#[must_use]
pub fn is_read_only(&self) -> bool { self.db.is_read_only() }
#[inline]
#[must_use]
pub fn is_secondary(&self) -> bool { self.db.is_secondary() }
} }
impl Index<&str> for Database { impl Index<&str> for Database {
+8 -3
View File
@@ -1,4 +1,4 @@
use std::sync::Arc; use std::{path::PathBuf, sync::Arc};
use conduwuit_core::{ use conduwuit_core::{
Error, Result, Error, Result,
@@ -38,9 +38,14 @@ impl Server {
) -> Result<Arc<Self>, Error> { ) -> Result<Arc<Self>, Error> {
let _runtime_guard = runtime.map(runtime::Handle::enter); let _runtime_guard = runtime.map(runtime::Handle::enter);
let config_paths = args.config.clone().unwrap_or_default(); let config_paths = args
.config
.as_deref()
.into_iter()
.flat_map(<[_]>::iter)
.map(PathBuf::as_path);
let config = Config::load(&config_paths) let config = Config::load(config_paths)
.and_then(|raw| update(raw, args)) .and_then(|raw| update(raw, args))
.and_then(|raw| Config::new(&raw))?; .and_then(|raw| Config::new(&raw))?;
+6 -4
View File
@@ -1,4 +1,4 @@
use std::{ops::Deref, path::PathBuf, sync::Arc}; use std::{iter, ops::Deref, path::Path, sync::Arc};
use async_trait::async_trait; use async_trait::async_trait;
use conduwuit::{ use conduwuit::{
@@ -51,8 +51,7 @@ fn handle_reload(&self) -> Result {
]) ])
.expect("failed to notify systemd of reloading state"); .expect("failed to notify systemd of reloading state");
let config_paths = self.server.config.config_paths.clone().unwrap_or_default(); self.reload(iter::empty())?;
self.reload(&config_paths)?;
#[cfg(all(feature = "systemd", target_os = "linux"))] #[cfg(all(feature = "systemd", target_os = "linux"))]
sd_notify::notify(false, &[sd_notify::NotifyState::Ready]) sd_notify::notify(false, &[sd_notify::NotifyState::Ready])
@@ -63,7 +62,10 @@ fn handle_reload(&self) -> Result {
} }
#[implement(Service)] #[implement(Service)]
pub fn reload(&self, paths: &[PathBuf]) -> Result<Arc<Config>> { pub fn reload<'a, I>(&self, paths: I) -> Result<Arc<Config>>
where
I: Iterator<Item = &'a Path>,
{
let old = self.server.config.clone(); let old = self.server.config.clone();
let new = Config::load(paths).and_then(|raw| Config::new(&raw))?; let new = Config::load(paths).and_then(|raw| Config::new(&raw))?;
-4
View File
@@ -37,10 +37,6 @@ impl crate::Service for Service {
} }
async fn worker(self: Arc<Self>) -> Result { async fn worker(self: Arc<Self>) -> Result {
if self.services.globals.is_read_only() {
return Ok(());
}
if self.services.config.ldap.enable { if self.services.config.ldap.enable {
warn!("emergency password feature not available with LDAP enabled."); warn!("emergency password feature not available with LDAP enabled.");
return Ok(()); return Ok(());
-3
View File
@@ -171,7 +171,4 @@ impl Service {
pub fn server_is_ours(&self, server_name: &ServerName) -> bool { pub fn server_is_ours(&self, server_name: &ServerName) -> bool {
server_name == self.server_name() server_name == self.server_name()
} }
#[inline]
pub fn is_read_only(&self) -> bool { self.db.db.is_read_only() }
} }
+20 -175
View File
@@ -3,21 +3,14 @@
//! This module implements a check against a room-specific policy server, as //! This module implements a check against a room-specific policy server, as
//! described in the relevant Matrix spec proposal (see: https://github.com/matrix-org/matrix-spec-proposals/pull/4284). //! described in the relevant Matrix spec proposal (see: https://github.com/matrix-org/matrix-spec-proposals/pull/4284).
use std::{collections::BTreeMap, time::Duration}; use std::time::Duration;
use conduwuit::{ use conduwuit::{Err, Event, PduEvent, Result, debug, debug_info, implement, trace, warn};
Err, Event, PduEvent, Result, debug, debug_error, debug_info, debug_warn, implement, trace,
warn,
};
use ruma::{ use ruma::{
CanonicalJsonObject, CanonicalJsonValue, KeyId, RoomId, ServerName, SigningKeyId, CanonicalJsonObject, RoomId, ServerName,
api::federation::room::{ api::federation::room::policy::v1::Request as PolicyRequest,
policy_check::unstable::Request as PolicyCheckRequest,
policy_sign::unstable::Request as PolicySignRequest,
},
events::{StateEventType, room::policy::RoomPolicyEventContent}, events::{StateEventType, room::policy::RoomPolicyEventContent},
}; };
use serde_json::value::RawValue;
/// Asks a remote policy server if the event is allowed. /// Asks a remote policy server if the event is allowed.
/// ///
@@ -31,18 +24,25 @@ use serde_json::value::RawValue;
/// contacted for whatever reason, Err(e) is returned, which generally is a /// contacted for whatever reason, Err(e) is returned, which generally is a
/// fail-open operation. /// fail-open operation.
#[implement(super::Service)] #[implement(super::Service)]
#[tracing::instrument(skip(self, pdu, pdu_json, room_id))] #[tracing::instrument(skip_all, level = "debug")]
pub async fn ask_policy_server( pub async fn ask_policy_server(
&self, &self,
pdu: &PduEvent, pdu: &PduEvent,
pdu_json: &mut CanonicalJsonObject, pdu_json: &CanonicalJsonObject,
room_id: &RoomId, room_id: &RoomId,
incoming: bool,
) -> Result<bool> { ) -> Result<bool> {
if !self.services.server.config.enable_msc4284_policy_servers { if !self.services.server.config.enable_msc4284_policy_servers {
trace!("policy server checking is disabled");
return Ok(true); // don't ever contact policy servers return Ok(true); // don't ever contact policy servers
} }
if self.services.server.config.policy_server_check_own_events
&& pdu.origin.is_some()
&& self
.services
.server
.is_ours(pdu.origin.as_ref().unwrap().as_str())
{
return Ok(true); // don't contact policy servers for locally generated events
}
if *pdu.event_type() == StateEventType::RoomPolicy.into() { if *pdu.event_type() == StateEventType::RoomPolicy.into() {
debug!( debug!(
@@ -52,29 +52,16 @@ pub async fn ask_policy_server(
); );
return Ok(true); return Ok(true);
} }
let Ok(policyserver) = self let Ok(policyserver) = self
.services .services
.state_accessor .state_accessor
.room_state_get_content(room_id, &StateEventType::RoomPolicy, "") .room_state_get_content(room_id, &StateEventType::RoomPolicy, "")
.await .await
.inspect_err(|e| debug_error!("failed to load room policy server state event: {e}"))
.map(|c: RoomPolicyEventContent| c) .map(|c: RoomPolicyEventContent| c)
else { else {
debug!("room has no policy server configured");
return Ok(true); return Ok(true);
}; };
if self.services.server.config.policy_server_check_own_events
&& !incoming
&& policyserver.public_key.is_none()
{
// don't contact policy servers for locally generated events, but only when the
// policy server does not require signatures
trace!("won't contact policy server for locally generated event");
return Ok(true);
}
let via = match policyserver.via { let via = match policyserver.via {
| Some(ref via) => ServerName::parse(via)?, | Some(ref via) => ServerName::parse(via)?,
| None => { | None => {
@@ -88,6 +75,7 @@ pub async fn ask_policy_server(
} }
if !self.services.state_cache.server_in_room(via, room_id).await { if !self.services.state_cache.server_in_room(via, room_id).await {
debug!( debug!(
room_id = %room_id,
via = %via, via = %via,
"Policy server is not in the room, skipping spam check" "Policy server is not in the room, skipping spam check"
); );
@@ -98,43 +86,17 @@ pub async fn ask_policy_server(
.sending .sending
.convert_to_outgoing_federation_event(pdu_json.clone()) .convert_to_outgoing_federation_event(pdu_json.clone())
.await; .await;
if policyserver.public_key.is_some() {
if !incoming {
debug_info!(
via = %via,
outgoing = ?pdu_json,
"Getting policy server signature on event"
);
return self
.fetch_policy_server_signature(pdu, pdu_json, via, outgoing, room_id)
.await;
}
// for incoming events, is it signed by <via> with the key
// "ed25519:policy_server"?
if let Some(CanonicalJsonValue::Object(sigs)) = pdu_json.get("signatures") {
if let Some(CanonicalJsonValue::Object(server_sigs)) = sigs.get(via.as_str()) {
let wanted_key_id: &KeyId<ruma::SigningKeyAlgorithm, ruma::Base64PublicKey> =
SigningKeyId::parse("ed25519:policy_server")?;
if let Some(CanonicalJsonValue::String(_sig_value)) =
server_sigs.get(wanted_key_id.as_str())
{
// TODO: verify signature
}
}
}
debug!(
"Event is not local and has no policy server signature, performing legacy spam check"
);
}
debug_info!( debug_info!(
room_id = %room_id,
via = %via, via = %via,
"Checking event for spam with policy server via legacy check" outgoing = ?pdu_json,
"Checking event for spam with policy server"
); );
let response = tokio::time::timeout( let response = tokio::time::timeout(
Duration::from_secs(self.services.server.config.policy_server_request_timeout), Duration::from_secs(self.services.server.config.policy_server_request_timeout),
self.services self.services
.sending .sending
.send_federation_request(via, PolicyCheckRequest { .send_federation_request(via, PolicyRequest {
event_id: pdu.event_id().to_owned(), event_id: pdu.event_id().to_owned(),
pdu: Some(outgoing), pdu: Some(outgoing),
}), }),
@@ -180,120 +142,3 @@ pub async fn ask_policy_server(
Ok(true) Ok(true)
} }
/// Asks a remote policy server for a signature on this event.
/// If the policy server signs this event, the original data is mutated.
#[implement(super::Service)]
#[tracing::instrument(skip_all, fields(event_id=%pdu.event_id(), via=%via))]
pub async fn fetch_policy_server_signature(
&self,
pdu: &PduEvent,
pdu_json: &mut CanonicalJsonObject,
via: &ServerName,
outgoing: Box<RawValue>,
room_id: &RoomId,
) -> Result<bool> {
debug!("Requesting policy server signature");
let response = tokio::time::timeout(
Duration::from_secs(self.services.server.config.policy_server_request_timeout),
self.services
.sending
.send_federation_request(via, PolicySignRequest { pdu: outgoing }),
)
.await;
let response = match response {
| Ok(Ok(response)) => {
debug!("Response from policy server: {:?}", response);
response
},
| Ok(Err(e)) => {
warn!(
via = %via,
event_id = %pdu.event_id(),
room_id = %room_id,
"Failed to contact policy server: {e}"
);
// Network or policy server errors are treated as non-fatal: event is allowed by
// default.
return Err(e);
},
| Err(elapsed) => {
warn!(
%via,
event_id = %pdu.event_id(),
%room_id,
%elapsed,
"Policy server request timed out after 10 seconds"
);
return Err!("Request to policy server timed out");
},
};
if response.signatures.is_none() {
debug!("Policy server refused to sign event");
return Ok(false);
}
let sigs: ruma::Signatures<ruma::OwnedServerName, ruma::ServerSigningKeyVersion> =
response.signatures.unwrap();
if !sigs.contains_key(via) {
debug_warn!(
"Policy server returned signatures, but did not include the expected server name \
'{}': {:?}",
via,
sigs
);
return Ok(false);
}
let keypairs = sigs.get(via).unwrap();
let wanted_key_id = KeyId::parse("ed25519:policy_server")?;
if !keypairs.contains_key(wanted_key_id) {
debug_warn!(
"Policy server returned signature, but did not use the key ID \
'ed25519:policy_server'."
);
return Ok(false);
}
let signatures_entry = pdu_json
.entry("signatures".to_owned())
.or_insert_with(|| CanonicalJsonValue::Object(BTreeMap::default()));
if let CanonicalJsonValue::Object(signatures_map) = signatures_entry {
let sig_value = keypairs.get(wanted_key_id).unwrap().to_owned();
match signatures_map.get_mut(via.as_str()) {
| Some(CanonicalJsonValue::Object(inner_map)) => {
trace!("inserting PS signature: {}", sig_value);
inner_map.insert(
"ed25519:policy_server".to_owned(),
CanonicalJsonValue::String(sig_value),
);
},
| Some(_) => {
debug_warn!(
"Existing `signatures[{}]` field is not an object; cannot insert policy \
signature",
via
);
return Ok(false);
},
| None => {
let mut inner = BTreeMap::new();
inner.insert(
"ed25519:policy_server".to_owned(),
CanonicalJsonValue::String(sig_value.clone()),
);
trace!(
"created new signatures object for {via} with the signature {}",
sig_value
);
signatures_map.insert(via.as_str().to_owned(), CanonicalJsonValue::Object(inner));
},
}
} else {
debug_warn!(
"Existing `signatures` field is not an object; cannot insert policy signature"
);
return Ok(false);
}
Ok(true)
}
@@ -256,12 +256,7 @@ where
if incoming_pdu.state_key.is_none() { if incoming_pdu.state_key.is_none() {
debug!(event_id = %incoming_pdu.event_id, "Checking policy server for event"); debug!(event_id = %incoming_pdu.event_id, "Checking policy server for event");
match self match self
.ask_policy_server( .ask_policy_server(&incoming_pdu, &incoming_pdu.to_canonical_object(), room_id)
&incoming_pdu,
&mut incoming_pdu.to_canonical_object(),
room_id,
true,
)
.await .await
{ {
| Ok(false) => { | Ok(false) => {
@@ -1,746 +0,0 @@
use conduwuit::{Event, PduEvent, Result, err};
use ruma::{
UserId,
api::Direction,
events::relation::{BundledMessageLikeRelations, BundledReference, ReferenceChunk},
};
use crate::rooms::timeline::PdusIterItem;
const MAX_BUNDLED_RELATIONS: usize = 50;
impl super::Service {
/// Gets bundled aggregations for an event according to the Matrix
/// specification.
/// - m.replace relations are bundled to include the most recent replacement
/// event.
/// - m.reference relations are bundled to include a chunk of event IDs.
#[tracing::instrument(skip(self), level = "debug")]
pub async fn get_bundled_aggregations(
&self,
user_id: &UserId,
pdu: &PduEvent,
) -> Result<Option<BundledMessageLikeRelations<Box<serde_json::value::RawValue>>>> {
// Events that can never get bundled aggregations
if pdu.state_key().is_some() || Self::is_replacement_event(pdu) {
return Ok(None);
}
let relations = self
.get_relations(
user_id,
&pdu.room_id_or_hash(),
pdu.event_id(),
conduwuit::PduCount::max(),
MAX_BUNDLED_RELATIONS,
0,
Direction::Backward,
)
.await;
// The relations database code still handles the basic unsigned data
// We don't want to recursively fetch relations
if relations.is_empty() {
return Ok(None);
}
// Partition relations by type
let (replace_events, reference_events): (Vec<_>, Vec<_>) = relations
.iter()
.filter_map(|relation| {
let pdu = &relation.1;
let content = pdu.get_content_as_value();
content
.get("m.relates_to")
.and_then(|relates_to| relates_to.get("rel_type"))
.and_then(|rel_type| rel_type.as_str())
.and_then(|rel_type_str| match rel_type_str {
| "m.replace" => Some(RelationType::Replace(relation)),
| "m.reference" => Some(RelationType::Reference(relation)),
| _ => None, /* Ignore other relation types (threads are in DB but not
* handled here) */
})
})
.fold((Vec::new(), Vec::new()), |(mut replaces, mut references), rel_type| {
match rel_type {
| RelationType::Replace(r) => replaces.push(r),
| RelationType::Reference(r) => references.push(r),
}
(replaces, references)
});
// If no relations to bundle, return None
if replace_events.is_empty() && reference_events.is_empty() {
return Ok(None);
}
let mut bundled = BundledMessageLikeRelations::<Box<serde_json::value::RawValue>>::new();
// Handle m.replace relations - find the most recent valid one (lazy load
// original event)
if !replace_events.is_empty() {
if let Some(replacement) = self
.find_most_recent_valid_replacement(user_id, pdu, &replace_events)
.await?
{
bundled.replace = Some(Self::serialize_replacement(replacement)?);
}
}
// Handle m.reference relations - collect event IDs
if !reference_events.is_empty() {
let reference_chunk: Vec<_> = reference_events
.into_iter()
.map(|relation| BundledReference::new(relation.1.event_id().to_owned()))
.collect();
if !reference_chunk.is_empty() {
bundled.reference = Some(Box::new(ReferenceChunk::new(reference_chunk)));
}
}
// TODO: Handle other relation types (m.annotation, etc.) when specified
Ok(Some(bundled))
}
/// Serialize a replacement event to the bundled format
fn serialize_replacement(pdu: &PduEvent) -> Result<Box<Box<serde_json::value::RawValue>>> {
let replacement_json = serde_json::to_string(pdu)
.map_err(|e| err!(Database("Failed to serialize replacement event: {e}")))?;
let raw_value = serde_json::value::RawValue::from_string(replacement_json)
.map_err(|e| err!(Database("Failed to create RawValue: {e}")))?;
Ok(Box::new(raw_value))
}
/// Find the most recent valid replacement event based on origin_server_ts
/// and lexicographic event_id ordering
async fn find_most_recent_valid_replacement<'a>(
&self,
user_id: &UserId,
original_event: &PduEvent,
replacement_events: &[&'a PdusIterItem],
) -> Result<Option<&'a PduEvent>> {
// Filter valid replacements and find the maximum in a single pass
let mut result: Option<&PduEvent> = None;
for relation in replacement_events {
let pdu = &relation.1;
// Validate replacement
if !Self::is_valid_replacement_event(original_event, pdu).await? {
continue;
}
let next = match result {
| None => Some(pdu),
| Some(current) => {
// Compare by origin_server_ts first, then event_id lexicographically
match pdu.origin_server_ts().cmp(&current.origin_server_ts()) {
| std::cmp::Ordering::Greater => Some(pdu),
| std::cmp::Ordering::Equal if pdu.event_id() > current.event_id() =>
Some(pdu),
| _ => None,
}
},
};
if let Some(pdu) = next
&& self
.services
.state_accessor
.user_can_see_event(user_id, &pdu.room_id_or_hash(), pdu.event_id())
.await
{
result = Some(pdu);
}
}
Ok(result)
}
/// Adds bundled aggregations to a PDU's unsigned field
#[tracing::instrument(skip(self, pdu), level = "debug")]
pub async fn add_bundled_aggregations_to_pdu(
&self,
user_id: &UserId,
pdu: &mut PduEvent,
) -> Result<()> {
if pdu.is_redacted() {
return Ok(());
}
let bundled_aggregations = self.get_bundled_aggregations(user_id, pdu).await?;
if let Some(aggregations) = bundled_aggregations {
let aggregations_json = serde_json::to_value(aggregations)
.map_err(|e| err!(Database("Failed to serialize bundled aggregations: {e}")))?;
Self::add_bundled_aggregations_to_unsigned(pdu, aggregations_json)?;
}
Ok(())
}
/// Helper method to add bundled aggregations to a PDU's unsigned field
fn add_bundled_aggregations_to_unsigned(
pdu: &mut PduEvent,
aggregations_json: serde_json::Value,
) -> Result<()> {
use serde_json::{
Map, Value as JsonValue,
value::{RawValue as RawJsonValue, to_raw_value},
};
let mut unsigned: Map<String, JsonValue> = pdu
.unsigned
.as_deref()
.map(RawJsonValue::get)
.map_or_else(|| Ok(Map::new()), serde_json::from_str)
.map_err(|e| err!(Database("Invalid unsigned in pdu event: {e}")))?;
let relations = unsigned
.entry("m.relations")
.or_insert_with(|| JsonValue::Object(Map::new()))
.as_object_mut()
.ok_or_else(|| err!(Database("m.relations is not an object")))?;
if let JsonValue::Object(aggregations_map) = aggregations_json {
relations.extend(aggregations_map);
}
pdu.unsigned = Some(to_raw_value(&unsigned)?);
Ok(())
}
/// Validates that an event is acceptable as a replacement for another event
/// See C/S spec "Validity of replacement events"
#[tracing::instrument(level = "debug")]
async fn is_valid_replacement_event(
original_event: &PduEvent,
replacement_event: &PduEvent,
) -> Result<bool> {
Ok(
// 1. Same room_id
original_event.room_id() == replacement_event.room_id()
// 2. Same sender
&& original_event.sender() == replacement_event.sender()
// 3. Same type
&& original_event.event_type() == replacement_event.event_type()
// 4. Neither event should have a state_key property
&& original_event.state_key().is_none()
&& replacement_event.state_key().is_none()
// 5. Original event must not have rel_type of m.replace
&& !Self::is_replacement_event(original_event)
// 6. Replacement event must have m.new_content property (skip for encrypted)
&& Self::has_new_content_or_encrypted(replacement_event),
)
}
/// Check if an event is itself a replacement
#[inline]
fn is_replacement_event(event: &PduEvent) -> bool {
event
.get_content_as_value()
.get("m.relates_to")
.and_then(|relates_to| relates_to.get("rel_type"))
.and_then(|rel_type| rel_type.as_str())
.is_some_and(|rel_type| rel_type == "m.replace")
}
/// Check if event has m.new_content or is encrypted (where m.new_content
/// would be in the encrypted payload)
#[inline]
fn has_new_content_or_encrypted(event: &PduEvent) -> bool {
event.event_type() == &ruma::events::TimelineEventType::RoomEncrypted
|| event.get_content_as_value().get("m.new_content").is_some()
}
}
/// Helper enum for partitioning relations
enum RelationType<'a> {
Replace(&'a PdusIterItem),
Reference(&'a PdusIterItem),
}
#[cfg(test)]
mod tests {
use conduwuit_core::pdu::{EventHash, PduEvent};
use ruma::{UInt, events::TimelineEventType, owned_event_id, owned_room_id, owned_user_id};
use serde_json::{Value as JsonValue, json, value::to_raw_value};
fn create_test_pdu(unsigned_content: Option<JsonValue>) -> PduEvent {
PduEvent {
event_id: owned_event_id!("$test:example.com"),
room_id: Some(owned_room_id!("!test:example.com")),
sender: owned_user_id!("@test:example.com"),
origin_server_ts: UInt::try_from(1_234_567_890_u64).unwrap(),
kind: TimelineEventType::RoomMessage,
content: to_raw_value(&json!({"msgtype": "m.text", "body": "test"})).unwrap(),
state_key: None,
prev_events: vec![],
depth: UInt::from(1_u32),
auth_events: vec![],
redacts: None,
unsigned: unsigned_content.map(|content| to_raw_value(&content).unwrap()),
hashes: EventHash { sha256: "test_hash".to_owned() },
signatures: None,
origin: None,
}
}
fn create_bundled_aggregations() -> JsonValue {
json!({
"m.replace": {
"event_id": "$replace:example.com",
"origin_server_ts": 1_234_567_890,
"sender": "@replacer:example.com"
},
"m.reference": {
"count": 5,
"chunk": [
"$ref1:example.com",
"$ref2:example.com"
]
}
})
}
#[test]
fn test_add_bundled_aggregations_to_unsigned_no_existing_unsigned() {
let mut pdu = create_test_pdu(None);
let aggregations = create_bundled_aggregations();
let result = super::super::Service::add_bundled_aggregations_to_unsigned(
&mut pdu,
aggregations.clone(),
);
assert!(result.is_ok(), "Should succeed when no unsigned field exists");
assert!(pdu.unsigned.is_some(), "Unsigned field should be created");
let unsigned_str = pdu.unsigned.as_ref().unwrap().get();
let unsigned: JsonValue = serde_json::from_str(unsigned_str).unwrap();
assert!(unsigned.get("m.relations").is_some(), "m.relations should exist");
assert_eq!(
unsigned["m.relations"], aggregations,
"Relations should match the aggregations"
);
}
#[test]
fn test_add_bundled_aggregations_to_unsigned_overwrite_same_relation_type() {
let existing_unsigned = json!({
"m.relations": {
"m.replace": {
"event_id": "$old_replace:example.com",
"origin_server_ts": 1_111_111_111,
"sender": "@old_replacer:example.com"
}
}
});
let mut pdu = create_test_pdu(Some(existing_unsigned));
let new_aggregations = create_bundled_aggregations();
let result = super::super::Service::add_bundled_aggregations_to_unsigned(
&mut pdu,
new_aggregations.clone(),
);
assert!(result.is_ok(), "Should succeed when overwriting same relation type");
let unsigned_str = pdu.unsigned.as_ref().unwrap().get();
let unsigned: JsonValue = serde_json::from_str(unsigned_str).unwrap();
let relations = &unsigned["m.relations"];
assert_eq!(
relations["m.replace"], new_aggregations["m.replace"],
"m.replace should be updated"
);
assert_eq!(
relations["m.replace"]["event_id"], "$replace:example.com",
"Should have new event_id"
);
assert!(relations.get("m.reference").is_some(), "New m.reference should be added");
}
#[test]
fn test_add_bundled_aggregations_to_unsigned_preserve_other_unsigned_fields() {
// Test case: Other unsigned fields should be preserved
let existing_unsigned = json!({
"age": 98765,
"prev_content": {"msgtype": "m.text", "body": "old message"},
"redacted_because": {"event_id": "$redaction:example.com"},
"m.relations": {
"m.annotation": {"count": 1}
}
});
let mut pdu = create_test_pdu(Some(existing_unsigned));
let new_aggregations = json!({
"m.replace": {"event_id": "$new:example.com"}
});
let result = super::super::Service::add_bundled_aggregations_to_unsigned(
&mut pdu,
new_aggregations,
);
assert!(result.is_ok(), "Should succeed while preserving other fields");
let unsigned_str = pdu.unsigned.as_ref().unwrap().get();
let unsigned: JsonValue = serde_json::from_str(unsigned_str).unwrap();
// Verify all existing fields are preserved
assert_eq!(unsigned["age"], 98765, "age should be preserved");
assert!(unsigned.get("prev_content").is_some(), "prev_content should be preserved");
assert!(
unsigned.get("redacted_because").is_some(),
"redacted_because should be preserved"
);
// Verify relations were merged correctly
let relations = &unsigned["m.relations"];
assert!(
relations.get("m.annotation").is_some(),
"Existing m.annotation should be preserved"
);
assert!(relations.get("m.replace").is_some(), "New m.replace should be added");
}
#[test]
fn test_add_bundled_aggregations_to_unsigned_invalid_existing_unsigned() {
// Test case: Invalid JSON in existing unsigned should result in error
let mut pdu = create_test_pdu(None);
// Manually set invalid unsigned data
pdu.unsigned = Some(to_raw_value(&"invalid json").unwrap());
let aggregations = create_bundled_aggregations();
let result =
super::super::Service::add_bundled_aggregations_to_unsigned(&mut pdu, aggregations);
assert!(result.is_err(), "fails when existing unsigned is invalid");
// Should we ignore the error and overwrite anyway?
}
// Test helper function to create test PDU events
fn create_test_event(
event_id: &str,
room_id: &str,
sender: &str,
event_type: TimelineEventType,
content: &JsonValue,
state_key: Option<&str>,
) -> PduEvent {
PduEvent {
event_id: event_id.try_into().unwrap(),
room_id: Some(room_id.try_into().unwrap()),
sender: sender.try_into().unwrap(),
origin_server_ts: UInt::try_from(1_234_567_890_u64).unwrap(),
kind: event_type,
content: to_raw_value(&content).unwrap(),
state_key: state_key.map(Into::into),
prev_events: vec![],
depth: UInt::from(1_u32),
auth_events: vec![],
redacts: None,
unsigned: None,
hashes: EventHash { sha256: "test_hash".to_owned() },
signatures: None,
origin: None,
}
}
/// Test that a valid replacement event passes validation
#[tokio::test]
async fn test_valid_replacement_event() {
let original = create_test_event(
"$original:example.com",
"!room:example.com",
"@user:example.com",
TimelineEventType::RoomMessage,
&json!({"msgtype": "m.text", "body": "original message"}),
None,
);
let replacement = create_test_event(
"$replacement:example.com",
"!room:example.com",
"@user:example.com",
TimelineEventType::RoomMessage,
&json!({
"msgtype": "m.text",
"body": "* edited message",
"m.new_content": {
"msgtype": "m.text",
"body": "edited message"
},
"m.relates_to": {
"rel_type": "m.replace",
"event_id": "$original:example.com"
}
}),
None,
);
let result =
super::super::Service::is_valid_replacement_event(&original, &replacement).await;
assert!(result.is_ok(), "Validation should succeed");
assert!(result.unwrap(), "Valid replacement event should be accepted");
}
/// Test replacement event with different room ID is rejected
#[tokio::test]
async fn test_replacement_event_different_room() {
let original = create_test_event(
"$original:example.com",
"!room1:example.com",
"@user:example.com",
TimelineEventType::RoomMessage,
&json!({"msgtype": "m.text", "body": "original message"}),
None,
);
let replacement = create_test_event(
"$replacement:example.com",
"!room2:example.com", // Different room
"@user:example.com",
TimelineEventType::RoomMessage,
&json!({
"msgtype": "m.text",
"body": "* edited message",
"m.new_content": {
"msgtype": "m.text",
"body": "edited message"
}
}),
None,
);
let result =
super::super::Service::is_valid_replacement_event(&original, &replacement).await;
assert!(result.is_ok(), "Validation should succeed");
assert!(!result.unwrap(), "Different room ID should be rejected");
}
/// Test replacement event with different sender is rejected
#[tokio::test]
async fn test_replacement_event_different_sender() {
let original = create_test_event(
"$original:example.com",
"!room:example.com",
"@user1:example.com",
TimelineEventType::RoomMessage,
&json!({"msgtype": "m.text", "body": "original message"}),
None,
);
let replacement = create_test_event(
"$replacement:example.com",
"!room:example.com",
"@user2:example.com", // Different sender
TimelineEventType::RoomMessage,
&json!({
"msgtype": "m.text",
"body": "* edited message",
"m.new_content": {
"msgtype": "m.text",
"body": "edited message"
}
}),
None,
);
let result =
super::super::Service::is_valid_replacement_event(&original, &replacement).await;
assert!(result.is_ok(), "Validation should succeed");
assert!(!result.unwrap(), "Different sender should be rejected");
}
/// Test replacement event with different type is rejected
#[tokio::test]
async fn test_replacement_event_different_type() {
let original = create_test_event(
"$original:example.com",
"!room:example.com",
"@user:example.com",
TimelineEventType::RoomMessage,
&json!({"msgtype": "m.text", "body": "original message"}),
None,
);
let replacement = create_test_event(
"$replacement:example.com",
"!room:example.com",
"@user:example.com",
TimelineEventType::RoomTopic, // Different event type
&json!({
"topic": "new topic",
"m.new_content": {
"topic": "new topic"
}
}),
None,
);
let result =
super::super::Service::is_valid_replacement_event(&original, &replacement).await;
assert!(result.is_ok(), "Validation should succeed");
assert!(!result.unwrap(), "Different event type should be rejected");
}
/// Test replacement event with state key is rejected
#[tokio::test]
async fn test_replacement_event_with_state_key() {
let original = create_test_event(
"$original:example.com",
"!room:example.com",
"@user:example.com",
TimelineEventType::RoomName,
&json!({"name": "room name"}),
Some(""), // Has state key
);
let replacement = create_test_event(
"$replacement:example.com",
"!room:example.com",
"@user:example.com",
TimelineEventType::RoomName,
&json!({
"name": "new room name",
"m.new_content": {
"name": "new room name"
}
}),
None,
);
let result =
super::super::Service::is_valid_replacement_event(&original, &replacement).await;
assert!(result.is_ok(), "Validation should succeed");
assert!(!result.unwrap(), "Event with state key should be rejected");
}
/// Test replacement of an event that is already a replacement is rejected
#[tokio::test]
async fn test_replacement_event_original_is_replacement() {
let original = create_test_event(
"$original:example.com",
"!room:example.com",
"@user:example.com",
TimelineEventType::RoomMessage,
&json!({
"msgtype": "m.text",
"body": "* edited message",
"m.relates_to": {
"rel_type": "m.replace", // Original is already a replacement
"event_id": "$some_other:example.com"
}
}),
None,
);
let replacement = create_test_event(
"$replacement:example.com",
"!room:example.com",
"@user:example.com",
TimelineEventType::RoomMessage,
&json!({
"msgtype": "m.text",
"body": "* edited again",
"m.new_content": {
"msgtype": "m.text",
"body": "edited again"
}
}),
None,
);
let result =
super::super::Service::is_valid_replacement_event(&original, &replacement).await;
assert!(result.is_ok(), "Validation should succeed");
assert!(!result.unwrap(), "Replacement of replacement should be rejected");
}
/// Test replacement event missing m.new_content is rejected
#[tokio::test]
async fn test_replacement_event_missing_new_content() {
let original = create_test_event(
"$original:example.com",
"!room:example.com",
"@user:example.com",
TimelineEventType::RoomMessage,
&json!({"msgtype": "m.text", "body": "original message"}),
None,
);
let replacement = create_test_event(
"$replacement:example.com",
"!room:example.com",
"@user:example.com",
TimelineEventType::RoomMessage,
&json!({
"msgtype": "m.text",
"body": "* edited message"
// Missing m.new_content
}),
None,
);
let result =
super::super::Service::is_valid_replacement_event(&original, &replacement).await;
assert!(result.is_ok(), "Validation should succeed");
assert!(!result.unwrap(), "Missing m.new_content should be rejected");
}
/// Test encrypted replacement event without m.new_content is accepted
#[tokio::test]
async fn test_replacement_event_encrypted_missing_new_content_is_valid() {
let original = create_test_event(
"$original:example.com",
"!room:example.com",
"@user:example.com",
TimelineEventType::RoomEncrypted,
&json!({
"algorithm": "m.megolm.v1.aes-sha2",
"ciphertext": "encrypted_payload_base64",
"sender_key": "sender_key",
"session_id": "session_id"
}),
None,
);
let replacement = create_test_event(
"$replacement:example.com",
"!room:example.com",
"@user:example.com",
TimelineEventType::RoomEncrypted,
&json!({
"algorithm": "m.megolm.v1.aes-sha2",
"ciphertext": "encrypted_replacement_payload_base64",
"sender_key": "sender_key",
"session_id": "session_id",
"m.relates_to": {
"rel_type": "m.replace",
"event_id": "$original:example.com"
}
// No m.new_content in cleartext - this is valid for encrypted events
}),
None,
);
let result =
super::super::Service::is_valid_replacement_event(&original, &replacement).await;
assert!(result.is_ok(), "Validation should succeed");
assert!(
result.unwrap(),
"Encrypted replacement without cleartext m.new_content should be accepted"
);
}
}
+7 -5
View File
@@ -3,6 +3,7 @@ use std::{mem::size_of, sync::Arc};
use conduwuit::{ use conduwuit::{
arrayvec::ArrayVec, arrayvec::ArrayVec,
matrix::{Event, PduCount}, matrix::{Event, PduCount},
result::LogErr,
utils::{ utils::{
ReadyExt, ReadyExt,
stream::{TryIgnore, WidebandExt}, stream::{TryIgnore, WidebandExt},
@@ -14,11 +15,10 @@ use futures::{Stream, StreamExt};
use ruma::{EventId, RoomId, UserId, api::Direction}; use ruma::{EventId, RoomId, UserId, api::Direction};
use crate::{ use crate::{
Dep, Dep, rooms,
rooms::{ rooms::{
self,
short::{ShortEventId, ShortRoomId}, short::{ShortEventId, ShortRoomId},
timeline::{PduId, PdusIterItem, RawPduId}, timeline::{PduId, RawPduId},
}, },
}; };
@@ -60,7 +60,7 @@ impl Data {
target: ShortEventId, target: ShortEventId,
from: PduCount, from: PduCount,
dir: Direction, dir: Direction,
) -> impl Stream<Item = PdusIterItem> + Send + 'a { ) -> impl Stream<Item = (PduCount, impl Event)> + Send + 'a {
// Query from exact position then filter excludes it (saturating_inc could skip // Query from exact position then filter excludes it (saturating_inc could skip
// events at min/max boundaries) // events at min/max boundaries)
let from_unsigned = from.into_unsigned(); let from_unsigned = from.into_unsigned();
@@ -92,7 +92,9 @@ impl Data {
let mut pdu = self.services.timeline.get_pdu_from_id(&pdu_id).await.ok()?; let mut pdu = self.services.timeline.get_pdu_from_id(&pdu_id).await.ok()?;
pdu.as_mut_pdu().set_unsigned(Some(user_id)); if pdu.sender() != user_id {
pdu.as_mut_pdu().remove_transaction_id().log_err().ok();
}
Some((shorteventid, pdu)) Some((shorteventid, pdu))
}) })
+6 -10
View File
@@ -1,16 +1,15 @@
mod bundled_aggregations;
mod data; mod data;
use std::sync::Arc; use std::sync::Arc;
use conduwuit::{Result, matrix::PduCount}; use conduwuit::{
Result,
matrix::{Event, PduCount},
};
use futures::{StreamExt, future::try_join}; use futures::{StreamExt, future::try_join};
use ruma::{EventId, RoomId, UserId, api::Direction}; use ruma::{EventId, RoomId, UserId, api::Direction};
use self::data::Data; use self::data::Data;
use crate::{ use crate::{Dep, rooms};
Dep,
rooms::{self, timeline::PdusIterItem},
};
pub struct Service { pub struct Service {
services: Services, services: Services,
@@ -20,7 +19,6 @@ pub struct Service {
struct Services { struct Services {
short: Dep<rooms::short::Service>, short: Dep<rooms::short::Service>,
timeline: Dep<rooms::timeline::Service>, timeline: Dep<rooms::timeline::Service>,
state_accessor: Dep<rooms::state_accessor::Service>,
} }
impl crate::Service for Service { impl crate::Service for Service {
@@ -29,8 +27,6 @@ impl crate::Service for Service {
services: Services { services: Services {
short: args.depend::<rooms::short::Service>("rooms::short"), short: args.depend::<rooms::short::Service>("rooms::short"),
timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"), timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"),
state_accessor: args
.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
}, },
db: Data::new(&args), db: Data::new(&args),
})) }))
@@ -60,7 +56,7 @@ impl Service {
limit: usize, limit: usize,
max_depth: u8, max_depth: u8,
dir: Direction, dir: Direction,
) -> Vec<PdusIterItem> { ) -> Vec<(PduCount, impl Event)> {
let room_id = self.services.short.get_shortroomid(room_id); let room_id = self.services.short.get_shortroomid(room_id);
let target = self.services.timeline.get_pdu_count(target); let target = self.services.timeline.get_pdu_count(target);
+4 -23
View File
@@ -1,9 +1,9 @@
use std::sync::Arc; use std::sync::Arc;
use conduwuit::{ use conduwuit::{
PduCount, PduEvent, Result, PduCount, Result,
arrayvec::ArrayVec, arrayvec::ArrayVec,
debug_warn, implement, implement,
matrix::event::{Event, Matches}, matrix::event::{Event, Matches},
utils::{ utils::{
ArrayVecExt, IterStream, ReadyExt, set, ArrayVecExt, IterStream, ReadyExt, set,
@@ -35,7 +35,6 @@ struct Services {
short: Dep<rooms::short::Service>, short: Dep<rooms::short::Service>,
state_accessor: Dep<rooms::state_accessor::Service>, state_accessor: Dep<rooms::state_accessor::Service>,
timeline: Dep<rooms::timeline::Service>, timeline: Dep<rooms::timeline::Service>,
pdu_metadata: Dep<rooms::pdu_metadata::Service>,
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
@@ -62,7 +61,6 @@ impl crate::Service for Service {
state_accessor: args state_accessor: args
.depend::<rooms::state_accessor::Service>("rooms::state_accessor"), .depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"), timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"),
pdu_metadata: args.depend::<rooms::pdu_metadata::Service>("rooms::pdu_metadata"),
}, },
})) }))
} }
@@ -106,8 +104,7 @@ pub fn deindex_pdu(&self, shortroomid: ShortRoomId, pdu_id: &RawPduId, message_b
pub async fn search_pdus<'a>( pub async fn search_pdus<'a>(
&'a self, &'a self,
query: &'a RoomQuery<'a>, query: &'a RoomQuery<'a>,
sender_user: &'a UserId, ) -> Result<(usize, impl Stream<Item = impl Event + use<>> + Send + 'a)> {
) -> Result<(usize, impl Stream<Item = PduEvent> + Send + 'a)> {
let pdu_ids: Vec<_> = self.search_pdu_ids(query).await?.collect().await; let pdu_ids: Vec<_> = self.search_pdu_ids(query).await?.collect().await;
let filter = &query.criteria.filter; let filter = &query.criteria.filter;
@@ -132,23 +129,7 @@ pub async fn search_pdus<'a>(
.then_some(pdu) .then_some(pdu)
}) })
.skip(query.skip) .skip(query.skip)
.take(query.limit) .take(query.limit);
.map(move |mut pdu| {
pdu.set_unsigned(query.user_id);
pdu
})
.then(async move |mut pdu| {
if let Err(e) = self
.services
.pdu_metadata
.add_bundled_aggregations_to_pdu(sender_user, &mut pdu)
.await
{
debug_warn!("Failed to add bundled aggregations: {e}");
}
pdu
});
Ok((count, pdus)) Ok((count, pdus))
} }
-2
View File
@@ -264,8 +264,6 @@ fn get_space_child_events<'a>(
if content.via.is_empty() { if content.via.is_empty() {
return None; return None;
} }
} else {
return None;
} }
if RoomId::parse(&state_key).is_err() { if RoomId::parse(&state_key).is_err() {
+3 -1
View File
@@ -163,7 +163,9 @@ impl Service {
let mut pdu = self.services.timeline.get_pdu_from_id(&pdu_id).await.ok()?; let mut pdu = self.services.timeline.get_pdu_from_id(&pdu_id).await.ok()?;
let pdu_id: PduId = pdu_id.into(); let pdu_id: PduId = pdu_id.into();
pdu.as_mut_pdu().set_unsigned(Some(user_id)); if pdu.sender() != user_id {
pdu.as_mut_pdu().remove_transaction_id().ok();
}
Some((pdu_id.shorteventid, pdu)) Some((pdu_id.shorteventid, pdu))
}); });
-4
View File
@@ -347,10 +347,6 @@ where
| _ => {}, | _ => {},
} }
// CONCERN: If we receive events with a relation out-of-order, we never write
// their relation / thread. We need some kind of way to trigger when we receive
// this event, and potentially a way to rebuild the table entirely.
if let Ok(content) = pdu.get_content::<ExtractRelatesToEventId>() { if let Ok(content) = pdu.get_content::<ExtractRelatesToEventId>() {
if let Ok(related_pducount) = self.get_pdu_count(&content.relates_to.event_id).await { if let Ok(related_pducount) = self.get_pdu_count(&content.relates_to.event_id).await {
self.services self.services
+1 -1
View File
@@ -308,7 +308,7 @@ pub async fn create_hash_and_sign_event(
match self match self
.services .services
.event_handler .event_handler
.ask_policy_server(&pdu, &mut pdu_json, pdu.room_id().expect("has room ID"), false) .ask_policy_server(&pdu, &pdu_json, pdu.room_id().expect("has room ID"))
.await .await
{ {
| Ok(true) => {}, | Ok(true) => {},
+27 -11
View File
@@ -1,13 +1,13 @@
use std::sync::Arc; use std::{borrow::Borrow, sync::Arc};
use conduwuit::{ use conduwuit::{
Err, PduCount, PduEvent, Result, at, err, Err, PduCount, PduEvent, Result, at, err,
result::NotFound, result::{LogErr, NotFound},
utils::{self, stream::TryReadyExt}, utils::{self, stream::TryReadyExt},
}; };
use database::{Database, Deserialized, Json, KeyVal, Map}; use database::{Database, Deserialized, Json, KeyVal, Map};
use futures::{FutureExt, Stream, TryFutureExt, TryStreamExt, future::select_ok, pin_mut}; use futures::{FutureExt, Stream, TryFutureExt, TryStreamExt, future::select_ok, pin_mut};
use ruma::{CanonicalJsonObject, EventId, OwnedUserId, RoomId, api::Direction}; use ruma::{CanonicalJsonObject, EventId, OwnedUserId, RoomId, UserId, api::Direction};
use super::{PduId, RawPduId}; use super::{PduId, RawPduId};
use crate::{Dep, rooms, rooms::short::ShortRoomId}; use crate::{Dep, rooms, rooms::short::ShortRoomId};
@@ -45,8 +45,12 @@ impl Data {
} }
#[inline] #[inline]
pub(super) async fn last_timeline_count(&self, room_id: &RoomId) -> Result<PduCount> { pub(super) async fn last_timeline_count(
let pdus_rev = self.pdus_rev(room_id, PduCount::max()); &self,
sender_user: Option<&UserId>,
room_id: &RoomId,
) -> Result<PduCount> {
let pdus_rev = self.pdus_rev(sender_user, room_id, PduCount::max());
pin_mut!(pdus_rev); pin_mut!(pdus_rev);
let last_count = pdus_rev let last_count = pdus_rev
@@ -60,8 +64,12 @@ impl Data {
} }
#[inline] #[inline]
pub(super) async fn latest_pdu_in_room(&self, room_id: &RoomId) -> Result<PduEvent> { pub(super) async fn latest_pdu_in_room(
let pdus_rev = self.pdus_rev(room_id, PduCount::max()); &self,
sender_user: Option<&UserId>,
room_id: &RoomId,
) -> Result<PduEvent> {
let pdus_rev = self.pdus_rev(sender_user, room_id, PduCount::max());
pin_mut!(pdus_rev); pin_mut!(pdus_rev);
pdus_rev pdus_rev
@@ -213,6 +221,7 @@ impl Data {
/// order. /// order.
pub(super) fn pdus_rev<'a>( pub(super) fn pdus_rev<'a>(
&'a self, &'a self,
user_id: Option<&'a UserId>,
room_id: &'a RoomId, room_id: &'a RoomId,
until: PduCount, until: PduCount,
) -> impl Stream<Item = Result<PdusIterItem>> + Send + 'a { ) -> impl Stream<Item = Result<PdusIterItem>> + Send + 'a {
@@ -222,13 +231,14 @@ impl Data {
self.pduid_pdu self.pduid_pdu
.rev_raw_stream_from(&current) .rev_raw_stream_from(&current)
.ready_try_take_while(move |(key, _)| Ok(key.starts_with(&prefix))) .ready_try_take_while(move |(key, _)| Ok(key.starts_with(&prefix)))
.ready_and_then(Self::from_json_slice) .ready_and_then(move |item| Self::each_pdu(item, user_id))
}) })
.try_flatten_stream() .try_flatten_stream()
} }
pub(super) fn pdus<'a>( pub(super) fn pdus<'a>(
&'a self, &'a self,
user_id: Option<&'a UserId>,
room_id: &'a RoomId, room_id: &'a RoomId,
from: PduCount, from: PduCount,
) -> impl Stream<Item = Result<PdusIterItem>> + Send + 'a { ) -> impl Stream<Item = Result<PdusIterItem>> + Send + 'a {
@@ -238,15 +248,21 @@ impl Data {
self.pduid_pdu self.pduid_pdu
.raw_stream_from(&current) .raw_stream_from(&current)
.ready_try_take_while(move |(key, _)| Ok(key.starts_with(&prefix))) .ready_try_take_while(move |(key, _)| Ok(key.starts_with(&prefix)))
.ready_and_then(Self::from_json_slice) .ready_and_then(move |item| Self::each_pdu(item, user_id))
}) })
.try_flatten_stream() .try_flatten_stream()
} }
fn from_json_slice((pdu_id, pdu): KeyVal<'_>) -> Result<PdusIterItem> { fn each_pdu((pdu_id, pdu): KeyVal<'_>, user_id: Option<&UserId>) -> Result<PdusIterItem> {
let pdu_id: RawPduId = pdu_id.into(); let pdu_id: RawPduId = pdu_id.into();
let pdu = serde_json::from_slice::<PduEvent>(pdu)?; let mut pdu = serde_json::from_slice::<PduEvent>(pdu)?;
if Some(pdu.sender.borrow()) != user_id {
pdu.remove_transaction_id().log_err().ok();
}
pdu.add_age().log_err().ok();
Ok((pdu_id.pdu_count(), pdu)) Ok((pdu_id.pdu_count(), pdu))
} }
+16 -8
View File
@@ -20,7 +20,7 @@ use conduwuit_core::{
}; };
use futures::{Future, Stream, TryStreamExt, pin_mut}; use futures::{Future, Stream, TryStreamExt, pin_mut};
use ruma::{ use ruma::{
CanonicalJsonObject, EventId, OwnedEventId, OwnedRoomId, RoomId, CanonicalJsonObject, EventId, OwnedEventId, OwnedRoomId, RoomId, UserId,
events::room::encrypted::Relation, events::room::encrypted::Relation,
}; };
use serde::Deserialize; use serde::Deserialize;
@@ -138,7 +138,7 @@ impl Service {
#[tracing::instrument(skip(self), level = "debug")] #[tracing::instrument(skip(self), level = "debug")]
pub async fn first_item_in_room(&self, room_id: &RoomId) -> Result<(PduCount, impl Event)> { pub async fn first_item_in_room(&self, room_id: &RoomId) -> Result<(PduCount, impl Event)> {
let pdus = self.pdus(room_id, None); let pdus = self.pdus(None, room_id, None);
pin_mut!(pdus); pin_mut!(pdus);
pdus.try_next() pdus.try_next()
@@ -148,12 +148,16 @@ impl Service {
#[tracing::instrument(skip(self), level = "debug")] #[tracing::instrument(skip(self), level = "debug")]
pub async fn latest_pdu_in_room(&self, room_id: &RoomId) -> Result<impl Event> { pub async fn latest_pdu_in_room(&self, room_id: &RoomId) -> Result<impl Event> {
self.db.latest_pdu_in_room(room_id).await self.db.latest_pdu_in_room(None, room_id).await
} }
#[tracing::instrument(skip(self), level = "debug")] #[tracing::instrument(skip(self), level = "debug")]
pub async fn last_timeline_count(&self, room_id: &RoomId) -> Result<PduCount> { pub async fn last_timeline_count(
self.db.last_timeline_count(room_id).await &self,
sender_user: Option<&UserId>,
room_id: &RoomId,
) -> Result<PduCount> {
self.db.last_timeline_count(sender_user, room_id).await
} }
/// Returns the `count` of this pdu's id. /// Returns the `count` of this pdu's id.
@@ -231,29 +235,33 @@ impl Service {
#[inline] #[inline]
pub fn all_pdus<'a>( pub fn all_pdus<'a>(
&'a self, &'a self,
user_id: &'a UserId,
room_id: &'a RoomId, room_id: &'a RoomId,
) -> impl Stream<Item = PdusIterItem> + Send + 'a { ) -> impl Stream<Item = PdusIterItem> + Send + 'a {
self.pdus(room_id, None).ignore_err() self.pdus(Some(user_id), room_id, None).ignore_err()
} }
/// Reverse iteration starting after `until`. /// Reverse iteration starting after `until`.
#[tracing::instrument(skip(self), level = "debug")] #[tracing::instrument(skip(self), level = "debug")]
pub fn pdus_rev<'a>( pub fn pdus_rev<'a>(
&'a self, &'a self,
user_id: Option<&'a UserId>,
room_id: &'a RoomId, room_id: &'a RoomId,
until: Option<PduCount>, until: Option<PduCount>,
) -> impl Stream<Item = Result<PdusIterItem>> + Send + 'a { ) -> impl Stream<Item = Result<PdusIterItem>> + Send + 'a {
self.db self.db
.pdus_rev(room_id, until.unwrap_or_else(PduCount::max)) .pdus_rev(user_id, room_id, until.unwrap_or_else(PduCount::max))
} }
/// Forward iteration starting after `from`. /// Forward iteration starting after `from`.
#[tracing::instrument(skip(self), level = "debug")] #[tracing::instrument(skip(self), level = "debug")]
pub fn pdus<'a>( pub fn pdus<'a>(
&'a self, &'a self,
user_id: Option<&'a UserId>,
room_id: &'a RoomId, room_id: &'a RoomId,
from: Option<PduCount>, from: Option<PduCount>,
) -> impl Stream<Item = Result<PdusIterItem>> + Send + 'a { ) -> impl Stream<Item = Result<PdusIterItem>> + Send + 'a {
self.db.pdus(room_id, from.unwrap_or_else(PduCount::min)) self.db
.pdus(user_id, room_id, from.unwrap_or_else(PduCount::min))
} }
} }
+1 -1
View File
@@ -781,7 +781,7 @@ impl Service {
for pdu in pdus { for pdu in pdus {
// Redacted events are not notification targets (we don't send push for them) // Redacted events are not notification targets (we don't send push for them)
if pdu.is_redacted() { if pdu.contains_unsigned_property("redacted_because", serde_json::Value::is_string) {
continue; continue;
} }
+2 -2
View File
@@ -130,7 +130,7 @@ impl Services {
// reset dormant online/away statuses to offline, and set the server user as // reset dormant online/away statuses to offline, and set the server user as
// online // online
if self.server.config.allow_local_presence && !self.db.is_read_only() { if self.server.config.allow_local_presence {
self.presence.unset_all_presence().await; self.presence.unset_all_presence().await;
_ = self _ = self
.presence .presence
@@ -146,7 +146,7 @@ impl Services {
info!("Shutting down services..."); info!("Shutting down services...");
// set the server user as offline // set the server user as offline
if self.server.config.allow_local_presence && !self.db.is_read_only() { if self.server.config.allow_local_presence {
_ = self _ = self
.presence .presence
.ping_presence(&self.globals.server_user, &ruma::presence::PresenceState::Offline) .ping_presence(&self.globals.server_user, &ruma::presence::PresenceState::Offline)