Compare commits

...

23 Commits

Author SHA1 Message Date
strawberry e524590860 disable more unnecessary features in various build outputs
Signed-off-by: strawberry <strawberry@puppygock.gay>
2025-01-10 22:45:48 -05:00
strawberry d5217566d9 ci: require docker publishing to pass tests
Signed-off-by: strawberry <strawberry@puppygock.gay>
2025-01-10 21:34:20 -05:00
strawberry 099f98978b gate libloading to conduwuit_mods feature and cfg only
Signed-off-by: strawberry <strawberry@puppygock.gay>
2025-01-10 21:34:20 -05:00
strawberry 43e70fe7c1 gate sd_notify to linux target_os only
Signed-off-by: strawberry <strawberry@puppygock.gay>
2025-01-10 21:34:20 -05:00
morguldir 721659f22a Add initial MSC4186 (Simplified Sliding Sync) implementation
Signed-off-by: morguldir <morguldir@protonmail.com>
2025-01-10 21:34:20 -05:00
morguldir 099c9fb22f syncv3: use a function for repeated pattern of fetching sticky params 2025-01-10 21:34:20 -05:00
Jason Volk 66231676f1 gracefully ignore unknown columns; add dropped flag in descriptor
Signed-off-by: Jason Volk <jason@zemos.net>
2025-01-10 22:29:05 +00:00
Jason Volk 16fa2eca87 add conf item for write buffer size
Signed-off-by: Jason Volk <jason@zemos.net>
2025-01-10 07:03:34 +00:00
Jason Volk 6a0f9add0c refactor database engine/options; add column descriptors
Signed-off-by: Jason Volk <jason@zemos.net>
2025-01-10 07:03:34 +00:00
Jason Volk 02f19cf951 tweak tracing spans; inlines
db deserializer tracing instrument cover

Signed-off-by: Jason Volk <jason@zemos.net>
2025-01-10 07:03:34 +00:00
Jason Volk 685b127f99 simplify iterator state constructor arguments
Signed-off-by: Jason Volk <jason@zemos.net>
2025-01-10 07:03:34 +00:00
Jason Volk cc1889d135 Add default-enabled feature-gates for url_preview and media_thumbnail
Signed-off-by: Jason Volk <jason@zemos.net>
2025-01-10 07:03:34 +00:00
Jason Volk 0238f27605 prevent example-config generating in test builds
Signed-off-by: Jason Volk <jason@zemos.net>
2025-01-10 07:03:34 +00:00
Jason Volk 5dae086197 exclude config item from doctest
Signed-off-by: Jason Volk <jason@zemos.net>
2025-01-10 06:53:15 +00:00
Jason Volk 44e6b1af3c fixes for tests to be run in release-mode
Signed-off-by: Jason Volk <jason@zemos.net>
2025-01-09 19:56:42 +00:00
Jason Volk 94c8683836 improve db pool topology configuration
Signed-off-by: Jason Volk <jason@zemos.net>
2025-01-09 18:14:23 +00:00
Jason Volk d36167ab64 partially revert 9a9c071e82; use std threads for db pool.
Signed-off-by: Jason Volk <jason@zemos.net>
2025-01-09 18:14:23 +00:00
Jason Volk 925061b92d flatten timeline pdus iterations; increase concurrency
Signed-off-by: Jason Volk <jason@zemos.net>
2025-01-09 18:14:23 +00:00
Jason Volk 27328cbc01 additional futures extension utils
Signed-off-by: Jason Volk <jason@zemos.net>
2025-01-09 18:14:23 +00:00
Jason Volk a3f9432da8 eliminate the state-res mutex hazard
Signed-off-by: Jason Volk <jason@zemos.net>
2025-01-09 18:14:23 +00:00
Jason Volk 82168b972a fix heroes calculation regression
Signed-off-by: Jason Volk <jason@zemos.net>
2025-01-09 18:14:23 +00:00
Neil Svedberg 7526ba9d6f Add header to console
When the console is launched, it now prints this message:

    conduwuit VERSION admin console
    "help" for help, ^D to exit the console, ^\ to stop the server
2025-01-09 10:18:49 -05:00
Jade Ellis 8c74e35e76 automatically retry returning data in syncv3 (#652)
* automatically retry returning data in syncv3

* reference service

* clippy fixes
2025-01-03 22:15:48 -05:00
102 changed files with 3677 additions and 1710 deletions
+1 -1
View File
@@ -733,7 +733,7 @@ jobs:
docker:
name: Docker publish
runs-on: ubuntu-24.04
needs: [build, variables]
needs: [build, variables, tests]
permissions:
packages: write
contents: read
Generated
-1
View File
@@ -723,7 +723,6 @@ dependencies = [
"hardened_malloc-rs",
"http",
"http-body-util",
"image",
"ipaddress",
"itertools 0.13.0",
"libc",
+29 -15
View File
@@ -100,20 +100,6 @@
#
#database_backups_to_keep = 1
# Set this to any float value in megabytes for conduwuit to tell the
# database engine that this much memory is available for database-related
# caches.
#
# May be useful if you have significant memory to spare to increase
# performance.
#
# Similar to the individual LRU caches, this is scaled up with your CPU
# core count.
#
# This defaults to 128.0 + (64.0 * CPU core count).
#
#db_cache_capacity_mb = varies by system
# Text which will be added to the end of the user's displayname upon
# registration with a space before the text. In Conduit, this was the
# lightning bolt emoji.
@@ -149,6 +135,34 @@
#
#cache_capacity_modifier = 1.0
# Set this to any float value in megabytes for conduwuit to tell the
# database engine that this much memory is available for database read
# caches.
#
# May be useful if you have significant memory to spare to increase
# performance.
#
# Similar to the individual LRU caches, this is scaled up with your CPU
# core count.
#
# This defaults to 128.0 + (64.0 * CPU core count).
#
#db_cache_capacity_mb = varies by system
# Set this to any float value in megabytes for conduwuit to tell the
# database engine that this much memory is available for database write
# caches.
#
# May be useful if you have significant memory to spare to increase
# performance.
#
# Similar to the individual LRU caches, this is scaled up with your CPU
# core count.
#
# This defaults to 48.0 + (4.0 * CPU core count).
#
#db_write_buffer_capacity_mb = varies by system
# This item is undocumented. Please contribute documentation for it.
#
#pdu_cache_capacity = varies by system
@@ -1452,7 +1466,7 @@
# responsiveness for many users at the cost of throughput for each.
#
# Setting this value to 0.0 causes the stream width to be fixed at the
# value of stream_width_default. The default is 1.0 to match the
# value of stream_width_default. The default scale is 1.0 to match the
# capabilities detected for the system.
#
#stream_width_scale = 1.0
+108 -48
View File
@@ -191,27 +191,59 @@
in
{
packages = {
default = scopeHost.main;
default-debug = scopeHost.main.override {
profile = "dev";
# debug build users expect full logs
disable_release_max_log_level = true;
};
default-test = scopeHost.main.override {
profile = "test";
disable_release_max_log_level = true;
};
all-features = scopeHost.main.override {
all_features = true;
default = scopeHost.main.override {
disable_features = [
# this is non-functional on nix for some reason
"hardened_malloc"
# dont include experimental features
"experimental"
# jemalloc profiling/stats features are expensive and shouldn't
# be expected on non-debug builds.
"jemalloc_prof"
"jemalloc_stats"
# this is non-functional on nix for some reason
"hardened_malloc"
# conduwuit_mods is a development-only hot reload feature
"conduwuit_mods"
];
};
default-debug = scopeHost.main.override {
profile = "dev";
# debug build users expect full logs
disable_release_max_log_level = true;
disable_features = [
# dont include experimental features
"experimental"
# this is non-functional on nix for some reason
"hardened_malloc"
# conduwuit_mods is a development-only hot reload feature
"conduwuit_mods"
];
};
# just a test profile used for things like CI and complement
default-test = scopeHost.main.override {
profile = "test";
disable_release_max_log_level = true;
disable_features = [
# dont include experimental features
"experimental"
# this is non-functional on nix for some reason
"hardened_malloc"
# conduwuit_mods is a development-only hot reload feature
"conduwuit_mods"
];
};
all-features = scopeHost.main.override {
all_features = true;
disable_features = [
# dont include experimental features
"experimental"
# jemalloc profiling/stats features are expensive and shouldn't
# be expected on non-debug builds.
"jemalloc_prof"
"jemalloc_stats"
# this is non-functional on nix for some reason
"hardened_malloc"
# conduwuit_mods is a development-only hot reload feature
"conduwuit_mods"
];
};
all-features-debug = scopeHost.main.override {
@@ -220,10 +252,12 @@
# debug build users expect full logs
disable_release_max_log_level = true;
disable_features = [
# this is non-functional on nix for some reason
"hardened_malloc"
# dont include experimental features
"experimental"
# this is non-functional on nix for some reason
"hardened_malloc"
# conduwuit_mods is a development-only hot reload feature
"conduwuit_mods"
];
};
hmalloc = scopeHost.main.override { features = ["hardened_malloc"]; };
@@ -233,14 +267,16 @@
main = scopeHost.main.override {
all_features = true;
disable_features = [
# this is non-functional on nix for some reason
"hardened_malloc"
# dont include experimental features
"experimental"
# jemalloc profiling/stats features are expensive and shouldn't
# be expected on non-debug builds.
"jemalloc_prof"
"jemalloc_stats"
# this is non-functional on nix for some reason
"hardened_malloc"
# conduwuit_mods is a development-only hot reload feature
"conduwuit_mods"
];
};
};
@@ -251,10 +287,12 @@
# debug build users expect full logs
disable_release_max_log_level = true;
disable_features = [
# this is non-functional on nix for some reason
"hardened_malloc"
# dont include experimental features
"experimental"
# this is non-functional on nix for some reason
"hardened_malloc"
# conduwuit_mods is a development-only hot reload feature
"conduwuit_mods"
];
};
};
@@ -313,6 +351,14 @@
value = scopeCrossStatic.main.override {
profile = "test";
disable_release_max_log_level = true;
disable_features = [
# dont include experimental features
"experimental"
# this is non-functional on nix for some reason
"hardened_malloc"
# conduwuit_mods is a development-only hot reload feature
"conduwuit_mods"
];
};
}
@@ -322,14 +368,16 @@
value = scopeCrossStatic.main.override {
all_features = true;
disable_features = [
# this is non-functional on nix for some reason
"hardened_malloc"
# dont include experimental features
"experimental"
# jemalloc profiling/stats features are expensive and shouldn't
# be expected on non-debug builds.
"jemalloc_prof"
"jemalloc_stats"
# this is non-functional on nix for some reason
"hardened_malloc"
# conduwuit_mods is a development-only hot reload feature
"conduwuit_mods"
];
};
}
@@ -341,14 +389,16 @@
value = scopeCrossStatic.main.override {
all_features = true;
disable_features = [
# this is non-functional on nix for some reason
"hardened_malloc"
# dont include experimental features
"experimental"
# jemalloc profiling/stats features are expensive and shouldn't
# be expected on non-debug builds.
"jemalloc_prof"
"jemalloc_stats"
# this is non-functional on nix for some reason
"hardened_malloc"
# conduwuit_mods is a development-only hot reload feature
"conduwuit_mods"
];
x86_64_haswell_target_optimised = (if (crossSystem == "x86_64-linux-gnu" || crossSystem == "x86_64-linux-musl") then true else false);
};
@@ -363,10 +413,12 @@
# debug build users expect full logs
disable_release_max_log_level = true;
disable_features = [
# this is non-functional on nix for some reason
"hardened_malloc"
# dont include experimental features
"experimental"
# this is non-functional on nix for some reason
"hardened_malloc"
# conduwuit_mods is a development-only hot reload feature
"conduwuit_mods"
];
};
}
@@ -415,14 +467,16 @@
main = scopeCrossStatic.main.override {
all_features = true;
disable_features = [
# this is non-functional on nix for some reason
"hardened_malloc"
# dont include experimental features
"experimental"
# jemalloc profiling/stats features are expensive and shouldn't
# be expected on non-debug builds.
"jemalloc_prof"
"jemalloc_stats"
# dont include experimental features
"experimental"
# jemalloc profiling/stats features are expensive and shouldn't
# be expected on non-debug builds.
"jemalloc_prof"
"jemalloc_stats"
# this is non-functional on nix for some reason
"hardened_malloc"
# conduwuit_mods is a development-only hot reload feature
"conduwuit_mods"
];
};
};
@@ -436,14 +490,16 @@
main = scopeCrossStatic.main.override {
all_features = true;
disable_features = [
# this is non-functional on nix for some reason
"hardened_malloc"
# dont include experimental features
"experimental"
# jemalloc profiling/stats features are expensive and shouldn't
# be expected on non-debug builds.
"jemalloc_prof"
"jemalloc_stats"
# dont include experimental features
"experimental"
# jemalloc profiling/stats features are expensive and shouldn't
# be expected on non-debug builds.
"jemalloc_prof"
"jemalloc_stats"
# this is non-functional on nix for some reason
"hardened_malloc"
# conduwuit_mods is a development-only hot reload feature
"conduwuit_mods"
];
x86_64_haswell_target_optimised = (if (crossSystem == "x86_64-linux-gnu" || crossSystem == "x86_64-linux-musl") then true else false);
};
@@ -460,10 +516,12 @@
# debug build users expect full logs
disable_release_max_log_level = true;
disable_features = [
# this is non-functional on nix for some reason
"hardened_malloc"
# dont include experimental features
"experimental"
# dont include experimental features
"experimental"
# this is non-functional on nix for some reason
"hardened_malloc"
# conduwuit_mods is a development-only hot reload feature
"conduwuit_mods"
];
};
};
@@ -502,14 +560,16 @@
main = prev.main.override {
all_features = true;
disable_features = [
# this is non-functional on nix for some reason
"hardened_malloc"
# dont include experimental features
"experimental"
# jemalloc profiling/stats features are expensive and shouldn't
# be expected on non-debug builds.
"jemalloc_prof"
"jemalloc_stats"
# this is non-functional on nix for some reason
"hardened_malloc"
# conduwuit_mods is a development-only hot reload feature
"conduwuit_mods"
];
};
}));
+10
View File
@@ -20,6 +20,8 @@ let
disable_features = [
# no reason to use jemalloc for complement, just has compatibility/build issues
"jemalloc"
"jemalloc_stats"
"jemalloc_prof"
# console/CLI stuff isn't used or relevant for complement
"console"
"tokio_console"
@@ -32,6 +34,14 @@ let
"hardened_malloc"
# dont include experimental features
"experimental"
# compression isn't needed for complement
"brotli_compression"
"gzip_compression"
"zstd_compression"
# complement doesn't need hot reloading
"conduwuit_mods"
# complement doesn't have URL preview media tests
"url_preview"
];
};
+13 -1
View File
@@ -15,7 +15,19 @@
# Options (keep sorted)
, all_features ? false
, default_features ? true
, disable_features ? []
# default list of disabled features
, disable_features ? [
# dont include experimental features
"experimental"
# jemalloc profiling/stats features are expensive and shouldn't
# be expected on non-debug builds.
"jemalloc_prof"
"jemalloc_stats"
# this is non-functional on nix for some reason
"hardened_malloc"
# conduwuit_mods is a development-only hot reload feature
"conduwuit_mods"
]
, disable_release_max_log_level ? false
, features ? []
, profile ? "release"
+2 -2
View File
@@ -911,8 +911,8 @@ pub(super) async fn database_stats(
let map_name = map.as_ref().map_or(EMPTY, String::as_str);
let mut out = String::new();
for (name, map) in self.services.db.iter() {
if !map_name.is_empty() && *map_name != *name {
for (&name, map) in self.services.db.iter() {
if !map_name.is_empty() && map_name != name {
continue;
}
+61
View File
@@ -0,0 +1,61 @@
use clap::Subcommand;
use conduwuit::{utils::stream::TryTools, PduCount, Result};
use futures::TryStreamExt;
use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomOrAliasId};
use crate::{admin_command, admin_command_dispatch};
#[admin_command_dispatch]
#[derive(Debug, Subcommand)]
/// Query tables from database
pub(crate) enum RoomTimelineCommand {
Pdus {
room_id: OwnedRoomOrAliasId,
from: Option<String>,
#[arg(short, long)]
limit: Option<usize>,
},
Last {
room_id: OwnedRoomOrAliasId,
},
}
#[admin_command]
pub(super) async fn last(&self, room_id: OwnedRoomOrAliasId) -> Result<RoomMessageEventContent> {
let room_id = self.services.rooms.alias.resolve(&room_id).await?;
let result = self
.services
.rooms
.timeline
.last_timeline_count(None, &room_id)
.await?;
Ok(RoomMessageEventContent::notice_markdown(format!("{result:#?}")))
}
#[admin_command]
pub(super) async fn pdus(
&self,
room_id: OwnedRoomOrAliasId,
from: Option<String>,
limit: Option<usize>,
) -> Result<RoomMessageEventContent> {
let room_id = self.services.rooms.alias.resolve(&room_id).await?;
let from: Option<PduCount> = from.as_deref().map(str::parse).transpose()?;
let result: Vec<_> = self
.services
.rooms
.timeline
.pdus_rev(None, &room_id, from)
.try_take(limit.unwrap_or(3))
.try_collect()
.await?;
Ok(RoomMessageEventContent::notice_markdown(format!("{result:#?}")))
}
+31 -36
View File
@@ -1,14 +1,12 @@
use std::iter::once;
use axum::extract::State;
use conduwuit::{
at, err, ref_at,
utils::{
future::TryExtExt,
stream::{BroadbandExt, ReadyExt, WidebandExt},
stream::{BroadbandExt, ReadyExt, TryIgnore, WidebandExt},
IterStream,
},
Err, Result,
Err, PduEvent, Result,
};
use futures::{join, try_join, FutureExt, StreamExt, TryFutureExt};
use ruma::{
@@ -59,13 +57,13 @@ pub(crate) async fn get_context_route(
false
};
let base_token = services
let base_id = services
.rooms
.timeline
.get_pdu_count(&body.event_id)
.get_pdu_id(&body.event_id)
.map_err(|_| err!(Request(NotFound("Event not found."))));
let base_event = services
let base_pdu = services
.rooms
.timeline
.get_pdu(&body.event_id)
@@ -77,48 +75,44 @@ pub(crate) async fn get_context_route(
.user_can_see_event(sender_user, &body.room_id, &body.event_id)
.map(Ok);
let (base_token, base_event, visible) = try_join!(base_token, base_event, visible)?;
let (base_id, base_pdu, visible) = try_join!(base_id, base_pdu, visible)?;
if base_event.room_id != body.room_id || base_event.event_id != body.event_id {
if base_pdu.room_id != body.room_id || base_pdu.event_id != body.event_id {
return Err!(Request(NotFound("Base event not found.")));
}
if !visible
|| ignored_filter(&services, (base_token, base_event.clone()), sender_user)
.await
.is_none()
{
if !visible {
return Err!(Request(Forbidden("You don't have permission to view this event.")));
}
let events_before =
services
.rooms
.timeline
.pdus_rev(Some(sender_user), room_id, Some(base_token));
let base_count = base_id.pdu_count();
let base_event = ignored_filter(&services, (base_count, base_pdu), sender_user);
let events_before = services
.rooms
.timeline
.pdus_rev(Some(sender_user), room_id, Some(base_count))
.ignore_err()
.ready_filter_map(|item| event_filter(item, filter))
.wide_filter_map(|item| ignored_filter(&services, item, sender_user))
.wide_filter_map(|item| visibility_filter(&services, item, sender_user))
.take(limit / 2)
.collect();
let events_after = services
.rooms
.timeline
.pdus(Some(sender_user), room_id, Some(base_token));
let (events_before, events_after) = try_join!(events_before, events_after)?;
let events_before = events_before
.pdus(Some(sender_user), room_id, Some(base_count))
.ignore_err()
.ready_filter_map(|item| event_filter(item, filter))
.wide_filter_map(|item| ignored_filter(&services, item, sender_user))
.wide_filter_map(|item| visibility_filter(&services, item, sender_user))
.take(limit / 2)
.collect();
let events_after = events_after
.ready_filter_map(|item| event_filter(item, filter))
.wide_filter_map(|item| ignored_filter(&services, item, sender_user))
.wide_filter_map(|item| visibility_filter(&services, item, sender_user))
.take(limit / 2)
.collect();
let (events_before, events_after): (Vec<_>, Vec<_>) = join!(events_before, events_after);
let (base_event, events_before, events_after): (_, Vec<_>, Vec<_>) =
join!(base_event, events_before, events_after);
let state_at = events_after
.last()
@@ -134,7 +128,8 @@ pub(crate) async fn get_context_route(
.map_err(|e| err!(Database("State not found: {e}")))
.await?;
let lazy = once(&(base_token, base_event.clone()))
let lazy = base_event
.iter()
.chain(events_before.iter())
.chain(events_after.iter())
.stream()
@@ -175,19 +170,19 @@ pub(crate) async fn get_context_route(
.await;
Ok(get_context::v3::Response {
event: Some(base_event.to_room_event()),
event: base_event.map(at!(1)).as_ref().map(PduEvent::to_room_event),
start: events_before
.last()
.map(at!(0))
.or(Some(base_token))
.or(Some(base_count))
.as_ref()
.map(ToString::to_string),
end: events_after
.last()
.map(at!(0))
.or(Some(base_token))
.or(Some(base_count))
.as_ref()
.map(ToString::to_string),
+2
View File
@@ -1314,6 +1314,7 @@ async fn join_room_by_id_helper_local(
.rooms
.event_handler
.handle_incoming_pdu(&remote_server, room_id, &signed_event_id, signed_value, true)
.boxed()
.await?;
} else {
return Err(error);
@@ -1491,6 +1492,7 @@ pub(crate) async fn invite_helper(
.rooms
.event_handler
.handle_incoming_pdu(&origin, room_id, &event_id, value, true)
.boxed()
.await?
.ok_or_else(|| {
err!(Request(InvalidParam("Could not accept incoming PDU as timeline event.")))
+3 -3
View File
@@ -5,7 +5,7 @@ use conduwuit::{
at, is_equal_to,
utils::{
result::{FlatOk, LogErr},
stream::{BroadbandExt, WidebandExt},
stream::{BroadbandExt, TryIgnore, WidebandExt},
IterStream, ReadyExt,
},
Event, PduCount, Result,
@@ -107,14 +107,14 @@ pub(crate) async fn get_message_events_route(
.rooms
.timeline
.pdus(Some(sender_user), room_id, Some(from))
.await?
.ignore_err()
.boxed(),
| Direction::Backward => services
.rooms
.timeline
.pdus_rev(Some(sender_user), room_id, Some(from))
.await?
.ignore_err()
.boxed(),
};
+9 -6
View File
@@ -1,6 +1,10 @@
use axum::extract::State;
use conduwuit::{at, utils::BoolExt, Err, Result};
use futures::StreamExt;
use conduwuit::{
at,
utils::{stream::TryTools, BoolExt},
Err, Result,
};
use futures::TryStreamExt;
use ruma::api::client::room::initial_sync::v3::{PaginationChunk, Request, Response};
use crate::Ruma;
@@ -27,10 +31,9 @@ pub(crate) async fn room_initial_sync_route(
.rooms
.timeline
.pdus_rev(None, room_id, None)
.await?
.take(limit)
.collect()
.await;
.try_take(limit)
.try_collect()
.await?;
let state: Vec<_> = services
.rooms
+55 -14
View File
@@ -1,16 +1,31 @@
mod v3;
mod v4;
mod v5;
use conduwuit::{
utils::stream::{BroadbandExt, ReadyExt},
utils::{
stream::{BroadbandExt, ReadyExt, TryIgnore},
IterStream,
},
PduCount,
};
use futures::StreamExt;
use ruma::{RoomId, UserId};
use futures::{pin_mut, StreamExt};
use ruma::{
directory::RoomTypeFilter,
events::TimelineEventType::{
self, Beacon, CallInvite, PollStart, RoomEncrypted, RoomMessage, Sticker,
},
RoomId, UserId,
};
pub(crate) use self::{v3::sync_events_route, v4::sync_events_v4_route};
pub(crate) use self::{
v3::sync_events_route, v4::sync_events_v4_route, v5::sync_events_v5_route,
};
use crate::{service::Services, Error, PduEvent, Result};
pub(crate) const DEFAULT_BUMP_TYPES: &[TimelineEventType; 6] =
&[CallInvite, PollStart, Beacon, RoomEncrypted, RoomMessage, Sticker];
async fn load_timeline(
services: &Services,
sender_user: &UserId,
@@ -29,23 +44,19 @@ async fn load_timeline(
return Ok((Vec::new(), false));
}
let mut non_timeline_pdus = services
let non_timeline_pdus = services
.rooms
.timeline
.pdus_rev(Some(sender_user), room_id, None)
.await?
.ignore_err()
.ready_skip_while(|&(pducount, _)| pducount > next_batch.unwrap_or_else(PduCount::max))
.ready_take_while(|&(pducount, _)| pducount > roomsincecount);
// Take the last events for the timeline
let timeline_pdus: Vec<_> = non_timeline_pdus
.by_ref()
.take(limit)
.collect::<Vec<_>>()
.await
.into_iter()
.rev()
.collect();
pin_mut!(non_timeline_pdus);
let timeline_pdus: Vec<_> = non_timeline_pdus.by_ref().take(limit).collect().await;
let timeline_pdus: Vec<_> = timeline_pdus.into_iter().rev().collect();
// They /sync response doesn't always return all messages, so we say the output
// is limited unless there are events in non_timeline_pdus
@@ -73,3 +84,33 @@ async fn share_encrypted_room(
})
.await
}
pub(crate) async fn filter_rooms<'a>(
services: &Services,
rooms: &[&'a RoomId],
filter: &[RoomTypeFilter],
negate: bool,
) -> Vec<&'a RoomId> {
rooms
.iter()
.stream()
.filter_map(|r| async move {
let room_type = services.rooms.state_accessor.get_room_type(r).await;
if room_type.as_ref().is_err_and(|e| !e.is_not_found()) {
return None;
}
let room_type_filter = RoomTypeFilter::from(room_type.ok());
let include = if negate {
!filter.contains(&room_type_filter)
} else {
filter.is_empty() || filter.contains(&room_type_filter)
};
include.then_some(r)
})
.collect()
.await
}
+32 -20
View File
@@ -124,6 +124,33 @@ pub(crate) async fn sync_events_route(
// Setup watchers, so if there's no response, we can wait for them
let watcher = services.sync.watch(sender_user, sender_device);
let response = build_sync_events(&services, &body).await?;
if body.body.full_state
|| !(response.rooms.is_empty()
&& response.presence.is_empty()
&& response.account_data.is_empty()
&& response.device_lists.is_empty()
&& response.to_device.is_empty())
{
return Ok(response);
}
// Hang a few seconds so requests are not spammed
// Stop hanging if new info arrives
let default = Duration::from_secs(30);
let duration = cmp::min(body.body.timeout.unwrap_or(default), default);
_ = tokio::time::timeout(duration, watcher).await;
// Retry returning data
build_sync_events(&services, &body).await
}
pub(crate) async fn build_sync_events(
services: &Services,
body: &Ruma<sync_events::v3::Request>,
) -> Result<sync_events::v3::Response, RumaResponse<UiaaResponse>> {
let (sender_user, sender_device) = body.sender();
let next_batch = services.globals.current_count()?;
let next_batch_string = next_batch.to_string();
@@ -163,7 +190,7 @@ pub(crate) async fn sync_events_route(
.map(ToOwned::to_owned)
.broad_filter_map(|room_id| {
load_joined_room(
&services,
services,
sender_user,
sender_device,
room_id.clone(),
@@ -196,7 +223,7 @@ pub(crate) async fn sync_events_route(
.rooms_left(sender_user)
.broad_filter_map(|(room_id, _)| {
handle_left_room(
&services,
services,
since,
room_id.clone(),
sender_user,
@@ -242,7 +269,7 @@ pub(crate) async fn sync_events_route(
let presence_updates: OptionFuture<_> = services
.globals
.allow_local_presence()
.then(|| process_presence_updates(&services, since, sender_user))
.then(|| process_presence_updates(services, since, sender_user))
.into();
let account_data = services
@@ -292,7 +319,7 @@ pub(crate) async fn sync_events_route(
.stream()
.broad_filter_map(|user_id| async move {
let no_shared_encrypted_room =
!share_encrypted_room(&services, sender_user, &user_id, None).await;
!share_encrypted_room(services, sender_user, &user_id, None).await;
no_shared_encrypted_room.then_some(user_id)
})
.ready_fold(HashSet::new(), |mut device_list_left, user_id| {
@@ -327,21 +354,6 @@ pub(crate) async fn sync_events_route(
to_device: ToDevice { events: to_device_events },
};
// TODO: Retry the endpoint instead of returning
if !full_state
&& response.rooms.is_empty()
&& response.presence.is_empty()
&& response.account_data.is_empty()
&& response.device_lists.is_empty()
&& response.to_device.is_empty()
{
// Hang a few seconds so requests are not spammed
// Stop hanging if new info arrives
let default = Duration::from_secs(30);
let duration = cmp::min(body.body.timeout.unwrap_or(default), default);
_ = tokio::time::timeout(duration, watcher).await;
}
Ok(response)
}
@@ -1223,7 +1235,7 @@ async fn calculate_counts(
let (joined_member_count, invited_member_count) =
join(joined_member_count, invited_member_count).await;
let small_room = joined_member_count.saturating_add(invited_member_count) > 5;
let small_room = joined_member_count.saturating_add(invited_member_count) <= 5;
let heroes: OptionFuture<_> = small_room
.then(|| calculate_heroes(services, room_id, sender_user))
+33 -58
View File
@@ -23,24 +23,23 @@ use ruma::{
DeviceLists, UnreadNotificationsCount,
},
},
directory::RoomTypeFilter,
events::{
room::member::{MembershipState, RoomMemberEventContent},
AnyRawAccountDataEvent, AnySyncEphemeralRoomEvent, StateEventType,
TimelineEventType::{self, *},
TimelineEventType::*,
},
serde::Raw,
uint, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedRoomId, OwnedUserId, UInt,
uint, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedRoomId, OwnedUserId, RoomId, UInt,
};
use service::{rooms::read_receipt::pack_receipts, Services};
use service::rooms::read_receipt::pack_receipts;
use super::{load_timeline, share_encrypted_room};
use crate::{client::ignored_filter, Ruma};
use crate::{
client::{filter_rooms, ignored_filter, sync::v5::TodoRooms, DEFAULT_BUMP_TYPES},
Ruma,
};
const SINGLE_CONNECTION_SYNC: &str = "single_connection_sync";
const DEFAULT_BUMP_TYPES: &[TimelineEventType; 6] =
&[CallInvite, PollStart, Beacon, RoomEncrypted, RoomMessage, Sticker];
pub(crate) const SINGLE_CONNECTION_SYNC: &str = "single_connection_sync";
/// POST `/_matrix/client/unstable/org.matrix.msc3575/sync`
///
@@ -113,12 +112,17 @@ pub(crate) async fn sync_events_v4_route(
.collect()
.await;
let all_rooms = all_joined_rooms
let all_invited_rooms: Vec<&RoomId> = all_invited_rooms.iter().map(AsRef::as_ref).collect();
let all_rooms: Vec<&RoomId> = all_joined_rooms
.iter()
.chain(all_invited_rooms.iter())
.map(Clone::clone)
.map(AsRef::as_ref)
.chain(all_invited_rooms.iter().map(AsRef::as_ref))
.collect();
let all_joined_rooms = all_joined_rooms.iter().map(AsRef::as_ref).collect();
let all_invited_rooms = all_invited_rooms.iter().map(AsRef::as_ref).collect();
if body.extensions.to_device.enabled.unwrap_or(false) {
services
.users
@@ -171,6 +175,7 @@ pub(crate) async fn sync_events_v4_route(
);
for room_id in &all_joined_rooms {
let room_id: &&RoomId = room_id;
let Ok(current_shortstatehash) =
services.rooms.state.get_room_shortstatehash(room_id).await
else {
@@ -323,7 +328,7 @@ pub(crate) async fn sync_events_v4_route(
}
let mut lists = BTreeMap::new();
let mut todo_rooms = BTreeMap::new(); // and required state
let mut todo_rooms: TodoRooms = BTreeMap::new(); // and required state
for (list_id, list) in &body.lists {
let active_rooms = match list.filters.clone().and_then(|f| f.is_invite) {
@@ -344,7 +349,7 @@ pub(crate) async fn sync_events_v4_route(
| None => active_rooms,
};
let mut new_known_rooms = BTreeSet::new();
let mut new_known_rooms: BTreeSet<OwnedRoomId> = BTreeSet::new();
let ranges = list.ranges.clone();
lists.insert(list_id.clone(), sync_events::v4::SyncList {
@@ -366,9 +371,9 @@ pub(crate) async fn sync_events_v4_route(
Vec::new()
};
new_known_rooms.extend(room_ids.iter().cloned());
new_known_rooms.extend(room_ids.clone().into_iter().map(ToOwned::to_owned));
for room_id in &room_ids {
let todo_room = todo_rooms.entry(room_id.clone()).or_insert((
let todo_room = todo_rooms.entry((*room_id).to_owned()).or_insert((
BTreeSet::new(),
0_usize,
u64::MAX,
@@ -390,7 +395,7 @@ pub(crate) async fn sync_events_v4_route(
todo_room.2 = todo_room.2.min(
known_rooms
.get(list_id.as_str())
.and_then(|k| k.get(room_id))
.and_then(|k| k.get(*room_id))
.copied()
.unwrap_or(0),
);
@@ -399,7 +404,7 @@ pub(crate) async fn sync_events_v4_route(
op: SlidingOp::Sync,
range: Some(r),
index: None,
room_ids,
room_ids: room_ids.into_iter().map(ToOwned::to_owned).collect(),
room_id: None,
}
})
@@ -409,8 +414,8 @@ pub(crate) async fn sync_events_v4_route(
if let Some(conn_id) = &body.conn_id {
services.sync.update_sync_known_rooms(
sender_user.clone(),
sender_device.clone(),
sender_user,
&sender_device,
conn_id.clone(),
list_id.clone(),
new_known_rooms,
@@ -455,8 +460,8 @@ pub(crate) async fn sync_events_v4_route(
if let Some(conn_id) = &body.conn_id {
services.sync.update_sync_known_rooms(
sender_user.clone(),
sender_device.clone(),
sender_user,
&sender_device,
conn_id.clone(),
"subscriptions".to_owned(),
known_subscription_rooms,
@@ -480,7 +485,8 @@ pub(crate) async fn sync_events_v4_route(
let mut timestamp: Option<_> = None;
let mut invite_state = None;
let (timeline_pdus, limited);
if all_invited_rooms.contains(room_id) {
let new_room_id: &RoomId = (*room_id).as_ref();
if all_invited_rooms.contains(&new_room_id) {
// TODO: figure out a timestamp we can use for remote invites
invite_state = services
.rooms
@@ -510,7 +516,7 @@ pub(crate) async fn sync_events_v4_route(
}
account_data.rooms.insert(
room_id.clone(),
room_id.to_owned(),
services
.account_data
.changes_since(Some(room_id), sender_user, *roomsince)
@@ -740,10 +746,9 @@ pub(crate) async fn sync_events_v4_route(
});
}
if rooms
.iter()
.all(|(_, r)| r.timeline.is_empty() && r.required_state.is_empty())
{
if rooms.iter().all(|(id, r)| {
r.timeline.is_empty() && r.required_state.is_empty() && !receipts.rooms.contains_key(id)
}) {
// Hang a few seconds so requests are not spammed
// Stop hanging if new info arrives
let default = Duration::from_secs(30);
@@ -789,33 +794,3 @@ pub(crate) async fn sync_events_v4_route(
delta_token: None,
})
}
async fn filter_rooms(
services: &Services,
rooms: &[OwnedRoomId],
filter: &[RoomTypeFilter],
negate: bool,
) -> Vec<OwnedRoomId> {
rooms
.iter()
.stream()
.filter_map(|r| async move {
let room_type = services.rooms.state_accessor.get_room_type(r).await;
if room_type.as_ref().is_err_and(|e| !e.is_not_found()) {
return None;
}
let room_type_filter = RoomTypeFilter::from(room_type.ok());
let include = if negate {
!filter.contains(&room_type_filter)
} else {
filter.is_empty() || filter.contains(&room_type_filter)
};
include.then_some(r.to_owned())
})
.collect()
.await
}
+871
View File
@@ -0,0 +1,871 @@
use std::{
cmp::{self, Ordering},
collections::{BTreeMap, BTreeSet, HashMap, HashSet},
time::Duration,
};
use axum::extract::State;
use conduwuit::{
debug, error, extract_variant, trace,
utils::{
math::{ruma_from_usize, usize_from_ruma},
BoolExt, IterStream, ReadyExt, TryFutureExtExt,
},
warn, Error, Result,
};
use futures::{FutureExt, StreamExt, TryFutureExt};
use ruma::{
api::client::{
error::ErrorKind,
sync::sync_events::{self, DeviceLists, UnreadNotificationsCount},
},
events::{
room::member::{MembershipState, RoomMemberEventContent},
AnyRawAccountDataEvent, AnySyncEphemeralRoomEvent, StateEventType, TimelineEventType,
},
serde::Raw,
state_res::TypeStateKey,
uint, DeviceId, OwnedEventId, OwnedRoomId, OwnedUserId, RoomId, UInt, UserId,
};
use service::{rooms::read_receipt::pack_receipts, PduCount};
use super::{filter_rooms, share_encrypted_room};
use crate::{
client::{ignored_filter, sync::load_timeline, DEFAULT_BUMP_TYPES},
Ruma,
};
type SyncInfo<'a> = (&'a UserId, &'a DeviceId, u64, &'a sync_events::v5::Request);
pub(crate) async fn sync_events_v5_route(
State(services): State<crate::State>,
body: Ruma<sync_events::v5::Request>,
) -> Result<sync_events::v5::Response> {
debug_assert!(DEFAULT_BUMP_TYPES.is_sorted(), "DEFAULT_BUMP_TYPES is not sorted");
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
let mut body = body.body;
// Setup watchers, so if there's no response, we can wait for them
let watcher = services.sync.watch(sender_user, sender_device);
let next_batch = services.globals.next_count()?;
let conn_id = body.conn_id.clone();
let globalsince = body
.pos
.as_ref()
.and_then(|string| string.parse().ok())
.unwrap_or(0);
if globalsince != 0
&& !services.sync.snake_connection_cached(
sender_user.clone(),
sender_device.clone(),
conn_id.clone(),
) {
debug!("Restarting sync stream because it was gone from the database");
return Err(Error::Request(
ErrorKind::UnknownPos,
"Connection data lost since last time".into(),
http::StatusCode::BAD_REQUEST,
));
}
// Client / User requested an initial sync
if globalsince == 0 {
services.sync.forget_snake_sync_connection(
sender_user.clone(),
sender_device.clone(),
conn_id.clone(),
);
}
// Get sticky parameters from cache
let known_rooms = services.sync.update_snake_sync_request_with_cache(
sender_user.clone(),
sender_device.clone(),
&mut body,
);
let all_joined_rooms: Vec<_> = services
.rooms
.state_cache
.rooms_joined(sender_user)
.map(ToOwned::to_owned)
.collect()
.await;
let all_invited_rooms: Vec<_> = services
.rooms
.state_cache
.rooms_invited(sender_user)
.map(|r| r.0)
.collect()
.await;
let all_rooms: Vec<&RoomId> = all_joined_rooms
.iter()
.map(AsRef::as_ref)
.chain(all_invited_rooms.iter().map(AsRef::as_ref))
.collect();
let all_joined_rooms = all_joined_rooms.iter().map(AsRef::as_ref).collect();
let all_invited_rooms = all_invited_rooms.iter().map(AsRef::as_ref).collect();
let pos = next_batch.clone().to_string();
let mut todo_rooms: TodoRooms = BTreeMap::new();
let sync_info: SyncInfo<'_> = (sender_user, sender_device, globalsince, &body);
let mut response = sync_events::v5::Response {
txn_id: body.txn_id.clone(),
pos,
lists: BTreeMap::new(),
rooms: BTreeMap::new(),
extensions: sync_events::v5::response::Extensions {
account_data: collect_account_data(services, sync_info).await,
e2ee: collect_e2ee(services, sync_info, &all_joined_rooms).await?,
to_device: collect_to_device(services, sync_info, next_batch).await,
receipts: collect_receipts(services).await,
typing: sync_events::v5::response::Typing::default(),
},
};
{
let _test2 = handle_lists(
services,
sync_info,
&all_invited_rooms,
&all_joined_rooms,
&all_rooms,
&mut todo_rooms,
&known_rooms,
&mut response,
)
.await;
}
{
fetch_subscriptions(services, sync_info, &known_rooms, &mut todo_rooms).await;
};
response.rooms = process_rooms(
services,
sender_user,
next_batch,
&all_invited_rooms,
&todo_rooms,
&mut response,
&body,
)
.await?;
if response.rooms.iter().all(|(id, r)| {
r.timeline.is_empty()
&& r.required_state.is_empty()
&& !response.extensions.receipts.rooms.contains_key(id)
}) && response
.extensions
.to_device
.clone()
.is_none_or(|to| to.events.is_empty())
{
// Hang a few seconds so requests are not spammed
// Stop hanging if new info arrives
let default = Duration::from_secs(30);
let duration = cmp::min(body.timeout.unwrap_or(default), default);
_ = tokio::time::timeout(duration, watcher).await;
}
trace!(
rooms=?response.rooms.len(),
account_data=?response.extensions.account_data.rooms.len(),
receipts=?response.extensions.receipts.rooms.len(),
"responding to request with"
);
Ok(response)
}
type KnownRooms = BTreeMap<String, BTreeMap<OwnedRoomId, u64>>;
pub(crate) type TodoRooms = BTreeMap<OwnedRoomId, (BTreeSet<TypeStateKey>, usize, u64)>;
async fn fetch_subscriptions(
services: crate::State,
(sender_user, sender_device, globalsince, body): SyncInfo<'_>,
known_rooms: &KnownRooms,
todo_rooms: &mut TodoRooms,
) {
let mut known_subscription_rooms = BTreeSet::new();
for (room_id, room) in &body.room_subscriptions {
if !services.rooms.metadata.exists(room_id).await {
continue;
}
let todo_room =
todo_rooms
.entry(room_id.clone())
.or_insert((BTreeSet::new(), 0_usize, u64::MAX));
let limit: UInt = room.timeline_limit;
todo_room.0.extend(room.required_state.iter().cloned());
todo_room.1 = todo_room.1.max(usize_from_ruma(limit));
// 0 means unknown because it got out of date
todo_room.2 = todo_room.2.min(
known_rooms
.get("subscriptions")
.and_then(|k| k.get(room_id))
.copied()
.unwrap_or(0),
);
known_subscription_rooms.insert(room_id.clone());
}
// where this went (protomsc says it was removed)
//for r in body.unsubscribe_rooms {
// known_subscription_rooms.remove(&r);
// body.room_subscriptions.remove(&r);
//}
if let Some(conn_id) = &body.conn_id {
services.sync.update_snake_sync_known_rooms(
sender_user,
sender_device,
conn_id.clone(),
"subscriptions".to_owned(),
known_subscription_rooms,
globalsince,
);
}
}
#[allow(clippy::too_many_arguments)]
async fn handle_lists<'a>(
services: crate::State,
(sender_user, sender_device, globalsince, body): SyncInfo<'_>,
all_invited_rooms: &Vec<&'a RoomId>,
all_joined_rooms: &Vec<&'a RoomId>,
all_rooms: &Vec<&'a RoomId>,
todo_rooms: &'a mut TodoRooms,
known_rooms: &'a KnownRooms,
response: &'_ mut sync_events::v5::Response,
) -> KnownRooms {
for (list_id, list) in &body.lists {
let active_rooms = match list.filters.clone().and_then(|f| f.is_invite) {
| Some(true) => all_invited_rooms,
| Some(false) => all_joined_rooms,
| None => all_rooms,
};
let active_rooms = match list.filters.clone().map(|f| f.not_room_types) {
| Some(filter) if filter.is_empty() => active_rooms,
| Some(value) => &filter_rooms(&services, active_rooms, &value, true).await,
| None => active_rooms,
};
let mut new_known_rooms: BTreeSet<OwnedRoomId> = BTreeSet::new();
let ranges = list.ranges.clone();
for mut range in ranges {
range.0 = uint!(0);
range.1 = range
.1
.clamp(range.0, UInt::try_from(active_rooms.len()).unwrap_or(UInt::MAX));
let room_ids =
active_rooms[usize_from_ruma(range.0)..usize_from_ruma(range.1)].to_vec();
let new_rooms: BTreeSet<OwnedRoomId> =
room_ids.clone().into_iter().map(From::from).collect();
new_known_rooms.extend(new_rooms);
//new_known_rooms.extend(room_ids..cloned());
for room_id in room_ids {
let todo_room = todo_rooms.entry(room_id.to_owned()).or_insert((
BTreeSet::new(),
0_usize,
u64::MAX,
));
let limit: usize = usize_from_ruma(list.room_details.timeline_limit).min(100);
todo_room
.0
.extend(list.room_details.required_state.iter().cloned());
todo_room.1 = todo_room.1.max(limit);
// 0 means unknown because it got out of date
todo_room.2 = todo_room.2.min(
known_rooms
.get(list_id.as_str())
.and_then(|k| k.get(room_id))
.copied()
.unwrap_or(0),
);
}
}
response
.lists
.insert(list_id.clone(), sync_events::v5::response::List {
count: ruma_from_usize(active_rooms.len()),
});
if let Some(conn_id) = &body.conn_id {
services.sync.update_snake_sync_known_rooms(
sender_user,
sender_device,
conn_id.clone(),
list_id.clone(),
new_known_rooms,
globalsince,
);
}
}
BTreeMap::default()
}
async fn process_rooms(
services: crate::State,
sender_user: &UserId,
next_batch: u64,
all_invited_rooms: &[&RoomId],
todo_rooms: &TodoRooms,
response: &mut sync_events::v5::Response,
body: &sync_events::v5::Request,
) -> Result<BTreeMap<OwnedRoomId, sync_events::v5::response::Room>> {
let mut rooms = BTreeMap::new();
for (room_id, (required_state_request, timeline_limit, roomsince)) in todo_rooms {
let roomsincecount = PduCount::Normal(*roomsince);
let mut timestamp: Option<_> = None;
let mut invite_state = None;
let (timeline_pdus, limited);
let new_room_id: &RoomId = (*room_id).as_ref();
if all_invited_rooms.contains(&new_room_id) {
// TODO: figure out a timestamp we can use for remote invites
invite_state = services
.rooms
.state_cache
.invite_state(sender_user, room_id)
.await
.ok();
(timeline_pdus, limited) = (Vec::new(), true);
} else {
(timeline_pdus, limited) = match load_timeline(
&services,
sender_user,
room_id,
roomsincecount,
Some(PduCount::from(next_batch)),
*timeline_limit,
)
.await
{
| Ok(value) => value,
| Err(err) => {
warn!("Encountered missing timeline in {}, error {}", room_id, err);
continue;
},
};
}
if body.extensions.to_device.enabled == Some(true) {
response.extensions.account_data.rooms.insert(
room_id.to_owned(),
services
.account_data
.changes_since(Some(room_id), sender_user, *roomsince)
.ready_filter_map(|e| extract_variant!(e, AnyRawAccountDataEvent::Room))
.collect()
.await,
);
}
let last_privateread_update = services
.rooms
.read_receipt
.last_privateread_update(sender_user, room_id)
.await > *roomsince;
let private_read_event = if last_privateread_update {
services
.rooms
.read_receipt
.private_read_get(room_id, sender_user)
.await
.ok()
} else {
None
};
let mut receipts: Vec<Raw<AnySyncEphemeralRoomEvent>> = services
.rooms
.read_receipt
.readreceipts_since(room_id, *roomsince)
.filter_map(|(read_user, _ts, v)| async move {
services
.users
.user_is_ignored(read_user, sender_user)
.await
.or_some(v)
})
.collect()
.await;
if let Some(private_read_event) = private_read_event {
receipts.push(private_read_event);
}
let receipt_size = receipts.len();
if receipt_size > 0 {
response
.extensions
.receipts
.rooms
.insert(room_id.clone(), pack_receipts(Box::new(receipts.into_iter())));
}
if roomsince != &0
&& timeline_pdus.is_empty()
&& response
.extensions
.account_data
.rooms
.get(room_id)
.is_none_or(Vec::is_empty)
&& receipt_size == 0
{
continue;
}
let prev_batch = timeline_pdus
.first()
.map_or(Ok::<_, Error>(None), |(pdu_count, _)| {
Ok(Some(match pdu_count {
| PduCount::Backfilled(_) => {
error!("timeline in backfill state?!");
"0".to_owned()
},
| PduCount::Normal(c) => c.to_string(),
}))
})?
.or_else(|| {
if roomsince != &0 {
Some(roomsince.to_string())
} else {
None
}
});
let room_events: Vec<_> = timeline_pdus
.iter()
.stream()
.filter_map(|item| ignored_filter(&services, item.clone(), sender_user))
.map(|(_, pdu)| pdu.to_sync_room_event())
.collect()
.await;
for (_, pdu) in timeline_pdus {
let ts = pdu.origin_server_ts;
if DEFAULT_BUMP_TYPES.binary_search(&pdu.kind).is_ok()
&& timestamp.is_none_or(|time| time <= ts)
{
timestamp = Some(ts);
}
}
let required_state = required_state_request
.iter()
.stream()
.filter_map(|state| async move {
services
.rooms
.state_accessor
.room_state_get(room_id, &state.0, &state.1)
.await
.map(|s| s.to_sync_state_event())
.ok()
})
.collect()
.await;
// Heroes
let heroes: Vec<_> = services
.rooms
.state_cache
.room_members(room_id)
.ready_filter(|member| *member != sender_user)
.filter_map(|user_id| {
services
.rooms
.state_accessor
.get_member(room_id, user_id)
.map_ok(|memberevent| sync_events::v5::response::Hero {
user_id: user_id.into(),
name: memberevent.displayname,
avatar: memberevent.avatar_url,
})
.ok()
})
.take(5)
.collect()
.await;
let name = match heroes.len().cmp(&(1_usize)) {
| Ordering::Greater => {
let firsts = heroes[1..]
.iter()
.map(|h| h.name.clone().unwrap_or_else(|| h.user_id.to_string()))
.collect::<Vec<_>>()
.join(", ");
let last = heroes[0]
.name
.clone()
.unwrap_or_else(|| heroes[0].user_id.to_string());
Some(format!("{firsts} and {last}"))
},
| Ordering::Equal => Some(
heroes[0]
.name
.clone()
.unwrap_or_else(|| heroes[0].user_id.to_string()),
),
| Ordering::Less => None,
};
let heroes_avatar = if heroes.len() == 1 {
heroes[0].avatar.clone()
} else {
None
};
rooms.insert(room_id.clone(), sync_events::v5::response::Room {
name: services
.rooms
.state_accessor
.get_name(room_id)
.await
.ok()
.or(name),
avatar: if let Some(heroes_avatar) = heroes_avatar {
ruma::JsOption::Some(heroes_avatar)
} else {
match services.rooms.state_accessor.get_avatar(room_id).await {
| ruma::JsOption::Some(avatar) => ruma::JsOption::from_option(avatar.url),
| ruma::JsOption::Null => ruma::JsOption::Null,
| ruma::JsOption::Undefined => ruma::JsOption::Undefined,
}
},
initial: Some(roomsince == &0),
is_dm: None,
invite_state,
unread_notifications: UnreadNotificationsCount {
highlight_count: Some(
services
.rooms
.user
.highlight_count(sender_user, room_id)
.await
.try_into()
.expect("notification count can't go that high"),
),
notification_count: Some(
services
.rooms
.user
.notification_count(sender_user, room_id)
.await
.try_into()
.expect("notification count can't go that high"),
),
},
timeline: room_events,
required_state,
prev_batch,
limited,
joined_count: Some(
services
.rooms
.state_cache
.room_joined_count(room_id)
.await
.unwrap_or(0)
.try_into()
.unwrap_or_else(|_| uint!(0)),
),
invited_count: Some(
services
.rooms
.state_cache
.room_invited_count(room_id)
.await
.unwrap_or(0)
.try_into()
.unwrap_or_else(|_| uint!(0)),
),
num_live: None, // Count events in timeline greater than global sync counter
bump_stamp: timestamp,
heroes: Some(heroes),
});
}
Ok(rooms)
}
async fn collect_account_data(
services: crate::State,
(sender_user, _, globalsince, body): (&UserId, &DeviceId, u64, &sync_events::v5::Request),
) -> sync_events::v5::response::AccountData {
let mut account_data = sync_events::v5::response::AccountData {
global: Vec::new(),
rooms: BTreeMap::new(),
};
if !body.extensions.account_data.enabled.unwrap_or(false) {
return sync_events::v5::response::AccountData::default();
}
account_data.global = services
.account_data
.changes_since(None, sender_user, globalsince)
.ready_filter_map(|e| extract_variant!(e, AnyRawAccountDataEvent::Global))
.collect()
.await;
if let Some(rooms) = &body.extensions.account_data.rooms {
for room in rooms {
account_data.rooms.insert(
room.clone(),
services
.account_data
.changes_since(Some(room), sender_user, globalsince)
.ready_filter_map(|e| extract_variant!(e, AnyRawAccountDataEvent::Room))
.collect()
.await,
);
}
}
account_data
}
async fn collect_e2ee<'a>(
services: crate::State,
(sender_user, sender_device, globalsince, body): (
&UserId,
&DeviceId,
u64,
&sync_events::v5::Request,
),
all_joined_rooms: &'a Vec<&'a RoomId>,
) -> Result<sync_events::v5::response::E2EE> {
if !body.extensions.e2ee.enabled.unwrap_or(false) {
return Ok(sync_events::v5::response::E2EE::default());
}
let mut left_encrypted_users = HashSet::new(); // Users that have left any encrypted rooms the sender was in
let mut device_list_changes = HashSet::new();
let mut device_list_left = HashSet::new();
// Look for device list updates of this account
device_list_changes.extend(
services
.users
.keys_changed(sender_user, globalsince, None)
.map(ToOwned::to_owned)
.collect::<Vec<_>>()
.await,
);
for room_id in all_joined_rooms {
let Ok(current_shortstatehash) =
services.rooms.state.get_room_shortstatehash(room_id).await
else {
error!("Room {room_id} has no state");
continue;
};
let since_shortstatehash = services
.rooms
.user
.get_token_shortstatehash(room_id, globalsince)
.await
.ok();
let encrypted_room = services
.rooms
.state_accessor
.state_get(current_shortstatehash, &StateEventType::RoomEncryption, "")
.await
.is_ok();
if let Some(since_shortstatehash) = since_shortstatehash {
// Skip if there are only timeline changes
if since_shortstatehash == current_shortstatehash {
continue;
}
let since_encryption = services
.rooms
.state_accessor
.state_get(since_shortstatehash, &StateEventType::RoomEncryption, "")
.await;
let since_sender_member: Option<RoomMemberEventContent> = services
.rooms
.state_accessor
.state_get_content(
since_shortstatehash,
&StateEventType::RoomMember,
sender_user.as_str(),
)
.ok()
.await;
let joined_since_last_sync = since_sender_member
.as_ref()
.is_none_or(|member| member.membership != MembershipState::Join);
let new_encrypted_room = encrypted_room && since_encryption.is_err();
if encrypted_room {
let current_state_ids: HashMap<_, OwnedEventId> = services
.rooms
.state_accessor
.state_full_ids(current_shortstatehash)
.await?;
let since_state_ids = services
.rooms
.state_accessor
.state_full_ids(since_shortstatehash)
.await?;
for (key, id) in current_state_ids {
if since_state_ids.get(&key) != Some(&id) {
let Ok(pdu) = services.rooms.timeline.get_pdu(&id).await else {
error!("Pdu in state not found: {id}");
continue;
};
if pdu.kind == TimelineEventType::RoomMember {
if let Some(state_key) = &pdu.state_key {
let user_id =
OwnedUserId::parse(state_key.clone()).map_err(|_| {
Error::bad_database("Invalid UserId in member PDU.")
})?;
if user_id == *sender_user {
continue;
}
let content: RoomMemberEventContent = pdu.get_content()?;
match content.membership {
| MembershipState::Join => {
// A new user joined an encrypted room
if !share_encrypted_room(
&services,
sender_user,
&user_id,
Some(room_id),
)
.await
{
device_list_changes.insert(user_id);
}
},
| MembershipState::Leave => {
// Write down users that have left encrypted rooms we
// are in
left_encrypted_users.insert(user_id);
},
| _ => {},
}
}
}
}
}
if joined_since_last_sync || new_encrypted_room {
// If the user is in a new encrypted room, give them all joined users
device_list_changes.extend(
services
.rooms
.state_cache
.room_members(room_id)
// Don't send key updates from the sender to the sender
.ready_filter(|user_id| sender_user != *user_id)
// Only send keys if the sender doesn't share an encrypted room with the target
// already
.filter_map(|user_id| {
share_encrypted_room(&services, sender_user, user_id, Some(room_id))
.map(|res| res.or_some(user_id.to_owned()))
})
.collect::<Vec<_>>()
.await,
);
}
}
}
// Look for device list updates in this room
device_list_changes.extend(
services
.users
.room_keys_changed(room_id, globalsince, None)
.map(|(user_id, _)| user_id)
.map(ToOwned::to_owned)
.collect::<Vec<_>>()
.await,
);
}
for user_id in left_encrypted_users {
let dont_share_encrypted_room =
!share_encrypted_room(&services, sender_user, &user_id, None).await;
// If the user doesn't share an encrypted room with the target anymore, we need
// to tell them
if dont_share_encrypted_room {
device_list_left.insert(user_id);
}
}
Ok(sync_events::v5::response::E2EE {
device_lists: DeviceLists {
changed: device_list_changes.into_iter().collect(),
left: device_list_left.into_iter().collect(),
},
device_one_time_keys_count: services
.users
.count_one_time_keys(sender_user, sender_device)
.await,
device_unused_fallback_key_types: None,
})
}
async fn collect_to_device(
services: crate::State,
(sender_user, sender_device, globalsince, body): SyncInfo<'_>,
next_batch: u64,
) -> Option<sync_events::v5::response::ToDevice> {
if !body.extensions.to_device.enabled.unwrap_or(false) {
return None;
}
services
.users
.remove_to_device_events(sender_user, sender_device, globalsince)
.await;
Some(sync_events::v5::response::ToDevice {
next_batch: next_batch.to_string(),
events: services
.users
.get_to_device_events(sender_user, sender_device)
.collect()
.await,
})
}
async fn collect_receipts(_services: crate::State) -> sync_events::v5::response::Receipts {
sync_events::v5::response::Receipts { rooms: BTreeMap::new() }
// TODO: get explicitly requested read receipts
}
+1
View File
@@ -52,6 +52,7 @@ pub(crate) async fn get_supported_versions_route(
("org.matrix.msc4180".to_owned(), true), /* stable flag for 3916 (https://github.com/matrix-org/matrix-spec-proposals/pull/4180) */
("uk.tcpip.msc4133".to_owned(), true), /* Extending User Profile API with Key:Value Pairs (https://github.com/matrix-org/matrix-spec-proposals/pull/4133) */
("us.cloke.msc4175".to_owned(), true), /* Profile field for user time zone (https://github.com/matrix-org/matrix-spec-proposals/pull/4175) */
("org.matrix.simplified_msc3575".to_owned(), true), /* Simplified Sliding sync (https://github.com/matrix-org/matrix-spec-proposals/pull/4186) */
]),
};
+1
View File
@@ -144,6 +144,7 @@ pub fn build(router: Router<State>, server: &Server) -> Router<State> {
)
.ruma_route(&client::sync_events_route)
.ruma_route(&client::sync_events_v4_route)
.ruma_route(&client::sync_events_v5_route)
.ruma_route(&client::get_context_route)
.ruma_route(&client::get_message_events_route)
.ruma_route(&client::search_events_route)
+17 -13
View File
@@ -2,10 +2,10 @@ use std::cmp;
use axum::extract::State;
use conduwuit::{
utils::{IterStream, ReadyExt},
utils::{stream::TryTools, IterStream, ReadyExt},
PduCount, Result,
};
use futures::{FutureExt, StreamExt};
use futures::{FutureExt, StreamExt, TryStreamExt};
use ruma::{api::federation::backfill::get_backfill, uint, MilliSecondsSinceUnixEpoch};
use super::AccessCheck;
@@ -57,26 +57,30 @@ pub(crate) async fn get_backfill_route(
.rooms
.timeline
.pdus_rev(None, &body.room_id, Some(from.saturating_add(1)))
.await?
.take(limit)
.filter_map(|(_, pdu)| async move {
services
.try_take(limit)
.try_filter_map(|(_, pdu)| async move {
Ok(services
.rooms
.state_accessor
.server_can_see_event(body.origin(), &pdu.room_id, &pdu.event_id)
.await
.then_some(pdu)
.then_some(pdu))
})
.filter_map(|pdu| async move {
services
.try_filter_map(|pdu| async move {
Ok(services
.rooms
.timeline
.get_pdu_json(&pdu.event_id)
.await
.ok()
.ok())
})
.then(|pdu| services.sending.convert_to_outgoing_federation_event(pdu))
.collect()
.await,
.and_then(|pdu| {
services
.sending
.convert_to_outgoing_federation_event(pdu)
.map(Ok)
})
.try_collect()
.await?,
})
}
+1
View File
@@ -135,6 +135,7 @@ async fn handle_pdus(
.rooms
.event_handler
.handle_incoming_pdu(origin, &room_id, &event_id, value, true)
.boxed()
.await
.map(|_| ());
+4 -1
View File
@@ -50,6 +50,9 @@ zstd_compression = [
]
perf_measurements = []
sentry_telemetry = []
conduwuit_mods = [
"dep:libloading"
]
[dependencies]
argon2.workspace = true
@@ -71,11 +74,11 @@ figment.workspace = true
futures.workspace = true
http-body-util.workspace = true
http.workspace = true
image.workspace = true
ipaddress.workspace = true
itertools.workspace = true
libc.workspace = true
libloading.workspace = true
libloading.optional = true
log.workspace = true
num-traits.workspace = true
rand.workspace = true
+36 -17
View File
@@ -147,22 +147,6 @@ pub struct Config {
#[serde(default = "default_database_backups_to_keep")]
pub database_backups_to_keep: i16,
/// Set this to any float value in megabytes for conduwuit to tell the
/// database engine that this much memory is available for database-related
/// caches.
///
/// May be useful if you have significant memory to spare to increase
/// performance.
///
/// Similar to the individual LRU caches, this is scaled up with your CPU
/// core count.
///
/// This defaults to 128.0 + (64.0 * CPU core count).
///
/// default: varies by system
#[serde(default = "default_db_cache_capacity_mb")]
pub db_cache_capacity_mb: f64,
/// Text which will be added to the end of the user's displayname upon
/// registration with a space before the text. In Conduit, this was the
/// lightning bolt emoji.
@@ -205,6 +189,38 @@ pub struct Config {
)]
pub cache_capacity_modifier: f64,
/// Set this to any float value in megabytes for conduwuit to tell the
/// database engine that this much memory is available for database read
/// caches.
///
/// May be useful if you have significant memory to spare to increase
/// performance.
///
/// Similar to the individual LRU caches, this is scaled up with your CPU
/// core count.
///
/// This defaults to 128.0 + (64.0 * CPU core count).
///
/// default: varies by system
#[serde(default = "default_db_cache_capacity_mb")]
pub db_cache_capacity_mb: f64,
/// Set this to any float value in megabytes for conduwuit to tell the
/// database engine that this much memory is available for database write
/// caches.
///
/// May be useful if you have significant memory to spare to increase
/// performance.
///
/// Similar to the individual LRU caches, this is scaled up with your CPU
/// core count.
///
/// This defaults to 48.0 + (4.0 * CPU core count).
///
/// default: varies by system
#[serde(default = "default_db_write_buffer_capacity_mb")]
pub db_write_buffer_capacity_mb: f64,
/// default: varies by system
#[serde(default = "default_pdu_cache_capacity")]
pub pdu_cache_capacity: u32,
@@ -621,6 +637,7 @@ pub struct Config {
#[serde(default = "default_tracing_flame_output_path")]
pub tracing_flame_output_path: String,
#[cfg(not(doctest))]
/// Examples:
///
/// - No proxy (default):
@@ -1646,7 +1663,7 @@ pub struct Config {
/// responsiveness for many users at the cost of throughput for each.
///
/// Setting this value to 0.0 causes the stream width to be fixed at the
/// value of stream_width_default. The default is 1.0 to match the
/// value of stream_width_default. The default scale is 1.0 to match the
/// capabilities detected for the system.
///
/// default: 1.0
@@ -2232,6 +2249,8 @@ fn default_unix_socket_perms() -> u32 { 660 }
fn default_database_backups_to_keep() -> i16 { 1 }
fn default_db_write_buffer_capacity_mb() -> f64 { 48.0 + parallelism_scaled_f64(4.0) }
fn default_db_cache_capacity_mb() -> f64 { 128.0 + parallelism_scaled_f64(64.0) }
fn default_pdu_cache_capacity() -> u32 { parallelism_scaled_u32(10_000).saturating_add(100_000) }
-2
View File
@@ -48,8 +48,6 @@ pub enum Error {
Http(#[from] http::Error),
#[error(transparent)]
HttpHeader(#[from] http::header::InvalidHeaderValue),
#[error("Image error: {0}")]
Image(#[from] image::error::ImageError),
#[error("Join error: {0}")]
JoinError(#[from] tokio::task::JoinError),
#[error(transparent)]
+1 -1
View File
@@ -25,7 +25,7 @@ pub use crate as conduwuit_core;
rustc_flags_capture! {}
#[cfg(not(conduwuit_mods))]
#[cfg(any(not(conduwuit_mods), not(feature = "conduwuit_mods")))]
pub mod mods {
#[macro_export]
macro_rules! mod_ctor {
+1 -1
View File
@@ -1,4 +1,4 @@
#![cfg(conduwuit_mods)]
#![cfg(all(conduwuit_mods, feature = "conduwuit_mods"))]
pub(crate) use libloading::os::unix::{Library, Symbol};
+1
View File
@@ -9,6 +9,7 @@ mod raw_id;
mod redact;
mod relation;
mod strip;
#[cfg(test)]
mod tests;
mod unsigned;
-2
View File
@@ -1,5 +1,3 @@
#![cfg(test)]
use super::Count;
#[test]
+1 -1
View File
@@ -59,7 +59,7 @@ impl Server {
}
pub fn reload(&self) -> Result<()> {
if cfg!(not(conduwuit_mods)) {
if cfg!(any(not(conduwuit_mods), not(feature = "conduwuit_mods"))) {
return Err!("Reloading not enabled");
}
+34
View File
@@ -0,0 +1,34 @@
//! Extended external extensions to futures::FutureExt
use std::marker::Unpin;
use futures::{future, future::Select, Future};
/// This interface is not necessarily complete; feel free to add as-needed.
pub trait ExtExt<T>
where
Self: Future<Output = T> + Send,
{
fn until<A, B, F>(self, f: F) -> Select<A, B>
where
Self: Sized,
F: FnOnce() -> B,
A: Future<Output = T> + From<Self> + Send + Unpin,
B: Future<Output = ()> + Send + Unpin;
}
impl<T, Fut> ExtExt<T> for Fut
where
Fut: Future<Output = T> + Send,
{
#[inline]
fn until<A, B, F>(self, f: F) -> Select<A, B>
where
Self: Sized,
F: FnOnce() -> B,
A: Future<Output = T> + From<Self> + Send + Unpin,
B: Future<Output = ()> + Send + Unpin,
{
future::select(self.into(), f())
}
}
+2
View File
@@ -1,5 +1,7 @@
mod ext_ext;
mod option_ext;
mod try_ext_ext;
pub use ext_ext::ExtExt;
pub use option_ext::OptionExt;
pub use try_ext_ext::TryExtExt;
+22 -1
View File
@@ -4,8 +4,11 @@
// caller only ever caring about result status while discarding all contents.
#![allow(clippy::wrong_self_convention)]
use std::marker::Unpin;
use futures::{
future::{MapOkOrElse, UnwrapOrElse},
future,
future::{MapOkOrElse, TrySelect, UnwrapOrElse},
TryFuture, TryFutureExt,
};
@@ -46,6 +49,13 @@ where
where
Self: Sized;
fn try_until<A, B, F>(self, f: F) -> TrySelect<A, B>
where
Self: Sized,
F: FnOnce() -> B,
A: TryFuture<Ok = Self::Ok> + From<Self> + Send + Unpin,
B: TryFuture<Ok = (), Error = Self::Error> + Send + Unpin;
fn unwrap_or(
self,
default: Self::Ok,
@@ -110,6 +120,17 @@ where
self.map_ok_or(None, Some)
}
#[inline]
fn try_until<A, B, F>(self, f: F) -> TrySelect<A, B>
where
Self: Sized,
F: FnOnce() -> B,
A: TryFuture<Ok = Self::Ok> + From<Self> + Send + Unpin,
B: TryFuture<Ok = (), Error = Self::Error> + Send + Unpin,
{
future::try_select(self.into(), f())
}
#[inline]
fn unwrap_or(
self,
+1
View File
@@ -16,6 +16,7 @@ pub mod set;
pub mod stream;
pub mod string;
pub mod sys;
#[cfg(test)]
mod tests;
pub mod time;
+7 -6
View File
@@ -19,8 +19,8 @@ type Value<Val> = Arc<tokio::sync::Mutex<Val>>;
impl<Key, Val> MutexMap<Key, Val>
where
Key: Send + Hash + Eq + Clone,
Val: Send + Default,
Key: Clone + Eq + Hash + Send,
Val: Default + Send,
{
#[must_use]
pub fn new() -> Self {
@@ -29,10 +29,10 @@ where
}
}
#[tracing::instrument(skip(self), level = "debug")]
#[tracing::instrument(level = "trace", skip(self))]
pub async fn lock<K>(&self, k: &K) -> Guard<Key, Val>
where
K: ?Sized + Send + Sync + Debug,
K: Debug + Send + ?Sized + Sync,
Key: for<'a> From<&'a K>,
{
let val = self
@@ -61,13 +61,14 @@ where
impl<Key, Val> Default for MutexMap<Key, Val>
where
Key: Send + Hash + Eq + Clone,
Val: Send + Default,
Key: Clone + Eq + Hash + Send,
Val: Default + Send,
{
fn default() -> Self { Self::new() }
}
impl<Key, Val> Drop for Guard<Key, Val> {
#[tracing::instrument(name = "unlock", level = "trace", skip_all)]
fn drop(&mut self) {
if Arc::strong_count(Omg::mutex(&self.val)) <= 2 {
self.map.lock().expect("locked").retain(|_, val| {
+2
View File
@@ -8,6 +8,7 @@ mod ready;
mod tools;
mod try_broadband;
mod try_ready;
mod try_tools;
mod wideband;
pub use band::{
@@ -23,4 +24,5 @@ pub use ready::ReadyExt;
pub use tools::Tools;
pub use try_broadband::TryBroadbandExt;
pub use try_ready::TryReadyExt;
pub use try_tools::TryTools;
pub use wideband::WidebandExt;
+19 -1
View File
@@ -3,7 +3,7 @@
use futures::{
future::{ready, Ready},
stream::{AndThen, TryFilterMap, TryFold, TryForEach, TryStream, TryStreamExt},
stream::{AndThen, TryFilterMap, TryFold, TryForEach, TryStream, TryStreamExt, TryTakeWhile},
};
use crate::Result;
@@ -56,6 +56,13 @@ where
) -> TryForEach<Self, Ready<Result<(), E>>, impl FnMut(S::Ok) -> Ready<Result<(), E>>>
where
F: FnMut(S::Ok) -> Result<(), E>;
fn ready_try_take_while<F>(
self,
f: F,
) -> TryTakeWhile<Self, Ready<Result<bool, E>>, impl FnMut(&S::Ok) -> Ready<Result<bool, E>>>
where
F: Fn(&S::Ok) -> Result<bool, E>;
}
impl<T, E, S> TryReadyExt<T, E, S> for S
@@ -122,4 +129,15 @@ where
{
self.try_for_each(move |t| ready(f(t)))
}
#[inline]
fn ready_try_take_while<F>(
self,
f: F,
) -> TryTakeWhile<Self, Ready<Result<bool, E>>, impl FnMut(&S::Ok) -> Ready<Result<bool, E>>>
where
F: Fn(&S::Ok) -> Result<bool, E>,
{
self.try_take_while(move |t| ready(f(t)))
}
}
+44
View File
@@ -0,0 +1,44 @@
//! TryStreamTools for futures::TryStream
#![allow(clippy::type_complexity)]
use futures::{future, future::Ready, stream::TryTakeWhile, TryStream, TryStreamExt};
use crate::Result;
/// TryStreamTools
pub trait TryTools<T, E, S>
where
S: TryStream<Ok = T, Error = E, Item = Result<T, E>> + Send + ?Sized,
Self: TryStream + Send + Sized,
{
fn try_take(
self,
n: usize,
) -> TryTakeWhile<
Self,
Ready<Result<bool, S::Error>>,
impl FnMut(&S::Ok) -> Ready<Result<bool, S::Error>>,
>;
}
impl<T, E, S> TryTools<T, E, S> for S
where
S: TryStream<Ok = T, Error = E, Item = Result<T, E>> + Send + ?Sized,
Self: TryStream + Send + Sized,
{
#[inline]
fn try_take(
self,
mut n: usize,
) -> TryTakeWhile<
Self,
Ready<Result<bool, S::Error>>,
impl FnMut(&S::Ok) -> Ready<Result<bool, S::Error>>,
> {
self.try_take_while(move |_| {
let res = future::ok(n > 0);
n = n.saturating_sub(1);
res
})
}
}
-1
View File
@@ -1,4 +1,3 @@
#![cfg(test)]
#![allow(clippy::disallowed_methods)]
use crate::utils;
-55
View File
@@ -1,55 +0,0 @@
use std::{ops::Index, sync::Arc};
use conduwuit::{err, Result, Server};
use crate::{
maps,
maps::{Maps, MapsKey, MapsVal},
Engine, Map,
};
pub struct Database {
pub db: Arc<Engine>,
maps: Maps,
}
impl Database {
/// Load an existing database or create a new one.
pub async fn open(server: &Arc<Server>) -> Result<Arc<Self>> {
let db = Engine::open(server).await?;
Ok(Arc::new(Self { db: db.clone(), maps: maps::open(&db)? }))
}
#[inline]
pub fn get(&self, name: &str) -> Result<&Arc<Map>> {
self.maps
.get(name)
.ok_or_else(|| err!(Request(NotFound("column not found"))))
}
#[inline]
pub fn iter(&self) -> impl Iterator<Item = (&MapsKey, &MapsVal)> + Send + '_ {
self.maps.iter()
}
#[inline]
pub fn keys(&self) -> impl Iterator<Item = &MapsKey> + Send + '_ { self.maps.keys() }
#[inline]
#[must_use]
pub fn is_read_only(&self) -> bool { self.db.is_read_only() }
#[inline]
#[must_use]
pub fn is_secondary(&self) -> bool { self.db.is_secondary() }
}
impl Index<&str> for Database {
type Output = Arc<Map>;
fn index(&self, name: &str) -> &Self::Output {
self.maps
.get(name)
.expect("column in database does not exist")
}
}
+55
View File
@@ -9,6 +9,15 @@ use serde::{
use crate::util::unhandled;
/// Deserialize into T from buffer.
#[cfg_attr(
unabridged,
tracing::instrument(
name = "deserialize",
level = "trace",
skip_all,
fields(len = %buf.len()),
)
)]
pub(crate) fn from_slice<'a, T>(buf: &'a [u8]) -> Result<T>
where
T: Deserialize<'a>,
@@ -132,6 +141,17 @@ impl<'de> Deserializer<'de> {
/// Increment the position pointer.
#[inline]
#[cfg_attr(
unabridged,
tracing::instrument(
level = "trace",
skip(self),
fields(
len = self.buf.len(),
rem = self.remaining().unwrap_or_default().saturating_sub(n),
),
)
)]
fn inc_pos(&mut self, n: usize) {
self.pos = self.pos.saturating_add(n);
debug_assert!(self.pos <= self.buf.len(), "pos out of range");
@@ -149,6 +169,7 @@ impl<'de> Deserializer<'de> {
impl<'a, 'de: 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
type Error = Error;
#[cfg_attr(unabridged, tracing::instrument(level = "trace", skip_all))]
fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
@@ -157,6 +178,7 @@ impl<'a, 'de: 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
visitor.visit_seq(self)
}
#[cfg_attr(unabridged, tracing::instrument(level = "trace", skip(self, visitor)))]
fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
@@ -165,6 +187,7 @@ impl<'a, 'de: 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
visitor.visit_seq(self)
}
#[cfg_attr(unabridged, tracing::instrument(level = "trace", skip(self, visitor)))]
fn deserialize_tuple_struct<V>(
self,
_name: &'static str,
@@ -178,6 +201,7 @@ impl<'a, 'de: 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
visitor.visit_seq(self)
}
#[cfg_attr(unabridged, tracing::instrument(level = "trace", skip_all))]
fn deserialize_map<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
@@ -187,6 +211,7 @@ impl<'a, 'de: 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
d.deserialize_map(visitor).map_err(Into::into)
}
#[cfg_attr(unabridged, tracing::instrument(level = "trace", skip(self, visitor)))]
fn deserialize_struct<V>(
self,
name: &'static str,
@@ -202,6 +227,7 @@ impl<'a, 'de: 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
.map_err(Into::into)
}
#[cfg_attr(unabridged, tracing::instrument(level = "trace", skip(self, visitor)))]
fn deserialize_unit_struct<V>(self, name: &'static str, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
@@ -215,6 +241,7 @@ impl<'a, 'de: 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
visitor.visit_unit()
}
#[cfg_attr(unabridged, tracing::instrument(level = "trace", skip(self, visitor)))]
fn deserialize_newtype_struct<V>(self, name: &'static str, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
@@ -225,6 +252,7 @@ impl<'a, 'de: 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
}
}
#[cfg_attr(unabridged, tracing::instrument(level = "trace", skip(self, _visitor)))]
fn deserialize_enum<V>(
self,
_name: &'static str,
@@ -237,26 +265,32 @@ impl<'a, 'de: 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
unhandled!("deserialize Enum not implemented")
}
#[cfg_attr(unabridged, tracing::instrument(level = "trace", skip_all))]
fn deserialize_option<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
unhandled!("deserialize Option not implemented")
}
#[cfg_attr(unabridged, tracing::instrument(level = "trace", skip_all))]
fn deserialize_bool<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
unhandled!("deserialize bool not implemented")
}
#[cfg_attr(unabridged, tracing::instrument(level = "trace", skip_all))]
fn deserialize_i8<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
unhandled!("deserialize i8 not implemented")
}
#[cfg_attr(unabridged, tracing::instrument(level = "trace", skip_all))]
fn deserialize_i16<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
unhandled!("deserialize i16 not implemented")
}
#[cfg_attr(unabridged, tracing::instrument(level = "trace", skip_all))]
fn deserialize_i32<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
unhandled!("deserialize i32 not implemented")
}
#[cfg_attr(unabridged, tracing::instrument(level = "trace", skip_all))]
fn deserialize_i64<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
const BYTES: usize = size_of::<i64>();
@@ -268,6 +302,7 @@ impl<'a, 'de: 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
visitor.visit_i64(i64::from_be_bytes(bytes))
}
#[cfg_attr(unabridged, tracing::instrument(level = "trace", skip_all))]
fn deserialize_u8<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
unhandled!(
"deserialize u8 not implemented; try dereferencing the Handle for [u8] access \
@@ -275,14 +310,17 @@ impl<'a, 'de: 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
)
}
#[cfg_attr(unabridged, tracing::instrument(level = "trace", skip_all))]
fn deserialize_u16<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
unhandled!("deserialize u16 not implemented")
}
#[cfg_attr(unabridged, tracing::instrument(level = "trace", skip_all))]
fn deserialize_u32<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
unhandled!("deserialize u32 not implemented")
}
#[cfg_attr(unabridged, tracing::instrument(level = "trace", skip_all))]
fn deserialize_u64<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
const BYTES: usize = size_of::<u64>();
@@ -294,53 +332,67 @@ impl<'a, 'de: 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
visitor.visit_u64(u64::from_be_bytes(bytes))
}
#[cfg_attr(unabridged, tracing::instrument(level = "trace", skip_all))]
fn deserialize_f32<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
unhandled!("deserialize f32 not implemented")
}
#[cfg_attr(unabridged, tracing::instrument(level = "trace", skip_all))]
fn deserialize_f64<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
unhandled!("deserialize f64 not implemented")
}
#[cfg_attr(unabridged, tracing::instrument(level = "trace", skip_all))]
fn deserialize_char<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
unhandled!("deserialize char not implemented")
}
#[cfg_attr(unabridged, tracing::instrument(level = "trace", skip_all))]
fn deserialize_str<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
let input = self.record_next();
let out = deserialize_str(input)?;
visitor.visit_borrowed_str(out)
}
#[cfg_attr(unabridged, tracing::instrument(level = "trace", skip_all))]
fn deserialize_string<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
let input = self.record_next();
let out = string::string_from_bytes(input)?;
visitor.visit_string(out)
}
#[cfg_attr(unabridged, tracing::instrument(level = "trace", skip_all))]
fn deserialize_bytes<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
let input = self.record_trail();
visitor.visit_borrowed_bytes(input)
}
#[cfg_attr(unabridged, tracing::instrument(level = "trace", skip_all))]
fn deserialize_byte_buf<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
unhandled!("deserialize Byte Buf not implemented")
}
#[cfg_attr(unabridged, tracing::instrument(level = "trace", skip_all))]
fn deserialize_unit<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
unhandled!("deserialize Unit not implemented")
}
// this only used for $serde_json::private::RawValue at this time; see MapAccess
#[cfg_attr(unabridged, tracing::instrument(level = "trace", skip_all))]
fn deserialize_identifier<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
let input = "$serde_json::private::RawValue";
visitor.visit_borrowed_str(input)
}
#[cfg_attr(unabridged, tracing::instrument(level = "trace", skip_all))]
fn deserialize_ignored_any<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
unhandled!("deserialize Ignored Any not implemented")
}
#[cfg_attr(
unabridged,
tracing::instrument(level = "trace", skip_all, fields(?self.buf))
)]
fn deserialize_any<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
debug_assert_eq!(
conduwuit::debug::type_name::<V>(),
@@ -363,6 +415,7 @@ impl<'a, 'de: 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
impl<'a, 'de: 'a> de::SeqAccess<'de> for &'a mut Deserializer<'de> {
type Error = Error;
#[cfg_attr(unabridged, tracing::instrument(level = "trace", skip(self, seed)))]
fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
where
T: DeserializeSeed<'de>,
@@ -381,6 +434,7 @@ impl<'a, 'de: 'a> de::SeqAccess<'de> for &'a mut Deserializer<'de> {
impl<'a, 'de: 'a> de::MapAccess<'de> for &'a mut Deserializer<'de> {
type Error = Error;
#[cfg_attr(unabridged, tracing::instrument(level = "trace", skip(self, seed)))]
fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>>
where
K: DeserializeSeed<'de>,
@@ -388,6 +442,7 @@ impl<'a, 'de: 'a> de::MapAccess<'de> for &'a mut Deserializer<'de> {
seed.deserialize(&mut **self).map(Some)
}
#[cfg_attr(unabridged, tracing::instrument(level = "trace", skip(self, seed)))]
fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value>
where
V: DeserializeSeed<'de>,
+34 -296
View File
@@ -1,284 +1,64 @@
mod backup;
mod cf_opts;
pub(crate) mod context;
mod db_opts;
pub(crate) mod descriptor;
mod files;
mod logger;
mod memory_usage;
mod open;
mod repair;
use std::{
collections::{BTreeSet, HashMap},
ffi::CStr,
fmt::Write,
path::PathBuf,
sync::{atomic::AtomicU32, Arc, Mutex, RwLock},
sync::{
atomic::{AtomicU32, Ordering},
Arc,
},
};
use conduwuit::{
debug, error, info, utils::time::rfc2822_from_seconds, warn, Err, Result, Server,
};
use rocksdb::{
backup::{BackupEngine, BackupEngineOptions},
perf::get_memory_usage_stats,
AsColumnFamilyRef, BoundColumnFamily, Cache, ColumnFamilyDescriptor, DBCommon,
DBWithThreadMode, Env, LogLevel, MultiThreaded, Options,
};
use conduwuit::{debug, info, warn, Err, Result};
use rocksdb::{AsColumnFamilyRef, BoundColumnFamily, DBCommon, DBWithThreadMode, MultiThreaded};
use crate::{
opts::{cf_options, db_options},
or_else,
pool::Pool,
result,
util::map_err,
};
use crate::{pool::Pool, result, Context};
pub struct Engine {
pub(crate) server: Arc<Server>,
row_cache: Cache,
col_cache: RwLock<HashMap<String, Cache>>,
opts: Options,
env: Env,
cfs: Mutex<BTreeSet<String>>,
pub(crate) db: Db,
corks: AtomicU32,
pub(super) read_only: bool,
pub(super) secondary: bool,
corks: AtomicU32,
pub(crate) db: Db,
pub(crate) pool: Arc<Pool>,
pub(crate) ctx: Arc<Context>,
}
pub(crate) type Db = DBWithThreadMode<MultiThreaded>;
impl Engine {
#[tracing::instrument(skip_all)]
pub(crate) async fn open(server: &Arc<Server>) -> Result<Arc<Self>> {
let config = &server.config;
let cache_capacity_bytes = config.db_cache_capacity_mb * 1024.0 * 1024.0;
#[allow(clippy::as_conversions, clippy::cast_sign_loss, clippy::cast_possible_truncation)]
let row_cache_capacity_bytes = (cache_capacity_bytes * 0.50) as usize;
#[allow(clippy::as_conversions, clippy::cast_sign_loss, clippy::cast_possible_truncation)]
let col_cache_capacity_bytes = (cache_capacity_bytes * 0.50) as usize;
let mut col_cache = HashMap::new();
col_cache.insert("primary".to_owned(), Cache::new_lru_cache(col_cache_capacity_bytes));
let mut db_env = Env::new().or_else(or_else)?;
let row_cache = Cache::new_lru_cache(row_cache_capacity_bytes);
let db_opts = db_options(
config,
&mut db_env,
&row_cache,
col_cache.get("primary").expect("primary cache exists"),
)?;
let load_time = std::time::Instant::now();
if config.rocksdb_repair {
repair(&db_opts, &config.database_path)?;
}
debug!("Listing column families in database");
let cfs = Db::list_cf(&db_opts, &config.database_path)
.unwrap_or_default()
.into_iter()
.collect::<BTreeSet<_>>();
debug!("Opening {} column family descriptors in database", cfs.len());
let cfopts = cfs
.iter()
.map(|name| cf_options(config, name, db_opts.clone(), &mut col_cache))
.collect::<Result<Vec<_>>>()?;
let cfds = cfs
.iter()
.zip(cfopts.into_iter())
.map(|(name, opts)| ColumnFamilyDescriptor::new(name, opts))
.collect::<Vec<_>>();
let path = &config.database_path;
debug!("Opening database...");
let res = if config.rocksdb_read_only {
Db::open_cf_descriptors_read_only(&db_opts, path, cfds, false)
} else if config.rocksdb_secondary {
Db::open_cf_descriptors_as_secondary(&db_opts, path, path, cfds)
} else {
Db::open_cf_descriptors(&db_opts, path, cfds)
};
let db = res.or_else(or_else)?;
info!(
columns = cfs.len(),
sequence = %db.latest_sequence_number(),
time = ?load_time.elapsed(),
"Opened database."
);
Ok(Arc::new(Self {
server: server.clone(),
row_cache,
col_cache: RwLock::new(col_cache),
opts: db_opts,
env: db_env,
cfs: Mutex::new(cfs),
db,
corks: AtomicU32::new(0),
read_only: config.rocksdb_read_only,
secondary: config.rocksdb_secondary,
pool: Pool::new(server).await?,
}))
}
#[tracing::instrument(skip(self), level = "trace")]
pub(crate) fn open_cf(&self, name: &str) -> Result<Arc<BoundColumnFamily<'_>>> {
let mut cfs = self.cfs.lock().expect("locked");
if !cfs.contains(name) {
debug!("Creating new column family in database: {name}");
let mut col_cache = self.col_cache.write().expect("locked");
let opts = cf_options(&self.server.config, name, self.opts.clone(), &mut col_cache)?;
if let Err(e) = self.db.create_cf(name, &opts) {
error!(?name, "Failed to create new column family: {e}");
return or_else(e);
}
cfs.insert(name.to_owned());
}
Ok(self.cf(name))
}
pub(crate) fn cf(&self, name: &str) -> Arc<BoundColumnFamily<'_>> {
self.db
.cf_handle(name)
.expect("column was created and exists")
.expect("column must be described prior to database open")
}
pub async fn shutdown_pool(&self) { self.pool.shutdown().await; }
pub fn flush(&self) -> Result<()> { result(DBCommon::flush_wal(&self.db, false)) }
pub fn sync(&self) -> Result<()> { result(DBCommon::flush_wal(&self.db, true)) }
#[inline]
pub fn corked(&self) -> bool { self.corks.load(std::sync::atomic::Ordering::Relaxed) > 0 }
pub(crate) fn cork(&self) { self.corks.fetch_add(1, Ordering::Relaxed); }
pub(crate) fn cork(&self) {
self.corks
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
#[inline]
pub(crate) fn uncork(&self) { self.corks.fetch_sub(1, Ordering::Relaxed); }
pub(crate) fn uncork(&self) {
self.corks
.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
}
pub fn memory_usage(&self) -> Result<String> {
let mut res = String::new();
let stats = get_memory_usage_stats(Some(&[&self.db]), Some(&[&self.row_cache]))
.or_else(or_else)?;
let mibs = |input| f64::from(u32::try_from(input / 1024).unwrap_or(0)) / 1024.0;
writeln!(
res,
"Memory buffers: {:.2} MiB\nPending write: {:.2} MiB\nTable readers: {:.2} MiB\nRow \
cache: {:.2} MiB",
mibs(stats.mem_table_total),
mibs(stats.mem_table_unflushed),
mibs(stats.mem_table_readers_total),
mibs(u64::try_from(self.row_cache.get_usage())?),
)?;
for (name, cache) in &*self.col_cache.read().expect("locked") {
writeln!(res, "{name} cache: {:.2} MiB", mibs(u64::try_from(cache.get_usage())?))?;
}
Ok(res)
}
#[tracing::instrument(skip(self), level = "debug")]
pub fn cleanup(&self) -> Result<()> {
debug!("Running flush_opt");
let flushoptions = rocksdb::FlushOptions::default();
result(DBCommon::flush_opt(&self.db, &flushoptions))
}
#[inline]
pub fn corked(&self) -> bool { self.corks.load(Ordering::Relaxed) > 0 }
#[tracing::instrument(skip(self))]
pub fn backup(&self) -> Result {
let config = &self.server.config;
let path = config.database_backup_path.as_ref();
if path.is_none() || path.is_some_and(|path| path.as_os_str().is_empty()) {
return Ok(());
}
pub fn sync(&self) -> Result { result(DBCommon::flush_wal(&self.db, true)) }
let options = BackupEngineOptions::new(path.expect("valid database backup path"))
.map_err(map_err)?;
let mut engine = BackupEngine::open(&options, &self.env).map_err(map_err)?;
if config.database_backups_to_keep > 0 {
let flush = !self.is_read_only();
engine
.create_new_backup_flush(&self.db, flush)
.map_err(map_err)?;
#[tracing::instrument(skip(self), level = "debug")]
pub fn flush(&self) -> Result { result(DBCommon::flush_wal(&self.db, false)) }
let engine_info = engine.get_backup_info();
let info = &engine_info.last().expect("backup engine info is not empty");
info!(
"Created database backup #{} using {} bytes in {} files",
info.backup_id, info.size, info.num_files,
);
}
if config.database_backups_to_keep >= 0 {
let keep = u32::try_from(config.database_backups_to_keep)?;
if let Err(e) = engine.purge_old_backups(keep.try_into()?) {
error!("Failed to purge old backup: {e:?}");
}
}
Ok(())
}
pub fn backup_list(&self) -> Result<String> {
let config = &self.server.config;
let path = config.database_backup_path.as_ref();
if path.is_none() || path.is_some_and(|path| path.as_os_str().is_empty()) {
return Ok("Configure database_backup_path to enable backups, or the path \
specified is not valid"
.to_owned());
}
let mut res = String::new();
let options = BackupEngineOptions::new(path.expect("valid database backup path"))
.or_else(or_else)?;
let engine = BackupEngine::open(&options, &self.env).or_else(or_else)?;
for info in engine.get_backup_info() {
writeln!(
res,
"#{} {}: {} bytes, {} files",
info.backup_id,
rfc2822_from_seconds(info.timestamp),
info.size,
info.num_files,
)?;
}
Ok(res)
}
pub fn file_list(&self) -> Result<String> {
match self.db.live_files() {
| Err(e) => Ok(String::from(e)),
| Ok(files) => {
let mut res = String::new();
writeln!(res, "| lev | sst | keys | dels | size | column |")?;
writeln!(res, "| ---: | :--- | ---: | ---: | ---: | :--- |")?;
for file in files {
writeln!(
res,
"| {} | {:<13} | {:7}+ | {:4}- | {:9} | {} |",
file.level,
file.name,
file.num_entries,
file.num_deletions,
file.size,
file.column_family_name,
)?;
}
Ok(res)
},
}
#[tracing::instrument(skip(self), level = "debug")]
pub fn sort(&self) -> Result {
let flushoptions = rocksdb::FlushOptions::default();
result(DBCommon::flush_opt(&self.db, &flushoptions))
}
/// Query for database property by null-terminated name which is expected to
@@ -308,56 +88,14 @@ impl Engine {
pub fn is_secondary(&self) -> bool { self.secondary }
}
pub(crate) fn repair(db_opts: &Options, path: &PathBuf) -> Result<()> {
warn!("Starting database repair. This may take a long time...");
match Db::repair(db_opts, path) {
| Ok(()) => info!("Database repair successful."),
| Err(e) => return Err!("Repair failed: {e:?}"),
}
Ok(())
}
#[tracing::instrument(
parent = None,
name = "rocksdb",
level = "trace"
skip(msg),
)]
pub(crate) fn handle_log(level: LogLevel, msg: &str) {
let msg = msg.trim();
if msg.starts_with("Options") {
return;
}
match level {
| LogLevel::Header | LogLevel::Debug => debug!("{msg}"),
| LogLevel::Error | LogLevel::Fatal => error!("{msg}"),
| LogLevel::Info => debug!("{msg}"),
| LogLevel::Warn => warn!("{msg}"),
};
}
impl Drop for Engine {
#[cold]
fn drop(&mut self) {
const BLOCKING: bool = true;
debug!("Closing frontend pool");
self.pool.close();
debug!("Waiting for background tasks to finish...");
self.db.cancel_all_background_work(BLOCKING);
debug!("Shutting down background threads");
self.env.set_high_priority_background_threads(0);
self.env.set_low_priority_background_threads(0);
self.env.set_bottom_priority_background_threads(0);
self.env.set_background_threads(0);
debug!("Joining background threads...");
self.env.join_all_threads();
info!(
sequence = %self.db.latest_sequence_number(),
"Closing database..."
+73
View File
@@ -0,0 +1,73 @@
use std::fmt::Write;
use conduwuit::{error, implement, info, utils::time::rfc2822_from_seconds, warn, Result};
use rocksdb::backup::{BackupEngine, BackupEngineOptions};
use super::Engine;
use crate::{or_else, util::map_err};
#[implement(Engine)]
#[tracing::instrument(skip(self))]
pub fn backup(&self) -> Result {
let server = &self.ctx.server;
let config = &server.config;
let path = config.database_backup_path.as_ref();
if path.is_none() || path.is_some_and(|path| path.as_os_str().is_empty()) {
return Ok(());
}
let options =
BackupEngineOptions::new(path.expect("valid database backup path")).map_err(map_err)?;
let mut engine = BackupEngine::open(&options, &*self.ctx.env.lock()?).map_err(map_err)?;
if config.database_backups_to_keep > 0 {
let flush = !self.is_read_only();
engine
.create_new_backup_flush(&self.db, flush)
.map_err(map_err)?;
let engine_info = engine.get_backup_info();
let info = &engine_info.last().expect("backup engine info is not empty");
info!(
"Created database backup #{} using {} bytes in {} files",
info.backup_id, info.size, info.num_files,
);
}
if config.database_backups_to_keep >= 0 {
let keep = u32::try_from(config.database_backups_to_keep)?;
if let Err(e) = engine.purge_old_backups(keep.try_into()?) {
error!("Failed to purge old backup: {e:?}");
}
}
Ok(())
}
#[implement(Engine)]
pub fn backup_list(&self) -> Result<String> {
let server = &self.ctx.server;
let config = &server.config;
let path = config.database_backup_path.as_ref();
if path.is_none() || path.is_some_and(|path| path.as_os_str().is_empty()) {
return Ok("Configure database_backup_path to enable backups, or the path specified is \
not valid"
.to_owned());
}
let mut res = String::new();
let options =
BackupEngineOptions::new(path.expect("valid database backup path")).or_else(or_else)?;
let engine = BackupEngine::open(&options, &*self.ctx.env.lock()?).or_else(or_else)?;
for info in engine.get_backup_info() {
writeln!(
res,
"#{} {}: {} bytes, {} files",
info.backup_id,
rfc2822_from_seconds(info.timestamp),
info.size,
info.num_files,
)?;
}
Ok(res)
}
+233
View File
@@ -0,0 +1,233 @@
use conduwuit::{
err,
utils::{math::Expected, BoolExt},
Config, Result,
};
use rocksdb::{
BlockBasedIndexType, BlockBasedOptions, BlockBasedPinningTier, Cache,
DBCompressionType as CompressionType, DataBlockIndexType, LruCacheOptions, Options,
UniversalCompactOptions, UniversalCompactionStopStyle,
};
use super::descriptor::{CacheDisp, Descriptor};
use crate::{util::map_err, Context};
/// Adjust options for the specific column by name. Provide the result of
/// db_options() as the argument to this function and use the return value in
/// the arguments to open the specific column.
pub(crate) fn cf_options(ctx: &Context, opts: Options, desc: &Descriptor) -> Result<Options> {
let cache = get_cache(ctx, desc);
let config = &ctx.server.config;
descriptor_cf_options(opts, desc.clone(), config, cache.as_ref())
}
fn descriptor_cf_options(
mut opts: Options,
mut desc: Descriptor,
config: &Config,
cache: Option<&Cache>,
) -> Result<Options> {
set_compression(&mut desc, config);
set_table_options(&mut opts, &desc, cache)?;
opts.set_min_write_buffer_number(1);
opts.set_max_write_buffer_number(2);
if let Some(write_size) = desc.write_size {
opts.set_write_buffer_size(write_size);
}
opts.set_target_file_size_base(desc.file_size);
opts.set_target_file_size_multiplier(desc.file_shape[0]);
opts.set_level_zero_file_num_compaction_trigger(desc.level0_width);
opts.set_level_compaction_dynamic_level_bytes(false);
opts.set_ttl(desc.ttl);
opts.set_max_bytes_for_level_base(desc.level_size);
opts.set_max_bytes_for_level_multiplier(1.0);
opts.set_max_bytes_for_level_multiplier_additional(&desc.level_shape);
opts.set_compaction_style(desc.compaction);
opts.set_compaction_pri(desc.compaction_pri);
opts.set_universal_compaction_options(&uc_options(&desc));
opts.set_compression_type(desc.compression);
opts.set_compression_options(-14, desc.compression_level, 0, 0); // -14 w_bits used by zlib.
if let Some(&bottommost_level) = desc.bottommost_level.as_ref() {
opts.set_bottommost_compression_type(desc.compression);
opts.set_bottommost_zstd_max_train_bytes(0, true);
opts.set_bottommost_compression_options(
-14, // -14 w_bits is only read by zlib.
bottommost_level,
0,
0,
true,
);
}
opts.set_options_from_string("{{arena_block_size=2097152;}}")
.map_err(map_err)?;
Ok(opts)
}
fn set_table_options(opts: &mut Options, desc: &Descriptor, cache: Option<&Cache>) -> Result {
let mut table = table_options(desc, cache.is_some());
if let Some(cache) = cache {
table.set_block_cache(cache);
} else {
table.disable_cache();
}
opts.set_options_from_string(
"{{block_based_table_factory={num_file_reads_for_auto_readahead=0;\
max_auto_readahead_size=524288;initial_auto_readahead_size=16384}}}",
)
.map_err(map_err)?;
opts.set_block_based_table_factory(&table);
Ok(())
}
fn set_compression(desc: &mut Descriptor, config: &Config) {
desc.compression = match config.rocksdb_compression_algo.as_ref() {
| "snappy" => CompressionType::Snappy,
| "zlib" => CompressionType::Zlib,
| "bz2" => CompressionType::Bz2,
| "lz4" => CompressionType::Lz4,
| "lz4hc" => CompressionType::Lz4hc,
| "none" => CompressionType::None,
| _ => CompressionType::Zstd,
};
desc.compression_level = config.rocksdb_compression_level;
desc.bottommost_level = config
.rocksdb_bottommost_compression
.then_some(config.rocksdb_bottommost_compression_level);
}
fn uc_options(desc: &Descriptor) -> UniversalCompactOptions {
let mut opts = UniversalCompactOptions::default();
opts.set_stop_style(UniversalCompactionStopStyle::Total);
opts.set_min_merge_width(desc.merge_width.0);
opts.set_max_merge_width(desc.merge_width.1);
opts.set_max_size_amplification_percent(10000);
opts.set_compression_size_percent(-1);
opts.set_size_ratio(1);
opts
}
fn table_options(desc: &Descriptor, has_cache: bool) -> BlockBasedOptions {
let mut opts = BlockBasedOptions::default();
opts.set_block_size(desc.block_size);
opts.set_metadata_block_size(desc.index_size);
opts.set_cache_index_and_filter_blocks(has_cache);
opts.set_pin_top_level_index_and_filter(false);
opts.set_pin_l0_filter_and_index_blocks_in_cache(false);
opts.set_partition_pinning_tier(BlockBasedPinningTier::None);
opts.set_unpartitioned_pinning_tier(BlockBasedPinningTier::None);
opts.set_top_level_index_pinning_tier(BlockBasedPinningTier::None);
opts.set_partition_filters(true);
opts.set_use_delta_encoding(false);
opts.set_index_type(BlockBasedIndexType::TwoLevelIndexSearch);
opts.set_data_block_index_type(
desc.block_index_hashing
.map_or(DataBlockIndexType::BinarySearch, || DataBlockIndexType::BinaryAndHash),
);
opts
}
fn get_cache(ctx: &Context, desc: &Descriptor) -> Option<Cache> {
if desc.dropped {
return None;
}
// Some cache capacities are overriden by server config in a strange but
// legacy-compat way
let config = &ctx.server.config;
let cap = match desc.name {
| "eventid_pduid" => Some(config.eventid_pdu_cache_capacity),
| "eventid_shorteventid" => Some(config.eventidshort_cache_capacity),
| "shorteventid_eventid" => Some(config.shorteventid_cache_capacity),
| "shorteventid_authchain" => Some(config.auth_chain_cache_capacity),
| "shortstatekey_statekey" => Some(config.shortstatekey_cache_capacity),
| "statekey_shortstatekey" => Some(config.statekeyshort_cache_capacity),
| "servernameevent_data" => Some(config.servernameevent_data_cache_capacity),
| "pduid_pdu" | "eventid_outlierpdu" => Some(config.pdu_cache_capacity),
| _ => None,
}
.map(TryInto::try_into)
.transpose()
.expect("u32 to usize");
let ent_size: usize = desc
.key_size_hint
.unwrap_or_default()
.expected_add(desc.val_size_hint.unwrap_or_default());
let size = match cap {
| Some(cap) => cache_size(config, cap, ent_size),
| _ => desc.cache_size,
};
let shard_bits: i32 = desc
.cache_shards
.ilog2()
.try_into()
.expect("u32 to i32 conversion");
debug_assert!(shard_bits <= 6, "cache shards limited to 64");
let mut cache_opts = LruCacheOptions::default();
cache_opts.set_num_shard_bits(shard_bits);
cache_opts.set_capacity(size);
let mut caches = ctx.col_cache.lock().expect("locked");
match desc.cache_disp {
| CacheDisp::Unique if desc.cache_size == 0 => None,
| CacheDisp::Unique => {
let cache = Cache::new_lru_cache_opts(&cache_opts);
caches.insert(desc.name.into(), cache.clone());
Some(cache)
},
| CacheDisp::SharedWith(other) if !caches.contains_key(other) => {
let cache = Cache::new_lru_cache_opts(&cache_opts);
caches.insert(desc.name.into(), cache.clone());
Some(cache)
},
| CacheDisp::SharedWith(other) => Some(
caches
.get(other)
.cloned()
.expect("caches.contains_key(other) must be true"),
),
| CacheDisp::Shared => Some(
caches
.get("Shared")
.cloned()
.expect("shared cache must already exist"),
),
}
}
pub(crate) fn cache_size(config: &Config, base_size: u32, entity_size: usize) -> usize {
cache_size_f64(config, f64::from(base_size), entity_size)
}
#[allow(clippy::as_conversions, clippy::cast_sign_loss, clippy::cast_possible_truncation)]
pub(crate) fn cache_size_f64(config: &Config, base_size: f64, entity_size: usize) -> usize {
let ents = base_size * config.cache_capacity_modifier;
(ents as usize)
.checked_mul(entity_size)
.ok_or_else(|| err!(Config("cache_capacity_modifier", "Cache size is too large.")))
.expect("invalid cache size")
}
+73
View File
@@ -0,0 +1,73 @@
use std::{
collections::BTreeMap,
sync::{Arc, Mutex},
};
use conduwuit::{debug, utils::math::usize_from_f64, Result, Server};
use rocksdb::{Cache, Env};
use crate::{or_else, pool::Pool};
/// Some components are constructed prior to opening the database and must
/// outlive the database. These can also be shared between database instances
/// though at the time of this comment we only open one database per process.
/// These assets are housed in the shared Context.
pub(crate) struct Context {
pub(crate) pool: Arc<Pool>,
pub(crate) col_cache: Mutex<BTreeMap<String, Cache>>,
pub(crate) row_cache: Mutex<Cache>,
pub(crate) env: Mutex<Env>,
pub(crate) server: Arc<Server>,
}
impl Context {
pub(crate) fn new(server: &Arc<Server>) -> Result<Arc<Self>> {
let config = &server.config;
let cache_capacity_bytes = config.db_cache_capacity_mb * 1024.0 * 1024.0;
let row_cache_capacity_bytes = usize_from_f64(cache_capacity_bytes * 0.50)?;
let row_cache = Cache::new_lru_cache(row_cache_capacity_bytes);
let col_cache_capacity_bytes = usize_from_f64(cache_capacity_bytes * 0.50)?;
let col_cache = Cache::new_lru_cache(col_cache_capacity_bytes);
let col_cache: BTreeMap<_, _> = [("Shared".to_owned(), col_cache)].into();
let mut env = Env::new().or_else(or_else)?;
if config.rocksdb_compaction_prio_idle {
env.lower_thread_pool_cpu_priority();
}
if config.rocksdb_compaction_ioprio_idle {
env.lower_thread_pool_io_priority();
}
Ok(Arc::new(Self {
pool: Pool::new(server)?,
col_cache: col_cache.into(),
row_cache: row_cache.into(),
env: env.into(),
server: server.clone(),
}))
}
}
impl Drop for Context {
#[cold]
fn drop(&mut self) {
debug!("Closing frontend pool");
self.pool.close();
let mut env = self.env.lock().expect("locked");
debug!("Shutting down background threads");
env.set_high_priority_background_threads(0);
env.set_low_priority_background_threads(0);
env.set_bottom_priority_background_threads(0);
env.set_background_threads(0);
debug!("Joining background threads...");
env.join_all_threads();
}
}
+140
View File
@@ -0,0 +1,140 @@
use std::{cmp, convert::TryFrom};
use conduwuit::{utils, Config, Result};
use rocksdb::{statistics::StatsLevel, Cache, DBRecoveryMode, Env, LogLevel, Options};
use super::{cf_opts::cache_size_f64, logger::handle as handle_log};
/// Create database-wide options suitable for opening the database. This also
/// sets our default column options in case of opening a column with the same
/// resulting value. Note that we require special per-column options on some
/// columns, therefor columns should only be opened after passing this result
/// through cf_options().
pub(crate) fn db_options(config: &Config, env: &Env, row_cache: &Cache) -> Result<Options> {
const DEFAULT_STATS_LEVEL: StatsLevel = if cfg!(debug_assertions) {
StatsLevel::ExceptDetailedTimers
} else {
StatsLevel::DisableAll
};
let mut opts = Options::default();
// Logging
set_logging_defaults(&mut opts, config);
// Processing
opts.set_max_background_jobs(num_threads::<i32>(config)?);
opts.set_max_subcompactions(num_threads::<u32>(config)?);
opts.set_avoid_unnecessary_blocking_io(true);
opts.set_max_file_opening_threads(0);
// IO
opts.set_atomic_flush(true);
opts.set_manual_wal_flush(true);
opts.set_enable_pipelined_write(false);
if config.rocksdb_direct_io {
opts.set_use_direct_reads(true);
opts.set_use_direct_io_for_flush_and_compaction(true);
}
if config.rocksdb_optimize_for_spinning_disks {
// speeds up opening DB on hard drives
opts.set_skip_checking_sst_file_sizes_on_db_open(true);
opts.set_skip_stats_update_on_db_open(true);
//opts.set_max_file_opening_threads(threads.try_into().unwrap());
} else {
opts.set_compaction_readahead_size(1024 * 512);
}
// Blocks
opts.set_row_cache(row_cache);
opts.set_db_write_buffer_size(cache_size_f64(
config,
config.db_write_buffer_capacity_mb,
1_048_576,
));
// Files
opts.set_table_cache_num_shard_bits(7);
opts.set_wal_size_limit_mb(1024 * 1024 * 1024);
opts.set_max_total_wal_size(1024 * 1024 * 512);
opts.set_writable_file_max_buffer_size(1024 * 1024 * 2);
// Misc
opts.set_disable_auto_compactions(!config.rocksdb_compaction);
opts.create_missing_column_families(true);
opts.create_if_missing(true);
opts.set_statistics_level(match config.rocksdb_stats_level {
| 0 => StatsLevel::DisableAll,
| 1 => DEFAULT_STATS_LEVEL,
| 2 => StatsLevel::ExceptHistogramOrTimers,
| 3 => StatsLevel::ExceptTimers,
| 4 => StatsLevel::ExceptDetailedTimers,
| 5 => StatsLevel::ExceptTimeForMutex,
| 6_u8..=u8::MAX => StatsLevel::All,
});
opts.set_report_bg_io_stats(match config.rocksdb_stats_level {
| 0..=1 => false,
| 2_u8..=u8::MAX => true,
});
// Default: https://github.com/facebook/rocksdb/wiki/WAL-Recovery-Modes#ktoleratecorruptedtailrecords
//
// Unclean shutdowns of a Matrix homeserver are likely to be fine when
// recovered in this manner as it's likely any lost information will be
// restored via federation.
opts.set_wal_recovery_mode(match config.rocksdb_recovery_mode {
| 0 => DBRecoveryMode::AbsoluteConsistency,
| 1 => DBRecoveryMode::TolerateCorruptedTailRecords,
| 2 => DBRecoveryMode::PointInTime,
| 3 => DBRecoveryMode::SkipAnyCorruptedRecord,
| 4_u8..=u8::MAX => unimplemented!(),
});
// <https://github.com/facebook/rocksdb/wiki/Track-WAL-in-MANIFEST>
// "We recommend to set track_and_verify_wals_in_manifest to true for
// production, it has been enabled in production for the entire database cluster
// serving the social graph for all Meta apps."
opts.set_track_and_verify_wals_in_manifest(true);
opts.set_paranoid_checks(config.rocksdb_paranoid_file_checks);
opts.set_env(env);
Ok(opts)
}
fn set_logging_defaults(opts: &mut Options, config: &Config) {
let rocksdb_log_level = match config.rocksdb_log_level.as_ref() {
| "debug" => LogLevel::Debug,
| "info" => LogLevel::Info,
| "warn" => LogLevel::Warn,
| "fatal" => LogLevel::Fatal,
| _ => LogLevel::Error,
};
opts.set_log_level(rocksdb_log_level);
opts.set_max_log_file_size(config.rocksdb_max_log_file_size);
opts.set_log_file_time_to_roll(config.rocksdb_log_time_to_roll);
opts.set_keep_log_file_num(config.rocksdb_max_log_files);
opts.set_stats_dump_period_sec(0);
if config.rocksdb_log_stderr {
opts.set_stderr_logger(rocksdb_log_level, "rocksdb");
} else {
opts.set_callback_logger(rocksdb_log_level, &handle_log);
}
}
fn num_threads<T: TryFrom<usize>>(config: &Config) -> Result<T> {
const MIN_PARALLELISM: usize = 2;
let requested = if config.rocksdb_parallelism_threads != 0 {
config.rocksdb_parallelism_threads
} else {
utils::available_parallelism()
};
utils::math::try_into::<T, usize>(cmp::max(MIN_PARALLELISM, requested))
}
+91
View File
@@ -0,0 +1,91 @@
use conduwuit::utils::string::EMPTY;
use rocksdb::{
DBCompactionPri as CompactionPri, DBCompactionStyle as CompactionStyle,
DBCompressionType as CompressionType,
};
#[derive(Debug, Clone, Copy)]
pub(crate) enum CacheDisp {
Unique,
Shared,
SharedWith(&'static str),
}
#[derive(Debug, Clone)]
pub(crate) struct Descriptor {
pub(crate) name: &'static str,
pub(crate) dropped: bool,
pub(crate) cache_disp: CacheDisp,
pub(crate) key_size_hint: Option<usize>,
pub(crate) val_size_hint: Option<usize>,
pub(crate) block_size: usize,
pub(crate) index_size: usize,
pub(crate) write_size: Option<usize>,
pub(crate) cache_size: usize,
pub(crate) level_size: u64,
pub(crate) level_shape: [i32; 7],
pub(crate) file_size: u64,
pub(crate) file_shape: [i32; 1],
pub(crate) level0_width: i32,
pub(crate) merge_width: (i32, i32),
pub(crate) ttl: u64,
pub(crate) compaction: CompactionStyle,
pub(crate) compaction_pri: CompactionPri,
pub(crate) compression: CompressionType,
pub(crate) compression_level: i32,
pub(crate) bottommost_level: Option<i32>,
pub(crate) block_index_hashing: bool,
pub(crate) cache_shards: u32,
}
pub(crate) static BASE: Descriptor = Descriptor {
name: EMPTY,
dropped: false,
cache_disp: CacheDisp::Shared,
key_size_hint: None,
val_size_hint: None,
block_size: 1024 * 4,
index_size: 1024 * 4,
write_size: None,
cache_size: 1024 * 1024 * 4,
level_size: 1024 * 1024 * 8,
level_shape: [1, 1, 1, 3, 7, 15, 31],
file_size: 1024 * 1024,
file_shape: [2],
level0_width: 2,
merge_width: (2, 16),
ttl: 60 * 60 * 24 * 21,
compaction: CompactionStyle::Level,
compaction_pri: CompactionPri::MinOverlappingRatio,
compression: CompressionType::Zstd,
compression_level: 32767,
bottommost_level: Some(32767),
block_index_hashing: false,
cache_shards: 64,
};
pub(crate) static RANDOM: Descriptor = Descriptor {
compaction_pri: CompactionPri::OldestSmallestSeqFirst,
..BASE
};
pub(crate) static SEQUENTIAL: Descriptor = Descriptor {
compaction_pri: CompactionPri::OldestLargestSeqFirst,
level_size: 1024 * 1024 * 32,
file_size: 1024 * 1024 * 2,
..BASE
};
pub(crate) static RANDOM_SMALL: Descriptor = Descriptor {
compaction: CompactionStyle::Universal,
level_size: 1024 * 512,
file_size: 1024 * 128,
..RANDOM
};
pub(crate) static SEQUENTIAL_SMALL: Descriptor = Descriptor {
compaction: CompactionStyle::Universal,
level_size: 1024 * 1024,
file_size: 1024 * 512,
..SEQUENTIAL
};
+32
View File
@@ -0,0 +1,32 @@
use std::fmt::Write;
use conduwuit::{implement, Result};
use super::Engine;
#[implement(Engine)]
pub fn file_list(&self) -> Result<String> {
match self.db.live_files() {
| Err(e) => Ok(String::from(e)),
| Ok(mut files) => {
files.sort_by_key(|f| f.name.clone());
let mut res = String::new();
writeln!(res, "| lev | sst | keys | dels | size | column |")?;
writeln!(res, "| ---: | :--- | ---: | ---: | ---: | :--- |")?;
for file in files {
writeln!(
res,
"| {} | {:<13} | {:7}+ | {:4}- | {:9} | {} |",
file.level,
file.name,
file.num_entries,
file.num_deletions,
file.size,
file.column_family_name,
)?;
}
Ok(res)
},
}
}
+22
View File
@@ -0,0 +1,22 @@
use conduwuit::{debug, error, warn};
use rocksdb::LogLevel;
#[tracing::instrument(
parent = None,
name = "rocksdb",
level = "trace"
skip(msg),
)]
pub(crate) fn handle(level: LogLevel, msg: &str) {
let msg = msg.trim();
if msg.starts_with("Options") {
return;
}
match level {
| LogLevel::Header | LogLevel::Debug => debug!("{msg}"),
| LogLevel::Error | LogLevel::Fatal => error!("{msg}"),
| LogLevel::Info => debug!("{msg}"),
| LogLevel::Warn => warn!("{msg}"),
};
}
+30
View File
@@ -0,0 +1,30 @@
use std::fmt::Write;
use conduwuit::{implement, Result};
use rocksdb::perf::get_memory_usage_stats;
use super::Engine;
use crate::or_else;
#[implement(Engine)]
pub fn memory_usage(&self) -> Result<String> {
let mut res = String::new();
let stats = get_memory_usage_stats(Some(&[&self.db]), Some(&[&*self.ctx.row_cache.lock()?]))
.or_else(or_else)?;
let mibs = |input| f64::from(u32::try_from(input / 1024).unwrap_or(0)) / 1024.0;
writeln!(
res,
"Memory buffers: {:.2} MiB\nPending write: {:.2} MiB\nTable readers: {:.2} MiB\nRow \
cache: {:.2} MiB",
mibs(stats.mem_table_total),
mibs(stats.mem_table_unflushed),
mibs(stats.mem_table_readers_total),
mibs(u64::try_from(self.ctx.row_cache.lock()?.get_usage())?),
)?;
for (name, cache) in &*self.ctx.col_cache.lock()? {
writeln!(res, "{name} cache: {:.2} MiB", mibs(u64::try_from(cache.get_usage())?))?;
}
Ok(res)
}
+133
View File
@@ -0,0 +1,133 @@
use std::{
collections::BTreeSet,
path::Path,
sync::{atomic::AtomicU32, Arc},
};
use conduwuit::{debug, implement, info, warn, Result};
use rocksdb::{ColumnFamilyDescriptor, Options};
use super::{
cf_opts::cf_options,
db_opts::db_options,
descriptor::{self, Descriptor},
repair::repair,
Db, Engine,
};
use crate::{or_else, Context};
#[implement(Engine)]
#[tracing::instrument(skip_all)]
pub(crate) async fn open(ctx: Arc<Context>, desc: &[Descriptor]) -> Result<Arc<Self>> {
let server = &ctx.server;
let config = &server.config;
let path = &config.database_path;
let db_opts = db_options(
config,
&ctx.env.lock().expect("environment locked"),
&ctx.row_cache.lock().expect("row cache locked"),
)?;
let cfds = Self::configure_cfds(&ctx, &db_opts, desc)?;
let num_cfds = cfds.len();
debug!("Configured {num_cfds} column descriptors...");
let load_time = std::time::Instant::now();
if config.rocksdb_repair {
repair(&db_opts, &config.database_path)?;
}
debug!("Opening database...");
let db = if config.rocksdb_read_only {
Db::open_cf_descriptors_read_only(&db_opts, path, cfds, false)
} else if config.rocksdb_secondary {
Db::open_cf_descriptors_as_secondary(&db_opts, path, path, cfds)
} else {
Db::open_cf_descriptors(&db_opts, path, cfds)
}
.or_else(or_else)?;
info!(
columns = num_cfds,
sequence = %db.latest_sequence_number(),
time = ?load_time.elapsed(),
"Opened database."
);
Ok(Arc::new(Self {
read_only: config.rocksdb_read_only,
secondary: config.rocksdb_secondary,
corks: AtomicU32::new(0),
pool: ctx.pool.clone(),
db,
ctx,
}))
}
#[implement(Engine)]
#[tracing::instrument(name = "configure", skip_all)]
fn configure_cfds(
ctx: &Arc<Context>,
db_opts: &Options,
desc: &[Descriptor],
) -> Result<Vec<ColumnFamilyDescriptor>> {
let server = &ctx.server;
let config = &server.config;
let path = &config.database_path;
let existing = Self::discover_cfs(path, db_opts);
let creating = desc.iter().filter(|desc| !existing.contains(desc.name));
let missing = existing
.iter()
.filter(|&name| name != "default")
.filter(|&name| !desc.iter().any(|desc| desc.name == name));
debug!(
existing = existing.len(),
described = desc.len(),
missing = missing.clone().count(),
creating = creating.clone().count(),
"Discovered database columns"
);
missing.clone().for_each(|name| {
debug!("Found unrecognized column {name:?} in existing database.");
});
creating.map(|desc| desc.name).for_each(|name| {
debug!("Creating new column {name:?} not previously found in existing database.");
});
let missing_descriptors = missing
.clone()
.map(|_| Descriptor { dropped: true, ..descriptor::BASE });
let cfopts: Vec<_> = desc
.iter()
.cloned()
.chain(missing_descriptors)
.map(|ref desc| cf_options(ctx, db_opts.clone(), desc))
.collect::<Result<_>>()?;
let cfds: Vec<_> = desc
.iter()
.map(|desc| desc.name)
.map(ToOwned::to_owned)
.chain(missing.cloned())
.zip(cfopts.into_iter())
.map(|(name, opts)| ColumnFamilyDescriptor::new(name, opts))
.collect();
Ok(cfds)
}
#[implement(Engine)]
#[tracing::instrument(name = "discover", skip_all)]
fn discover_cfs(path: &Path, opts: &Options) -> BTreeSet<String> {
Db::list_cf(opts, path)
.unwrap_or_default()
.into_iter()
.collect::<BTreeSet<_>>()
}
+16
View File
@@ -0,0 +1,16 @@
use std::path::PathBuf;
use conduwuit::{info, warn, Err, Result};
use rocksdb::Options;
use super::Db;
pub(crate) fn repair(db_opts: &Options, path: &PathBuf) -> Result {
warn!("Starting database repair. This may take a long time...");
match Db::repair(db_opts, path) {
| Ok(()) => info!("Database repair successful."),
| Err(e) => return Err!("Repair failed: {e:?}"),
}
Ok(())
}
+16 -64
View File
@@ -6,6 +6,8 @@ mod insert;
mod keys;
mod keys_from;
mod keys_prefix;
mod open;
mod options;
mod remove;
mod rev_keys;
mod rev_keys_from;
@@ -28,12 +30,15 @@ use std::{
};
use conduwuit::Result;
use rocksdb::{AsColumnFamilyRef, ColumnFamily, ReadOptions, ReadTier, WriteOptions};
use rocksdb::{AsColumnFamilyRef, ColumnFamily, ReadOptions, WriteOptions};
pub(crate) use self::options::{
cache_read_options_default, iter_options_default, read_options_default, write_options_default,
};
use crate::{watchers::Watchers, Engine};
pub struct Map {
name: String,
name: &'static str,
db: Arc<Engine>,
cf: Arc<ColumnFamily>,
watchers: Watchers,
@@ -43,11 +48,11 @@ pub struct Map {
}
impl Map {
pub(crate) fn open(db: &Arc<Engine>, name: &str) -> Result<Arc<Self>> {
pub(crate) fn open(db: &Arc<Engine>, name: &'static str) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
name: name.to_owned(),
name,
db: db.clone(),
cf: open(db, name)?,
cf: open::open(db, name),
watchers: Watchers::default(),
write_options: write_options_default(),
read_options: read_options_default(),
@@ -75,9 +80,13 @@ impl Map {
pub fn property(&self, name: &str) -> Result<String> { self.db.property(&self.cf(), name) }
#[inline]
pub fn name(&self) -> &str { &self.name }
pub fn name(&self) -> &str { self.name }
fn cf(&self) -> impl AsColumnFamilyRef + '_ { &*self.cf }
#[inline]
pub(crate) fn db(&self) -> &Arc<Engine> { &self.db }
#[inline]
pub(crate) fn cf(&self) -> impl AsColumnFamilyRef + '_ { &*self.cf }
}
impl Debug for Map {
@@ -89,60 +98,3 @@ impl Debug for Map {
impl Display for Map {
fn fmt(&self, out: &mut fmt::Formatter<'_>) -> fmt::Result { write!(out, "{0}", self.name) }
}
fn open(db: &Arc<Engine>, name: &str) -> Result<Arc<ColumnFamily>> {
let bounded_arc = db.open_cf(name)?;
let bounded_ptr = Arc::into_raw(bounded_arc);
let cf_ptr = bounded_ptr.cast::<ColumnFamily>();
// SAFETY: Column family handles out of RocksDB are basic pointers and can
// be invalidated: 1. when the database closes. 2. when the column is dropped or
// closed. rust_rocksdb wraps this for us by storing handles in their own
// `RwLock<BTreeMap>` map and returning an Arc<BoundColumnFamily<'_>>` to
// provide expected safety. Similarly in "single-threaded mode" we would
// receive `&'_ ColumnFamily`.
//
// PROBLEM: We need to hold these handles in a field, otherwise we have to take
// a lock and get them by name from this map for every query, which is what
// conduit was doing, but we're not going to make a query for every query so we
// need to be holding it right. The lifetime parameter on these references makes
// that complicated. If this can be done without polluting the userspace
// with lifetimes on every instance of `Map` then this `unsafe` might not be
// necessary.
//
// SOLUTION: After investigating the underlying types it appears valid to
// Arc-swap `BoundColumnFamily<'_>` for `ColumnFamily`. They have the
// same inner data, the same Drop behavior, Deref, etc. We're just losing the
// lifetime parameter. We should not hold this handle, even in its Arc, after
// closing the database (dropping `Engine`). Since `Arc<Engine>` is a sibling
// member along with this handle in `Map`, that is prevented.
Ok(unsafe {
Arc::increment_strong_count(cf_ptr);
Arc::from_raw(cf_ptr)
})
}
#[inline]
pub(crate) fn iter_options_default() -> ReadOptions {
let mut read_options = read_options_default();
read_options.set_background_purge_on_iterator_cleanup(true);
//read_options.set_pin_data(true);
read_options
}
#[inline]
pub(crate) fn cache_read_options_default() -> ReadOptions {
let mut read_options = read_options_default();
read_options.set_read_tier(ReadTier::BlockCache);
read_options
}
#[inline]
pub(crate) fn read_options_default() -> ReadOptions {
let mut read_options = ReadOptions::default();
read_options.set_total_order_seek(true);
read_options
}
#[inline]
pub(crate) fn write_options_default() -> WriteOptions { WriteOptions::default() }
+1 -1
View File
@@ -23,7 +23,7 @@ pub fn raw_keys(self: &Arc<Self>) -> impl Stream<Item = Result<Key<'_>>> + Send
use crate::pool::Seek;
let opts = super::iter_options_default();
let state = stream::State::new(&self.db, &self.cf, opts);
let state = stream::State::new(self, opts);
if is_cached(self) {
let state = state.init_fwd(None);
return task::consume_budget()
+1 -1
View File
@@ -54,7 +54,7 @@ where
use crate::pool::Seek;
let opts = super::iter_options_default();
let state = stream::State::new(&self.db, &self.cf, opts);
let state = stream::State::new(self, opts);
if is_cached(self, from) {
return stream::Keys::<'_>::from(state.init_fwd(from.as_ref().into())).boxed();
}
+37
View File
@@ -0,0 +1,37 @@
use std::sync::Arc;
use rocksdb::ColumnFamily;
use crate::Engine;
pub(super) fn open(db: &Arc<Engine>, name: &str) -> Arc<ColumnFamily> {
let bounded_arc = db.cf(name);
let bounded_ptr = Arc::into_raw(bounded_arc);
let cf_ptr = bounded_ptr.cast::<ColumnFamily>();
// SAFETY: Column family handles out of RocksDB are basic pointers and can
// be invalidated: 1. when the database closes. 2. when the column is dropped or
// closed. rust_rocksdb wraps this for us by storing handles in their own
// `RwLock<BTreeMap>` map and returning an Arc<BoundColumnFamily<'_>>` to
// provide expected safety. Similarly in "single-threaded mode" we would
// receive `&'_ ColumnFamily`.
//
// PROBLEM: We need to hold these handles in a field, otherwise we have to take
// a lock and get them by name from this map for every query, which is what
// conduit was doing, but we're not going to make a query for every query so we
// need to be holding it right. The lifetime parameter on these references makes
// that complicated. If this can be done without polluting the userspace
// with lifetimes on every instance of `Map` then this `unsafe` might not be
// necessary.
//
// SOLUTION: After investigating the underlying types it appears valid to
// Arc-swap `BoundColumnFamily<'_>` for `ColumnFamily`. They have the
// same inner data, the same Drop behavior, Deref, etc. We're just losing the
// lifetime parameter. We should not hold this handle, even in its Arc, after
// closing the database (dropping `Engine`). Since `Arc<Engine>` is a sibling
// member along with this handle in `Map`, that is prevented.
unsafe {
Arc::increment_strong_count(cf_ptr);
Arc::from_raw(cf_ptr)
}
}
+26
View File
@@ -0,0 +1,26 @@
use rocksdb::{ReadOptions, ReadTier, WriteOptions};
#[inline]
pub(crate) fn iter_options_default() -> ReadOptions {
let mut read_options = read_options_default();
read_options.set_background_purge_on_iterator_cleanup(true);
//read_options.set_pin_data(true);
read_options
}
#[inline]
pub(crate) fn cache_read_options_default() -> ReadOptions {
let mut read_options = read_options_default();
read_options.set_read_tier(ReadTier::BlockCache);
read_options
}
#[inline]
pub(crate) fn read_options_default() -> ReadOptions {
let mut read_options = ReadOptions::default();
read_options.set_total_order_seek(true);
read_options
}
#[inline]
pub(crate) fn write_options_default() -> WriteOptions { WriteOptions::default() }
+1 -1
View File
@@ -23,7 +23,7 @@ pub fn rev_raw_keys(self: &Arc<Self>) -> impl Stream<Item = Result<Key<'_>>> + S
use crate::pool::Seek;
let opts = super::iter_options_default();
let state = stream::State::new(&self.db, &self.cf, opts);
let state = stream::State::new(self, opts);
if is_cached(self) {
let state = state.init_rev(None);
return task::consume_budget()
+1 -1
View File
@@ -62,7 +62,7 @@ where
use crate::pool::Seek;
let opts = super::iter_options_default();
let state = stream::State::new(&self.db, &self.cf, opts);
let state = stream::State::new(self, opts);
if is_cached(self, from) {
return stream::KeysRev::<'_>::from(state.init_rev(from.as_ref().into())).boxed();
}
+3 -3
View File
@@ -32,7 +32,7 @@ pub fn rev_raw_stream(self: &Arc<Self>) -> impl Stream<Item = Result<KeyVal<'_>>
use crate::pool::Seek;
let opts = super::read_options_default();
let state = stream::State::new(&self.db, &self.cf, opts);
let state = stream::State::new(self, opts);
if is_cached(self) {
let state = state.init_rev(None);
return task::consume_budget()
@@ -65,9 +65,9 @@ pub fn rev_raw_stream(self: &Arc<Self>) -> impl Stream<Item = Result<KeyVal<'_>>
skip_all,
fields(%map),
)]
pub(super) fn is_cached(map: &super::Map) -> bool {
pub(super) fn is_cached(map: &Arc<super::Map>) -> bool {
let opts = super::cache_read_options_default();
let state = stream::State::new(&map.db, &map.cf, opts).init_rev(None);
let state = stream::State::new(map, opts).init_rev(None);
!state.is_incomplete()
}
+2 -2
View File
@@ -81,7 +81,7 @@ where
use crate::pool::Seek;
let opts = super::iter_options_default();
let state = stream::State::new(&self.db, &self.cf, opts);
let state = stream::State::new(self, opts);
if is_cached(self, from) {
let state = state.init_rev(from.as_ref().into());
return task::consume_budget()
@@ -119,7 +119,7 @@ where
P: AsRef<[u8]> + ?Sized,
{
let cache_opts = super::cache_read_options_default();
let cache_status = stream::State::new(&map.db, &map.cf, cache_opts)
let cache_status = stream::State::new(map, cache_opts)
.init_rev(from.as_ref().into())
.status();
+3 -3
View File
@@ -31,7 +31,7 @@ pub fn raw_stream(self: &Arc<Self>) -> impl Stream<Item = Result<KeyVal<'_>>> +
use crate::pool::Seek;
let opts = super::read_options_default();
let state = stream::State::new(&self.db, &self.cf, opts);
let state = stream::State::new(self, opts);
if is_cached(self) {
let state = state.init_fwd(None);
return task::consume_budget()
@@ -64,9 +64,9 @@ pub fn raw_stream(self: &Arc<Self>) -> impl Stream<Item = Result<KeyVal<'_>>> +
skip_all,
fields(%map),
)]
pub(super) fn is_cached(map: &super::Map) -> bool {
pub(super) fn is_cached(map: &Arc<super::Map>) -> bool {
let opts = super::cache_read_options_default();
let state = stream::State::new(&map.db, &map.cf, opts).init_fwd(None);
let state = stream::State::new(map, opts).init_fwd(None);
!state.is_incomplete()
}
+2 -2
View File
@@ -78,7 +78,7 @@ where
use crate::pool::Seek;
let opts = super::read_options_default();
let state = stream::State::new(&self.db, &self.cf, opts);
let state = stream::State::new(self, opts);
if is_cached(self, from) {
let state = state.init_fwd(from.as_ref().into());
return task::consume_budget()
@@ -116,7 +116,7 @@ where
P: AsRef<[u8]> + ?Sized,
{
let opts = super::cache_read_options_default();
let state = stream::State::new(&map.db, &map.cf, opts).init_fwd(from.as_ref().into());
let state = stream::State::new(map, opts).init_fwd(from.as_ref().into());
!state.is_incomplete()
}
+373 -93
View File
@@ -2,103 +2,383 @@ use std::{collections::BTreeMap, sync::Arc};
use conduwuit::Result;
use crate::{Engine, Map};
use crate::{
engine::descriptor::{self, CacheDisp, Descriptor},
Engine, Map,
};
pub type Maps = BTreeMap<MapsKey, MapsVal>;
pub(crate) type MapsVal = Arc<Map>;
pub(crate) type MapsKey = String;
pub(super) type Maps = BTreeMap<MapsKey, MapsVal>;
pub(super) type MapsKey = &'static str;
pub(super) type MapsVal = Arc<Map>;
pub(crate) fn open(db: &Arc<Engine>) -> Result<Maps> { open_list(db, MAPS) }
pub(super) fn open(db: &Arc<Engine>) -> Result<Maps> { open_list(db, MAPS) }
#[tracing::instrument(name = "maps", level = "debug", skip_all)]
pub(crate) fn open_list(db: &Arc<Engine>, maps: &[&str]) -> Result<Maps> {
Ok(maps
.iter()
.map(|&name| (name.to_owned(), Map::open(db, name).expect("valid column opened")))
.collect::<Maps>())
pub(super) fn open_list(db: &Arc<Engine>, maps: &[Descriptor]) -> Result<Maps> {
maps.iter()
.map(|desc| Ok((desc.name, Map::open(db, desc.name)?)))
.collect()
}
pub const MAPS: &[&str] = &[
"alias_roomid",
"alias_userid",
"aliasid_alias",
"backupid_algorithm",
"backupid_etag",
"backupkeyid_backup",
"bannedroomids",
"disabledroomids",
"eventid_outlierpdu",
"eventid_pduid",
"eventid_shorteventid",
"global",
"id_appserviceregistrations",
"keychangeid_userid",
"keyid_key",
"lazyloadedids",
"mediaid_file",
"mediaid_user",
"onetimekeyid_onetimekeys",
"pduid_pdu",
"presenceid_presence",
"publicroomids",
"readreceiptid_readreceipt",
"referencedevents",
"roomid_invitedcount",
"roomid_inviteviaservers",
"roomid_joinedcount",
"roomid_pduleaves",
"roomid_shortroomid",
"roomid_shortstatehash",
"roomserverids",
"roomsynctoken_shortstatehash",
"roomuserdataid_accountdata",
"roomuserid_invitecount",
"roomuserid_joined",
"roomuserid_lastprivatereadupdate",
"roomuserid_leftcount",
"roomuserid_privateread",
"roomuseroncejoinedids",
"roomusertype_roomuserdataid",
"senderkey_pusher",
"server_signingkeys",
"servercurrentevent_data",
"servername_educount",
"servernameevent_data",
"serverroomids",
"shorteventid_authchain",
"shorteventid_eventid",
"shorteventid_shortstatehash",
"shortstatehash_statediff",
"shortstatekey_statekey",
"softfailedeventids",
"statehash_shortstatehash",
"statekey_shortstatekey",
"threadid_userids",
"todeviceid_events",
"tofrom_relation",
"token_userdeviceid",
"tokenids",
"url_previews",
"userdeviceid_metadata",
"userdeviceid_token",
"userdevicesessionid_uiaainfo",
"userdevicetxnid_response",
"userfilterid_filter",
"userid_avatarurl",
"userid_blurhash",
"userid_devicelistversion",
"userid_displayname",
"userid_lastonetimekeyupdate",
"userid_masterkeyid",
"userid_password",
"userid_presenceid",
"userid_selfsigningkeyid",
"userid_usersigningkeyid",
"useridprofilekey_value",
"openidtoken_expiresatuserid",
"userroomid_highlightcount",
"userroomid_invitestate",
"userroomid_joined",
"userroomid_leftstate",
"userroomid_notificationcount",
pub(super) static MAPS: &[Descriptor] = &[
Descriptor {
name: "alias_roomid",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "alias_userid",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "aliasid_alias",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "backupid_algorithm",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "backupid_etag",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "backupkeyid_backup",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "bannedroomids",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "disabledroomids",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "eventid_outlierpdu",
cache_disp: CacheDisp::SharedWith("pduid_pdu"),
key_size_hint: Some(48),
val_size_hint: Some(1488),
..descriptor::RANDOM
},
Descriptor {
name: "eventid_pduid",
cache_disp: CacheDisp::Unique,
key_size_hint: Some(48),
val_size_hint: Some(16),
..descriptor::RANDOM
},
Descriptor {
name: "eventid_shorteventid",
cache_disp: CacheDisp::Unique,
key_size_hint: Some(48),
val_size_hint: Some(8),
..descriptor::RANDOM
},
Descriptor {
name: "global",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "id_appserviceregistrations",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "keychangeid_userid",
..descriptor::RANDOM
},
Descriptor {
name: "keyid_key",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "lazyloadedids",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "mediaid_file",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "mediaid_user",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "onetimekeyid_onetimekeys",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "pduid_pdu",
cache_disp: CacheDisp::SharedWith("eventid_outlierpdu"),
key_size_hint: Some(16),
val_size_hint: Some(1520),
..descriptor::SEQUENTIAL
},
Descriptor {
name: "presenceid_presence",
..descriptor::SEQUENTIAL_SMALL
},
Descriptor {
name: "publicroomids",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "readreceiptid_readreceipt",
..descriptor::RANDOM
},
Descriptor {
name: "referencedevents",
..descriptor::RANDOM
},
Descriptor {
name: "roomid_invitedcount",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "roomid_inviteviaservers",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "roomid_joinedcount",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "roomid_pduleaves",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "roomid_shortroomid",
val_size_hint: Some(8),
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "roomid_shortstatehash",
val_size_hint: Some(8),
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "roomserverids",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "roomsynctoken_shortstatehash",
val_size_hint: Some(8),
..descriptor::SEQUENTIAL
},
Descriptor {
name: "roomuserdataid_accountdata",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "roomuserid_invitecount",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "roomuserid_joined",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "roomuserid_lastprivatereadupdate",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "roomuserid_leftcount",
..descriptor::RANDOM
},
Descriptor {
name: "roomuserid_privateread",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "roomuseroncejoinedids",
..descriptor::RANDOM
},
Descriptor {
name: "roomusertype_roomuserdataid",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "senderkey_pusher",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "server_signingkeys",
..descriptor::RANDOM
},
Descriptor {
name: "servercurrentevent_data",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "servername_educount",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "servernameevent_data",
cache_disp: CacheDisp::Unique,
val_size_hint: Some(128),
..descriptor::RANDOM
},
Descriptor {
name: "serverroomids",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "shorteventid_authchain",
cache_disp: CacheDisp::Unique,
key_size_hint: Some(8),
..descriptor::SEQUENTIAL
},
Descriptor {
name: "shorteventid_eventid",
cache_disp: CacheDisp::Unique,
key_size_hint: Some(8),
val_size_hint: Some(48),
..descriptor::SEQUENTIAL_SMALL
},
Descriptor {
name: "shorteventid_shortstatehash",
key_size_hint: Some(8),
val_size_hint: Some(8),
..descriptor::SEQUENTIAL
},
Descriptor {
name: "shortstatehash_statediff",
key_size_hint: Some(8),
..descriptor::SEQUENTIAL_SMALL
},
Descriptor {
name: "shortstatekey_statekey",
cache_disp: CacheDisp::Unique,
key_size_hint: Some(8),
val_size_hint: Some(1016),
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "softfailedeventids",
key_size_hint: Some(48),
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "statehash_shortstatehash",
val_size_hint: Some(8),
..descriptor::RANDOM
},
Descriptor {
name: "statekey_shortstatekey",
cache_disp: CacheDisp::Unique,
key_size_hint: Some(1016),
val_size_hint: Some(8),
..descriptor::RANDOM
},
Descriptor {
name: "threadid_userids",
..descriptor::SEQUENTIAL_SMALL
},
Descriptor {
name: "todeviceid_events",
..descriptor::RANDOM
},
Descriptor {
name: "tofrom_relation",
key_size_hint: Some(8),
val_size_hint: Some(8),
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "token_userdeviceid",
..descriptor::RANDOM_SMALL
},
Descriptor { name: "tokenids", ..descriptor::RANDOM },
Descriptor {
name: "url_previews",
..descriptor::RANDOM
},
Descriptor {
name: "userdeviceid_metadata",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "userdeviceid_token",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "userdevicesessionid_uiaainfo",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "userdevicetxnid_response",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "userfilterid_filter",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "userid_avatarurl",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "userid_blurhash",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "userid_devicelistversion",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "userid_displayname",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "userid_lastonetimekeyupdate",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "userid_masterkeyid",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "userid_password",
..descriptor::RANDOM
},
Descriptor {
name: "userid_presenceid",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "userid_selfsigningkeyid",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "userid_usersigningkeyid",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "useridprofilekey_value",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "openidtoken_expiresatuserid",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "userroomid_highlightcount",
..descriptor::RANDOM
},
Descriptor {
name: "userroomid_invitestate",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "userroomid_joined",
..descriptor::RANDOM
},
Descriptor {
name: "userroomid_leftstate",
..descriptor::RANDOM
},
Descriptor {
name: "userroomid_notificationcount",
..descriptor::RANDOM
},
];
+66 -12
View File
@@ -1,5 +1,11 @@
extern crate conduwuit_core as conduwuit;
extern crate rust_rocksdb as rocksdb;
conduwuit::mod_ctor! {}
conduwuit::mod_dtor! {}
conduwuit::rustc_flags_capture! {}
mod cork;
mod database;
mod de;
mod deserialized;
mod engine;
@@ -7,24 +13,19 @@ mod handle;
pub mod keyval;
mod map;
pub mod maps;
mod opts;
mod pool;
mod ser;
mod stream;
#[cfg(test)]
mod tests;
pub(crate) mod util;
mod watchers;
pub(crate) use self::{
engine::Engine,
util::{or_else, result},
};
use std::{ops::Index, sync::Arc};
extern crate conduwuit_core as conduwuit;
extern crate rust_rocksdb as rocksdb;
use conduwuit::{err, Result, Server};
pub use self::{
database::Database,
de::{Ignore, IgnoreAll},
deserialized::Deserialized,
handle::Handle,
@@ -32,7 +33,60 @@ pub use self::{
map::Map,
ser::{serialize, serialize_to, serialize_to_vec, Interfix, Json, Separator, SEP},
};
pub(crate) use self::{
engine::{context::Context, Engine},
util::{or_else, result},
};
use crate::maps::{Maps, MapsKey, MapsVal};
conduwuit::mod_ctor! {}
conduwuit::mod_dtor! {}
conduwuit::rustc_flags_capture! {}
pub struct Database {
maps: Maps,
pub db: Arc<Engine>,
pub(crate) _ctx: Arc<Context>,
}
impl Database {
/// Load an existing database or create a new one.
pub async fn open(server: &Arc<Server>) -> Result<Arc<Self>> {
let ctx = Context::new(server)?;
let db = Engine::open(ctx.clone(), maps::MAPS).await?;
Ok(Arc::new(Self {
maps: maps::open(&db)?,
db: db.clone(),
_ctx: ctx,
}))
}
#[inline]
pub fn get(&self, name: &str) -> Result<&Arc<Map>> {
self.maps
.get(name)
.ok_or_else(|| err!(Request(NotFound("column not found"))))
}
#[inline]
pub fn iter(&self) -> impl Iterator<Item = (&MapsKey, &MapsVal)> + Send + '_ {
self.maps.iter()
}
#[inline]
pub fn keys(&self) -> impl Iterator<Item = &MapsKey> + Send + '_ { self.maps.keys() }
#[inline]
#[must_use]
pub fn is_read_only(&self) -> bool { self.db.is_read_only() }
#[inline]
#[must_use]
pub fn is_secondary(&self) -> bool { self.db.is_secondary() }
}
impl Index<&str> for Database {
type Output = Arc<Map>;
fn index(&self, name: &str) -> &Self::Output {
self.maps
.get(name)
.expect("column in database does not exist")
}
}
-433
View File
@@ -1,433 +0,0 @@
use std::{cmp, collections::HashMap, convert::TryFrom};
use conduwuit::{err, utils, Config, Result};
use rocksdb::{
statistics::StatsLevel, BlockBasedOptions, Cache, DBCompactionStyle, DBCompressionType,
DBRecoveryMode, Env, LogLevel, LruCacheOptions, Options, UniversalCompactOptions,
UniversalCompactionStopStyle,
};
/// Create database-wide options suitable for opening the database. This also
/// sets our default column options in case of opening a column with the same
/// resulting value. Note that we require special per-column options on some
/// columns, therefor columns should only be opened after passing this result
/// through cf_options().
pub(crate) fn db_options(
config: &Config,
env: &mut Env,
row_cache: &Cache,
col_cache: &Cache,
) -> Result<Options> {
const DEFAULT_STATS_LEVEL: StatsLevel = if cfg!(debug_assertions) {
StatsLevel::ExceptDetailedTimers
} else {
StatsLevel::DisableAll
};
let mut opts = Options::default();
// Logging
set_logging_defaults(&mut opts, config);
// Processing
opts.set_max_background_jobs(num_threads::<i32>(config)?);
opts.set_max_subcompactions(num_threads::<u32>(config)?);
opts.set_avoid_unnecessary_blocking_io(true);
opts.set_max_file_opening_threads(0);
if config.rocksdb_compaction_prio_idle {
env.lower_thread_pool_cpu_priority();
}
// IO
opts.set_atomic_flush(true);
opts.set_manual_wal_flush(true);
opts.set_enable_pipelined_write(false);
if config.rocksdb_direct_io {
opts.set_use_direct_reads(true);
opts.set_use_direct_io_for_flush_and_compaction(true);
}
if config.rocksdb_optimize_for_spinning_disks {
// speeds up opening DB on hard drives
opts.set_skip_checking_sst_file_sizes_on_db_open(true);
opts.set_skip_stats_update_on_db_open(true);
//opts.set_max_file_opening_threads(threads.try_into().unwrap());
}
if config.rocksdb_compaction_ioprio_idle {
env.lower_thread_pool_io_priority();
}
// Blocks
let mut table_opts = table_options(config);
table_opts.set_block_cache(col_cache);
opts.set_block_based_table_factory(&table_opts);
opts.set_row_cache(row_cache);
// Buffers
opts.set_write_buffer_size(2 * 1024 * 1024);
opts.set_max_write_buffer_number(2);
opts.set_min_write_buffer_number(1);
// Files
opts.set_table_cache_num_shard_bits(7);
opts.set_max_total_wal_size(96 * 1024 * 1024);
set_level_defaults(&mut opts, config);
// Compression
set_compression_defaults(&mut opts, config);
// Misc
opts.create_if_missing(true);
opts.set_disable_auto_compactions(!config.rocksdb_compaction);
opts.set_statistics_level(match config.rocksdb_stats_level {
| 0 => StatsLevel::DisableAll,
| 1 => DEFAULT_STATS_LEVEL,
| 2 => StatsLevel::ExceptHistogramOrTimers,
| 3 => StatsLevel::ExceptTimers,
| 4 => StatsLevel::ExceptDetailedTimers,
| 5 => StatsLevel::ExceptTimeForMutex,
| 6_u8..=u8::MAX => StatsLevel::All,
});
// Default: https://github.com/facebook/rocksdb/wiki/WAL-Recovery-Modes#ktoleratecorruptedtailrecords
//
// Unclean shutdowns of a Matrix homeserver are likely to be fine when
// recovered in this manner as it's likely any lost information will be
// restored via federation.
opts.set_wal_recovery_mode(match config.rocksdb_recovery_mode {
| 0 => DBRecoveryMode::AbsoluteConsistency,
| 1 => DBRecoveryMode::TolerateCorruptedTailRecords,
| 2 => DBRecoveryMode::PointInTime,
| 3 => DBRecoveryMode::SkipAnyCorruptedRecord,
| 4_u8..=u8::MAX => unimplemented!(),
});
// <https://github.com/facebook/rocksdb/wiki/Track-WAL-in-MANIFEST>
// "We recommend to set track_and_verify_wals_in_manifest to true for
// production, it has been enabled in production for the entire database cluster
// serving the social graph for all Meta apps."
opts.set_track_and_verify_wals_in_manifest(true);
opts.set_paranoid_checks(config.rocksdb_paranoid_file_checks);
opts.set_env(env);
Ok(opts)
}
/// Adjust options for the specific column by name. Provide the result of
/// db_options() as the argument to this function and use the return value in
/// the arguments to open the specific column.
pub(crate) fn cf_options(
cfg: &Config,
name: &str,
mut opts: Options,
cache: &mut HashMap<String, Cache>,
) -> Result<Options> {
// Columns with non-default compaction options
match name {
| "backupid_algorithm"
| "backupid_etag"
| "backupkeyid_backup"
| "roomid_shortroomid"
| "shorteventid_shortstatehash"
| "shorteventid_eventid"
| "shortstatekey_statekey"
| "shortstatehash_statediff"
| "userdevicetxnid_response"
| "userfilterid_filter" => set_for_sequential_small_uc(&mut opts, cfg),
| &_ => {},
}
// Columns with non-default table/cache configs
match name {
| "shorteventid_eventid" => set_table_with_new_cache(
&mut opts,
cfg,
cache,
name,
cache_size(cfg, cfg.shorteventid_cache_capacity, 64)?,
),
| "eventid_shorteventid" => set_table_with_new_cache(
&mut opts,
cfg,
cache,
name,
cache_size(cfg, cfg.eventidshort_cache_capacity, 64)?,
),
| "eventid_pduid" => set_table_with_new_cache(
&mut opts,
cfg,
cache,
name,
cache_size(cfg, cfg.eventid_pdu_cache_capacity, 64)?,
),
| "shorteventid_authchain" => {
set_table_with_new_cache(
&mut opts,
cfg,
cache,
name,
cache_size(cfg, cfg.auth_chain_cache_capacity, 192)?,
);
},
| "shortstatekey_statekey" => set_table_with_new_cache(
&mut opts,
cfg,
cache,
name,
cache_size(cfg, cfg.shortstatekey_cache_capacity, 1024)?,
),
| "statekey_shortstatekey" => set_table_with_new_cache(
&mut opts,
cfg,
cache,
name,
cache_size(cfg, cfg.statekeyshort_cache_capacity, 1024)?,
),
| "servernameevent_data" => set_table_with_new_cache(
&mut opts,
cfg,
cache,
name,
cache_size(cfg, cfg.servernameevent_data_cache_capacity, 128)?, /* Raw average
* value size =
* 102, key
* size = 34 */
),
| "eventid_outlierpdu" => {
set_table_with_new_cache(
&mut opts,
cfg,
cache,
name,
cache_size(cfg, cfg.pdu_cache_capacity, 1536)?,
);
},
| "pduid_pdu" => {
set_table_with_shared_cache(&mut opts, cfg, cache, name, "eventid_outlierpdu");
},
| &_ => {},
}
Ok(opts)
}
fn set_logging_defaults(opts: &mut Options, config: &Config) {
let rocksdb_log_level = match config.rocksdb_log_level.as_ref() {
| "debug" => LogLevel::Debug,
| "info" => LogLevel::Info,
| "warn" => LogLevel::Warn,
| "fatal" => LogLevel::Fatal,
| _ => LogLevel::Error,
};
opts.set_log_level(rocksdb_log_level);
opts.set_max_log_file_size(config.rocksdb_max_log_file_size);
opts.set_log_file_time_to_roll(config.rocksdb_log_time_to_roll);
opts.set_keep_log_file_num(config.rocksdb_max_log_files);
opts.set_stats_dump_period_sec(0);
if config.rocksdb_log_stderr {
opts.set_stderr_logger(rocksdb_log_level, "rocksdb");
} else {
opts.set_callback_logger(rocksdb_log_level, &super::engine::handle_log);
}
}
fn set_compression_defaults(opts: &mut Options, config: &Config) {
let rocksdb_compression_algo = match config.rocksdb_compression_algo.as_ref() {
| "snappy" => DBCompressionType::Snappy,
| "zlib" => DBCompressionType::Zlib,
| "bz2" => DBCompressionType::Bz2,
| "lz4" => DBCompressionType::Lz4,
| "lz4hc" => DBCompressionType::Lz4hc,
| "none" => DBCompressionType::None,
| _ => DBCompressionType::Zstd,
};
if config.rocksdb_bottommost_compression {
opts.set_bottommost_compression_type(rocksdb_compression_algo);
opts.set_bottommost_zstd_max_train_bytes(0, true);
// -14 w_bits is only read by zlib.
opts.set_bottommost_compression_options(
-14,
config.rocksdb_bottommost_compression_level,
0,
0,
true,
);
}
// -14 w_bits is only read by zlib.
opts.set_compression_options(-14, config.rocksdb_compression_level, 0, 0);
opts.set_compression_type(rocksdb_compression_algo);
}
#[allow(dead_code)]
fn set_for_random_small_uc(opts: &mut Options, config: &Config) {
let uco = uc_options(config);
set_for_random_small(opts, config);
opts.set_universal_compaction_options(&uco);
opts.set_compaction_style(DBCompactionStyle::Universal);
}
fn set_for_sequential_small_uc(opts: &mut Options, config: &Config) {
let uco = uc_options(config);
set_for_sequential_small(opts, config);
opts.set_universal_compaction_options(&uco);
opts.set_compaction_style(DBCompactionStyle::Universal);
}
#[allow(dead_code)]
fn set_for_random_small(opts: &mut Options, config: &Config) {
set_for_random(opts, config);
opts.set_write_buffer_size(1024 * 128);
opts.set_target_file_size_base(1024 * 128);
opts.set_target_file_size_multiplier(2);
opts.set_max_bytes_for_level_base(1024 * 512);
opts.set_max_bytes_for_level_multiplier(2.0);
}
fn set_for_sequential_small(opts: &mut Options, config: &Config) {
set_for_sequential(opts, config);
opts.set_write_buffer_size(1024 * 512);
opts.set_target_file_size_base(1024 * 512);
opts.set_target_file_size_multiplier(2);
opts.set_max_bytes_for_level_base(1024 * 1024);
opts.set_max_bytes_for_level_multiplier(2.0);
}
fn set_for_random(opts: &mut Options, config: &Config) {
set_level_defaults(opts, config);
let pri = "compaction_pri=kOldestSmallestSeqFirst";
opts.set_options_from_string(pri)
.expect("set compaction priority string");
opts.set_max_bytes_for_level_base(8 * 1024 * 1024);
opts.set_max_bytes_for_level_multiplier(1.0);
opts.set_max_bytes_for_level_multiplier_additional(&[0, 1, 1, 3, 7, 15, 31]);
}
fn set_for_sequential(opts: &mut Options, config: &Config) {
set_level_defaults(opts, config);
let pri = "compaction_pri=kOldestLargestSeqFirst";
opts.set_options_from_string(pri)
.expect("set compaction priority string");
opts.set_target_file_size_base(2 * 1024 * 1024);
opts.set_target_file_size_multiplier(2);
opts.set_max_bytes_for_level_base(32 * 1024 * 1024);
opts.set_max_bytes_for_level_multiplier(1.0);
opts.set_max_bytes_for_level_multiplier_additional(&[0, 1, 1, 3, 7, 15, 31]);
}
fn set_level_defaults(opts: &mut Options, _config: &Config) {
opts.set_level_zero_file_num_compaction_trigger(2);
opts.set_target_file_size_base(1024 * 1024);
opts.set_target_file_size_multiplier(2);
opts.set_level_compaction_dynamic_level_bytes(false);
opts.set_max_bytes_for_level_base(16 * 1024 * 1024);
opts.set_max_bytes_for_level_multiplier(2.0);
opts.set_ttl(21 * 24 * 60 * 60);
}
fn uc_options(_config: &Config) -> UniversalCompactOptions {
let mut opts = UniversalCompactOptions::default();
opts.set_stop_style(UniversalCompactionStopStyle::Total);
opts.set_max_size_amplification_percent(10000);
opts.set_compression_size_percent(-1);
opts.set_size_ratio(1);
opts.set_min_merge_width(2);
opts.set_max_merge_width(16);
opts
}
fn set_table_with_new_cache(
opts: &mut Options,
config: &Config,
caches: &mut HashMap<String, Cache>,
name: &str,
size: usize,
) {
let mut cache_opts = LruCacheOptions::default();
cache_opts.set_capacity(size);
cache_opts.set_num_shard_bits(7);
let cache = Cache::new_lru_cache_opts(&cache_opts);
caches.insert(name.into(), cache);
set_table_with_shared_cache(opts, config, caches, name, name);
}
fn set_table_with_shared_cache(
opts: &mut Options,
config: &Config,
cache: &HashMap<String, Cache>,
_name: &str,
cache_name: &str,
) {
let mut table = table_options(config);
table.set_block_cache(
cache
.get(cache_name)
.expect("existing cache to share with this column"),
);
opts.set_block_based_table_factory(&table);
}
fn cache_size(config: &Config, base_size: u32, entity_size: usize) -> Result<usize> {
let ents = f64::from(base_size) * config.cache_capacity_modifier;
#[allow(clippy::as_conversions, clippy::cast_sign_loss, clippy::cast_possible_truncation)]
(ents as usize)
.checked_mul(entity_size)
.ok_or_else(|| err!(Config("cache_capacity_modifier", "Cache size is too large.")))
}
fn table_options(_config: &Config) -> BlockBasedOptions {
let mut opts = BlockBasedOptions::default();
opts.set_block_size(4 * 1024);
opts.set_metadata_block_size(4 * 1024);
opts.set_use_delta_encoding(false);
opts.set_optimize_filters_for_memory(true);
opts.set_cache_index_and_filter_blocks(true);
opts.set_pin_top_level_index_and_filter(true);
opts
}
fn num_threads<T: TryFrom<usize>>(config: &Config) -> Result<T> {
const MIN_PARALLELISM: usize = 2;
let requested = if config.rocksdb_parallelism_threads != 0 {
config.rocksdb_parallelism_threads
} else {
utils::available_parallelism()
};
utils::math::try_into::<T, usize>(cmp::max(MIN_PARALLELISM, requested))
}
+59 -49
View File
@@ -6,21 +6,22 @@ use std::{
atomic::{AtomicUsize, Ordering},
Arc, Mutex,
},
thread,
thread::JoinHandle,
};
use async_channel::{QueueStrategy, Receiver, RecvError, Sender};
use conduwuit::{
debug, debug_warn, defer, err, implement,
debug, debug_warn, defer, err, error, implement,
result::DebugInspect,
trace,
utils::sys::compute::{get_affinity, nth_core_available, set_affinity},
Result, Server,
Error, Result, Server,
};
use futures::{channel::oneshot, TryFutureExt};
use oneshot::Sender as ResultSender;
use rocksdb::Direction;
use smallvec::SmallVec;
use tokio::task::JoinSet;
use self::configure::configure;
use crate::{keyval::KeyBuf, stream, Handle, Map};
@@ -31,7 +32,7 @@ use crate::{keyval::KeyBuf, stream, Handle, Map};
pub(crate) struct Pool {
server: Arc<Server>,
queues: Vec<Sender<Cmd>>,
workers: Mutex<JoinSet<()>>,
workers: Mutex<Vec<JoinHandle<()>>>,
topology: Vec<usize>,
busy: AtomicUsize,
queued_max: AtomicUsize,
@@ -66,42 +67,45 @@ pub(crate) type BatchResult<'a> = SmallVec<[ResultHandle<'a>; BATCH_INLINE]>;
pub(crate) type ResultHandle<'a> = Result<Handle<'a>>;
const WORKER_LIMIT: (usize, usize) = (1, 1024);
const QUEUE_LIMIT: (usize, usize) = (1, 2048);
const QUEUE_LIMIT: (usize, usize) = (1, 4096);
const BATCH_INLINE: usize = 1;
const WORKER_STACK_SIZE: usize = 1_048_576;
const WORKER_NAME: &str = "conduwuit:db";
#[implement(Pool)]
pub(crate) async fn new(server: &Arc<Server>) -> Result<Arc<Self>> {
pub(crate) fn new(server: &Arc<Server>) -> Result<Arc<Self>> {
const CHAN_SCHED: (QueueStrategy, QueueStrategy) = (QueueStrategy::Fifo, QueueStrategy::Lifo);
let (total_workers, queue_sizes, topology) = configure(server);
let (senders, receivers) = queue_sizes
let (senders, receivers): (Vec<_>, Vec<_>) = queue_sizes
.into_iter()
.map(|cap| async_channel::bounded_with_queue_strategy(cap, CHAN_SCHED))
.unzip();
let pool = Arc::new(Self {
server: server.clone(),
queues: senders,
workers: JoinSet::new().into(),
workers: Vec::new().into(),
topology,
busy: AtomicUsize::default(),
queued_max: AtomicUsize::default(),
});
pool.spawn_until(receivers, total_workers).await?;
pool.spawn_until(&receivers, total_workers)?;
Ok(pool)
}
impl Drop for Pool {
fn drop(&mut self) {
debug_assert!(self.queues.iter().all(Sender::is_empty), "channel must be empty on drop");
self.close();
debug_assert!(
self.queues.iter().all(Sender::is_empty),
"channel must should not have requests queued on drop"
);
debug_assert!(
self.queues.iter().all(Sender::is_closed),
"channel should be closed on drop"
@@ -110,17 +114,10 @@ impl Drop for Pool {
}
#[implement(Pool)]
pub(crate) async fn shutdown(self: &Arc<Self>) {
self.close();
let workers = take(&mut *self.workers.lock().expect("locked"));
debug!(workers = workers.len(), "Waiting for workers to join...");
workers.join_all().await;
}
#[implement(Pool)]
#[tracing::instrument(skip_all)]
pub(crate) fn close(&self) {
let workers = take(&mut *self.workers.lock().expect("locked"));
let senders = self.queues.iter().map(Sender::sender_count).sum::<usize>();
let receivers = self
@@ -129,27 +126,40 @@ pub(crate) fn close(&self) {
.map(Sender::receiver_count)
.sum::<usize>();
debug!(
queues = self.queues.len(),
workers = self.workers.lock().expect("locked").len(),
?senders,
?receivers,
"Closing pool..."
);
for queue in &self.queues {
queue.close();
}
self.workers.lock().expect("locked").abort_all();
std::thread::yield_now();
if workers.is_empty() {
return;
}
debug!(
senders,
receivers,
queues = self.queues.len(),
workers = workers.len(),
"Closing pool. Waiting for workers to join..."
);
workers
.into_iter()
.map(JoinHandle::join)
.map(|result| result.map_err(Error::from_panic))
.enumerate()
.for_each(|(id, result)| {
match result {
| Ok(()) => trace!(?id, "worker joined"),
| Err(error) => error!(?id, "worker joined with error: {error}"),
};
});
}
#[implement(Pool)]
async fn spawn_until(self: &Arc<Self>, recv: Vec<Receiver<Cmd>>, count: usize) -> Result {
fn spawn_until(self: &Arc<Self>, recv: &[Receiver<Cmd>], count: usize) -> Result {
let mut workers = self.workers.lock().expect("locked");
while workers.len() < count {
self.spawn_one(&mut workers, &recv)?;
self.clone().spawn_one(&mut workers, recv)?;
}
Ok(())
@@ -162,23 +172,24 @@ async fn spawn_until(self: &Arc<Self>, recv: Vec<Receiver<Cmd>>, count: usize) -
skip_all,
fields(id = %workers.len())
)]
fn spawn_one(self: &Arc<Self>, workers: &mut JoinSet<()>, recv: &[Receiver<Cmd>]) -> Result {
fn spawn_one(
self: Arc<Self>,
workers: &mut Vec<JoinHandle<()>>,
recv: &[Receiver<Cmd>],
) -> Result {
debug_assert!(!self.queues.is_empty(), "Must have at least one queue");
debug_assert!(!recv.is_empty(), "Must have at least one receiver");
let id = workers.len();
let group = id.overflowing_rem(self.queues.len()).0;
let recv = recv[group].clone();
let self_ = self.clone();
#[cfg(not(tokio_unstable))]
let _abort = workers.spawn_blocking_on(move || self_.worker(id, recv), self.server.runtime());
let handle = thread::Builder::new()
.name(WORKER_NAME.into())
.stack_size(WORKER_STACK_SIZE)
.spawn(move || self.worker(id, recv))?;
#[cfg(tokio_unstable)]
let _abort = workers
.build_task()
.name("conduwuit:dbpool")
.spawn_blocking_on(move || self_.worker(id, recv), self.server.runtime());
workers.push(handle);
Ok(())
}
@@ -196,8 +207,6 @@ pub(crate) async fn execute_get(self: &Arc<Self>, mut cmd: Get) -> Result<BatchR
.map_err(|e| err!(error!("recv failed {e:?}")))
})
.await
.map(Into::into)
.map_err(Into::into)
}
#[implement(Pool)]
@@ -258,7 +267,7 @@ async fn execute(&self, queue: &Sender<Cmd>, cmd: Cmd) -> Result {
level = "debug",
skip(self, recv),
fields(
tid = ?std::thread::current().id(),
tid = ?thread::current().id(),
),
)]
fn worker(self: Arc<Self>, id: usize, recv: Receiver<Cmd>) {
@@ -277,6 +286,7 @@ fn worker_init(&self, id: usize) {
.iter()
.enumerate()
.filter(|_| self.queues.len() > 1)
.filter(|_| self.server.config.db_pool_affinity)
.filter_map(|(core_id, &queue_id)| (group == queue_id).then_some(core_id))
.filter_map(nth_core_available);
+33 -32
View File
@@ -1,14 +1,13 @@
use std::{path::PathBuf, sync::Arc};
use conduwuit::{
debug, debug_info, expected,
debug, debug_info, expected, is_equal_to,
utils::{
math::usize_from_f64,
result::LogDebugErr,
stream,
stream::{AMPLIFICATION_LIMIT, WIDTH_LIMIT},
sys::{compute::is_core_available, storage},
BoolExt,
},
Server,
};
@@ -19,39 +18,32 @@ pub(super) fn configure(server: &Arc<Server>) -> (usize, Vec<usize>, Vec<usize>)
let config = &server.config;
// This finds the block device and gathers all the properties we need.
let (device_name, device_prop) = config
.db_pool_affinity
.and_then(|| {
let path: PathBuf = config.database_path.clone();
let name = storage::name_from_path(&path).log_debug_err().ok();
let prop = storage::parallelism(&path);
name.map(|name| (name, prop))
})
.unzip();
let path: PathBuf = config.database_path.clone();
let device_name = storage::name_from_path(&path).log_debug_err().ok();
let device_prop = storage::parallelism(&path);
// The default worker count is masked-on if we didn't find better information.
let default_worker_count = device_prop
.as_ref()
.is_none_or(|prop| prop.mq.is_empty())
.then_some(config.db_pool_workers);
let default_worker_count = device_prop.mq.is_empty().then_some(config.db_pool_workers);
// Determine the worker groupings. Each indice represents a hardware queue and
// contains the number of workers which will service it.
let worker_counts: Vec<_> = device_prop
.mq
.iter()
.map(|dev| &dev.mq)
.flat_map(|mq| mq.iter())
.filter(|mq| mq.cpu_list.iter().copied().any(is_core_available))
.map(|mq| {
mq.nr_tags.unwrap_or_default().min(
config.db_pool_workers_limit.saturating_mul(
mq.cpu_list
.iter()
.filter(|&&id| is_core_available(id))
.count()
.max(1),
),
)
let shares = mq
.cpu_list
.iter()
.filter(|&&id| is_core_available(id))
.count()
.max(1);
let limit = config.db_pool_workers_limit.saturating_mul(shares);
let limit = device_prop.nr_requests.map_or(limit, |nr| nr.min(limit));
mq.nr_tags.unwrap_or(WORKER_LIMIT.0).min(limit)
})
.chain(default_worker_count)
.collect();
@@ -72,9 +64,8 @@ pub(super) fn configure(server: &Arc<Server>) -> (usize, Vec<usize>, Vec<usize>)
// going on because cpu's which are not available to the process are filtered
// out, similar to the worker_counts.
let topology = device_prop
.mq
.iter()
.map(|dev| &dev.mq)
.flat_map(|mq| mq.iter())
.fold(vec![0; 128], |mut topology, mq| {
mq.cpu_list
.iter()
@@ -89,9 +80,12 @@ pub(super) fn configure(server: &Arc<Server>) -> (usize, Vec<usize>, Vec<usize>)
// Regardless of the capacity of all queues we establish some limit on the total
// number of workers; this is hopefully hinted by nr_requests.
let max_workers = device_prop
.as_ref()
.and_then(|prop| prop.nr_requests)
.unwrap_or(WORKER_LIMIT.1);
.mq
.iter()
.filter_map(|mq| mq.nr_tags)
.chain(default_worker_count)
.fold(0_usize, usize::saturating_add)
.clamp(WORKER_LIMIT.0, WORKER_LIMIT.1);
// Determine the final worker count which we'll be spawning.
let total_workers = worker_counts
@@ -102,7 +96,7 @@ pub(super) fn configure(server: &Arc<Server>) -> (usize, Vec<usize>, Vec<usize>)
// After computing all of the above we can update the global automatic stream
// width, hopefully with a better value tailored to this system.
if config.stream_width_scale > 0.0 {
let num_queues = queue_sizes.len();
let num_queues = queue_sizes.len().max(1);
update_stream_width(server, num_queues, total_workers);
}
@@ -117,6 +111,13 @@ pub(super) fn configure(server: &Arc<Server>) -> (usize, Vec<usize>, Vec<usize>)
"Frontend topology",
);
assert!(total_workers > 0, "some workers expected");
assert!(!queue_sizes.is_empty(), "some queues expected");
assert!(
!queue_sizes.iter().copied().any(is_equal_to!(0)),
"positive queue sizes expected"
);
(total_workers, queue_sizes, topology)
}
+2
View File
@@ -22,7 +22,9 @@ where
Ok(buf)
}
/// Serialize T into Writer W
#[inline]
#[cfg_attr(unabridged, tracing::instrument(level = "trace", skip_all))]
pub fn serialize<'a, W, T>(out: &'a mut W, val: T) -> Result<&'a [u8]>
where
W: Write + AsRef<[u8]> + 'a,
+17 -4
View File
@@ -6,14 +6,14 @@ mod keys_rev;
use std::sync::Arc;
use conduwuit::{utils::exchange, Result};
use rocksdb::{ColumnFamily, DBRawIteratorWithThreadMode, ReadOptions};
use rocksdb::{DBRawIteratorWithThreadMode, ReadOptions};
pub(crate) use self::{items::Items, items_rev::ItemsRev, keys::Keys, keys_rev::KeysRev};
use crate::{
engine::Db,
keyval::{Key, KeyVal, Val},
util::{is_incomplete, map_err},
Engine, Slice,
Map, Slice,
};
pub(crate) struct State<'a> {
@@ -29,12 +29,14 @@ pub(crate) trait Cursor<'a, T> {
fn seek(&mut self);
#[inline]
fn get(&self) -> Option<Result<T>> {
self.fetch()
.map(Ok)
.or_else(|| self.state().status().map(map_err).map(Err))
}
#[inline]
fn seek_and_get(&mut self) -> Option<Result<T>> {
self.seek();
self.get()
@@ -45,14 +47,17 @@ type Inner<'a> = DBRawIteratorWithThreadMode<'a, Db>;
type From<'a> = Option<Key<'a>>;
impl<'a> State<'a> {
pub(super) fn new(db: &'a Arc<Engine>, cf: &'a Arc<ColumnFamily>, opts: ReadOptions) -> Self {
#[inline]
pub(super) fn new(map: &'a Arc<Map>, opts: ReadOptions) -> Self {
Self {
inner: db.db.raw_iterator_cf_opt(&**cf, opts),
inner: map.db().db.raw_iterator_cf_opt(&map.cf(), opts),
init: true,
seek: false,
}
}
#[inline]
#[tracing::instrument(level = "trace", skip_all)]
pub(super) fn init_fwd(mut self, from: From<'_>) -> Self {
debug_assert!(self.init, "init must be set to make this call");
debug_assert!(!self.seek, "seek must not be set to make this call");
@@ -67,6 +72,8 @@ impl<'a> State<'a> {
self
}
#[inline]
#[tracing::instrument(level = "trace", skip_all)]
pub(super) fn init_rev(mut self, from: From<'_>) -> Self {
debug_assert!(self.init, "init must be set to make this call");
debug_assert!(!self.seek, "seek must not be set to make this call");
@@ -82,6 +89,7 @@ impl<'a> State<'a> {
}
#[inline]
#[cfg_attr(unabridged, tracing::instrument(level = "trace", skip_all))]
pub(super) fn seek_fwd(&mut self) {
if !exchange(&mut self.init, false) {
self.inner.next();
@@ -91,6 +99,7 @@ impl<'a> State<'a> {
}
#[inline]
#[cfg_attr(unabridged, tracing::instrument(level = "trace", skip_all))]
pub(super) fn seek_rev(&mut self) {
if !exchange(&mut self.init, false) {
self.inner.prev();
@@ -103,12 +112,16 @@ impl<'a> State<'a> {
matches!(self.status(), Some(e) if is_incomplete(&e))
}
#[inline]
fn fetch_key(&self) -> Option<Key<'_>> { self.inner.key().map(Key::from) }
#[inline]
fn _fetch_val(&self) -> Option<Val<'_>> { self.inner.value().map(Val::from) }
#[inline]
fn fetch(&self) -> Option<KeyVal<'_>> { self.inner.item().map(KeyVal::from) }
#[inline]
pub(super) fn status(&self) -> Option<rocksdb::Error> { self.inner.status().err() }
#[inline]
+3
View File
@@ -15,12 +15,15 @@ pub(crate) struct Items<'a> {
}
impl<'a> From<State<'a>> for Items<'a> {
#[inline]
fn from(state: State<'a>) -> Self { Self { state } }
}
impl<'a> Cursor<'a, KeyVal<'a>> for Items<'a> {
#[inline]
fn state(&self) -> &State<'a> { &self.state }
#[inline]
fn fetch(&self) -> Option<KeyVal<'a>> { self.state.fetch().map(keyval_longevity) }
#[inline]
+3
View File
@@ -15,12 +15,15 @@ pub(crate) struct ItemsRev<'a> {
}
impl<'a> From<State<'a>> for ItemsRev<'a> {
#[inline]
fn from(state: State<'a>) -> Self { Self { state } }
}
impl<'a> Cursor<'a, KeyVal<'a>> for ItemsRev<'a> {
#[inline]
fn state(&self) -> &State<'a> { &self.state }
#[inline]
fn fetch(&self) -> Option<KeyVal<'a>> { self.state.fetch().map(keyval_longevity) }
#[inline]
+2
View File
@@ -15,10 +15,12 @@ pub(crate) struct Keys<'a> {
}
impl<'a> From<State<'a>> for Keys<'a> {
#[inline]
fn from(state: State<'a>) -> Self { Self { state } }
}
impl<'a> Cursor<'a, Key<'a>> for Keys<'a> {
#[inline]
fn state(&self) -> &State<'a> { &self.state }
#[inline]
+2
View File
@@ -15,10 +15,12 @@ pub(crate) struct KeysRev<'a> {
}
impl<'a> From<State<'a>> for KeysRev<'a> {
#[inline]
fn from(state: State<'a>) -> Self { Self { state } }
}
impl<'a> Cursor<'a, Key<'a>> for KeysRev<'a> {
#[inline]
fn state(&self) -> &State<'a> { &self.state }
#[inline]
+3 -4
View File
@@ -1,4 +1,3 @@
#![cfg(test)]
#![allow(clippy::needless_borrows_for_generic_args)]
use std::fmt::Debug;
@@ -14,7 +13,7 @@ use crate::{
};
#[test]
#[should_panic(expected = "serializing string at the top-level")]
#[cfg_attr(debug_assertions, should_panic(expected = "serializing string at the top-level"))]
fn ser_str() {
let user_id: &UserId = "@user:example.com".try_into().unwrap();
let s = serialize_to_vec(&user_id).expect("failed to serialize user_id");
@@ -139,7 +138,7 @@ fn ser_json_macro() {
}
#[test]
#[should_panic(expected = "serializing string at the top-level")]
#[cfg_attr(debug_assertions, should_panic(expected = "serializing string at the top-level"))]
fn ser_json_raw() {
use conduwuit::ruma::api::client::filter::FilterDefinition;
@@ -156,7 +155,7 @@ fn ser_json_raw() {
}
#[test]
#[should_panic(expected = "you can skip serialization instead")]
#[cfg_attr(debug_assertions, should_panic(expected = "you can skip serialization instead"))]
fn ser_json_raw_json() {
use conduwuit::ruma::api::client::filter::FilterDefinition;
+2 -2
View File
@@ -9,7 +9,7 @@ use syn::{
};
use crate::{
utils::{get_simple_settings, is_cargo_build},
utils::{get_simple_settings, is_cargo_build, is_cargo_test},
Result,
};
@@ -17,7 +17,7 @@ const UNDOCUMENTED: &str = "# This item is undocumented. Please contribute docum
#[allow(clippy::needless_pass_by_value)]
pub(super) fn example_generator(input: ItemStruct, args: &[Meta]) -> Result<TokenStream> {
if is_cargo_build() {
if is_cargo_build() && !is_cargo_test() {
generate_example(&input, args)?;
}
+2
View File
@@ -32,6 +32,8 @@ pub(crate) fn is_cargo_build() -> bool {
.is_some()
}
pub(crate) fn is_cargo_test() -> bool { std::env::args().any(|flag| flag == "--test") }
pub(crate) fn get_named_generics(args: &[Meta], name: &str) -> Result<Generics> {
const DEFAULT: &str = "<>";
+11
View File
@@ -41,8 +41,10 @@ default = [
"gzip_compression",
"io_uring",
"jemalloc",
"media_thumbnail",
"release_max_log_level",
"systemd",
"url_preview",
"zstd_compression",
]
@@ -83,6 +85,9 @@ jemalloc_prof = [
jemalloc_stats = [
"conduwuit-core/jemalloc_stats",
]
media_thumbnail = [
"conduwuit-service/media_thumbnail",
]
perf_measurements = [
"dep:opentelemetry",
"dep:tracing-flame",
@@ -121,12 +126,18 @@ tokio_console = [
"dep:console-subscriber",
"tokio/tracing",
]
url_preview = [
"conduwuit-service/url_preview",
]
zstd_compression = [
"conduwuit-api/zstd_compression",
"conduwuit-core/zstd_compression",
"conduwuit-database/zstd_compression",
"conduwuit-router/zstd_compression",
]
conduwuit_mods = [
"conduwuit-core/conduwuit_mods",
]
[dependencies]
conduwuit-admin.workspace = true
+2 -2
View File
@@ -37,7 +37,7 @@ fn main() -> Result<(), Error> {
/// Operate the server normally in release-mode static builds. This will start,
/// run and stop the server within the asynchronous runtime.
#[cfg(not(conduwuit_mods))]
#[cfg(any(not(conduwuit_mods), not(feature = "conduwuit_mods")))]
#[tracing::instrument(
name = "main",
parent = None,
@@ -89,7 +89,7 @@ async fn async_main(server: &Arc<Server>) -> Result<(), Error> {
/// Operate the server in developer-mode dynamic builds. This will start, run,
/// and hot-reload portions of the server as-needed before returning for an
/// actual shutdown. This is not available in release-mode or static builds.
#[cfg(conduwuit_mods)]
#[cfg(all(conduwuit_mods, feature = "conduwuit_mods"))]
async fn async_main(server: &Arc<Server>) -> Result<(), Error> {
let mut starts = true;
let mut reloads = true;
+1 -1
View File
@@ -1,4 +1,4 @@
#![cfg(conduwuit_mods)]
#![cfg(all(conduwuit_mods, feature = "conduwuit_mods"))]
#[unsafe(no_link)]
extern crate conduwuit_service;
+1 -1
View File
@@ -19,7 +19,7 @@ use crate::clap::Args;
const WORKER_NAME: &str = "conduwuit:worker";
const WORKER_MIN: usize = 2;
const WORKER_KEEPALIVE: u64 = 36;
const MAX_BLOCKING_THREADS: usize = 2048;
const MAX_BLOCKING_THREADS: usize = 1024;
static WORKER_AFFINITY: OnceLock<bool> = OnceLock::new();
+2 -2
View File
@@ -23,7 +23,7 @@ pub(crate) struct Server {
#[cfg(feature = "sentry_telemetry")]
_sentry_guard: Option<::sentry::ClientInitGuard>,
#[cfg(conduwuit_mods)]
#[cfg(all(conduwuit_mods, feature = "conduwuit_mods"))]
// Module instances; TODO: move to mods::loaded mgmt vector
pub(crate) mods: tokio::sync::RwLock<Vec<conduwuit::mods::Module>>,
}
@@ -75,7 +75,7 @@ impl Server {
#[cfg(feature = "sentry_telemetry")]
_sentry_guard: sentry_guard,
#[cfg(conduwuit_mods)]
#[cfg(all(conduwuit_mods, feature = "conduwuit_mods"))]
mods: tokio::sync::RwLock::new(Vec::new()),
}))
}
+1 -1
View File
@@ -12,7 +12,7 @@ pub(super) async fn signal(server: Arc<Server>) {
use unix::SignalKind;
const CONSOLE: bool = cfg!(feature = "console");
const RELOADING: bool = cfg!(all(conduwuit_mods, not(CONSOLE)));
const RELOADING: bool = cfg!(all(conduwuit_mods, feature = "conduwuit_mods", not(CONSOLE)));
let mut quit = unix::signal(SignalKind::quit()).expect("SIGQUIT handler");
let mut term = unix::signal(SignalKind::terminate()).expect("SIGTERM handler");
+1 -1
View File
@@ -80,7 +80,7 @@ tower.workspace = true
tower-http.workspace = true
tracing.workspace = true
[target.'cfg(unix)'.dependencies]
[target.'cfg(all(unix, target_os = "linux"))'.dependencies]
sd-notify.workspace = true
sd-notify.optional = true
+8 -8
View File
@@ -3,7 +3,7 @@ extern crate conduwuit_core as conduwuit;
extern crate conduwuit_service as service;
use std::{
sync::{atomic::Ordering, Arc},
sync::{atomic::Ordering, Arc, Weak},
time::Duration,
};
@@ -63,7 +63,7 @@ pub(crate) async fn start(server: Arc<Server>) -> Result<Arc<Services>> {
let services = Services::build(server).await?.start().await?;
#[cfg(feature = "systemd")]
#[cfg(all(feature = "systemd", target_os = "linux"))]
sd_notify::notify(true, &[sd_notify::NotifyState::Ready])
.expect("failed to notify systemd of ready state");
@@ -92,14 +92,14 @@ pub(crate) async fn stop(services: Arc<Services>) -> Result<()> {
);
}
// The db threadpool requires async join if we use tokio/spawn_blocking to
// manage the threads. Without async-drop we have to wait here; for symmetry
// with Services construction it can't be done in services.stop().
if let Some(db) = db.upgrade() {
db.db.shutdown_pool().await;
if Weak::strong_count(&db) > 0 {
debug_error!(
"{} dangling references to Database after shutdown",
Weak::strong_count(&db)
);
}
#[cfg(feature = "systemd")]
#[cfg(all(feature = "systemd", target_os = "linux"))]
sd_notify::notify(true, &[sd_notify::NotifyState::Stopping])
.expect("failed to notify systemd of stopping state");
+11 -2
View File
@@ -28,8 +28,8 @@ element_hacks = []
gzip_compression = [
"reqwest/gzip",
]
zstd_compression = [
"reqwest/zstd",
media_thumbnail = [
"dep:image",
]
release_max_log_level = [
"tracing/max_level_trace",
@@ -37,6 +37,13 @@ release_max_log_level = [
"log/max_level_trace",
"log/release_max_level_info",
]
url_preview = [
"dep:image",
"dep:webpage",
]
zstd_compression = [
"reqwest/zstd",
]
[dependencies]
arrayvec.workspace = true
@@ -51,6 +58,7 @@ futures.workspace = true
hickory-resolver.workspace = true
http.workspace = true
image.workspace = true
image.optional = true
ipaddress.workspace = true
itertools.workspace = true
jsonwebtoken.workspace = true
@@ -73,6 +81,7 @@ tokio.workspace = true
tracing.workspace = true
url.workspace = true
webpage.workspace = true
webpage.optional = true
[lints]
workspace = true
+8
View File
@@ -58,6 +58,7 @@ impl Console {
pub async fn close(self: &Arc<Self>) {
self.interrupt();
let Some(worker_join) = self.worker_join.lock().expect("locked").take() else {
return;
};
@@ -92,6 +93,12 @@ impl Console {
#[tracing::instrument(skip_all, name = "console", level = "trace")]
async fn worker(self: Arc<Self>) {
debug!("session starting");
self.output
.print_inline(&format!("**conduwuit {}** admin console\n", conduwuit::version()));
self.output
.print_text("\"help\" for help, ^D to exit the console, ^\\ to stop the server\n");
while self.server.running() {
match self.readline().await {
| Ok(event) => match event {
@@ -147,6 +154,7 @@ impl Console {
self.add_history(line.clone());
let future = self.clone().process(line);
let (abort, abort_reg) = AbortHandle::new_pair();
let future = Abortable::new(future, abort_reg);
_ = self.command_abort.lock().expect("locked").insert(abort);
-3
View File
@@ -11,7 +11,6 @@ use conduwuit::{error, utils::bytes::pretty, Config, Result};
use data::Data;
use regex::RegexSet;
use ruma::{OwnedEventId, OwnedRoomAliasId, OwnedServerName, OwnedUserId, ServerName, UserId};
use tokio::sync::Mutex;
use crate::service;
@@ -21,7 +20,6 @@ pub struct Service {
pub config: Config,
jwt_decoding_key: Option<jsonwebtoken::DecodingKey>,
pub bad_event_ratelimiter: Arc<RwLock<HashMap<OwnedEventId, RateLimitState>>>,
pub stateres_mutex: Arc<Mutex<()>>,
pub server_user: OwnedUserId,
pub admin_alias: OwnedRoomAliasId,
pub turn_secret: String,
@@ -70,7 +68,6 @@ impl crate::Service for Service {
config: config.clone(),
jwt_decoding_key,
bad_event_ratelimiter: Arc::new(RwLock::new(HashMap::new())),
stateres_mutex: Arc::new(Mutex::new(())),
admin_alias: OwnedRoomAliasId::try_from(format!("#admins:{}", &config.server_name))
.expect("#admins:server_name is valid alias name"),
server_user: UserId::parse_with_server_name(
+14 -23
View File
@@ -3,7 +3,7 @@ use std::{sync::Arc, time::Duration};
use conduwuit::{
debug, debug_info, err,
utils::{str_from_bytes, stream::TryIgnore, string_from_bytes, ReadyExt},
Err, Error, Result,
Err, Result,
};
use database::{Database, Interfix, Map};
use futures::StreamExt;
@@ -123,30 +123,21 @@ impl Data {
let content_type = parts
.next()
.map(|bytes| {
string_from_bytes(bytes).map_err(|_| {
Error::bad_database("Content type in mediaid_file is invalid unicode.")
})
})
.transpose()?;
.map(string_from_bytes)
.transpose()
.map_err(|e| err!(Database(error!(?mxc, "Content-type is invalid: {e}"))))?;
let content_disposition_bytes = parts
let content_disposition = parts
.next()
.ok_or_else(|| Error::bad_database("Media ID in db is invalid."))?;
let content_disposition = if content_disposition_bytes.is_empty() {
None
} else {
Some(
string_from_bytes(content_disposition_bytes)
.map_err(|_| {
Error::bad_database(
"Content Disposition in mediaid_file is invalid unicode.",
)
})?
.parse()?,
)
};
.map(Some)
.ok_or_else(|| err!(Database(error!(?mxc, "Media ID in db is invalid."))))?
.filter(|bytes| !bytes.is_empty())
.map(string_from_bytes)
.transpose()
.map_err(|e| err!(Database(error!(?mxc, "Content-type is invalid: {e}"))))?
.as_deref()
.map(str::parse)
.transpose()?;
Ok(Metadata { content_disposition, content_type, key })
}
+59 -34
View File
@@ -1,15 +1,19 @@
use std::{io::Cursor, time::SystemTime};
//! URL Previews
//!
//! This functionality is gated by 'url_preview', but not at the unit level for
//! historical and simplicity reasons. Instead the feature gates the inclusion
//! of dependencies and nulls out results through the existing interface when
//! not featured.
use conduwuit::{debug, utils, Err, Result};
use std::time::SystemTime;
use conduwuit::{debug, Err, Result};
use conduwuit_core::implement;
use image::ImageReader as ImgReader;
use ipaddress::IPAddress;
use ruma::Mxc;
use serde::Serialize;
use url::Url;
use webpage::HTML;
use super::{Service, MXC_LENGTH};
use super::Service;
#[derive(Serialize, Default)]
pub struct UrlPreviewData {
@@ -41,34 +45,6 @@ pub async fn set_url_preview(&self, url: &str, data: &UrlPreviewData) -> Result<
self.db.set_url_preview(url, data, now)
}
#[implement(Service)]
pub async fn download_image(&self, url: &str) -> Result<UrlPreviewData> {
let client = &self.services.client.url_preview;
let image = client.get(url).send().await?.bytes().await?;
let mxc = Mxc {
server_name: self.services.globals.server_name(),
media_id: &utils::random_string(MXC_LENGTH),
};
self.create(&mxc, None, None, None, &image).await?;
let (width, height) = match ImgReader::new(Cursor::new(&image)).with_guessed_format() {
| Err(_) => (None, None),
| Ok(reader) => match reader.into_dimensions() {
| Err(_) => (None, None),
| Ok((width, height)) => (Some(width), Some(height)),
},
};
Ok(UrlPreviewData {
image: Some(mxc.to_string()),
image_size: Some(image.len()),
image_width: width,
image_height: height,
..Default::default()
})
}
#[implement(Service)]
pub async fn get_url_preview(&self, url: &Url) -> Result<UrlPreviewData> {
if let Ok(preview) = self.db.get_url_preview(url.as_str()).await {
@@ -121,8 +97,51 @@ async fn request_url_preview(&self, url: &Url) -> Result<UrlPreviewData> {
Ok(data)
}
#[cfg(feature = "url_preview")]
#[implement(Service)]
pub async fn download_image(&self, url: &str) -> Result<UrlPreviewData> {
use conduwuit::utils::random_string;
use image::ImageReader;
use ruma::Mxc;
let image = self.services.client.url_preview.get(url).send().await?;
let image = image.bytes().await?;
let mxc = Mxc {
server_name: self.services.globals.server_name(),
media_id: &random_string(super::MXC_LENGTH),
};
self.create(&mxc, None, None, None, &image).await?;
let cursor = std::io::Cursor::new(&image);
let (width, height) = match ImageReader::new(cursor).with_guessed_format() {
| Err(_) => (None, None),
| Ok(reader) => match reader.into_dimensions() {
| Err(_) => (None, None),
| Ok((width, height)) => (Some(width), Some(height)),
},
};
Ok(UrlPreviewData {
image: Some(mxc.to_string()),
image_size: Some(image.len()),
image_width: width,
image_height: height,
..Default::default()
})
}
#[cfg(not(feature = "url_preview"))]
#[implement(Service)]
pub async fn download_image(&self, _url: &str) -> Result<UrlPreviewData> {
Err!(FeatureDisabled("url_preview"))
}
#[cfg(feature = "url_preview")]
#[implement(Service)]
async fn download_html(&self, url: &str) -> Result<UrlPreviewData> {
use webpage::HTML;
let client = &self.services.client.url_preview;
let mut response = client.get(url).send().await?;
@@ -159,6 +178,12 @@ async fn download_html(&self, url: &str) -> Result<UrlPreviewData> {
Ok(data)
}
#[cfg(not(feature = "url_preview"))]
#[implement(Service)]
async fn download_html(&self, _url: &str) -> Result<UrlPreviewData> {
Err!(FeatureDisabled("url_preview"))
}
#[implement(Service)]
pub fn url_preview_allowed(&self, url: &Url) -> bool {
if ["http", "https"]
+90 -60
View File
@@ -1,7 +1,13 @@
use std::{cmp, io::Cursor, num::Saturating as Sat};
//! Media Thumbnails
//!
//! This functionality is gated by 'media_thumbnail', but not at the unit level
//! for historical and simplicity reasons. Instead the feature gates the
//! inclusion of dependencies and nulls out results using the existing interface
//! when not featured.
use conduwuit::{checked, err, Result};
use image::{imageops::FilterType, DynamicImage};
use std::{cmp, num::Saturating as Sat};
use conduwuit::{checked, err, implement, Result};
use ruma::{http_headers::ContentDisposition, media::Method, Mxc, UInt, UserId};
use tokio::{
fs,
@@ -67,65 +73,89 @@ impl super::Service {
Ok(None)
}
}
/// Using saved thumbnail
#[tracing::instrument(skip(self), name = "saved", level = "debug")]
async fn get_thumbnail_saved(&self, data: Metadata) -> Result<Option<FileMeta>> {
let mut content = Vec::new();
let path = self.get_media_file(&data.key);
fs::File::open(path)
.await?
.read_to_end(&mut content)
.await?;
Ok(Some(into_filemeta(data, content)))
}
/// Generate a thumbnail
#[tracing::instrument(skip(self), name = "generate", level = "debug")]
async fn get_thumbnail_generate(
&self,
mxc: &Mxc<'_>,
dim: &Dim,
data: Metadata,
) -> Result<Option<FileMeta>> {
let mut content = Vec::new();
let path = self.get_media_file(&data.key);
fs::File::open(path)
.await?
.read_to_end(&mut content)
.await?;
let Ok(image) = image::load_from_memory(&content) else {
// Couldn't parse file to generate thumbnail, send original
return Ok(Some(into_filemeta(data, content)));
};
if dim.width > image.width() || dim.height > image.height() {
return Ok(Some(into_filemeta(data, content)));
}
let mut thumbnail_bytes = Vec::new();
let thumbnail = thumbnail_generate(&image, dim)?;
thumbnail.write_to(&mut Cursor::new(&mut thumbnail_bytes), image::ImageFormat::Png)?;
// Save thumbnail in database so we don't have to generate it again next time
let thumbnail_key = self.db.create_file_metadata(
mxc,
None,
dim,
data.content_disposition.as_ref(),
data.content_type.as_deref(),
)?;
let mut f = self.create_media_file(&thumbnail_key).await?;
f.write_all(&thumbnail_bytes).await?;
Ok(Some(into_filemeta(data, thumbnail_bytes)))
}
}
fn thumbnail_generate(image: &DynamicImage, requested: &Dim) -> Result<DynamicImage> {
/// Using saved thumbnail
#[implement(super::Service)]
#[tracing::instrument(name = "saved", level = "debug", skip(self, data))]
async fn get_thumbnail_saved(&self, data: Metadata) -> Result<Option<FileMeta>> {
let mut content = Vec::new();
let path = self.get_media_file(&data.key);
fs::File::open(path)
.await?
.read_to_end(&mut content)
.await?;
Ok(Some(into_filemeta(data, content)))
}
/// Generate a thumbnail
#[cfg(feature = "media_thumbnail")]
#[implement(super::Service)]
#[tracing::instrument(name = "generate", level = "debug", skip(self, data))]
async fn get_thumbnail_generate(
&self,
mxc: &Mxc<'_>,
dim: &Dim,
data: Metadata,
) -> Result<Option<FileMeta>> {
let mut content = Vec::new();
let path = self.get_media_file(&data.key);
fs::File::open(path)
.await?
.read_to_end(&mut content)
.await?;
let Ok(image) = image::load_from_memory(&content) else {
// Couldn't parse file to generate thumbnail, send original
return Ok(Some(into_filemeta(data, content)));
};
if dim.width > image.width() || dim.height > image.height() {
return Ok(Some(into_filemeta(data, content)));
}
let mut thumbnail_bytes = Vec::new();
let thumbnail = thumbnail_generate(&image, dim)?;
let mut cursor = std::io::Cursor::new(&mut thumbnail_bytes);
thumbnail
.write_to(&mut cursor, image::ImageFormat::Png)
.map_err(|error| err!(error!(?error, "Error writing PNG thumbnail.")))?;
// Save thumbnail in database so we don't have to generate it again next time
let thumbnail_key = self.db.create_file_metadata(
mxc,
None,
dim,
data.content_disposition.as_ref(),
data.content_type.as_deref(),
)?;
let mut f = self.create_media_file(&thumbnail_key).await?;
f.write_all(&thumbnail_bytes).await?;
Ok(Some(into_filemeta(data, thumbnail_bytes)))
}
#[cfg(not(feature = "media_thumbnail"))]
#[implement(super::Service)]
#[tracing::instrument(name = "fallback", level = "debug", skip_all)]
async fn get_thumbnail_generate(
&self,
_mxc: &Mxc<'_>,
_dim: &Dim,
data: Metadata,
) -> Result<Option<FileMeta>> {
self.get_thumbnail_saved(data).await
}
#[cfg(feature = "media_thumbnail")]
fn thumbnail_generate(
image: &image::DynamicImage,
requested: &Dim,
) -> Result<image::DynamicImage> {
use image::imageops::FilterType;
let thumbnail = if !requested.crop() {
let Dim { width, height, .. } = requested.scaled(&Dim {
width: image.width(),
+4 -4
View File
@@ -379,7 +379,7 @@ async fn fix_bad_double_separator_in_state_cache(services: &Services) -> Result<
})
.await;
db.db.cleanup()?;
db.db.sort()?;
db["global"].insert(b"fix_bad_double_separator_in_state_cache", []);
info!("Finished fixing");
@@ -465,7 +465,7 @@ async fn retroactively_fix_bad_data_from_roomuserid_joined(services: &Services)
.await;
}
db.db.cleanup()?;
db.db.sort()?;
db["global"].insert(b"retroactively_fix_bad_data_from_roomuserid_joined", []);
info!("Finished fixing");
@@ -511,7 +511,7 @@ async fn fix_referencedevents_missing_sep(services: &Services) -> Result {
info!(?total, ?fixed, "Fixed missing record separators in 'referencedevents'.");
db["global"].insert(b"fix_referencedevents_missing_sep", []);
db.db.cleanup()
db.db.sort()
}
async fn fix_readreceiptid_readreceipt_duplicates(services: &Services) -> Result {
@@ -561,5 +561,5 @@ async fn fix_readreceiptid_readreceipt_duplicates(services: &Services) -> Result
info!(?total, ?fixed, "Fixed undeleted entries in readreceiptid_readreceipt.");
db["global"].insert(b"fix_readreceiptid_readreceipt_duplicates", []);
db.db.cleanup()
db.db.sort()
}
@@ -15,8 +15,12 @@ use ruma::{
use super::check_room_id;
#[implement(super::Service)]
#[tracing::instrument(
level = "warn",
skip_all,
fields(%origin),
)]
#[allow(clippy::type_complexity)]
#[tracing::instrument(skip_all)]
pub(super) async fn fetch_prev(
&self,
origin: &ServerName,
@@ -1,6 +1,6 @@
use std::collections::{hash_map, HashMap};
use conduwuit::{debug, implement, warn, Err, Error, PduEvent, Result};
use conduwuit::{debug, debug_warn, implement, Err, Error, PduEvent, Result};
use futures::FutureExt;
use ruma::{
api::federation::event::get_room_state_ids, events::StateEventType, EventId, OwnedEventId,
@@ -13,7 +13,11 @@ use crate::rooms::short::ShortStateKey;
/// server's response to some extend (sic), but we still do a lot of checks
/// on the events
#[implement(super::Service)]
#[tracing::instrument(skip(self, create_event, room_version_id))]
#[tracing::instrument(
level = "warn",
skip_all,
fields(%origin),
)]
pub(super) async fn fetch_state(
&self,
origin: &ServerName,
@@ -22,7 +26,6 @@ pub(super) async fn fetch_state(
room_version_id: &RoomVersionId,
event_id: &EventId,
) -> Result<Option<HashMap<u64, OwnedEventId>>> {
debug!("Fetching state ids");
let res = self
.services
.sending
@@ -31,7 +34,7 @@ pub(super) async fn fetch_state(
event_id: event_id.to_owned(),
})
.await
.inspect_err(|e| warn!("Fetching state for event failed: {e}"))?;
.inspect_err(|e| debug_warn!("Fetching state for event failed: {e}"))?;
debug!("Fetching state events");
let state_vec = self
@@ -39,7 +39,12 @@ use crate::rooms::timeline::RawPduId;
/// 14. Check if the event passes auth based on the "current state" of the room,
/// if not soft fail it
#[implement(super::Service)]
#[tracing::instrument(skip(self, origin, value, is_timeline_event), name = "pdu")]
#[tracing::instrument(
name = "pdu",
level = "warn",
skip_all,
fields(%room_id, %event_id),
)]
pub async fn handle_incoming_pdu<'a>(
&self,
origin: &'a ServerName,
@@ -13,8 +13,10 @@ use ruma::{CanonicalJsonValue, EventId, OwnedEventId, RoomId, ServerName};
#[allow(clippy::type_complexity)]
#[allow(clippy::too_many_arguments)]
#[tracing::instrument(
skip(self, origin, event_id, room_id, eventid_info, create_event, first_pdu_in_room),
name = "prev"
name = "prev",
level = "warn",
skip_all,
fields(%prev_id),
)]
pub(super) async fn handle_prev_pdu<'a>(
&self,
@@ -26,7 +28,7 @@ pub(super) async fn handle_prev_pdu<'a>(
(Arc<PduEvent>, BTreeMap<String, CanonicalJsonValue>),
>,
create_event: &Arc<PduEvent>,
first_pdu_in_room: &Arc<PduEvent>,
first_pdu_in_room: &PduEvent,
prev_id: &EventId,
) -> Result {
// Check for disabled again because it might have changed
@@ -116,9 +116,6 @@ pub async fn state_resolution(
state_sets: &[StateMap<OwnedEventId>],
auth_chain_sets: &[HashSet<OwnedEventId>],
) -> Result<StateMap<OwnedEventId>> {
//TODO: ???
let _lock = self.services.globals.stateres_mutex.lock();
state_res::resolve(
room_version,
state_sets.iter(),
+1
View File
@@ -155,6 +155,7 @@ where
}
let content = ReceiptEventContent::from_iter(json);
conduwuit::trace!(?content);
Raw::from_json(
serde_json::value::to_raw_value(&SyncEphemeralRoomEvent { content })
.expect("received valid json"),
+32 -17
View File
@@ -95,7 +95,16 @@ impl crate::Service for Service {
impl Service {
/// Update current membership data.
#[tracing::instrument(skip(self, last_state))]
#[tracing::instrument(
level = "debug",
skip_all,
fields(
%room_id,
%user_id,
%sender,
?membership_event,
),
)]
#[allow(clippy::too_many_arguments)]
pub async fn update_membership(
&self,
@@ -265,7 +274,7 @@ impl Service {
Ok(())
}
#[tracing::instrument(skip(self, room_id, appservice), level = "debug")]
#[tracing::instrument(level = "trace", skip_all)]
pub async fn appservice_in_room(
&self,
room_id: &RoomId,
@@ -383,7 +392,7 @@ impl Service {
.map(|(_, server): (Ignore, &ServerName)| server)
}
#[tracing::instrument(skip(self), level = "debug")]
#[tracing::instrument(skip(self), level = "trace")]
pub async fn server_in_room<'a>(
&'a self,
server: &'a ServerName,
@@ -409,7 +418,7 @@ impl Service {
}
/// Returns true if server can see user by sharing at least one room.
#[tracing::instrument(skip(self), level = "debug")]
#[tracing::instrument(skip(self), level = "trace")]
pub async fn server_sees_user(&self, server: &ServerName, user_id: &UserId) -> bool {
self.server_rooms(server)
.any(|room_id| self.is_joined(user_id, room_id))
@@ -417,7 +426,7 @@ impl Service {
}
/// Returns true if user_a and user_b share at least one room.
#[tracing::instrument(skip(self), level = "debug")]
#[tracing::instrument(skip(self), level = "trace")]
pub async fn user_sees_user(&self, user_a: &UserId, user_b: &UserId) -> bool {
let get_shared_rooms = self.get_shared_rooms(user_a, user_b);
@@ -426,6 +435,7 @@ impl Service {
}
/// List the rooms common between two users
#[tracing::instrument(skip(self), level = "debug")]
pub fn get_shared_rooms<'a>(
&'a self,
user_a: &'a UserId,
@@ -453,7 +463,7 @@ impl Service {
}
/// Returns the number of users which are currently in a room
#[tracing::instrument(skip(self), level = "debug")]
#[tracing::instrument(skip(self), level = "trace")]
pub async fn room_joined_count(&self, room_id: &RoomId) -> Result<u64> {
self.db.roomid_joinedcount.get(room_id).await.deserialized()
}
@@ -469,9 +479,9 @@ impl Service {
.ready_filter(|user| self.services.globals.user_is_local(user))
}
#[tracing::instrument(skip(self), level = "debug")]
/// Returns an iterator of all our local joined users in a room who are
/// active (not deactivated, not guest)
#[tracing::instrument(skip(self), level = "trace")]
pub fn active_local_users_in_room<'a>(
&'a self,
room_id: &'a RoomId,
@@ -481,7 +491,7 @@ impl Service {
}
/// Returns the number of users which are currently invited to a room
#[tracing::instrument(skip(self), level = "debug")]
#[tracing::instrument(skip(self), level = "trace")]
pub async fn room_invited_count(&self, room_id: &RoomId) -> Result<u64> {
self.db
.roomid_invitedcount
@@ -518,7 +528,7 @@ impl Service {
.map(|(_, user_id): (Ignore, &UserId)| user_id)
}
#[tracing::instrument(skip(self), level = "debug")]
#[tracing::instrument(skip(self), level = "trace")]
pub async fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<u64> {
let key = (room_id, user_id);
self.db
@@ -528,7 +538,7 @@ impl Service {
.deserialized()
}
#[tracing::instrument(skip(self), level = "debug")]
#[tracing::instrument(skip(self), level = "trace")]
pub async fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<u64> {
let key = (room_id, user_id);
self.db.roomuserid_leftcount.qry(&key).await.deserialized()
@@ -566,7 +576,7 @@ impl Service {
.ignore_err()
}
#[tracing::instrument(skip(self), level = "debug")]
#[tracing::instrument(skip(self), level = "trace")]
pub async fn invite_state(
&self,
user_id: &UserId,
@@ -583,7 +593,7 @@ impl Service {
})
}
#[tracing::instrument(skip(self), level = "debug")]
#[tracing::instrument(skip(self), level = "trace")]
pub async fn left_state(
&self,
user_id: &UserId,
@@ -625,24 +635,25 @@ impl Service {
self.db.roomuseroncejoinedids.qry(&key).await.is_ok()
}
#[tracing::instrument(skip(self), level = "debug")]
#[tracing::instrument(skip(self), level = "trace")]
pub async fn is_joined<'a>(&'a self, user_id: &'a UserId, room_id: &'a RoomId) -> bool {
let key = (user_id, room_id);
self.db.userroomid_joined.qry(&key).await.is_ok()
}
#[tracing::instrument(skip(self), level = "debug")]
#[tracing::instrument(skip(self), level = "trace")]
pub async fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> bool {
let key = (user_id, room_id);
self.db.userroomid_invitestate.qry(&key).await.is_ok()
}
#[tracing::instrument(skip(self), level = "debug")]
#[tracing::instrument(skip(self), level = "trace")]
pub async fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> bool {
let key = (user_id, room_id);
self.db.userroomid_leftstate.qry(&key).await.is_ok()
}
#[tracing::instrument(skip(self), level = "trace")]
pub async fn user_membership(
&self,
user_id: &UserId,
@@ -683,7 +694,7 @@ impl Service {
/// distant future.
///
/// See <https://spec.matrix.org/latest/appendices/#routing>
#[tracing::instrument(skip(self), level = "debug")]
#[tracing::instrument(skip(self), level = "trace")]
pub async fn servers_route_via(&self, room_id: &RoomId) -> Result<Vec<OwnedServerName>> {
let most_powerful_user_server = self
.services
@@ -724,6 +735,7 @@ impl Service {
(cache.len(), cache.capacity())
}
#[tracing::instrument(level = "debug", skip_all)]
pub fn clear_appservice_in_room_cache(&self) {
self.appservice_in_room_cache
.write()
@@ -731,6 +743,7 @@ impl Service {
.clear();
}
#[tracing::instrument(level = "debug", skip(self))]
pub async fn update_joined_count(&self, room_id: &RoomId) {
let mut joinedcount = 0_u64;
let mut invitedcount = 0_u64;
@@ -784,11 +797,13 @@ impl Service {
.remove(room_id);
}
#[tracing::instrument(level = "debug", skip(self))]
fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) {
let key = (user_id, room_id);
self.db.roomuseroncejoinedids.put_raw(key, []);
}
#[tracing::instrument(level = "debug", skip(self, last_state, invite_via))]
pub async fn mark_as_invited(
&self,
user_id: &UserId,
@@ -821,7 +836,7 @@ impl Service {
}
}
#[tracing::instrument(skip(self, servers), level = "debug")]
#[tracing::instrument(level = "debug", skip(self, servers))]
pub async fn add_servers_invite_via(&self, room_id: &RoomId, servers: Vec<OwnedServerName>) {
let mut servers: Vec<_> = self
.servers_invite_via(room_id)
+64 -72
View File
@@ -1,22 +1,15 @@
use std::{
borrow::Borrow,
collections::{hash_map, HashMap},
sync::Arc,
};
use std::{borrow::Borrow, sync::Arc};
use conduwuit::{
at, err,
result::{LogErr, NotFound},
utils,
utils::{future::TryExtExt, stream::TryIgnore, ReadyExt},
utils::stream::TryReadyExt,
Err, PduCount, PduEvent, Result,
};
use database::{Database, Deserialized, Json, KeyVal, Map};
use futures::{future::select_ok, FutureExt, Stream, StreamExt};
use ruma::{
api::Direction, CanonicalJsonObject, EventId, OwnedRoomId, OwnedUserId, RoomId, UserId,
};
use tokio::sync::Mutex;
use futures::{future::select_ok, pin_mut, FutureExt, Stream, TryFutureExt, TryStreamExt};
use ruma::{api::Direction, CanonicalJsonObject, EventId, OwnedUserId, RoomId, UserId};
use super::{PduId, RawPduId};
use crate::{rooms, rooms::short::ShortRoomId, Dep};
@@ -27,7 +20,6 @@ pub(super) struct Data {
pduid_pdu: Arc<Map>,
userroomid_highlightcount: Arc<Map>,
userroomid_notificationcount: Arc<Map>,
pub(super) lasttimelinecount_cache: LastTimelineCountCache,
pub(super) db: Arc<Database>,
services: Services,
}
@@ -37,7 +29,6 @@ struct Services {
}
pub type PdusIterItem = (PduCount, PduEvent);
type LastTimelineCountCache = Mutex<HashMap<OwnedRoomId, PduCount>>;
impl Data {
pub(super) fn new(args: &crate::Args<'_>) -> Self {
@@ -48,7 +39,6 @@ impl Data {
pduid_pdu: db["pduid_pdu"].clone(),
userroomid_highlightcount: db["userroomid_highlightcount"].clone(),
userroomid_notificationcount: db["userroomid_notificationcount"].clone(),
lasttimelinecount_cache: Mutex::new(HashMap::new()),
db: args.db.clone(),
services: Services {
short: args.depend::<rooms::short::Service>("rooms::short"),
@@ -56,27 +46,39 @@ impl Data {
}
}
#[inline]
pub(super) async fn last_timeline_count(
&self,
sender_user: Option<&UserId>,
room_id: &RoomId,
) -> Result<PduCount> {
match self
.lasttimelinecount_cache
.lock()
.await
.entry(room_id.into())
{
| hash_map::Entry::Occupied(o) => Ok(*o.get()),
| hash_map::Entry::Vacant(v) => Ok(self
.pdus_rev(sender_user, room_id, PduCount::max())
.await?
.next()
.await
.map(at!(0))
.filter(|&count| matches!(count, PduCount::Normal(_)))
.map_or_else(PduCount::max, |count| *v.insert(count))),
}
let pdus_rev = self.pdus_rev(sender_user, room_id, PduCount::max());
pin_mut!(pdus_rev);
let last_count = pdus_rev
.try_next()
.await?
.map(at!(0))
.filter(|&count| matches!(count, PduCount::Normal(_)))
.unwrap_or_else(PduCount::max);
Ok(last_count)
}
#[inline]
pub(super) async fn latest_pdu_in_room(
&self,
sender_user: Option<&UserId>,
room_id: &RoomId,
) -> Result<PduEvent> {
let pdus_rev = self.pdus_rev(sender_user, room_id, PduCount::max());
pin_mut!(pdus_rev);
pdus_rev
.try_next()
.await?
.map(at!(1))
.ok_or_else(|| err!(Request(NotFound("no PDU's found in room"))))
}
/// Returns the `count` of this pdu's id.
@@ -129,7 +131,7 @@ impl Data {
pub(super) async fn non_outlier_pdu_exists(&self, event_id: &EventId) -> Result {
let pduid = self.get_pdu_id(event_id).await?;
self.pduid_pdu.get(&pduid).await.map(|_| ())
self.pduid_pdu.exists(&pduid).await
}
/// Returns the pdu.
@@ -148,17 +150,17 @@ impl Data {
/// Like get_non_outlier_pdu(), but without the expense of fetching and
/// parsing the PduEvent
#[inline]
pub(super) async fn outlier_pdu_exists(&self, event_id: &EventId) -> Result {
self.eventid_outlierpdu.get(event_id).await.map(|_| ())
self.eventid_outlierpdu.exists(event_id).await
}
/// Like get_pdu(), but without the expense of fetching and parsing the data
pub(super) async fn pdu_exists(&self, event_id: &EventId) -> bool {
let non_outlier = self.non_outlier_pdu_exists(event_id).is_ok();
let outlier = self.outlier_pdu_exists(event_id).is_ok();
pub(super) async fn pdu_exists(&self, event_id: &EventId) -> Result {
let non_outlier = self.non_outlier_pdu_exists(event_id).boxed();
let outlier = self.outlier_pdu_exists(event_id).boxed();
//TODO: parallelize
non_outlier.await || outlier.await
select_ok([non_outlier, outlier]).await.map(at!(0))
}
/// Returns the pdu.
@@ -186,11 +188,6 @@ impl Data {
debug_assert!(matches!(count, PduCount::Normal(_)), "PduCount not Normal");
self.pduid_pdu.raw_put(pdu_id, Json(json));
self.lasttimelinecount_cache
.lock()
.await
.insert(pdu.room_id.clone(), count);
self.eventid_pduid.insert(pdu.event_id.as_bytes(), pdu_id);
self.eventid_outlierpdu.remove(pdu.event_id.as_bytes());
}
@@ -225,49 +222,44 @@ impl Data {
/// Returns an iterator over all events and their tokens in a room that
/// happened before the event with id `until` in reverse-chronological
/// order.
pub(super) async fn pdus_rev<'a>(
pub(super) fn pdus_rev<'a>(
&'a self,
user_id: Option<&'a UserId>,
room_id: &'a RoomId,
until: PduCount,
) -> Result<impl Stream<Item = PdusIterItem> + Send + 'a> {
let current = self
.count_to_id(room_id, until, Direction::Backward)
.await?;
let prefix = current.shortroomid();
let stream = self
.pduid_pdu
.rev_raw_stream_from(&current)
.ignore_err()
.ready_take_while(move |(key, _)| key.starts_with(&prefix))
.map(move |item| Self::each_pdu(item, user_id));
Ok(stream)
) -> impl Stream<Item = Result<PdusIterItem>> + Send + 'a {
self.count_to_id(room_id, until, Direction::Backward)
.map_ok(move |current| {
let prefix = current.shortroomid();
self.pduid_pdu
.rev_raw_stream_from(&current)
.ready_try_take_while(move |(key, _)| Ok(key.starts_with(&prefix)))
.ready_and_then(move |item| Self::each_pdu(item, user_id))
})
.try_flatten_stream()
}
pub(super) async fn pdus<'a>(
pub(super) fn pdus<'a>(
&'a self,
user_id: Option<&'a UserId>,
room_id: &'a RoomId,
from: PduCount,
) -> Result<impl Stream<Item = PdusIterItem> + Send + Unpin + 'a> {
let current = self.count_to_id(room_id, from, Direction::Forward).await?;
let prefix = current.shortroomid();
let stream = self
.pduid_pdu
.raw_stream_from(&current)
.ignore_err()
.ready_take_while(move |(key, _)| key.starts_with(&prefix))
.map(move |item| Self::each_pdu(item, user_id));
Ok(stream)
) -> impl Stream<Item = Result<PdusIterItem>> + Send + 'a {
self.count_to_id(room_id, from, Direction::Forward)
.map_ok(move |current| {
let prefix = current.shortroomid();
self.pduid_pdu
.raw_stream_from(&current)
.ready_try_take_while(move |(key, _)| Ok(key.starts_with(&prefix)))
.ready_and_then(move |item| Self::each_pdu(item, user_id))
})
.try_flatten_stream()
}
fn each_pdu((pdu_id, pdu): KeyVal<'_>, user_id: Option<&UserId>) -> PdusIterItem {
fn each_pdu((pdu_id, pdu): KeyVal<'_>, user_id: Option<&UserId>) -> Result<PdusIterItem> {
let pdu_id: RawPduId = pdu_id.into();
let mut pdu = serde_json::from_slice::<PduEvent>(pdu)
.expect("PduEvent in pduid_pdu database column is invalid JSON");
let mut pdu = serde_json::from_slice::<PduEvent>(pdu)?;
if Some(pdu.sender.borrow()) != user_id {
pdu.remove_transaction_id().log_err().ok();
@@ -275,7 +267,7 @@ impl Data {
pdu.add_age().log_err().ok();
(pdu_id.pdu_count(), pdu)
Ok((pdu_id.pdu_count(), pdu))
}
pub(super) fn increment_notification_counts(
+41 -94
View File
@@ -9,14 +9,16 @@ use std::{
};
use conduwuit::{
debug, debug_warn, err, error, implement, info,
at, debug, debug_warn, err, error, implement, info,
pdu::{gen_event_id, EventHash, PduBuilder, PduCount, PduEvent},
utils::{self, stream::TryIgnore, IterStream, MutexMap, MutexMapGuard, ReadyExt},
utils::{
self, future::TryExtExt, stream::TryIgnore, IterStream, MutexMap, MutexMapGuard, ReadyExt,
},
validated, warn, Err, Error, Result, Server,
};
pub use conduwuit::{PduId, RawPduId};
use futures::{
future, future::ready, Future, FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt,
future, future::ready, pin_mut, Future, FutureExt, Stream, StreamExt, TryStreamExt,
};
use ruma::{
api::federation,
@@ -34,7 +36,7 @@ use ruma::{
},
push::{Action, Ruleset, Tweak},
state_res::{self, Event, RoomVersion},
uint, user_id, CanonicalJsonObject, CanonicalJsonValue, EventId, OwnedEventId, OwnedRoomId,
uint, CanonicalJsonObject, CanonicalJsonValue, EventId, OwnedEventId, OwnedRoomId,
OwnedServerName, OwnedUserId, RoomId, RoomVersionId, ServerName, UserId,
};
use serde::Deserialize;
@@ -139,53 +141,34 @@ impl crate::Service for Service {
}
fn memory_usage(&self, out: &mut dyn Write) -> Result<()> {
/*
let lasttimelinecount_cache = self
.db
.lasttimelinecount_cache
.lock()
.expect("locked")
.len();
writeln!(out, "lasttimelinecount_cache: {lasttimelinecount_cache}")?;
*/
let mutex_insert = self.mutex_insert.len();
writeln!(out, "insert_mutex: {mutex_insert}")?;
Ok(())
}
fn clear_cache(&self) {
/*
self.db
.lasttimelinecount_cache
.lock()
.expect("locked")
.clear();
*/
}
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
}
impl Service {
#[tracing::instrument(skip(self), level = "debug")]
pub async fn first_pdu_in_room(&self, room_id: &RoomId) -> Result<Arc<PduEvent>> {
self.all_pdus(user_id!("@doesntmatter:conduit.rs"), room_id)
.next()
.await
.map(|(_, p)| Arc::new(p))
pub async fn first_pdu_in_room(&self, room_id: &RoomId) -> Result<PduEvent> {
self.first_item_in_room(room_id).await.map(at!(1))
}
#[tracing::instrument(skip(self), level = "debug")]
pub async fn first_item_in_room(&self, room_id: &RoomId) -> Result<(PduCount, PduEvent)> {
let pdus = self.pdus(None, room_id, None);
pin_mut!(pdus);
pdus.try_next()
.await?
.ok_or_else(|| err!(Request(NotFound("No PDU found in room"))))
}
#[tracing::instrument(skip(self), level = "debug")]
pub async fn latest_pdu_in_room(&self, room_id: &RoomId) -> Result<Arc<PduEvent>> {
self.pdus_rev(None, room_id, None)
.await?
.next()
.await
.map(|(_, p)| Arc::new(p))
.ok_or_else(|| err!(Request(NotFound("No PDU found in room"))))
pub async fn latest_pdu_in_room(&self, room_id: &RoomId) -> Result<PduEvent> {
self.db.latest_pdu_in_room(None, room_id).await
}
#[tracing::instrument(skip(self), level = "debug")]
@@ -202,29 +185,6 @@ impl Service {
self.db.get_pdu_count(event_id).await
}
// TODO Is this the same as the function above?
/*
#[tracing::instrument(skip(self))]
pub fn latest_pdu_count(&self, room_id: &RoomId) -> Result<u64> {
let prefix = self
.get_shortroomid(room_id)?
.expect("room exists")
.to_be_bytes()
.to_vec();
let mut last_possible_key = prefix.clone();
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
self.pduid_pdu
.iter_from(&last_possible_key, true)
.take_while(move |(k, _)| k.starts_with(&prefix))
.next()
.map(|b| self.pdu_count(&b.0))
.transpose()
.map(|op| op.unwrap_or_default())
}
*/
/// Returns the json of a pdu.
pub async fn get_pdu_json(&self, event_id: &EventId) -> Result<CanonicalJsonObject> {
self.db.get_pdu_json(event_id).await
@@ -260,16 +220,6 @@ impl Service {
self.db.get_pdu(event_id).await
}
/// Checks if pdu exists
///
/// Checks the `eventid_outlierpdu` Tree if not found in the timeline.
pub fn pdu_exists<'a>(
&'a self,
event_id: &'a EventId,
) -> impl Future<Output = bool> + Send + 'a {
self.db.pdu_exists(event_id)
}
/// Returns the pdu.
///
/// This does __NOT__ check the outliers `Tree`.
@@ -282,6 +232,16 @@ impl Service {
self.db.get_pdu_json_from_id(pdu_id).await
}
/// Checks if pdu exists
///
/// Checks the `eventid_outlierpdu` Tree if not found in the timeline.
pub fn pdu_exists<'a>(
&'a self,
event_id: &'a EventId,
) -> impl Future<Output = bool> + Send + 'a {
self.db.pdu_exists(event_id).is_ok()
}
/// Removes a pdu and creates a new one with the same id.
#[tracing::instrument(skip(self), level = "debug")]
pub async fn replace_pdu(
@@ -1027,38 +987,32 @@ impl Service {
&'a self,
user_id: &'a UserId,
room_id: &'a RoomId,
) -> impl Stream<Item = PdusIterItem> + Send + Unpin + 'a {
self.pdus(Some(user_id), room_id, None)
.map_ok(|stream| stream.map(Ok))
.try_flatten_stream()
.ignore_err()
.boxed()
) -> impl Stream<Item = PdusIterItem> + Send + 'a {
self.pdus(Some(user_id), room_id, None).ignore_err()
}
/// Reverse iteration starting at from.
#[tracing::instrument(skip(self), level = "debug")]
pub async fn pdus_rev<'a>(
pub fn pdus_rev<'a>(
&'a self,
user_id: Option<&'a UserId>,
room_id: &'a RoomId,
until: Option<PduCount>,
) -> Result<impl Stream<Item = PdusIterItem> + Send + 'a> {
) -> impl Stream<Item = Result<PdusIterItem>> + Send + 'a {
self.db
.pdus_rev(user_id, room_id, until.unwrap_or_else(PduCount::max))
.await
}
/// Forward iteration starting at from.
#[tracing::instrument(skip(self), level = "debug")]
pub async fn pdus<'a>(
pub fn pdus<'a>(
&'a self,
user_id: Option<&'a UserId>,
room_id: &'a RoomId,
from: Option<PduCount>,
) -> Result<impl Stream<Item = PdusIterItem> + Send + 'a> {
) -> impl Stream<Item = Result<PdusIterItem>> + Send + 'a {
self.db
.pdus(user_id, room_id, from.unwrap_or_else(PduCount::min))
.await
}
/// Replace a PDU with the redacted form.
@@ -1117,8 +1071,7 @@ impl Service {
}
let first_pdu = self
.all_pdus(user_id!("@doesntmatter:conduit.rs"), room_id)
.next()
.first_item_in_room(room_id)
.await
.expect("Room is not empty");
@@ -1232,20 +1185,14 @@ impl Service {
self.services
.event_handler
.handle_incoming_pdu(origin, &room_id, &event_id, value, false)
.boxed()
.await?;
let value = self
.get_pdu_json(&event_id)
.await
.expect("We just created it");
let pdu = self.get_pdu(&event_id).await.expect("We just created it");
let value = self.get_pdu_json(&event_id).await?;
let shortroomid = self
.services
.short
.get_shortroomid(&room_id)
.await
.expect("room exists");
let pdu = self.get_pdu(&event_id).await?;
let shortroomid = self.services.short.get_shortroomid(&room_id).await?;
let insert_lock = self.mutex_insert.lock(&room_id).await;
+3 -1
View File
@@ -80,7 +80,9 @@ impl Service {
self.work_loop(id, &mut futures, &mut statuses).await;
self.finish_responses(&mut futures).boxed().await;
if !futures.is_empty() {
self.finish_responses(&mut futures).boxed().await;
}
Ok(())
}

Some files were not shown because too many files have changed in this diff Show More