mirror of
https://forgejo.ellis.link/continuwuation/continuwuity.git
synced 2026-05-26 20:49:55 +00:00
add rustfmt.toml, format entire codebase
Signed-off-by: strawberry <strawberry@puppygock.gay>
This commit is contained in:
@@ -1,114 +1,94 @@
|
||||
use crate::{services, utils, Error, Result};
|
||||
use bytes::BytesMut;
|
||||
use ruma::api::{
|
||||
appservice::Registration, IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken,
|
||||
};
|
||||
use std::{fmt::Debug, mem, time::Duration};
|
||||
|
||||
use bytes::BytesMut;
|
||||
use ruma::api::{appservice::Registration, IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken};
|
||||
use tracing::warn;
|
||||
|
||||
use crate::{services, utils, Error, Result};
|
||||
|
||||
/// Sends a request to an appservice
|
||||
///
|
||||
/// Only returns None if there is no url specified in the appservice registration file
|
||||
pub(crate) async fn send_request<T>(
|
||||
registration: Registration,
|
||||
request: T,
|
||||
) -> Option<Result<T::IncomingResponse>>
|
||||
/// Only returns None if there is no url specified in the appservice
|
||||
/// registration file
|
||||
pub(crate) async fn send_request<T>(registration: Registration, request: T) -> Option<Result<T::IncomingResponse>>
|
||||
where
|
||||
T: OutgoingRequest + Debug,
|
||||
T: OutgoingRequest + Debug,
|
||||
{
|
||||
if let Some(destination) = registration.url {
|
||||
let hs_token = registration.hs_token.as_str();
|
||||
if let Some(destination) = registration.url {
|
||||
let hs_token = registration.hs_token.as_str();
|
||||
|
||||
let mut http_request = request
|
||||
.try_into_http_request::<BytesMut>(
|
||||
&destination,
|
||||
SendAccessToken::IfRequired(hs_token),
|
||||
&[MatrixVersion::V1_0],
|
||||
)
|
||||
.map_err(|e| {
|
||||
warn!("Failed to find destination {}: {}", destination, e);
|
||||
Error::BadServerResponse("Invalid destination")
|
||||
})
|
||||
.unwrap()
|
||||
.map(bytes::BytesMut::freeze);
|
||||
let mut http_request = request
|
||||
.try_into_http_request::<BytesMut>(
|
||||
&destination,
|
||||
SendAccessToken::IfRequired(hs_token),
|
||||
&[MatrixVersion::V1_0],
|
||||
)
|
||||
.map_err(|e| {
|
||||
warn!("Failed to find destination {}: {}", destination, e);
|
||||
Error::BadServerResponse("Invalid destination")
|
||||
})
|
||||
.unwrap()
|
||||
.map(bytes::BytesMut::freeze);
|
||||
|
||||
let mut parts = http_request.uri().clone().into_parts();
|
||||
let old_path_and_query = parts.path_and_query.unwrap().as_str().to_owned();
|
||||
let symbol = if old_path_and_query.contains('?') {
|
||||
"&"
|
||||
} else {
|
||||
"?"
|
||||
};
|
||||
let mut parts = http_request.uri().clone().into_parts();
|
||||
let old_path_and_query = parts.path_and_query.unwrap().as_str().to_owned();
|
||||
let symbol = if old_path_and_query.contains('?') {
|
||||
"&"
|
||||
} else {
|
||||
"?"
|
||||
};
|
||||
|
||||
parts.path_and_query = Some(
|
||||
(old_path_and_query + symbol + "access_token=" + hs_token)
|
||||
.parse()
|
||||
.unwrap(),
|
||||
);
|
||||
*http_request.uri_mut() = parts.try_into().expect("our manipulation is always valid");
|
||||
parts.path_and_query = Some((old_path_and_query + symbol + "access_token=" + hs_token).parse().unwrap());
|
||||
*http_request.uri_mut() = parts.try_into().expect("our manipulation is always valid");
|
||||
|
||||
let mut reqwest_request = reqwest::Request::try_from(http_request)
|
||||
.expect("all http requests are valid reqwest requests");
|
||||
let mut reqwest_request =
|
||||
reqwest::Request::try_from(http_request).expect("all http requests are valid reqwest requests");
|
||||
|
||||
*reqwest_request.timeout_mut() = Some(Duration::from_secs(120));
|
||||
*reqwest_request.timeout_mut() = Some(Duration::from_secs(120));
|
||||
|
||||
let url = reqwest_request.url().clone();
|
||||
let mut response = match services()
|
||||
.globals
|
||||
.default_client()
|
||||
.execute(reqwest_request)
|
||||
.await
|
||||
{
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"Could not send request to appservice {} at {}: {}",
|
||||
registration.id, destination, e
|
||||
);
|
||||
return Some(Err(e.into()));
|
||||
}
|
||||
};
|
||||
let url = reqwest_request.url().clone();
|
||||
let mut response = match services().globals.default_client().execute(reqwest_request).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"Could not send request to appservice {} at {}: {}",
|
||||
registration.id, destination, e
|
||||
);
|
||||
return Some(Err(e.into()));
|
||||
},
|
||||
};
|
||||
|
||||
// reqwest::Response -> http::Response conversion
|
||||
let status = response.status();
|
||||
let mut http_response_builder = http::Response::builder()
|
||||
.status(status)
|
||||
.version(response.version());
|
||||
mem::swap(
|
||||
response.headers_mut(),
|
||||
http_response_builder
|
||||
.headers_mut()
|
||||
.expect("http::response::Builder is usable"),
|
||||
);
|
||||
// reqwest::Response -> http::Response conversion
|
||||
let status = response.status();
|
||||
let mut http_response_builder = http::Response::builder().status(status).version(response.version());
|
||||
mem::swap(
|
||||
response.headers_mut(),
|
||||
http_response_builder.headers_mut().expect("http::response::Builder is usable"),
|
||||
);
|
||||
|
||||
let body = response.bytes().await.unwrap_or_else(|e| {
|
||||
warn!("server error: {}", e);
|
||||
Vec::new().into()
|
||||
}); // TODO: handle timeout
|
||||
let body = response.bytes().await.unwrap_or_else(|e| {
|
||||
warn!("server error: {}", e);
|
||||
Vec::new().into()
|
||||
}); // TODO: handle timeout
|
||||
|
||||
if !status.is_success() {
|
||||
warn!(
|
||||
"Appservice returned bad response {} {}\n{}\n{:?}",
|
||||
destination,
|
||||
status,
|
||||
url,
|
||||
utils::string_from_bytes(&body)
|
||||
);
|
||||
}
|
||||
if !status.is_success() {
|
||||
warn!(
|
||||
"Appservice returned bad response {} {}\n{}\n{:?}",
|
||||
destination,
|
||||
status,
|
||||
url,
|
||||
utils::string_from_bytes(&body)
|
||||
);
|
||||
}
|
||||
|
||||
let response = T::IncomingResponse::try_from_http_response(
|
||||
http_response_builder
|
||||
.body(body)
|
||||
.expect("reqwest body is valid http body"),
|
||||
);
|
||||
Some(response.map_err(|_| {
|
||||
warn!(
|
||||
"Appservice returned invalid response bytes {}\n{}",
|
||||
destination, url
|
||||
);
|
||||
Error::BadServerResponse("Server returned bad response.")
|
||||
}))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
let response = T::IncomingResponse::try_from_http_response(
|
||||
http_response_builder.body(body).expect("reqwest body is valid http body"),
|
||||
);
|
||||
Some(response.map_err(|_| {
|
||||
warn!("Appservice returned invalid response bytes {}\n{}", destination, url);
|
||||
Error::BadServerResponse("Server returned bad response.")
|
||||
}))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
+346
-397
@@ -1,21 +1,21 @@
|
||||
use super::{DEVICE_ID_LENGTH, SESSION_ID_LENGTH, TOKEN_LENGTH};
|
||||
use crate::{api::client_server, services, utils, Error, Result, Ruma};
|
||||
use register::RegistrationKind;
|
||||
use ruma::{
|
||||
api::client::{
|
||||
account::{
|
||||
change_password, deactivate, get_3pids, get_username_availability, register,
|
||||
request_3pid_management_token_via_email, request_3pid_management_token_via_msisdn,
|
||||
whoami, ThirdPartyIdRemovalStatus,
|
||||
},
|
||||
error::ErrorKind,
|
||||
uiaa::{AuthFlow, AuthType, UiaaInfo},
|
||||
},
|
||||
events::{room::message::RoomMessageEventContent, GlobalAccountDataEventType},
|
||||
push, UserId,
|
||||
api::client::{
|
||||
account::{
|
||||
change_password, deactivate, get_3pids, get_username_availability, register,
|
||||
request_3pid_management_token_via_email, request_3pid_management_token_via_msisdn, whoami,
|
||||
ThirdPartyIdRemovalStatus,
|
||||
},
|
||||
error::ErrorKind,
|
||||
uiaa::{AuthFlow, AuthType, UiaaInfo},
|
||||
},
|
||||
events::{room::message::RoomMessageEventContent, GlobalAccountDataEventType},
|
||||
push, UserId,
|
||||
};
|
||||
use tracing::{info, warn};
|
||||
|
||||
use register::RegistrationKind;
|
||||
use super::{DEVICE_ID_LENGTH, SESSION_ID_LENGTH, TOKEN_LENGTH};
|
||||
use crate::{api::client_server, services, utils, Error, Result, Ruma};
|
||||
|
||||
const RANDOM_USER_ID_LENGTH: usize = 10;
|
||||
|
||||
@@ -28,303 +28,266 @@ const RANDOM_USER_ID_LENGTH: usize = 10;
|
||||
/// - The server name of the user id matches this server
|
||||
/// - No user or appservice on this server already claimed this username
|
||||
///
|
||||
/// Note: This will not reserve the username, so the username might become invalid when trying to register
|
||||
/// Note: This will not reserve the username, so the username might become
|
||||
/// invalid when trying to register
|
||||
pub async fn get_register_available_route(
|
||||
body: Ruma<get_username_availability::v3::Request>,
|
||||
body: Ruma<get_username_availability::v3::Request>,
|
||||
) -> Result<get_username_availability::v3::Response> {
|
||||
// Validate user id
|
||||
let user_id = UserId::parse_with_server_name(
|
||||
body.username.to_lowercase(),
|
||||
services().globals.server_name(),
|
||||
)
|
||||
.ok()
|
||||
.filter(|user_id| {
|
||||
!user_id.is_historical() && user_id.server_name() == services().globals.server_name()
|
||||
})
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::InvalidUsername,
|
||||
"Username is invalid.",
|
||||
))?;
|
||||
// Validate user id
|
||||
let user_id = UserId::parse_with_server_name(body.username.to_lowercase(), services().globals.server_name())
|
||||
.ok()
|
||||
.filter(|user_id| !user_id.is_historical() && user_id.server_name() == services().globals.server_name())
|
||||
.ok_or(Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?;
|
||||
|
||||
// Check if username is creative enough
|
||||
if services().users.exists(&user_id)? {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::UserInUse,
|
||||
"Desired user ID is already taken.",
|
||||
));
|
||||
}
|
||||
// Check if username is creative enough
|
||||
if services().users.exists(&user_id)? {
|
||||
return Err(Error::BadRequest(ErrorKind::UserInUse, "Desired user ID is already taken."));
|
||||
}
|
||||
|
||||
if services()
|
||||
.globals
|
||||
.forbidden_usernames()
|
||||
.is_match(user_id.localpart())
|
||||
{
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Unknown,
|
||||
"Username is forbidden.",
|
||||
));
|
||||
}
|
||||
if services().globals.forbidden_usernames().is_match(user_id.localpart()) {
|
||||
return Err(Error::BadRequest(ErrorKind::Unknown, "Username is forbidden."));
|
||||
}
|
||||
|
||||
// TODO add check for appservice namespaces
|
||||
// TODO add check for appservice namespaces
|
||||
|
||||
// If no if check is true we have an username that's available to be used.
|
||||
Ok(get_username_availability::v3::Response { available: true })
|
||||
// If no if check is true we have an username that's available to be used.
|
||||
Ok(get_username_availability::v3::Response {
|
||||
available: true,
|
||||
})
|
||||
}
|
||||
|
||||
/// # `POST /_matrix/client/v3/register`
|
||||
///
|
||||
/// Register an account on this homeserver.
|
||||
///
|
||||
/// You can use [`GET /_matrix/client/v3/register/available`](fn.get_register_available_route.html)
|
||||
/// to check if the user id is valid and available.
|
||||
/// You can use [`GET
|
||||
/// /_matrix/client/v3/register/available`](fn.get_register_available_route.
|
||||
/// html) to check if the user id is valid and available.
|
||||
///
|
||||
/// - Only works if registration is enabled
|
||||
/// - If type is guest: ignores all parameters except initial_device_display_name
|
||||
/// - If type is guest: ignores all parameters except
|
||||
/// initial_device_display_name
|
||||
/// - If sender is not appservice: Requires UIAA (but we only use a dummy stage)
|
||||
/// - If type is not guest and no username is given: Always fails after UIAA check
|
||||
/// - If type is not guest and no username is given: Always fails after UIAA
|
||||
/// check
|
||||
/// - Creates a new account and populates it with default account data
|
||||
/// - If `inhibit_login` is false: Creates a device and returns device id and access_token
|
||||
/// - If `inhibit_login` is false: Creates a device and returns device id and
|
||||
/// access_token
|
||||
pub async fn register_route(body: Ruma<register::v3::Request>) -> Result<register::v3::Response> {
|
||||
if !services().globals.allow_registration() && !body.from_appservice {
|
||||
info!("Registration disabled and request not from known appservice, rejecting registration attempt for username {:?}", body.username);
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"Registration has been disabled.",
|
||||
));
|
||||
}
|
||||
if !services().globals.allow_registration() && !body.from_appservice {
|
||||
info!(
|
||||
"Registration disabled and request not from known appservice, rejecting registration attempt for username \
|
||||
{:?}",
|
||||
body.username
|
||||
);
|
||||
return Err(Error::BadRequest(ErrorKind::Forbidden, "Registration has been disabled."));
|
||||
}
|
||||
|
||||
let is_guest = body.kind == RegistrationKind::Guest;
|
||||
let is_guest = body.kind == RegistrationKind::Guest;
|
||||
|
||||
if is_guest
|
||||
&& (!services().globals.allow_guest_registration()
|
||||
|| (services().globals.allow_registration()
|
||||
&& services().globals.config.registration_token.is_some()))
|
||||
{
|
||||
info!("Guest registration disabled / registration enabled with token configured, rejecting guest registration, initial device name: {:?}", body.initial_device_display_name);
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::GuestAccessForbidden,
|
||||
"Guest registration is disabled.",
|
||||
));
|
||||
}
|
||||
if is_guest
|
||||
&& (!services().globals.allow_guest_registration()
|
||||
|| (services().globals.allow_registration() && services().globals.config.registration_token.is_some()))
|
||||
{
|
||||
info!(
|
||||
"Guest registration disabled / registration enabled with token configured, rejecting guest registration, \
|
||||
initial device name: {:?}",
|
||||
body.initial_device_display_name
|
||||
);
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::GuestAccessForbidden,
|
||||
"Guest registration is disabled.",
|
||||
));
|
||||
}
|
||||
|
||||
// forbid guests from registering if there is not a real admin user yet. give generic user error.
|
||||
if is_guest && services().users.count()? < 2 {
|
||||
warn!("Guest account attempted to register before a real admin user has been registered, rejecting registration. Guest's initial device name: {:?}", body.initial_device_display_name);
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"Registration temporarily disabled.",
|
||||
));
|
||||
}
|
||||
// forbid guests from registering if there is not a real admin user yet. give
|
||||
// generic user error.
|
||||
if is_guest && services().users.count()? < 2 {
|
||||
warn!(
|
||||
"Guest account attempted to register before a real admin user has been registered, rejecting \
|
||||
registration. Guest's initial device name: {:?}",
|
||||
body.initial_device_display_name
|
||||
);
|
||||
return Err(Error::BadRequest(ErrorKind::Forbidden, "Registration temporarily disabled."));
|
||||
}
|
||||
|
||||
let user_id = match (&body.username, is_guest) {
|
||||
(Some(username), false) => {
|
||||
let proposed_user_id = UserId::parse_with_server_name(
|
||||
username.to_lowercase(),
|
||||
services().globals.server_name(),
|
||||
)
|
||||
.ok()
|
||||
.filter(|user_id| {
|
||||
!user_id.is_historical()
|
||||
&& user_id.server_name() == services().globals.server_name()
|
||||
})
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::InvalidUsername,
|
||||
"Username is invalid.",
|
||||
))?;
|
||||
let user_id = match (&body.username, is_guest) {
|
||||
(Some(username), false) => {
|
||||
let proposed_user_id =
|
||||
UserId::parse_with_server_name(username.to_lowercase(), services().globals.server_name())
|
||||
.ok()
|
||||
.filter(|user_id| {
|
||||
!user_id.is_historical() && user_id.server_name() == services().globals.server_name()
|
||||
})
|
||||
.ok_or(Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?;
|
||||
|
||||
if services().users.exists(&proposed_user_id)? {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::UserInUse,
|
||||
"Desired user ID is already taken.",
|
||||
));
|
||||
}
|
||||
if services().users.exists(&proposed_user_id)? {
|
||||
return Err(Error::BadRequest(ErrorKind::UserInUse, "Desired user ID is already taken."));
|
||||
}
|
||||
|
||||
if services()
|
||||
.globals
|
||||
.forbidden_usernames()
|
||||
.is_match(proposed_user_id.localpart())
|
||||
{
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Unknown,
|
||||
"Username is forbidden.",
|
||||
));
|
||||
}
|
||||
if services().globals.forbidden_usernames().is_match(proposed_user_id.localpart()) {
|
||||
return Err(Error::BadRequest(ErrorKind::Unknown, "Username is forbidden."));
|
||||
}
|
||||
|
||||
proposed_user_id
|
||||
}
|
||||
_ => loop {
|
||||
let proposed_user_id = UserId::parse_with_server_name(
|
||||
utils::random_string(RANDOM_USER_ID_LENGTH).to_lowercase(),
|
||||
services().globals.server_name(),
|
||||
)
|
||||
.unwrap();
|
||||
if !services().users.exists(&proposed_user_id)? {
|
||||
break proposed_user_id;
|
||||
}
|
||||
},
|
||||
};
|
||||
proposed_user_id
|
||||
},
|
||||
_ => loop {
|
||||
let proposed_user_id = UserId::parse_with_server_name(
|
||||
utils::random_string(RANDOM_USER_ID_LENGTH).to_lowercase(),
|
||||
services().globals.server_name(),
|
||||
)
|
||||
.unwrap();
|
||||
if !services().users.exists(&proposed_user_id)? {
|
||||
break proposed_user_id;
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
// UIAA
|
||||
let mut uiaainfo;
|
||||
let skip_auth;
|
||||
if services().globals.config.registration_token.is_some() {
|
||||
// Registration token required
|
||||
uiaainfo = UiaaInfo {
|
||||
flows: vec![AuthFlow {
|
||||
stages: vec![AuthType::RegistrationToken],
|
||||
}],
|
||||
completed: Vec::new(),
|
||||
params: Box::default(),
|
||||
session: None,
|
||||
auth_error: None,
|
||||
};
|
||||
skip_auth = body.from_appservice;
|
||||
} else {
|
||||
// No registration token necessary, but clients must still go through the flow
|
||||
uiaainfo = UiaaInfo {
|
||||
flows: vec![AuthFlow {
|
||||
stages: vec![AuthType::Dummy],
|
||||
}],
|
||||
completed: Vec::new(),
|
||||
params: Box::default(),
|
||||
session: None,
|
||||
auth_error: None,
|
||||
};
|
||||
skip_auth = body.from_appservice || is_guest;
|
||||
}
|
||||
// UIAA
|
||||
let mut uiaainfo;
|
||||
let skip_auth;
|
||||
if services().globals.config.registration_token.is_some() {
|
||||
// Registration token required
|
||||
uiaainfo = UiaaInfo {
|
||||
flows: vec![AuthFlow {
|
||||
stages: vec![AuthType::RegistrationToken],
|
||||
}],
|
||||
completed: Vec::new(),
|
||||
params: Box::default(),
|
||||
session: None,
|
||||
auth_error: None,
|
||||
};
|
||||
skip_auth = body.from_appservice;
|
||||
} else {
|
||||
// No registration token necessary, but clients must still go through the flow
|
||||
uiaainfo = UiaaInfo {
|
||||
flows: vec![AuthFlow {
|
||||
stages: vec![AuthType::Dummy],
|
||||
}],
|
||||
completed: Vec::new(),
|
||||
params: Box::default(),
|
||||
session: None,
|
||||
auth_error: None,
|
||||
};
|
||||
skip_auth = body.from_appservice || is_guest;
|
||||
}
|
||||
|
||||
if !skip_auth {
|
||||
if let Some(auth) = &body.auth {
|
||||
let (worked, uiaainfo) = services().uiaa.try_auth(
|
||||
&UserId::parse_with_server_name("", services().globals.server_name())
|
||||
.expect("we know this is valid"),
|
||||
"".into(),
|
||||
auth,
|
||||
&uiaainfo,
|
||||
)?;
|
||||
if !worked {
|
||||
return Err(Error::Uiaa(uiaainfo));
|
||||
}
|
||||
// Success!
|
||||
} else if let Some(json) = body.json_body {
|
||||
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
|
||||
services().uiaa.create(
|
||||
&UserId::parse_with_server_name("", services().globals.server_name())
|
||||
.expect("we know this is valid"),
|
||||
"".into(),
|
||||
&uiaainfo,
|
||||
&json,
|
||||
)?;
|
||||
return Err(Error::Uiaa(uiaainfo));
|
||||
} else {
|
||||
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
|
||||
}
|
||||
}
|
||||
if !skip_auth {
|
||||
if let Some(auth) = &body.auth {
|
||||
let (worked, uiaainfo) = services().uiaa.try_auth(
|
||||
&UserId::parse_with_server_name("", services().globals.server_name()).expect("we know this is valid"),
|
||||
"".into(),
|
||||
auth,
|
||||
&uiaainfo,
|
||||
)?;
|
||||
if !worked {
|
||||
return Err(Error::Uiaa(uiaainfo));
|
||||
}
|
||||
// Success!
|
||||
} else if let Some(json) = body.json_body {
|
||||
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
|
||||
services().uiaa.create(
|
||||
&UserId::parse_with_server_name("", services().globals.server_name()).expect("we know this is valid"),
|
||||
"".into(),
|
||||
&uiaainfo,
|
||||
&json,
|
||||
)?;
|
||||
return Err(Error::Uiaa(uiaainfo));
|
||||
} else {
|
||||
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
|
||||
}
|
||||
}
|
||||
|
||||
let password = if is_guest {
|
||||
None
|
||||
} else {
|
||||
body.password.as_deref()
|
||||
};
|
||||
let password = if is_guest {
|
||||
None
|
||||
} else {
|
||||
body.password.as_deref()
|
||||
};
|
||||
|
||||
// Create user
|
||||
services().users.create(&user_id, password)?;
|
||||
// Create user
|
||||
services().users.create(&user_id, password)?;
|
||||
|
||||
// Default to pretty displayname
|
||||
let mut displayname = user_id.localpart().to_owned();
|
||||
// Default to pretty displayname
|
||||
let mut displayname = user_id.localpart().to_owned();
|
||||
|
||||
// If `new_user_displayname_suffix` is set, registration will push whatever content is set to the user's display name with a space before it
|
||||
if !services().globals.new_user_displayname_suffix().is_empty() {
|
||||
displayname.push_str(&(" ".to_owned() + services().globals.new_user_displayname_suffix()));
|
||||
}
|
||||
// If `new_user_displayname_suffix` is set, registration will push whatever
|
||||
// content is set to the user's display name with a space before it
|
||||
if !services().globals.new_user_displayname_suffix().is_empty() {
|
||||
displayname.push_str(&(" ".to_owned() + services().globals.new_user_displayname_suffix()));
|
||||
}
|
||||
|
||||
services()
|
||||
.users
|
||||
.set_displayname(&user_id, Some(displayname.clone()))
|
||||
.await?;
|
||||
services().users.set_displayname(&user_id, Some(displayname.clone())).await?;
|
||||
|
||||
// Initial account data
|
||||
services().account_data.update(
|
||||
None,
|
||||
&user_id,
|
||||
GlobalAccountDataEventType::PushRules.to_string().into(),
|
||||
&serde_json::to_value(ruma::events::push_rules::PushRulesEvent {
|
||||
content: ruma::events::push_rules::PushRulesEventContent {
|
||||
global: push::Ruleset::server_default(&user_id),
|
||||
},
|
||||
})
|
||||
.expect("to json always works"),
|
||||
)?;
|
||||
// Initial account data
|
||||
services().account_data.update(
|
||||
None,
|
||||
&user_id,
|
||||
GlobalAccountDataEventType::PushRules.to_string().into(),
|
||||
&serde_json::to_value(ruma::events::push_rules::PushRulesEvent {
|
||||
content: ruma::events::push_rules::PushRulesEventContent {
|
||||
global: push::Ruleset::server_default(&user_id),
|
||||
},
|
||||
})
|
||||
.expect("to json always works"),
|
||||
)?;
|
||||
|
||||
// Inhibit login does not work for guests
|
||||
if !is_guest && body.inhibit_login {
|
||||
return Ok(register::v3::Response {
|
||||
access_token: None,
|
||||
user_id,
|
||||
device_id: None,
|
||||
refresh_token: None,
|
||||
expires_in: None,
|
||||
});
|
||||
}
|
||||
// Inhibit login does not work for guests
|
||||
if !is_guest && body.inhibit_login {
|
||||
return Ok(register::v3::Response {
|
||||
access_token: None,
|
||||
user_id,
|
||||
device_id: None,
|
||||
refresh_token: None,
|
||||
expires_in: None,
|
||||
});
|
||||
}
|
||||
|
||||
// Generate new device id if the user didn't specify one
|
||||
let device_id = if is_guest {
|
||||
None
|
||||
} else {
|
||||
body.device_id.clone()
|
||||
}
|
||||
.unwrap_or_else(|| utils::random_string(DEVICE_ID_LENGTH).into());
|
||||
// Generate new device id if the user didn't specify one
|
||||
let device_id = if is_guest {
|
||||
None
|
||||
} else {
|
||||
body.device_id.clone()
|
||||
}
|
||||
.unwrap_or_else(|| utils::random_string(DEVICE_ID_LENGTH).into());
|
||||
|
||||
// Generate new token for the device
|
||||
let token = utils::random_string(TOKEN_LENGTH);
|
||||
// Generate new token for the device
|
||||
let token = utils::random_string(TOKEN_LENGTH);
|
||||
|
||||
// Create device for this account
|
||||
services().users.create_device(
|
||||
&user_id,
|
||||
&device_id,
|
||||
&token,
|
||||
body.initial_device_display_name.clone(),
|
||||
)?;
|
||||
// Create device for this account
|
||||
services().users.create_device(&user_id, &device_id, &token, body.initial_device_display_name.clone())?;
|
||||
|
||||
info!("New user \"{}\" registered on this server.", user_id);
|
||||
info!("New user \"{}\" registered on this server.", user_id);
|
||||
|
||||
// log in conduit admin channel if a non-guest user registered
|
||||
if !body.from_appservice && !is_guest {
|
||||
services()
|
||||
.admin
|
||||
.send_message(RoomMessageEventContent::notice_plain(format!(
|
||||
"New user \"{user_id}\" registered on this server."
|
||||
)));
|
||||
}
|
||||
// log in conduit admin channel if a non-guest user registered
|
||||
if !body.from_appservice && !is_guest {
|
||||
services().admin.send_message(RoomMessageEventContent::notice_plain(format!(
|
||||
"New user \"{user_id}\" registered on this server."
|
||||
)));
|
||||
}
|
||||
|
||||
// log in conduit admin channel if a guest registered
|
||||
if !body.from_appservice && is_guest {
|
||||
services()
|
||||
.admin
|
||||
.send_message(RoomMessageEventContent::notice_plain(format!(
|
||||
"Guest user \"{user_id}\" with device display name `{:?}` registered on this server.",
|
||||
body.initial_device_display_name
|
||||
)));
|
||||
}
|
||||
// log in conduit admin channel if a guest registered
|
||||
if !body.from_appservice && is_guest {
|
||||
services().admin.send_message(RoomMessageEventContent::notice_plain(format!(
|
||||
"Guest user \"{user_id}\" with device display name `{:?}` registered on this server.",
|
||||
body.initial_device_display_name
|
||||
)));
|
||||
}
|
||||
|
||||
// If this is the first real user, grant them admin privileges except for guest users
|
||||
// Note: the server user, @conduit:servername, is generated first
|
||||
if services().users.count()? == 2 && !is_guest {
|
||||
services()
|
||||
.admin
|
||||
.make_user_admin(&user_id, displayname)
|
||||
.await?;
|
||||
// If this is the first real user, grant them admin privileges except for guest
|
||||
// users Note: the server user, @conduit:servername, is generated first
|
||||
if services().users.count()? == 2 && !is_guest {
|
||||
services().admin.make_user_admin(&user_id, displayname).await?;
|
||||
|
||||
warn!("Granting {} admin privileges as the first user", user_id);
|
||||
}
|
||||
warn!("Granting {} admin privileges as the first user", user_id);
|
||||
}
|
||||
|
||||
Ok(register::v3::Response {
|
||||
access_token: Some(token),
|
||||
user_id,
|
||||
device_id: Some(device_id),
|
||||
refresh_token: None,
|
||||
expires_in: None,
|
||||
})
|
||||
Ok(register::v3::Response {
|
||||
access_token: Some(token),
|
||||
user_id,
|
||||
device_id: Some(device_id),
|
||||
refresh_token: None,
|
||||
expires_in: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// # `POST /_matrix/client/r0/account/password`
|
||||
@@ -333,73 +296,65 @@ pub async fn register_route(body: Ruma<register::v3::Request>) -> Result<registe
|
||||
///
|
||||
/// - Requires UIAA to verify user password
|
||||
/// - Changes the password of the sender user
|
||||
/// - The password hash is calculated using argon2 with 32 character salt, the plain password is
|
||||
/// - The password hash is calculated using argon2 with 32 character salt, the
|
||||
/// plain password is
|
||||
/// not saved
|
||||
///
|
||||
/// If logout_devices is true it does the following for each device except the sender device:
|
||||
/// If logout_devices is true it does the following for each device except the
|
||||
/// sender device:
|
||||
/// - Invalidates access token
|
||||
/// - Deletes device metadata (device id, device display name, last seen ip, last seen ts)
|
||||
/// - Deletes device metadata (device id, device display name, last seen ip,
|
||||
/// last seen ts)
|
||||
/// - Forgets to-device events
|
||||
/// - Triggers device list updates
|
||||
pub async fn change_password_route(
|
||||
body: Ruma<change_password::v3::Request>,
|
||||
) -> Result<change_password::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
|
||||
pub async fn change_password_route(body: Ruma<change_password::v3::Request>) -> Result<change_password::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
|
||||
|
||||
let mut uiaainfo = UiaaInfo {
|
||||
flows: vec![AuthFlow {
|
||||
stages: vec![AuthType::Password],
|
||||
}],
|
||||
completed: Vec::new(),
|
||||
params: Box::default(),
|
||||
session: None,
|
||||
auth_error: None,
|
||||
};
|
||||
let mut uiaainfo = UiaaInfo {
|
||||
flows: vec![AuthFlow {
|
||||
stages: vec![AuthType::Password],
|
||||
}],
|
||||
completed: Vec::new(),
|
||||
params: Box::default(),
|
||||
session: None,
|
||||
auth_error: None,
|
||||
};
|
||||
|
||||
if let Some(auth) = &body.auth {
|
||||
let (worked, uiaainfo) =
|
||||
services()
|
||||
.uiaa
|
||||
.try_auth(sender_user, sender_device, auth, &uiaainfo)?;
|
||||
if !worked {
|
||||
return Err(Error::Uiaa(uiaainfo));
|
||||
}
|
||||
// Success!
|
||||
} else if let Some(json) = body.json_body {
|
||||
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
|
||||
services()
|
||||
.uiaa
|
||||
.create(sender_user, sender_device, &uiaainfo, &json)?;
|
||||
return Err(Error::Uiaa(uiaainfo));
|
||||
} else {
|
||||
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
|
||||
}
|
||||
if let Some(auth) = &body.auth {
|
||||
let (worked, uiaainfo) = services().uiaa.try_auth(sender_user, sender_device, auth, &uiaainfo)?;
|
||||
if !worked {
|
||||
return Err(Error::Uiaa(uiaainfo));
|
||||
}
|
||||
// Success!
|
||||
} else if let Some(json) = body.json_body {
|
||||
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
|
||||
services().uiaa.create(sender_user, sender_device, &uiaainfo, &json)?;
|
||||
return Err(Error::Uiaa(uiaainfo));
|
||||
} else {
|
||||
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
|
||||
}
|
||||
|
||||
services()
|
||||
.users
|
||||
.set_password(sender_user, Some(&body.new_password))?;
|
||||
services().users.set_password(sender_user, Some(&body.new_password))?;
|
||||
|
||||
if body.logout_devices {
|
||||
// Logout all devices except the current one
|
||||
for id in services()
|
||||
.users
|
||||
.all_device_ids(sender_user)
|
||||
.filter_map(std::result::Result::ok)
|
||||
.filter(|id| id != sender_device)
|
||||
{
|
||||
services().users.remove_device(sender_user, &id)?;
|
||||
}
|
||||
}
|
||||
if body.logout_devices {
|
||||
// Logout all devices except the current one
|
||||
for id in services()
|
||||
.users
|
||||
.all_device_ids(sender_user)
|
||||
.filter_map(std::result::Result::ok)
|
||||
.filter(|id| id != sender_device)
|
||||
{
|
||||
services().users.remove_device(sender_user, &id)?;
|
||||
}
|
||||
}
|
||||
|
||||
info!("User {} changed their password.", sender_user);
|
||||
services()
|
||||
.admin
|
||||
.send_message(RoomMessageEventContent::notice_plain(format!(
|
||||
"User {sender_user} changed their password."
|
||||
)));
|
||||
info!("User {} changed their password.", sender_user);
|
||||
services().admin.send_message(RoomMessageEventContent::notice_plain(format!(
|
||||
"User {sender_user} changed their password."
|
||||
)));
|
||||
|
||||
Ok(change_password::v3::Response {})
|
||||
Ok(change_password::v3::Response {})
|
||||
}
|
||||
|
||||
/// # `GET _matrix/client/r0/account/whoami`
|
||||
@@ -408,14 +363,14 @@ pub async fn change_password_route(
|
||||
///
|
||||
/// Note: Also works for Application Services
|
||||
pub async fn whoami_route(body: Ruma<whoami::v3::Request>) -> Result<whoami::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let device_id = body.sender_device.clone();
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let device_id = body.sender_device.clone();
|
||||
|
||||
Ok(whoami::v3::Response {
|
||||
user_id: sender_user.clone(),
|
||||
device_id,
|
||||
is_guest: services().users.is_deactivated(sender_user)? && !body.from_appservice,
|
||||
})
|
||||
Ok(whoami::v3::Response {
|
||||
user_id: sender_user.clone(),
|
||||
device_id,
|
||||
is_guest: services().users.is_deactivated(sender_user)? && !body.from_appservice,
|
||||
})
|
||||
}
|
||||
|
||||
/// # `POST /_matrix/client/r0/account/deactivate`
|
||||
@@ -424,61 +379,53 @@ pub async fn whoami_route(body: Ruma<whoami::v3::Request>) -> Result<whoami::v3:
|
||||
///
|
||||
/// - Leaves all rooms and rejects all invitations
|
||||
/// - Invalidates all access tokens
|
||||
/// - Deletes all device metadata (device id, device display name, last seen ip, last seen ts)
|
||||
/// - Deletes all device metadata (device id, device display name, last seen ip,
|
||||
/// last seen ts)
|
||||
/// - Forgets all to-device events
|
||||
/// - Triggers device list updates
|
||||
/// - Removes ability to log in again
|
||||
pub async fn deactivate_route(
|
||||
body: Ruma<deactivate::v3::Request>,
|
||||
) -> Result<deactivate::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
|
||||
pub async fn deactivate_route(body: Ruma<deactivate::v3::Request>) -> Result<deactivate::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
|
||||
|
||||
let mut uiaainfo = UiaaInfo {
|
||||
flows: vec![AuthFlow {
|
||||
stages: vec![AuthType::Password],
|
||||
}],
|
||||
completed: Vec::new(),
|
||||
params: Box::default(),
|
||||
session: None,
|
||||
auth_error: None,
|
||||
};
|
||||
let mut uiaainfo = UiaaInfo {
|
||||
flows: vec![AuthFlow {
|
||||
stages: vec![AuthType::Password],
|
||||
}],
|
||||
completed: Vec::new(),
|
||||
params: Box::default(),
|
||||
session: None,
|
||||
auth_error: None,
|
||||
};
|
||||
|
||||
if let Some(auth) = &body.auth {
|
||||
let (worked, uiaainfo) =
|
||||
services()
|
||||
.uiaa
|
||||
.try_auth(sender_user, sender_device, auth, &uiaainfo)?;
|
||||
if !worked {
|
||||
return Err(Error::Uiaa(uiaainfo));
|
||||
}
|
||||
// Success!
|
||||
} else if let Some(json) = body.json_body {
|
||||
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
|
||||
services()
|
||||
.uiaa
|
||||
.create(sender_user, sender_device, &uiaainfo, &json)?;
|
||||
return Err(Error::Uiaa(uiaainfo));
|
||||
} else {
|
||||
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
|
||||
}
|
||||
if let Some(auth) = &body.auth {
|
||||
let (worked, uiaainfo) = services().uiaa.try_auth(sender_user, sender_device, auth, &uiaainfo)?;
|
||||
if !worked {
|
||||
return Err(Error::Uiaa(uiaainfo));
|
||||
}
|
||||
// Success!
|
||||
} else if let Some(json) = body.json_body {
|
||||
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
|
||||
services().uiaa.create(sender_user, sender_device, &uiaainfo, &json)?;
|
||||
return Err(Error::Uiaa(uiaainfo));
|
||||
} else {
|
||||
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
|
||||
}
|
||||
|
||||
// Make the user leave all rooms before deactivation
|
||||
client_server::leave_all_rooms(sender_user).await?;
|
||||
// Make the user leave all rooms before deactivation
|
||||
client_server::leave_all_rooms(sender_user).await?;
|
||||
|
||||
// Remove devices and mark account as deactivated
|
||||
services().users.deactivate_account(sender_user)?;
|
||||
// Remove devices and mark account as deactivated
|
||||
services().users.deactivate_account(sender_user)?;
|
||||
|
||||
info!("User {} deactivated their account.", sender_user);
|
||||
services()
|
||||
.admin
|
||||
.send_message(RoomMessageEventContent::notice_plain(format!(
|
||||
"User {sender_user} deactivated their account."
|
||||
)));
|
||||
info!("User {} deactivated their account.", sender_user);
|
||||
services().admin.send_message(RoomMessageEventContent::notice_plain(format!(
|
||||
"User {sender_user} deactivated their account."
|
||||
)));
|
||||
|
||||
Ok(deactivate::v3::Response {
|
||||
id_server_unbind_result: ThirdPartyIdRemovalStatus::NoSupport,
|
||||
})
|
||||
Ok(deactivate::v3::Response {
|
||||
id_server_unbind_result: ThirdPartyIdRemovalStatus::NoSupport,
|
||||
})
|
||||
}
|
||||
|
||||
/// # `GET _matrix/client/v3/account/3pid`
|
||||
@@ -486,38 +433,40 @@ pub async fn deactivate_route(
|
||||
/// Get a list of third party identifiers associated with this account.
|
||||
///
|
||||
/// - Currently always returns empty list
|
||||
pub async fn third_party_route(
|
||||
body: Ruma<get_3pids::v3::Request>,
|
||||
) -> Result<get_3pids::v3::Response> {
|
||||
let _sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
pub async fn third_party_route(body: Ruma<get_3pids::v3::Request>) -> Result<get_3pids::v3::Response> {
|
||||
let _sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
Ok(get_3pids::v3::Response::new(Vec::new()))
|
||||
Ok(get_3pids::v3::Response::new(Vec::new()))
|
||||
}
|
||||
|
||||
/// # `POST /_matrix/client/v3/account/3pid/email/requestToken`
|
||||
///
|
||||
/// "This API should be used to request validation tokens when adding an email address to an account"
|
||||
/// "This API should be used to request validation tokens when adding an email
|
||||
/// address to an account"
|
||||
///
|
||||
/// - 403 signals that The homeserver does not allow the third party identifier as a contact option.
|
||||
/// - 403 signals that The homeserver does not allow the third party identifier
|
||||
/// as a contact option.
|
||||
pub async fn request_3pid_management_token_via_email_route(
|
||||
_body: Ruma<request_3pid_management_token_via_email::v3::Request>,
|
||||
_body: Ruma<request_3pid_management_token_via_email::v3::Request>,
|
||||
) -> Result<request_3pid_management_token_via_email::v3::Response> {
|
||||
Err(Error::BadRequest(
|
||||
ErrorKind::ThreepidDenied,
|
||||
"Third party identifier is not allowed",
|
||||
))
|
||||
Err(Error::BadRequest(
|
||||
ErrorKind::ThreepidDenied,
|
||||
"Third party identifier is not allowed",
|
||||
))
|
||||
}
|
||||
|
||||
/// # `POST /_matrix/client/v3/account/3pid/msisdn/requestToken`
|
||||
///
|
||||
/// "This API should be used to request validation tokens when adding an phone number to an account"
|
||||
/// "This API should be used to request validation tokens when adding an phone
|
||||
/// number to an account"
|
||||
///
|
||||
/// - 403 signals that The homeserver does not allow the third party identifier as a contact option.
|
||||
/// - 403 signals that The homeserver does not allow the third party identifier
|
||||
/// as a contact option.
|
||||
pub async fn request_3pid_management_token_via_msisdn_route(
|
||||
_body: Ruma<request_3pid_management_token_via_msisdn::v3::Request>,
|
||||
_body: Ruma<request_3pid_management_token_via_msisdn::v3::Request>,
|
||||
) -> Result<request_3pid_management_token_via_msisdn::v3::Response> {
|
||||
Err(Error::BadRequest(
|
||||
ErrorKind::ThreepidDenied,
|
||||
"Third party identifier is not allowed",
|
||||
))
|
||||
Err(Error::BadRequest(
|
||||
ErrorKind::ThreepidDenied,
|
||||
"Third party identifier is not allowed",
|
||||
))
|
||||
}
|
||||
|
||||
+134
-201
@@ -1,64 +1,43 @@
|
||||
use crate::{services, Error, Result, Ruma};
|
||||
use rand::seq::SliceRandom;
|
||||
use regex::Regex;
|
||||
use ruma::{
|
||||
api::{
|
||||
appservice,
|
||||
client::{
|
||||
alias::{create_alias, delete_alias, get_alias},
|
||||
error::ErrorKind,
|
||||
},
|
||||
federation,
|
||||
},
|
||||
OwnedRoomAliasId, OwnedServerName,
|
||||
api::{
|
||||
appservice,
|
||||
client::{
|
||||
alias::{create_alias, delete_alias, get_alias},
|
||||
error::ErrorKind,
|
||||
},
|
||||
federation,
|
||||
},
|
||||
OwnedRoomAliasId, OwnedServerName,
|
||||
};
|
||||
|
||||
use crate::{services, Error, Result, Ruma};
|
||||
|
||||
/// # `PUT /_matrix/client/v3/directory/room/{roomAlias}`
|
||||
///
|
||||
/// Creates a new room alias on this server.
|
||||
pub async fn create_alias_route(
|
||||
body: Ruma<create_alias::v3::Request>,
|
||||
) -> Result<create_alias::v3::Response> {
|
||||
if body.room_alias.server_name() != services().globals.server_name() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Alias is from another server.",
|
||||
));
|
||||
}
|
||||
pub async fn create_alias_route(body: Ruma<create_alias::v3::Request>) -> Result<create_alias::v3::Response> {
|
||||
if body.room_alias.server_name() != services().globals.server_name() {
|
||||
return Err(Error::BadRequest(ErrorKind::InvalidParam, "Alias is from another server."));
|
||||
}
|
||||
|
||||
if services()
|
||||
.globals
|
||||
.forbidden_room_names()
|
||||
.is_match(body.room_alias.alias())
|
||||
{
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Unknown,
|
||||
"Room alias is forbidden.",
|
||||
));
|
||||
}
|
||||
if services().globals.forbidden_room_names().is_match(body.room_alias.alias()) {
|
||||
return Err(Error::BadRequest(ErrorKind::Unknown, "Room alias is forbidden."));
|
||||
}
|
||||
|
||||
if services()
|
||||
.rooms
|
||||
.alias
|
||||
.resolve_local_alias(&body.room_alias)?
|
||||
.is_some()
|
||||
{
|
||||
return Err(Error::Conflict("Alias already exists."));
|
||||
}
|
||||
if services().rooms.alias.resolve_local_alias(&body.room_alias)?.is_some() {
|
||||
return Err(Error::Conflict("Alias already exists."));
|
||||
}
|
||||
|
||||
if services()
|
||||
.rooms
|
||||
.alias
|
||||
.set_alias(&body.room_alias, &body.room_id)
|
||||
.is_err()
|
||||
{
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Invalid room alias. Alias must be in the form of '#localpart:server_name'",
|
||||
));
|
||||
};
|
||||
if services().rooms.alias.set_alias(&body.room_alias, &body.room_id).is_err() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Invalid room alias. Alias must be in the form of '#localpart:server_name'",
|
||||
));
|
||||
};
|
||||
|
||||
Ok(create_alias::v3::Response::new())
|
||||
Ok(create_alias::v3::Response::new())
|
||||
}
|
||||
|
||||
/// # `DELETE /_matrix/client/v3/directory/room/{roomAlias}`
|
||||
@@ -67,183 +46,137 @@ pub async fn create_alias_route(
|
||||
///
|
||||
/// - TODO: additional access control checks
|
||||
/// - TODO: Update canonical alias event
|
||||
pub async fn delete_alias_route(
|
||||
body: Ruma<delete_alias::v3::Request>,
|
||||
) -> Result<delete_alias::v3::Response> {
|
||||
if body.room_alias.server_name() != services().globals.server_name() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Alias is from another server.",
|
||||
));
|
||||
}
|
||||
pub async fn delete_alias_route(body: Ruma<delete_alias::v3::Request>) -> Result<delete_alias::v3::Response> {
|
||||
if body.room_alias.server_name() != services().globals.server_name() {
|
||||
return Err(Error::BadRequest(ErrorKind::InvalidParam, "Alias is from another server."));
|
||||
}
|
||||
|
||||
if services()
|
||||
.rooms
|
||||
.alias
|
||||
.resolve_local_alias(&body.room_alias)?
|
||||
.is_none()
|
||||
{
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"Alias does not exist.",
|
||||
));
|
||||
}
|
||||
if services().rooms.alias.resolve_local_alias(&body.room_alias)?.is_none() {
|
||||
return Err(Error::BadRequest(ErrorKind::NotFound, "Alias does not exist."));
|
||||
}
|
||||
|
||||
if services()
|
||||
.rooms
|
||||
.alias
|
||||
.remove_alias(&body.room_alias)
|
||||
.is_err()
|
||||
{
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Invalid room alias. Alias must be in the form of '#localpart:server_name'",
|
||||
));
|
||||
};
|
||||
if services().rooms.alias.remove_alias(&body.room_alias).is_err() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Invalid room alias. Alias must be in the form of '#localpart:server_name'",
|
||||
));
|
||||
};
|
||||
|
||||
// TODO: update alt_aliases?
|
||||
// TODO: update alt_aliases?
|
||||
|
||||
Ok(delete_alias::v3::Response::new())
|
||||
Ok(delete_alias::v3::Response::new())
|
||||
}
|
||||
|
||||
/// # `GET /_matrix/client/v3/directory/room/{roomAlias}`
|
||||
///
|
||||
/// Resolve an alias locally or over federation.
|
||||
pub async fn get_alias_route(
|
||||
body: Ruma<get_alias::v3::Request>,
|
||||
) -> Result<get_alias::v3::Response> {
|
||||
get_alias_helper(body.body.room_alias).await
|
||||
pub async fn get_alias_route(body: Ruma<get_alias::v3::Request>) -> Result<get_alias::v3::Response> {
|
||||
get_alias_helper(body.body.room_alias).await
|
||||
}
|
||||
|
||||
pub(crate) async fn get_alias_helper(
|
||||
room_alias: OwnedRoomAliasId,
|
||||
) -> Result<get_alias::v3::Response> {
|
||||
if room_alias.server_name() != services().globals.server_name() {
|
||||
let response = services()
|
||||
.sending
|
||||
.send_federation_request(
|
||||
room_alias.server_name(),
|
||||
federation::query::get_room_information::v1::Request {
|
||||
room_alias: room_alias.clone(),
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
pub(crate) async fn get_alias_helper(room_alias: OwnedRoomAliasId) -> Result<get_alias::v3::Response> {
|
||||
if room_alias.server_name() != services().globals.server_name() {
|
||||
let response = services()
|
||||
.sending
|
||||
.send_federation_request(
|
||||
room_alias.server_name(),
|
||||
federation::query::get_room_information::v1::Request {
|
||||
room_alias: room_alias.clone(),
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
let room_id = response.room_id;
|
||||
let room_id = response.room_id;
|
||||
|
||||
let mut servers = response.servers;
|
||||
let mut servers = response.servers;
|
||||
|
||||
// find active servers in room state cache to suggest
|
||||
for extra_servers in services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.room_servers(&room_id)
|
||||
.filter_map(std::result::Result::ok)
|
||||
{
|
||||
servers.push(extra_servers);
|
||||
}
|
||||
// find active servers in room state cache to suggest
|
||||
for extra_servers in services().rooms.state_cache.room_servers(&room_id).filter_map(std::result::Result::ok) {
|
||||
servers.push(extra_servers);
|
||||
}
|
||||
|
||||
// insert our server as the very first choice if in list
|
||||
if let Some(server_index) = servers
|
||||
.clone()
|
||||
.into_iter()
|
||||
.position(|server| server == services().globals.server_name())
|
||||
{
|
||||
servers.remove(server_index);
|
||||
servers.insert(0, services().globals.server_name().to_owned());
|
||||
}
|
||||
// insert our server as the very first choice if in list
|
||||
if let Some(server_index) =
|
||||
servers.clone().into_iter().position(|server| server == services().globals.server_name())
|
||||
{
|
||||
servers.remove(server_index);
|
||||
servers.insert(0, services().globals.server_name().to_owned());
|
||||
}
|
||||
|
||||
servers.sort_unstable();
|
||||
servers.dedup();
|
||||
servers.sort_unstable();
|
||||
servers.dedup();
|
||||
|
||||
// shuffle list of servers randomly after sort and dedupe
|
||||
servers.shuffle(&mut rand::thread_rng());
|
||||
// shuffle list of servers randomly after sort and dedupe
|
||||
servers.shuffle(&mut rand::thread_rng());
|
||||
|
||||
return Ok(get_alias::v3::Response::new(room_id, servers));
|
||||
}
|
||||
return Ok(get_alias::v3::Response::new(room_id, servers));
|
||||
}
|
||||
|
||||
let mut room_id = None;
|
||||
match services().rooms.alias.resolve_local_alias(&room_alias)? {
|
||||
Some(r) => room_id = Some(r),
|
||||
None => {
|
||||
for (_id, registration) in services().appservice.all()? {
|
||||
let aliases = registration
|
||||
.namespaces
|
||||
.aliases
|
||||
.iter()
|
||||
.filter_map(|alias| Regex::new(alias.regex.as_str()).ok())
|
||||
.collect::<Vec<_>>();
|
||||
let mut room_id = None;
|
||||
match services().rooms.alias.resolve_local_alias(&room_alias)? {
|
||||
Some(r) => room_id = Some(r),
|
||||
None => {
|
||||
for (_id, registration) in services().appservice.all()? {
|
||||
let aliases = registration
|
||||
.namespaces
|
||||
.aliases
|
||||
.iter()
|
||||
.filter_map(|alias| Regex::new(alias.regex.as_str()).ok())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if aliases
|
||||
.iter()
|
||||
.any(|aliases| aliases.is_match(room_alias.as_str()))
|
||||
&& if let Some(opt_result) = services()
|
||||
.sending
|
||||
.send_appservice_request(
|
||||
registration,
|
||||
appservice::query::query_room_alias::v1::Request {
|
||||
room_alias: room_alias.clone(),
|
||||
},
|
||||
)
|
||||
.await
|
||||
{
|
||||
opt_result.is_ok()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
{
|
||||
room_id = Some(
|
||||
services()
|
||||
.rooms
|
||||
.alias
|
||||
.resolve_local_alias(&room_alias)?
|
||||
.ok_or_else(|| {
|
||||
Error::bad_config("Appservice lied to us. Room does not exist.")
|
||||
})?,
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
if aliases.iter().any(|aliases| aliases.is_match(room_alias.as_str()))
|
||||
&& if let Some(opt_result) = services()
|
||||
.sending
|
||||
.send_appservice_request(
|
||||
registration,
|
||||
appservice::query::query_room_alias::v1::Request {
|
||||
room_alias: room_alias.clone(),
|
||||
},
|
||||
)
|
||||
.await
|
||||
{
|
||||
opt_result.is_ok()
|
||||
} else {
|
||||
false
|
||||
} {
|
||||
room_id = Some(
|
||||
services()
|
||||
.rooms
|
||||
.alias
|
||||
.resolve_local_alias(&room_alias)?
|
||||
.ok_or_else(|| Error::bad_config("Appservice lied to us. Room does not exist."))?,
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
let room_id = match room_id {
|
||||
Some(room_id) => room_id,
|
||||
None => {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"Room with alias not found.",
|
||||
))
|
||||
}
|
||||
};
|
||||
let room_id = match room_id {
|
||||
Some(room_id) => room_id,
|
||||
None => return Err(Error::BadRequest(ErrorKind::NotFound, "Room with alias not found.")),
|
||||
};
|
||||
|
||||
let mut servers: Vec<OwnedServerName> = Vec::new();
|
||||
let mut servers: Vec<OwnedServerName> = Vec::new();
|
||||
|
||||
// find active servers in room state cache to suggest
|
||||
for extra_servers in services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.room_servers(&room_id)
|
||||
.filter_map(std::result::Result::ok)
|
||||
{
|
||||
servers.push(extra_servers);
|
||||
}
|
||||
// find active servers in room state cache to suggest
|
||||
for extra_servers in services().rooms.state_cache.room_servers(&room_id).filter_map(std::result::Result::ok) {
|
||||
servers.push(extra_servers);
|
||||
}
|
||||
|
||||
// insert our server as the very first choice if in list
|
||||
if let Some(server_index) = servers
|
||||
.clone()
|
||||
.into_iter()
|
||||
.position(|server| server == services().globals.server_name())
|
||||
{
|
||||
servers.remove(server_index);
|
||||
servers.insert(0, services().globals.server_name().to_owned());
|
||||
}
|
||||
// insert our server as the very first choice if in list
|
||||
if let Some(server_index) =
|
||||
servers.clone().into_iter().position(|server| server == services().globals.server_name())
|
||||
{
|
||||
servers.remove(server_index);
|
||||
servers.insert(0, services().globals.server_name().to_owned());
|
||||
}
|
||||
|
||||
servers.sort_unstable();
|
||||
servers.dedup();
|
||||
servers.sort_unstable();
|
||||
servers.dedup();
|
||||
|
||||
// shuffle list of servers randomly after sort and dedupe
|
||||
servers.shuffle(&mut rand::thread_rng());
|
||||
// shuffle list of servers randomly after sort and dedupe
|
||||
servers.shuffle(&mut rand::thread_rng());
|
||||
|
||||
Ok(get_alias::v3::Response::new(room_id, servers))
|
||||
Ok(get_alias::v3::Response::new(room_id, servers))
|
||||
}
|
||||
|
||||
+144
-231
@@ -1,362 +1,275 @@
|
||||
use crate::{services, Error, Result, Ruma};
|
||||
use ruma::api::client::{
|
||||
backup::{
|
||||
add_backup_keys, add_backup_keys_for_room, add_backup_keys_for_session,
|
||||
create_backup_version, delete_backup_keys, delete_backup_keys_for_room,
|
||||
delete_backup_keys_for_session, delete_backup_version, get_backup_info, get_backup_keys,
|
||||
get_backup_keys_for_room, get_backup_keys_for_session, get_latest_backup_info,
|
||||
update_backup_version,
|
||||
},
|
||||
error::ErrorKind,
|
||||
backup::{
|
||||
add_backup_keys, add_backup_keys_for_room, add_backup_keys_for_session, create_backup_version,
|
||||
delete_backup_keys, delete_backup_keys_for_room, delete_backup_keys_for_session, delete_backup_version,
|
||||
get_backup_info, get_backup_keys, get_backup_keys_for_room, get_backup_keys_for_session,
|
||||
get_latest_backup_info, update_backup_version,
|
||||
},
|
||||
error::ErrorKind,
|
||||
};
|
||||
|
||||
use crate::{services, Error, Result, Ruma};
|
||||
|
||||
/// # `POST /_matrix/client/r0/room_keys/version`
|
||||
///
|
||||
/// Creates a new backup.
|
||||
pub async fn create_backup_version_route(
|
||||
body: Ruma<create_backup_version::v3::Request>,
|
||||
body: Ruma<create_backup_version::v3::Request>,
|
||||
) -> Result<create_backup_version::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let version = services()
|
||||
.key_backups
|
||||
.create_backup(sender_user, &body.algorithm)?;
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let version = services().key_backups.create_backup(sender_user, &body.algorithm)?;
|
||||
|
||||
Ok(create_backup_version::v3::Response { version })
|
||||
Ok(create_backup_version::v3::Response {
|
||||
version,
|
||||
})
|
||||
}
|
||||
|
||||
/// # `PUT /_matrix/client/r0/room_keys/version/{version}`
|
||||
///
|
||||
/// Update information about an existing backup. Only `auth_data` can be modified.
|
||||
/// Update information about an existing backup. Only `auth_data` can be
|
||||
/// modified.
|
||||
pub async fn update_backup_version_route(
|
||||
body: Ruma<update_backup_version::v3::Request>,
|
||||
body: Ruma<update_backup_version::v3::Request>,
|
||||
) -> Result<update_backup_version::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
services()
|
||||
.key_backups
|
||||
.update_backup(sender_user, &body.version, &body.algorithm)?;
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
services().key_backups.update_backup(sender_user, &body.version, &body.algorithm)?;
|
||||
|
||||
Ok(update_backup_version::v3::Response {})
|
||||
Ok(update_backup_version::v3::Response {})
|
||||
}
|
||||
|
||||
/// # `GET /_matrix/client/r0/room_keys/version`
|
||||
///
|
||||
/// Get information about the latest backup version.
|
||||
pub async fn get_latest_backup_info_route(
|
||||
body: Ruma<get_latest_backup_info::v3::Request>,
|
||||
body: Ruma<get_latest_backup_info::v3::Request>,
|
||||
) -> Result<get_latest_backup_info::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let (version, algorithm) = services()
|
||||
.key_backups
|
||||
.get_latest_backup(sender_user)?
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"Key backup does not exist.",
|
||||
))?;
|
||||
let (version, algorithm) = services()
|
||||
.key_backups
|
||||
.get_latest_backup(sender_user)?
|
||||
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Key backup does not exist."))?;
|
||||
|
||||
Ok(get_latest_backup_info::v3::Response {
|
||||
algorithm,
|
||||
count: (services().key_backups.count_keys(sender_user, &version)? as u32).into(),
|
||||
etag: services().key_backups.get_etag(sender_user, &version)?,
|
||||
version,
|
||||
})
|
||||
Ok(get_latest_backup_info::v3::Response {
|
||||
algorithm,
|
||||
count: (services().key_backups.count_keys(sender_user, &version)? as u32).into(),
|
||||
etag: services().key_backups.get_etag(sender_user, &version)?,
|
||||
version,
|
||||
})
|
||||
}
|
||||
|
||||
/// # `GET /_matrix/client/r0/room_keys/version`
|
||||
///
|
||||
/// Get information about an existing backup.
|
||||
pub async fn get_backup_info_route(
|
||||
body: Ruma<get_backup_info::v3::Request>,
|
||||
) -> Result<get_backup_info::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let algorithm = services()
|
||||
.key_backups
|
||||
.get_backup(sender_user, &body.version)?
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"Key backup does not exist.",
|
||||
))?;
|
||||
pub async fn get_backup_info_route(body: Ruma<get_backup_info::v3::Request>) -> Result<get_backup_info::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let algorithm = services()
|
||||
.key_backups
|
||||
.get_backup(sender_user, &body.version)?
|
||||
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Key backup does not exist."))?;
|
||||
|
||||
Ok(get_backup_info::v3::Response {
|
||||
algorithm,
|
||||
count: (services()
|
||||
.key_backups
|
||||
.count_keys(sender_user, &body.version)? as u32)
|
||||
.into(),
|
||||
etag: services()
|
||||
.key_backups
|
||||
.get_etag(sender_user, &body.version)?,
|
||||
version: body.version.clone(),
|
||||
})
|
||||
Ok(get_backup_info::v3::Response {
|
||||
algorithm,
|
||||
count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(),
|
||||
etag: services().key_backups.get_etag(sender_user, &body.version)?,
|
||||
version: body.version.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
/// # `DELETE /_matrix/client/r0/room_keys/version/{version}`
|
||||
///
|
||||
/// Delete an existing key backup.
|
||||
///
|
||||
/// - Deletes both information about the backup, as well as all key data related to the backup
|
||||
/// - Deletes both information about the backup, as well as all key data related
|
||||
/// to the backup
|
||||
pub async fn delete_backup_version_route(
|
||||
body: Ruma<delete_backup_version::v3::Request>,
|
||||
body: Ruma<delete_backup_version::v3::Request>,
|
||||
) -> Result<delete_backup_version::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
services()
|
||||
.key_backups
|
||||
.delete_backup(sender_user, &body.version)?;
|
||||
services().key_backups.delete_backup(sender_user, &body.version)?;
|
||||
|
||||
Ok(delete_backup_version::v3::Response {})
|
||||
Ok(delete_backup_version::v3::Response {})
|
||||
}
|
||||
|
||||
/// # `PUT /_matrix/client/r0/room_keys/keys`
|
||||
///
|
||||
/// Add the received backup keys to the database.
|
||||
///
|
||||
/// - Only manipulating the most recently created version of the backup is allowed
|
||||
/// - Only manipulating the most recently created version of the backup is
|
||||
/// allowed
|
||||
/// - Adds the keys to the backup
|
||||
/// - Returns the new number of keys in this backup and the etag
|
||||
pub async fn add_backup_keys_route(
|
||||
body: Ruma<add_backup_keys::v3::Request>,
|
||||
) -> Result<add_backup_keys::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
pub async fn add_backup_keys_route(body: Ruma<add_backup_keys::v3::Request>) -> Result<add_backup_keys::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
if Some(&body.version)
|
||||
!= services()
|
||||
.key_backups
|
||||
.get_latest_backup_version(sender_user)?
|
||||
.as_ref()
|
||||
{
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"You may only manipulate the most recently created version of the backup.",
|
||||
));
|
||||
}
|
||||
if Some(&body.version) != services().key_backups.get_latest_backup_version(sender_user)?.as_ref() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"You may only manipulate the most recently created version of the backup.",
|
||||
));
|
||||
}
|
||||
|
||||
for (room_id, room) in &body.rooms {
|
||||
for (session_id, key_data) in &room.sessions {
|
||||
services().key_backups.add_key(
|
||||
sender_user,
|
||||
&body.version,
|
||||
room_id,
|
||||
session_id,
|
||||
key_data,
|
||||
)?;
|
||||
}
|
||||
}
|
||||
for (room_id, room) in &body.rooms {
|
||||
for (session_id, key_data) in &room.sessions {
|
||||
services().key_backups.add_key(sender_user, &body.version, room_id, session_id, key_data)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(add_backup_keys::v3::Response {
|
||||
count: (services()
|
||||
.key_backups
|
||||
.count_keys(sender_user, &body.version)? as u32)
|
||||
.into(),
|
||||
etag: services()
|
||||
.key_backups
|
||||
.get_etag(sender_user, &body.version)?,
|
||||
})
|
||||
Ok(add_backup_keys::v3::Response {
|
||||
count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(),
|
||||
etag: services().key_backups.get_etag(sender_user, &body.version)?,
|
||||
})
|
||||
}
|
||||
|
||||
/// # `PUT /_matrix/client/r0/room_keys/keys/{roomId}`
|
||||
///
|
||||
/// Add the received backup keys to the database.
|
||||
///
|
||||
/// - Only manipulating the most recently created version of the backup is allowed
|
||||
/// - Only manipulating the most recently created version of the backup is
|
||||
/// allowed
|
||||
/// - Adds the keys to the backup
|
||||
/// - Returns the new number of keys in this backup and the etag
|
||||
pub async fn add_backup_keys_for_room_route(
|
||||
body: Ruma<add_backup_keys_for_room::v3::Request>,
|
||||
body: Ruma<add_backup_keys_for_room::v3::Request>,
|
||||
) -> Result<add_backup_keys_for_room::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
if Some(&body.version)
|
||||
!= services()
|
||||
.key_backups
|
||||
.get_latest_backup_version(sender_user)?
|
||||
.as_ref()
|
||||
{
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"You may only manipulate the most recently created version of the backup.",
|
||||
));
|
||||
}
|
||||
if Some(&body.version) != services().key_backups.get_latest_backup_version(sender_user)?.as_ref() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"You may only manipulate the most recently created version of the backup.",
|
||||
));
|
||||
}
|
||||
|
||||
for (session_id, key_data) in &body.sessions {
|
||||
services().key_backups.add_key(
|
||||
sender_user,
|
||||
&body.version,
|
||||
&body.room_id,
|
||||
session_id,
|
||||
key_data,
|
||||
)?;
|
||||
}
|
||||
for (session_id, key_data) in &body.sessions {
|
||||
services().key_backups.add_key(sender_user, &body.version, &body.room_id, session_id, key_data)?;
|
||||
}
|
||||
|
||||
Ok(add_backup_keys_for_room::v3::Response {
|
||||
count: (services()
|
||||
.key_backups
|
||||
.count_keys(sender_user, &body.version)? as u32)
|
||||
.into(),
|
||||
etag: services()
|
||||
.key_backups
|
||||
.get_etag(sender_user, &body.version)?,
|
||||
})
|
||||
Ok(add_backup_keys_for_room::v3::Response {
|
||||
count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(),
|
||||
etag: services().key_backups.get_etag(sender_user, &body.version)?,
|
||||
})
|
||||
}
|
||||
|
||||
/// # `PUT /_matrix/client/r0/room_keys/keys/{roomId}/{sessionId}`
|
||||
///
|
||||
/// Add the received backup key to the database.
|
||||
///
|
||||
/// - Only manipulating the most recently created version of the backup is allowed
|
||||
/// - Only manipulating the most recently created version of the backup is
|
||||
/// allowed
|
||||
/// - Adds the keys to the backup
|
||||
/// - Returns the new number of keys in this backup and the etag
|
||||
pub async fn add_backup_keys_for_session_route(
|
||||
body: Ruma<add_backup_keys_for_session::v3::Request>,
|
||||
body: Ruma<add_backup_keys_for_session::v3::Request>,
|
||||
) -> Result<add_backup_keys_for_session::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
if Some(&body.version)
|
||||
!= services()
|
||||
.key_backups
|
||||
.get_latest_backup_version(sender_user)?
|
||||
.as_ref()
|
||||
{
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"You may only manipulate the most recently created version of the backup.",
|
||||
));
|
||||
}
|
||||
if Some(&body.version) != services().key_backups.get_latest_backup_version(sender_user)?.as_ref() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"You may only manipulate the most recently created version of the backup.",
|
||||
));
|
||||
}
|
||||
|
||||
services().key_backups.add_key(
|
||||
sender_user,
|
||||
&body.version,
|
||||
&body.room_id,
|
||||
&body.session_id,
|
||||
&body.session_data,
|
||||
)?;
|
||||
services().key_backups.add_key(sender_user, &body.version, &body.room_id, &body.session_id, &body.session_data)?;
|
||||
|
||||
Ok(add_backup_keys_for_session::v3::Response {
|
||||
count: (services()
|
||||
.key_backups
|
||||
.count_keys(sender_user, &body.version)? as u32)
|
||||
.into(),
|
||||
etag: services()
|
||||
.key_backups
|
||||
.get_etag(sender_user, &body.version)?,
|
||||
})
|
||||
Ok(add_backup_keys_for_session::v3::Response {
|
||||
count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(),
|
||||
etag: services().key_backups.get_etag(sender_user, &body.version)?,
|
||||
})
|
||||
}
|
||||
|
||||
/// # `GET /_matrix/client/r0/room_keys/keys`
|
||||
///
|
||||
/// Retrieves all keys from the backup.
|
||||
pub async fn get_backup_keys_route(
|
||||
body: Ruma<get_backup_keys::v3::Request>,
|
||||
) -> Result<get_backup_keys::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
pub async fn get_backup_keys_route(body: Ruma<get_backup_keys::v3::Request>) -> Result<get_backup_keys::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let rooms = services().key_backups.get_all(sender_user, &body.version)?;
|
||||
let rooms = services().key_backups.get_all(sender_user, &body.version)?;
|
||||
|
||||
Ok(get_backup_keys::v3::Response { rooms })
|
||||
Ok(get_backup_keys::v3::Response {
|
||||
rooms,
|
||||
})
|
||||
}
|
||||
|
||||
/// # `GET /_matrix/client/r0/room_keys/keys/{roomId}`
|
||||
///
|
||||
/// Retrieves all keys from the backup for a given room.
|
||||
pub async fn get_backup_keys_for_room_route(
|
||||
body: Ruma<get_backup_keys_for_room::v3::Request>,
|
||||
body: Ruma<get_backup_keys_for_room::v3::Request>,
|
||||
) -> Result<get_backup_keys_for_room::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let sessions = services()
|
||||
.key_backups
|
||||
.get_room(sender_user, &body.version, &body.room_id)?;
|
||||
let sessions = services().key_backups.get_room(sender_user, &body.version, &body.room_id)?;
|
||||
|
||||
Ok(get_backup_keys_for_room::v3::Response { sessions })
|
||||
Ok(get_backup_keys_for_room::v3::Response {
|
||||
sessions,
|
||||
})
|
||||
}
|
||||
|
||||
/// # `GET /_matrix/client/r0/room_keys/keys/{roomId}/{sessionId}`
|
||||
///
|
||||
/// Retrieves a key from the backup.
|
||||
pub async fn get_backup_keys_for_session_route(
|
||||
body: Ruma<get_backup_keys_for_session::v3::Request>,
|
||||
body: Ruma<get_backup_keys_for_session::v3::Request>,
|
||||
) -> Result<get_backup_keys_for_session::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let key_data = services()
|
||||
.key_backups
|
||||
.get_session(sender_user, &body.version, &body.room_id, &body.session_id)?
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"Backup key not found for this user's session.",
|
||||
))?;
|
||||
let key_data =
|
||||
services().key_backups.get_session(sender_user, &body.version, &body.room_id, &body.session_id)?.ok_or(
|
||||
Error::BadRequest(ErrorKind::NotFound, "Backup key not found for this user's session."),
|
||||
)?;
|
||||
|
||||
Ok(get_backup_keys_for_session::v3::Response { key_data })
|
||||
Ok(get_backup_keys_for_session::v3::Response {
|
||||
key_data,
|
||||
})
|
||||
}
|
||||
|
||||
/// # `DELETE /_matrix/client/r0/room_keys/keys`
|
||||
///
|
||||
/// Delete the keys from the backup.
|
||||
pub async fn delete_backup_keys_route(
|
||||
body: Ruma<delete_backup_keys::v3::Request>,
|
||||
body: Ruma<delete_backup_keys::v3::Request>,
|
||||
) -> Result<delete_backup_keys::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
services()
|
||||
.key_backups
|
||||
.delete_all_keys(sender_user, &body.version)?;
|
||||
services().key_backups.delete_all_keys(sender_user, &body.version)?;
|
||||
|
||||
Ok(delete_backup_keys::v3::Response {
|
||||
count: (services()
|
||||
.key_backups
|
||||
.count_keys(sender_user, &body.version)? as u32)
|
||||
.into(),
|
||||
etag: services()
|
||||
.key_backups
|
||||
.get_etag(sender_user, &body.version)?,
|
||||
})
|
||||
Ok(delete_backup_keys::v3::Response {
|
||||
count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(),
|
||||
etag: services().key_backups.get_etag(sender_user, &body.version)?,
|
||||
})
|
||||
}
|
||||
|
||||
/// # `DELETE /_matrix/client/r0/room_keys/keys/{roomId}`
|
||||
///
|
||||
/// Delete the keys from the backup for a given room.
|
||||
pub async fn delete_backup_keys_for_room_route(
|
||||
body: Ruma<delete_backup_keys_for_room::v3::Request>,
|
||||
body: Ruma<delete_backup_keys_for_room::v3::Request>,
|
||||
) -> Result<delete_backup_keys_for_room::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
services()
|
||||
.key_backups
|
||||
.delete_room_keys(sender_user, &body.version, &body.room_id)?;
|
||||
services().key_backups.delete_room_keys(sender_user, &body.version, &body.room_id)?;
|
||||
|
||||
Ok(delete_backup_keys_for_room::v3::Response {
|
||||
count: (services()
|
||||
.key_backups
|
||||
.count_keys(sender_user, &body.version)? as u32)
|
||||
.into(),
|
||||
etag: services()
|
||||
.key_backups
|
||||
.get_etag(sender_user, &body.version)?,
|
||||
})
|
||||
Ok(delete_backup_keys_for_room::v3::Response {
|
||||
count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(),
|
||||
etag: services().key_backups.get_etag(sender_user, &body.version)?,
|
||||
})
|
||||
}
|
||||
|
||||
/// # `DELETE /_matrix/client/r0/room_keys/keys/{roomId}/{sessionId}`
|
||||
///
|
||||
/// Delete a key from the backup.
|
||||
pub async fn delete_backup_keys_for_session_route(
|
||||
body: Ruma<delete_backup_keys_for_session::v3::Request>,
|
||||
body: Ruma<delete_backup_keys_for_session::v3::Request>,
|
||||
) -> Result<delete_backup_keys_for_session::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
services().key_backups.delete_room_key(
|
||||
sender_user,
|
||||
&body.version,
|
||||
&body.room_id,
|
||||
&body.session_id,
|
||||
)?;
|
||||
services().key_backups.delete_room_key(sender_user, &body.version, &body.room_id, &body.session_id)?;
|
||||
|
||||
Ok(delete_backup_keys_for_session::v3::Response {
|
||||
count: (services()
|
||||
.key_backups
|
||||
.count_keys(sender_user, &body.version)? as u32)
|
||||
.into(),
|
||||
etag: services()
|
||||
.key_backups
|
||||
.get_etag(sender_user, &body.version)?,
|
||||
})
|
||||
Ok(delete_backup_keys_for_session::v3::Response {
|
||||
count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(),
|
||||
etag: services().key_backups.get_etag(sender_user, &body.version)?,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,28 +1,33 @@
|
||||
use crate::{services, Result, Ruma};
|
||||
use ruma::api::client::discovery::get_capabilities::{
|
||||
self, Capabilities, RoomVersionStability, RoomVersionsCapability,
|
||||
};
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
use ruma::api::client::discovery::get_capabilities::{
|
||||
self, Capabilities, RoomVersionStability, RoomVersionsCapability,
|
||||
};
|
||||
|
||||
use crate::{services, Result, Ruma};
|
||||
|
||||
/// # `GET /_matrix/client/r0/capabilities`
|
||||
///
|
||||
/// Get information on the supported feature set and other relevent capabilities of this server.
|
||||
/// Get information on the supported feature set and other relevent capabilities
|
||||
/// of this server.
|
||||
pub async fn get_capabilities_route(
|
||||
_body: Ruma<get_capabilities::v3::Request>,
|
||||
_body: Ruma<get_capabilities::v3::Request>,
|
||||
) -> Result<get_capabilities::v3::Response> {
|
||||
let mut available = BTreeMap::new();
|
||||
for room_version in &services().globals.unstable_room_versions {
|
||||
available.insert(room_version.clone(), RoomVersionStability::Unstable);
|
||||
}
|
||||
for room_version in &services().globals.stable_room_versions {
|
||||
available.insert(room_version.clone(), RoomVersionStability::Stable);
|
||||
}
|
||||
let mut available = BTreeMap::new();
|
||||
for room_version in &services().globals.unstable_room_versions {
|
||||
available.insert(room_version.clone(), RoomVersionStability::Unstable);
|
||||
}
|
||||
for room_version in &services().globals.stable_room_versions {
|
||||
available.insert(room_version.clone(), RoomVersionStability::Stable);
|
||||
}
|
||||
|
||||
let mut capabilities = Capabilities::new();
|
||||
capabilities.room_versions = RoomVersionsCapability {
|
||||
default: services().globals.default_room_version(),
|
||||
available,
|
||||
};
|
||||
let mut capabilities = Capabilities::new();
|
||||
capabilities.room_versions = RoomVersionsCapability {
|
||||
default: services().globals.default_room_version(),
|
||||
available,
|
||||
};
|
||||
|
||||
Ok(get_capabilities::v3::Response { capabilities })
|
||||
Ok(get_capabilities::v3::Response {
|
||||
capabilities,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,116 +1,118 @@
|
||||
use crate::{services, Error, Result, Ruma};
|
||||
use ruma::{
|
||||
api::client::{
|
||||
config::{
|
||||
get_global_account_data, get_room_account_data, set_global_account_data,
|
||||
set_room_account_data,
|
||||
},
|
||||
error::ErrorKind,
|
||||
},
|
||||
events::{AnyGlobalAccountDataEventContent, AnyRoomAccountDataEventContent},
|
||||
serde::Raw,
|
||||
api::client::{
|
||||
config::{get_global_account_data, get_room_account_data, set_global_account_data, set_room_account_data},
|
||||
error::ErrorKind,
|
||||
},
|
||||
events::{AnyGlobalAccountDataEventContent, AnyRoomAccountDataEventContent},
|
||||
serde::Raw,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, value::RawValue as RawJsonValue};
|
||||
|
||||
use crate::{services, Error, Result, Ruma};
|
||||
|
||||
/// # `PUT /_matrix/client/r0/user/{userId}/account_data/{type}`
|
||||
///
|
||||
/// Sets some account data for the sender user.
|
||||
pub async fn set_global_account_data_route(
|
||||
body: Ruma<set_global_account_data::v3::Request>,
|
||||
body: Ruma<set_global_account_data::v3::Request>,
|
||||
) -> Result<set_global_account_data::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let data: serde_json::Value = serde_json::from_str(body.data.json().get())
|
||||
.map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Data is invalid."))?;
|
||||
let data: serde_json::Value = serde_json::from_str(body.data.json().get())
|
||||
.map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Data is invalid."))?;
|
||||
|
||||
let event_type = body.event_type.to_string();
|
||||
let event_type = body.event_type.to_string();
|
||||
|
||||
services().account_data.update(
|
||||
None,
|
||||
sender_user,
|
||||
event_type.clone().into(),
|
||||
&json!({
|
||||
"type": event_type,
|
||||
"content": data,
|
||||
}),
|
||||
)?;
|
||||
services().account_data.update(
|
||||
None,
|
||||
sender_user,
|
||||
event_type.clone().into(),
|
||||
&json!({
|
||||
"type": event_type,
|
||||
"content": data,
|
||||
}),
|
||||
)?;
|
||||
|
||||
Ok(set_global_account_data::v3::Response {})
|
||||
Ok(set_global_account_data::v3::Response {})
|
||||
}
|
||||
|
||||
/// # `PUT /_matrix/client/r0/user/{userId}/rooms/{roomId}/account_data/{type}`
|
||||
///
|
||||
/// Sets some room account data for the sender user.
|
||||
pub async fn set_room_account_data_route(
|
||||
body: Ruma<set_room_account_data::v3::Request>,
|
||||
body: Ruma<set_room_account_data::v3::Request>,
|
||||
) -> Result<set_room_account_data::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let data: serde_json::Value = serde_json::from_str(body.data.json().get())
|
||||
.map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Data is invalid."))?;
|
||||
let data: serde_json::Value = serde_json::from_str(body.data.json().get())
|
||||
.map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Data is invalid."))?;
|
||||
|
||||
let event_type = body.event_type.to_string();
|
||||
let event_type = body.event_type.to_string();
|
||||
|
||||
services().account_data.update(
|
||||
Some(&body.room_id),
|
||||
sender_user,
|
||||
event_type.clone().into(),
|
||||
&json!({
|
||||
"type": event_type,
|
||||
"content": data,
|
||||
}),
|
||||
)?;
|
||||
services().account_data.update(
|
||||
Some(&body.room_id),
|
||||
sender_user,
|
||||
event_type.clone().into(),
|
||||
&json!({
|
||||
"type": event_type,
|
||||
"content": data,
|
||||
}),
|
||||
)?;
|
||||
|
||||
Ok(set_room_account_data::v3::Response {})
|
||||
Ok(set_room_account_data::v3::Response {})
|
||||
}
|
||||
|
||||
/// # `GET /_matrix/client/r0/user/{userId}/account_data/{type}`
|
||||
///
|
||||
/// Gets some account data for the sender user.
|
||||
pub async fn get_global_account_data_route(
|
||||
body: Ruma<get_global_account_data::v3::Request>,
|
||||
body: Ruma<get_global_account_data::v3::Request>,
|
||||
) -> Result<get_global_account_data::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let event: Box<RawJsonValue> = services()
|
||||
.account_data
|
||||
.get(None, sender_user, body.event_type.to_string().into())?
|
||||
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Data not found."))?;
|
||||
let event: Box<RawJsonValue> = services()
|
||||
.account_data
|
||||
.get(None, sender_user, body.event_type.to_string().into())?
|
||||
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Data not found."))?;
|
||||
|
||||
let account_data = serde_json::from_str::<ExtractGlobalEventContent>(event.get())
|
||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))?
|
||||
.content;
|
||||
let account_data = serde_json::from_str::<ExtractGlobalEventContent>(event.get())
|
||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))?
|
||||
.content;
|
||||
|
||||
Ok(get_global_account_data::v3::Response { account_data })
|
||||
Ok(get_global_account_data::v3::Response {
|
||||
account_data,
|
||||
})
|
||||
}
|
||||
|
||||
/// # `GET /_matrix/client/r0/user/{userId}/rooms/{roomId}/account_data/{type}`
|
||||
///
|
||||
/// Gets some room account data for the sender user.
|
||||
pub async fn get_room_account_data_route(
|
||||
body: Ruma<get_room_account_data::v3::Request>,
|
||||
body: Ruma<get_room_account_data::v3::Request>,
|
||||
) -> Result<get_room_account_data::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let event: Box<RawJsonValue> = services()
|
||||
.account_data
|
||||
.get(Some(&body.room_id), sender_user, body.event_type.clone())?
|
||||
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Data not found."))?;
|
||||
let event: Box<RawJsonValue> = services()
|
||||
.account_data
|
||||
.get(Some(&body.room_id), sender_user, body.event_type.clone())?
|
||||
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Data not found."))?;
|
||||
|
||||
let account_data = serde_json::from_str::<ExtractRoomEventContent>(event.get())
|
||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))?
|
||||
.content;
|
||||
let account_data = serde_json::from_str::<ExtractRoomEventContent>(event.get())
|
||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))?
|
||||
.content;
|
||||
|
||||
Ok(get_room_account_data::v3::Response { account_data })
|
||||
Ok(get_room_account_data::v3::Response {
|
||||
account_data,
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ExtractRoomEventContent {
|
||||
content: Raw<AnyRoomAccountDataEventContent>,
|
||||
content: Raw<AnyRoomAccountDataEventContent>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ExtractGlobalEventContent {
|
||||
content: Raw<AnyGlobalAccountDataEventContent>,
|
||||
content: Raw<AnyGlobalAccountDataEventContent>,
|
||||
}
|
||||
|
||||
+144
-176
@@ -1,209 +1,177 @@
|
||||
use crate::{services, Error, Result, Ruma};
|
||||
use ruma::{
|
||||
api::client::{context::get_context, error::ErrorKind, filter::LazyLoadOptions},
|
||||
events::StateEventType,
|
||||
};
|
||||
use std::collections::HashSet;
|
||||
|
||||
use ruma::{
|
||||
api::client::{context::get_context, error::ErrorKind, filter::LazyLoadOptions},
|
||||
events::StateEventType,
|
||||
};
|
||||
use tracing::error;
|
||||
|
||||
use crate::{services, Error, Result, Ruma};
|
||||
|
||||
/// # `GET /_matrix/client/r0/rooms/{roomId}/context`
|
||||
///
|
||||
/// Allows loading room history around an event.
|
||||
///
|
||||
/// - Only works if the user is joined (TODO: always allow, but only show events if the user was
|
||||
/// - Only works if the user is joined (TODO: always allow, but only show events
|
||||
/// if the user was
|
||||
/// joined, depending on history_visibility)
|
||||
pub async fn get_context_route(
|
||||
body: Ruma<get_context::v3::Request>,
|
||||
) -> Result<get_context::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
|
||||
pub async fn get_context_route(body: Ruma<get_context::v3::Request>) -> Result<get_context::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
|
||||
|
||||
let (lazy_load_enabled, lazy_load_send_redundant) = match &body.filter.lazy_load_options {
|
||||
LazyLoadOptions::Enabled {
|
||||
include_redundant_members,
|
||||
} => (true, *include_redundant_members),
|
||||
LazyLoadOptions::Disabled => (false, false),
|
||||
};
|
||||
let (lazy_load_enabled, lazy_load_send_redundant) = match &body.filter.lazy_load_options {
|
||||
LazyLoadOptions::Enabled {
|
||||
include_redundant_members,
|
||||
} => (true, *include_redundant_members),
|
||||
LazyLoadOptions::Disabled => (false, false),
|
||||
};
|
||||
|
||||
let mut lazy_loaded = HashSet::new();
|
||||
let mut lazy_loaded = HashSet::new();
|
||||
|
||||
let base_token = services()
|
||||
.rooms
|
||||
.timeline
|
||||
.get_pdu_count(&body.event_id)?
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"Base event id not found.",
|
||||
))?;
|
||||
let base_token = services()
|
||||
.rooms
|
||||
.timeline
|
||||
.get_pdu_count(&body.event_id)?
|
||||
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Base event id not found."))?;
|
||||
|
||||
let base_event =
|
||||
services()
|
||||
.rooms
|
||||
.timeline
|
||||
.get_pdu(&body.event_id)?
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"Base event not found.",
|
||||
))?;
|
||||
let base_event = services()
|
||||
.rooms
|
||||
.timeline
|
||||
.get_pdu(&body.event_id)?
|
||||
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Base event not found."))?;
|
||||
|
||||
let room_id = base_event.room_id.clone();
|
||||
let room_id = base_event.room_id.clone();
|
||||
|
||||
if !services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.user_can_see_event(sender_user, &room_id, &body.event_id)?
|
||||
{
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"You don't have permission to view this event.",
|
||||
));
|
||||
}
|
||||
if !services().rooms.state_accessor.user_can_see_event(sender_user, &room_id, &body.event_id)? {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"You don't have permission to view this event.",
|
||||
));
|
||||
}
|
||||
|
||||
if !services().rooms.lazy_loading.lazy_load_was_sent_before(
|
||||
sender_user,
|
||||
sender_device,
|
||||
&room_id,
|
||||
&base_event.sender,
|
||||
)? || lazy_load_send_redundant
|
||||
{
|
||||
lazy_loaded.insert(base_event.sender.as_str().to_owned());
|
||||
}
|
||||
if !services().rooms.lazy_loading.lazy_load_was_sent_before(
|
||||
sender_user,
|
||||
sender_device,
|
||||
&room_id,
|
||||
&base_event.sender,
|
||||
)? || lazy_load_send_redundant
|
||||
{
|
||||
lazy_loaded.insert(base_event.sender.as_str().to_owned());
|
||||
}
|
||||
|
||||
// Use limit with maximum 100
|
||||
let limit = u64::from(body.limit).min(100) as usize;
|
||||
// Use limit with maximum 100
|
||||
let limit = u64::from(body.limit).min(100) as usize;
|
||||
|
||||
let base_event = base_event.to_room_event();
|
||||
let base_event = base_event.to_room_event();
|
||||
|
||||
let events_before: Vec<_> = services()
|
||||
.rooms
|
||||
.timeline
|
||||
.pdus_until(sender_user, &room_id, base_token)?
|
||||
.take(limit / 2)
|
||||
.filter_map(std::result::Result::ok) // Remove buggy events
|
||||
.filter(|(_, pdu)| {
|
||||
services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.user_can_see_event(sender_user, &room_id, &pdu.event_id)
|
||||
.unwrap_or(false)
|
||||
})
|
||||
.collect();
|
||||
let events_before: Vec<_> = services()
|
||||
.rooms
|
||||
.timeline
|
||||
.pdus_until(sender_user, &room_id, base_token)?
|
||||
.take(limit / 2)
|
||||
.filter_map(std::result::Result::ok) // Remove buggy events
|
||||
.filter(|(_, pdu)| {
|
||||
services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.user_can_see_event(sender_user, &room_id, &pdu.event_id)
|
||||
.unwrap_or(false)
|
||||
})
|
||||
.collect();
|
||||
|
||||
for (_, event) in &events_before {
|
||||
if !services().rooms.lazy_loading.lazy_load_was_sent_before(
|
||||
sender_user,
|
||||
sender_device,
|
||||
&room_id,
|
||||
&event.sender,
|
||||
)? || lazy_load_send_redundant
|
||||
{
|
||||
lazy_loaded.insert(event.sender.as_str().to_owned());
|
||||
}
|
||||
}
|
||||
for (_, event) in &events_before {
|
||||
if !services().rooms.lazy_loading.lazy_load_was_sent_before(
|
||||
sender_user,
|
||||
sender_device,
|
||||
&room_id,
|
||||
&event.sender,
|
||||
)? || lazy_load_send_redundant
|
||||
{
|
||||
lazy_loaded.insert(event.sender.as_str().to_owned());
|
||||
}
|
||||
}
|
||||
|
||||
let start_token = events_before
|
||||
.last()
|
||||
.map(|(count, _)| count.stringify())
|
||||
.unwrap_or_else(|| base_token.stringify());
|
||||
let start_token =
|
||||
events_before.last().map(|(count, _)| count.stringify()).unwrap_or_else(|| base_token.stringify());
|
||||
|
||||
let events_before: Vec<_> = events_before
|
||||
.into_iter()
|
||||
.map(|(_, pdu)| pdu.to_room_event())
|
||||
.collect();
|
||||
let events_before: Vec<_> = events_before.into_iter().map(|(_, pdu)| pdu.to_room_event()).collect();
|
||||
|
||||
let events_after: Vec<_> = services()
|
||||
.rooms
|
||||
.timeline
|
||||
.pdus_after(sender_user, &room_id, base_token)?
|
||||
.take(limit / 2)
|
||||
.filter_map(std::result::Result::ok) // Remove buggy events
|
||||
.filter(|(_, pdu)| {
|
||||
services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.user_can_see_event(sender_user, &room_id, &pdu.event_id)
|
||||
.unwrap_or(false)
|
||||
})
|
||||
.collect();
|
||||
let events_after: Vec<_> = services()
|
||||
.rooms
|
||||
.timeline
|
||||
.pdus_after(sender_user, &room_id, base_token)?
|
||||
.take(limit / 2)
|
||||
.filter_map(std::result::Result::ok) // Remove buggy events
|
||||
.filter(|(_, pdu)| {
|
||||
services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.user_can_see_event(sender_user, &room_id, &pdu.event_id)
|
||||
.unwrap_or(false)
|
||||
})
|
||||
.collect();
|
||||
|
||||
for (_, event) in &events_after {
|
||||
if !services().rooms.lazy_loading.lazy_load_was_sent_before(
|
||||
sender_user,
|
||||
sender_device,
|
||||
&room_id,
|
||||
&event.sender,
|
||||
)? || lazy_load_send_redundant
|
||||
{
|
||||
lazy_loaded.insert(event.sender.as_str().to_owned());
|
||||
}
|
||||
}
|
||||
for (_, event) in &events_after {
|
||||
if !services().rooms.lazy_loading.lazy_load_was_sent_before(
|
||||
sender_user,
|
||||
sender_device,
|
||||
&room_id,
|
||||
&event.sender,
|
||||
)? || lazy_load_send_redundant
|
||||
{
|
||||
lazy_loaded.insert(event.sender.as_str().to_owned());
|
||||
}
|
||||
}
|
||||
|
||||
let shortstatehash = match services().rooms.state_accessor.pdu_shortstatehash(
|
||||
events_after
|
||||
.last()
|
||||
.map_or(&*body.event_id, |(_, e)| &*e.event_id),
|
||||
)? {
|
||||
Some(s) => s,
|
||||
None => services()
|
||||
.rooms
|
||||
.state
|
||||
.get_room_shortstatehash(&room_id)?
|
||||
.expect("All rooms have state"),
|
||||
};
|
||||
let shortstatehash = match services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.pdu_shortstatehash(events_after.last().map_or(&*body.event_id, |(_, e)| &*e.event_id))?
|
||||
{
|
||||
Some(s) => s,
|
||||
None => services().rooms.state.get_room_shortstatehash(&room_id)?.expect("All rooms have state"),
|
||||
};
|
||||
|
||||
let state_ids = services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.state_full_ids(shortstatehash)
|
||||
.await?;
|
||||
let state_ids = services().rooms.state_accessor.state_full_ids(shortstatehash).await?;
|
||||
|
||||
let end_token = events_after
|
||||
.last()
|
||||
.map(|(count, _)| count.stringify())
|
||||
.unwrap_or_else(|| base_token.stringify());
|
||||
let end_token = events_after.last().map(|(count, _)| count.stringify()).unwrap_or_else(|| base_token.stringify());
|
||||
|
||||
let events_after: Vec<_> = events_after
|
||||
.into_iter()
|
||||
.map(|(_, pdu)| pdu.to_room_event())
|
||||
.collect();
|
||||
let events_after: Vec<_> = events_after.into_iter().map(|(_, pdu)| pdu.to_room_event()).collect();
|
||||
|
||||
let mut state = Vec::new();
|
||||
let mut state = Vec::new();
|
||||
|
||||
for (shortstatekey, id) in state_ids {
|
||||
let (event_type, state_key) = services()
|
||||
.rooms
|
||||
.short
|
||||
.get_statekey_from_short(shortstatekey)?;
|
||||
for (shortstatekey, id) in state_ids {
|
||||
let (event_type, state_key) = services().rooms.short.get_statekey_from_short(shortstatekey)?;
|
||||
|
||||
if event_type != StateEventType::RoomMember {
|
||||
let pdu = match services().rooms.timeline.get_pdu(&id)? {
|
||||
Some(pdu) => pdu,
|
||||
None => {
|
||||
error!("Pdu in state not found: {}", id);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
state.push(pdu.to_state_event());
|
||||
} else if !lazy_load_enabled || lazy_loaded.contains(&state_key) {
|
||||
let pdu = match services().rooms.timeline.get_pdu(&id)? {
|
||||
Some(pdu) => pdu,
|
||||
None => {
|
||||
error!("Pdu in state not found: {}", id);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
state.push(pdu.to_state_event());
|
||||
}
|
||||
}
|
||||
if event_type != StateEventType::RoomMember {
|
||||
let pdu = match services().rooms.timeline.get_pdu(&id)? {
|
||||
Some(pdu) => pdu,
|
||||
None => {
|
||||
error!("Pdu in state not found: {}", id);
|
||||
continue;
|
||||
},
|
||||
};
|
||||
state.push(pdu.to_state_event());
|
||||
} else if !lazy_load_enabled || lazy_loaded.contains(&state_key) {
|
||||
let pdu = match services().rooms.timeline.get_pdu(&id)? {
|
||||
Some(pdu) => pdu,
|
||||
None => {
|
||||
error!("Pdu in state not found: {}", id);
|
||||
continue;
|
||||
},
|
||||
};
|
||||
state.push(pdu.to_state_event());
|
||||
}
|
||||
}
|
||||
|
||||
let resp = get_context::v3::Response {
|
||||
start: Some(start_token),
|
||||
end: Some(end_token),
|
||||
events_before,
|
||||
event: Some(base_event),
|
||||
events_after,
|
||||
state,
|
||||
};
|
||||
let resp = get_context::v3::Response {
|
||||
start: Some(start_token),
|
||||
end: Some(end_token),
|
||||
events_before,
|
||||
event: Some(base_event),
|
||||
events_after,
|
||||
state,
|
||||
};
|
||||
|
||||
Ok(resp)
|
||||
Ok(resp)
|
||||
}
|
||||
|
||||
+94
-112
@@ -1,65 +1,61 @@
|
||||
use crate::{services, utils, Error, Result, Ruma};
|
||||
use ruma::api::client::{
|
||||
device::{self, delete_device, delete_devices, get_device, get_devices, update_device},
|
||||
error::ErrorKind,
|
||||
uiaa::{AuthFlow, AuthType, UiaaInfo},
|
||||
device::{self, delete_device, delete_devices, get_device, get_devices, update_device},
|
||||
error::ErrorKind,
|
||||
uiaa::{AuthFlow, AuthType, UiaaInfo},
|
||||
};
|
||||
|
||||
use super::SESSION_ID_LENGTH;
|
||||
use crate::{services, utils, Error, Result, Ruma};
|
||||
|
||||
/// # `GET /_matrix/client/r0/devices`
|
||||
///
|
||||
/// Get metadata on all devices of the sender user.
|
||||
pub async fn get_devices_route(
|
||||
body: Ruma<get_devices::v3::Request>,
|
||||
) -> Result<get_devices::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
pub async fn get_devices_route(body: Ruma<get_devices::v3::Request>) -> Result<get_devices::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let devices: Vec<device::Device> = services()
|
||||
.users
|
||||
.all_devices_metadata(sender_user)
|
||||
.filter_map(std::result::Result::ok) // Filter out buggy devices
|
||||
.collect();
|
||||
let devices: Vec<device::Device> = services()
|
||||
.users
|
||||
.all_devices_metadata(sender_user)
|
||||
.filter_map(std::result::Result::ok) // Filter out buggy devices
|
||||
.collect();
|
||||
|
||||
Ok(get_devices::v3::Response { devices })
|
||||
Ok(get_devices::v3::Response {
|
||||
devices,
|
||||
})
|
||||
}
|
||||
|
||||
/// # `GET /_matrix/client/r0/devices/{deviceId}`
|
||||
///
|
||||
/// Get metadata on a single device of the sender user.
|
||||
pub async fn get_device_route(
|
||||
body: Ruma<get_device::v3::Request>,
|
||||
) -> Result<get_device::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
pub async fn get_device_route(body: Ruma<get_device::v3::Request>) -> Result<get_device::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let device = services()
|
||||
.users
|
||||
.get_device_metadata(sender_user, &body.body.device_id)?
|
||||
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Device not found."))?;
|
||||
let device = services()
|
||||
.users
|
||||
.get_device_metadata(sender_user, &body.body.device_id)?
|
||||
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Device not found."))?;
|
||||
|
||||
Ok(get_device::v3::Response { device })
|
||||
Ok(get_device::v3::Response {
|
||||
device,
|
||||
})
|
||||
}
|
||||
|
||||
/// # `PUT /_matrix/client/r0/devices/{deviceId}`
|
||||
///
|
||||
/// Updates the metadata on a given device of the sender user.
|
||||
pub async fn update_device_route(
|
||||
body: Ruma<update_device::v3::Request>,
|
||||
) -> Result<update_device::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
pub async fn update_device_route(body: Ruma<update_device::v3::Request>) -> Result<update_device::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let mut device = services()
|
||||
.users
|
||||
.get_device_metadata(sender_user, &body.device_id)?
|
||||
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Device not found."))?;
|
||||
let mut device = services()
|
||||
.users
|
||||
.get_device_metadata(sender_user, &body.device_id)?
|
||||
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Device not found."))?;
|
||||
|
||||
device.display_name = body.display_name.clone();
|
||||
device.display_name = body.display_name.clone();
|
||||
|
||||
services()
|
||||
.users
|
||||
.update_device_metadata(sender_user, &body.device_id, &device)?;
|
||||
services().users.update_device_metadata(sender_user, &body.device_id, &device)?;
|
||||
|
||||
Ok(update_device::v3::Response {})
|
||||
Ok(update_device::v3::Response {})
|
||||
}
|
||||
|
||||
/// # `DELETE /_matrix/client/r0/devices/{deviceId}`
|
||||
@@ -68,50 +64,42 @@ pub async fn update_device_route(
|
||||
///
|
||||
/// - Requires UIAA to verify user password
|
||||
/// - Invalidates access token
|
||||
/// - Deletes device metadata (device id, device display name, last seen ip, last seen ts)
|
||||
/// - Deletes device metadata (device id, device display name, last seen ip,
|
||||
/// last seen ts)
|
||||
/// - Forgets to-device events
|
||||
/// - Triggers device list updates
|
||||
pub async fn delete_device_route(
|
||||
body: Ruma<delete_device::v3::Request>,
|
||||
) -> Result<delete_device::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
|
||||
pub async fn delete_device_route(body: Ruma<delete_device::v3::Request>) -> Result<delete_device::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
|
||||
|
||||
// UIAA
|
||||
let mut uiaainfo = UiaaInfo {
|
||||
flows: vec![AuthFlow {
|
||||
stages: vec![AuthType::Password],
|
||||
}],
|
||||
completed: Vec::new(),
|
||||
params: Box::default(),
|
||||
session: None,
|
||||
auth_error: None,
|
||||
};
|
||||
// UIAA
|
||||
let mut uiaainfo = UiaaInfo {
|
||||
flows: vec![AuthFlow {
|
||||
stages: vec![AuthType::Password],
|
||||
}],
|
||||
completed: Vec::new(),
|
||||
params: Box::default(),
|
||||
session: None,
|
||||
auth_error: None,
|
||||
};
|
||||
|
||||
if let Some(auth) = &body.auth {
|
||||
let (worked, uiaainfo) =
|
||||
services()
|
||||
.uiaa
|
||||
.try_auth(sender_user, sender_device, auth, &uiaainfo)?;
|
||||
if !worked {
|
||||
return Err(Error::Uiaa(uiaainfo));
|
||||
}
|
||||
// Success!
|
||||
} else if let Some(json) = body.json_body {
|
||||
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
|
||||
services()
|
||||
.uiaa
|
||||
.create(sender_user, sender_device, &uiaainfo, &json)?;
|
||||
return Err(Error::Uiaa(uiaainfo));
|
||||
} else {
|
||||
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
|
||||
}
|
||||
if let Some(auth) = &body.auth {
|
||||
let (worked, uiaainfo) = services().uiaa.try_auth(sender_user, sender_device, auth, &uiaainfo)?;
|
||||
if !worked {
|
||||
return Err(Error::Uiaa(uiaainfo));
|
||||
}
|
||||
// Success!
|
||||
} else if let Some(json) = body.json_body {
|
||||
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
|
||||
services().uiaa.create(sender_user, sender_device, &uiaainfo, &json)?;
|
||||
return Err(Error::Uiaa(uiaainfo));
|
||||
} else {
|
||||
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
|
||||
}
|
||||
|
||||
services()
|
||||
.users
|
||||
.remove_device(sender_user, &body.device_id)?;
|
||||
services().users.remove_device(sender_user, &body.device_id)?;
|
||||
|
||||
Ok(delete_device::v3::Response {})
|
||||
Ok(delete_device::v3::Response {})
|
||||
}
|
||||
|
||||
/// # `PUT /_matrix/client/r0/devices/{deviceId}`
|
||||
@@ -122,48 +110,42 @@ pub async fn delete_device_route(
|
||||
///
|
||||
/// For each device:
|
||||
/// - Invalidates access token
|
||||
/// - Deletes device metadata (device id, device display name, last seen ip, last seen ts)
|
||||
/// - Deletes device metadata (device id, device display name, last seen ip,
|
||||
/// last seen ts)
|
||||
/// - Forgets to-device events
|
||||
/// - Triggers device list updates
|
||||
pub async fn delete_devices_route(
|
||||
body: Ruma<delete_devices::v3::Request>,
|
||||
) -> Result<delete_devices::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
|
||||
pub async fn delete_devices_route(body: Ruma<delete_devices::v3::Request>) -> Result<delete_devices::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
|
||||
|
||||
// UIAA
|
||||
let mut uiaainfo = UiaaInfo {
|
||||
flows: vec![AuthFlow {
|
||||
stages: vec![AuthType::Password],
|
||||
}],
|
||||
completed: Vec::new(),
|
||||
params: Box::default(),
|
||||
session: None,
|
||||
auth_error: None,
|
||||
};
|
||||
// UIAA
|
||||
let mut uiaainfo = UiaaInfo {
|
||||
flows: vec![AuthFlow {
|
||||
stages: vec![AuthType::Password],
|
||||
}],
|
||||
completed: Vec::new(),
|
||||
params: Box::default(),
|
||||
session: None,
|
||||
auth_error: None,
|
||||
};
|
||||
|
||||
if let Some(auth) = &body.auth {
|
||||
let (worked, uiaainfo) =
|
||||
services()
|
||||
.uiaa
|
||||
.try_auth(sender_user, sender_device, auth, &uiaainfo)?;
|
||||
if !worked {
|
||||
return Err(Error::Uiaa(uiaainfo));
|
||||
}
|
||||
// Success!
|
||||
} else if let Some(json) = body.json_body {
|
||||
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
|
||||
services()
|
||||
.uiaa
|
||||
.create(sender_user, sender_device, &uiaainfo, &json)?;
|
||||
return Err(Error::Uiaa(uiaainfo));
|
||||
} else {
|
||||
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
|
||||
}
|
||||
if let Some(auth) = &body.auth {
|
||||
let (worked, uiaainfo) = services().uiaa.try_auth(sender_user, sender_device, auth, &uiaainfo)?;
|
||||
if !worked {
|
||||
return Err(Error::Uiaa(uiaainfo));
|
||||
}
|
||||
// Success!
|
||||
} else if let Some(json) = body.json_body {
|
||||
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
|
||||
services().uiaa.create(sender_user, sender_device, &uiaainfo, &json)?;
|
||||
return Err(Error::Uiaa(uiaainfo));
|
||||
} else {
|
||||
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
|
||||
}
|
||||
|
||||
for device_id in &body.devices {
|
||||
services().users.remove_device(sender_user, device_id)?;
|
||||
}
|
||||
for device_id in &body.devices {
|
||||
services().users.remove_device(sender_user, device_id)?;
|
||||
}
|
||||
|
||||
Ok(delete_devices::v3::Response {})
|
||||
Ok(delete_devices::v3::Response {})
|
||||
}
|
||||
|
||||
+279
-322
@@ -1,57 +1,51 @@
|
||||
use crate::{services, Error, Result, Ruma};
|
||||
use ruma::{
|
||||
api::{
|
||||
client::{
|
||||
directory::{
|
||||
get_public_rooms, get_public_rooms_filtered, get_room_visibility,
|
||||
set_room_visibility,
|
||||
},
|
||||
error::ErrorKind,
|
||||
room,
|
||||
},
|
||||
federation,
|
||||
},
|
||||
directory::{Filter, PublicRoomJoinRule, PublicRoomsChunk, RoomNetwork},
|
||||
events::{
|
||||
room::{
|
||||
avatar::RoomAvatarEventContent,
|
||||
canonical_alias::RoomCanonicalAliasEventContent,
|
||||
create::RoomCreateEventContent,
|
||||
guest_access::{GuestAccess, RoomGuestAccessEventContent},
|
||||
history_visibility::{HistoryVisibility, RoomHistoryVisibilityEventContent},
|
||||
join_rules::{JoinRule, RoomJoinRulesEventContent},
|
||||
topic::RoomTopicEventContent,
|
||||
},
|
||||
StateEventType,
|
||||
},
|
||||
ServerName, UInt,
|
||||
api::{
|
||||
client::{
|
||||
directory::{get_public_rooms, get_public_rooms_filtered, get_room_visibility, set_room_visibility},
|
||||
error::ErrorKind,
|
||||
room,
|
||||
},
|
||||
federation,
|
||||
},
|
||||
directory::{Filter, PublicRoomJoinRule, PublicRoomsChunk, RoomNetwork},
|
||||
events::{
|
||||
room::{
|
||||
avatar::RoomAvatarEventContent,
|
||||
canonical_alias::RoomCanonicalAliasEventContent,
|
||||
create::RoomCreateEventContent,
|
||||
guest_access::{GuestAccess, RoomGuestAccessEventContent},
|
||||
history_visibility::{HistoryVisibility, RoomHistoryVisibilityEventContent},
|
||||
join_rules::{JoinRule, RoomJoinRulesEventContent},
|
||||
topic::RoomTopicEventContent,
|
||||
},
|
||||
StateEventType,
|
||||
},
|
||||
ServerName, UInt,
|
||||
};
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
use crate::{services, Error, Result, Ruma};
|
||||
|
||||
/// # `POST /_matrix/client/v3/publicRooms`
|
||||
///
|
||||
/// Lists the public rooms on this server.
|
||||
///
|
||||
/// - Rooms are ordered by the number of joined members
|
||||
pub async fn get_public_rooms_filtered_route(
|
||||
body: Ruma<get_public_rooms_filtered::v3::Request>,
|
||||
body: Ruma<get_public_rooms_filtered::v3::Request>,
|
||||
) -> Result<get_public_rooms_filtered::v3::Response> {
|
||||
if !services()
|
||||
.globals
|
||||
.config
|
||||
.allow_public_room_directory_without_auth
|
||||
{
|
||||
let _sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
}
|
||||
if !services().globals.config.allow_public_room_directory_without_auth {
|
||||
let _sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
}
|
||||
|
||||
get_public_rooms_filtered_helper(
|
||||
body.server.as_deref(),
|
||||
body.limit,
|
||||
body.since.as_deref(),
|
||||
&body.filter,
|
||||
&body.room_network,
|
||||
)
|
||||
.await
|
||||
get_public_rooms_filtered_helper(
|
||||
body.server.as_deref(),
|
||||
body.limit,
|
||||
body.since.as_deref(),
|
||||
&body.filter,
|
||||
&body.room_network,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// # `GET /_matrix/client/v3/publicRooms`
|
||||
@@ -60,31 +54,27 @@ pub async fn get_public_rooms_filtered_route(
|
||||
///
|
||||
/// - Rooms are ordered by the number of joined members
|
||||
pub async fn get_public_rooms_route(
|
||||
body: Ruma<get_public_rooms::v3::Request>,
|
||||
body: Ruma<get_public_rooms::v3::Request>,
|
||||
) -> Result<get_public_rooms::v3::Response> {
|
||||
if !services()
|
||||
.globals
|
||||
.config
|
||||
.allow_public_room_directory_without_auth
|
||||
{
|
||||
let _sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
}
|
||||
if !services().globals.config.allow_public_room_directory_without_auth {
|
||||
let _sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
}
|
||||
|
||||
let response = get_public_rooms_filtered_helper(
|
||||
body.server.as_deref(),
|
||||
body.limit,
|
||||
body.since.as_deref(),
|
||||
&Filter::default(),
|
||||
&RoomNetwork::Matrix,
|
||||
)
|
||||
.await?;
|
||||
let response = get_public_rooms_filtered_helper(
|
||||
body.server.as_deref(),
|
||||
body.limit,
|
||||
body.since.as_deref(),
|
||||
&Filter::default(),
|
||||
&RoomNetwork::Matrix,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(get_public_rooms::v3::Response {
|
||||
chunk: response.chunk,
|
||||
prev_batch: response.prev_batch,
|
||||
next_batch: response.next_batch,
|
||||
total_room_count_estimate: response.total_room_count_estimate,
|
||||
})
|
||||
Ok(get_public_rooms::v3::Response {
|
||||
chunk: response.chunk,
|
||||
prev_batch: response.prev_batch,
|
||||
next_batch: response.next_batch,
|
||||
total_room_count_estimate: response.total_room_count_estimate,
|
||||
})
|
||||
}
|
||||
|
||||
/// # `PUT /_matrix/client/r0/directory/list/room/{roomId}`
|
||||
@@ -93,294 +83,261 @@ pub async fn get_public_rooms_route(
|
||||
///
|
||||
/// - TODO: Access control checks
|
||||
pub async fn set_room_visibility_route(
|
||||
body: Ruma<set_room_visibility::v3::Request>,
|
||||
body: Ruma<set_room_visibility::v3::Request>,
|
||||
) -> Result<set_room_visibility::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
if !services().rooms.metadata.exists(&body.room_id)? {
|
||||
// Return 404 if the room doesn't exist
|
||||
return Err(Error::BadRequest(ErrorKind::NotFound, "Room not found"));
|
||||
}
|
||||
if !services().rooms.metadata.exists(&body.room_id)? {
|
||||
// Return 404 if the room doesn't exist
|
||||
return Err(Error::BadRequest(ErrorKind::NotFound, "Room not found"));
|
||||
}
|
||||
|
||||
match &body.visibility {
|
||||
room::Visibility::Public => {
|
||||
services().rooms.directory.set_public(&body.room_id)?;
|
||||
info!("{} made {} public", sender_user, body.room_id);
|
||||
}
|
||||
room::Visibility::Private => services().rooms.directory.set_not_public(&body.room_id)?,
|
||||
_ => {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Room visibility type is not supported.",
|
||||
));
|
||||
}
|
||||
}
|
||||
match &body.visibility {
|
||||
room::Visibility::Public => {
|
||||
services().rooms.directory.set_public(&body.room_id)?;
|
||||
info!("{} made {} public", sender_user, body.room_id);
|
||||
},
|
||||
room::Visibility::Private => services().rooms.directory.set_not_public(&body.room_id)?,
|
||||
_ => {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Room visibility type is not supported.",
|
||||
));
|
||||
},
|
||||
}
|
||||
|
||||
Ok(set_room_visibility::v3::Response {})
|
||||
Ok(set_room_visibility::v3::Response {})
|
||||
}
|
||||
|
||||
/// # `GET /_matrix/client/r0/directory/list/room/{roomId}`
|
||||
///
|
||||
/// Gets the visibility of a given room in the room directory.
|
||||
pub async fn get_room_visibility_route(
|
||||
body: Ruma<get_room_visibility::v3::Request>,
|
||||
body: Ruma<get_room_visibility::v3::Request>,
|
||||
) -> Result<get_room_visibility::v3::Response> {
|
||||
if !services().rooms.metadata.exists(&body.room_id)? {
|
||||
// Return 404 if the room doesn't exist
|
||||
return Err(Error::BadRequest(ErrorKind::NotFound, "Room not found"));
|
||||
}
|
||||
if !services().rooms.metadata.exists(&body.room_id)? {
|
||||
// Return 404 if the room doesn't exist
|
||||
return Err(Error::BadRequest(ErrorKind::NotFound, "Room not found"));
|
||||
}
|
||||
|
||||
Ok(get_room_visibility::v3::Response {
|
||||
visibility: if services().rooms.directory.is_public_room(&body.room_id)? {
|
||||
room::Visibility::Public
|
||||
} else {
|
||||
room::Visibility::Private
|
||||
},
|
||||
})
|
||||
Ok(get_room_visibility::v3::Response {
|
||||
visibility: if services().rooms.directory.is_public_room(&body.room_id)? {
|
||||
room::Visibility::Public
|
||||
} else {
|
||||
room::Visibility::Private
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) async fn get_public_rooms_filtered_helper(
|
||||
server: Option<&ServerName>,
|
||||
limit: Option<UInt>,
|
||||
since: Option<&str>,
|
||||
filter: &Filter,
|
||||
_network: &RoomNetwork,
|
||||
server: Option<&ServerName>, limit: Option<UInt>, since: Option<&str>, filter: &Filter, _network: &RoomNetwork,
|
||||
) -> Result<get_public_rooms_filtered::v3::Response> {
|
||||
if let Some(other_server) =
|
||||
server.filter(|server| *server != services().globals.server_name().as_str())
|
||||
{
|
||||
let response = services()
|
||||
.sending
|
||||
.send_federation_request(
|
||||
other_server,
|
||||
federation::directory::get_public_rooms_filtered::v1::Request {
|
||||
limit,
|
||||
since: since.map(ToOwned::to_owned),
|
||||
filter: Filter {
|
||||
generic_search_term: filter.generic_search_term.clone(),
|
||||
room_types: filter.room_types.clone(),
|
||||
},
|
||||
room_network: RoomNetwork::Matrix,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
if let Some(other_server) = server.filter(|server| *server != services().globals.server_name().as_str()) {
|
||||
let response = services()
|
||||
.sending
|
||||
.send_federation_request(
|
||||
other_server,
|
||||
federation::directory::get_public_rooms_filtered::v1::Request {
|
||||
limit,
|
||||
since: since.map(ToOwned::to_owned),
|
||||
filter: Filter {
|
||||
generic_search_term: filter.generic_search_term.clone(),
|
||||
room_types: filter.room_types.clone(),
|
||||
},
|
||||
room_network: RoomNetwork::Matrix,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
return Ok(get_public_rooms_filtered::v3::Response {
|
||||
chunk: response.chunk,
|
||||
prev_batch: response.prev_batch,
|
||||
next_batch: response.next_batch,
|
||||
total_room_count_estimate: response.total_room_count_estimate,
|
||||
});
|
||||
}
|
||||
return Ok(get_public_rooms_filtered::v3::Response {
|
||||
chunk: response.chunk,
|
||||
prev_batch: response.prev_batch,
|
||||
next_batch: response.next_batch,
|
||||
total_room_count_estimate: response.total_room_count_estimate,
|
||||
});
|
||||
}
|
||||
|
||||
let limit = limit.map_or(10, u64::from);
|
||||
let mut num_since = 0_u64;
|
||||
let limit = limit.map_or(10, u64::from);
|
||||
let mut num_since = 0_u64;
|
||||
|
||||
if let Some(s) = &since {
|
||||
let mut characters = s.chars();
|
||||
let backwards = match characters.next() {
|
||||
Some('n') => false,
|
||||
Some('p') => true,
|
||||
_ => {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Invalid `since` token",
|
||||
))
|
||||
}
|
||||
};
|
||||
if let Some(s) = &since {
|
||||
let mut characters = s.chars();
|
||||
let backwards = match characters.next() {
|
||||
Some('n') => false,
|
||||
Some('p') => true,
|
||||
_ => return Err(Error::BadRequest(ErrorKind::InvalidParam, "Invalid `since` token")),
|
||||
};
|
||||
|
||||
num_since = characters
|
||||
.collect::<String>()
|
||||
.parse()
|
||||
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `since` token."))?;
|
||||
num_since = characters
|
||||
.collect::<String>()
|
||||
.parse()
|
||||
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `since` token."))?;
|
||||
|
||||
if backwards {
|
||||
num_since = num_since.saturating_sub(limit);
|
||||
}
|
||||
}
|
||||
if backwards {
|
||||
num_since = num_since.saturating_sub(limit);
|
||||
}
|
||||
}
|
||||
|
||||
let mut all_rooms: Vec<_> = services()
|
||||
.rooms
|
||||
.directory
|
||||
.public_rooms()
|
||||
.map(|room_id| {
|
||||
let room_id = room_id?;
|
||||
let mut all_rooms: Vec<_> = services()
|
||||
.rooms
|
||||
.directory
|
||||
.public_rooms()
|
||||
.map(|room_id| {
|
||||
let room_id = room_id?;
|
||||
|
||||
let chunk = PublicRoomsChunk {
|
||||
canonical_alias: services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.room_state_get(&room_id, &StateEventType::RoomCanonicalAlias, "")?
|
||||
.map_or(Ok(None), |s| {
|
||||
serde_json::from_str(s.content.get())
|
||||
.map(|c: RoomCanonicalAliasEventContent| c.alias)
|
||||
.map_err(|_| {
|
||||
Error::bad_database("Invalid canonical alias event in database.")
|
||||
})
|
||||
})?,
|
||||
name: services().rooms.state_accessor.get_name(&room_id)?,
|
||||
num_joined_members: services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.room_joined_count(&room_id)?
|
||||
.unwrap_or_else(|| {
|
||||
warn!("Room {} has no member count", room_id);
|
||||
0
|
||||
})
|
||||
.try_into()
|
||||
.expect("user count should not be that big"),
|
||||
topic: services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.room_state_get(&room_id, &StateEventType::RoomTopic, "")?
|
||||
.map_or(Ok(None), |s| {
|
||||
serde_json::from_str(s.content.get())
|
||||
.map(|c: RoomTopicEventContent| Some(c.topic))
|
||||
.map_err(|_| {
|
||||
error!("Invalid room topic event in database for room {}", room_id);
|
||||
Error::bad_database("Invalid room topic event in database.")
|
||||
})
|
||||
})
|
||||
.unwrap_or(None),
|
||||
world_readable: services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.room_state_get(&room_id, &StateEventType::RoomHistoryVisibility, "")?
|
||||
.map_or(Ok(false), |s| {
|
||||
serde_json::from_str(s.content.get())
|
||||
.map(|c: RoomHistoryVisibilityEventContent| {
|
||||
c.history_visibility == HistoryVisibility::WorldReadable
|
||||
})
|
||||
.map_err(|_| {
|
||||
Error::bad_database(
|
||||
"Invalid room history visibility event in database.",
|
||||
)
|
||||
})
|
||||
})?,
|
||||
guest_can_join: services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.room_state_get(&room_id, &StateEventType::RoomGuestAccess, "")?
|
||||
.map_or(Ok(false), |s| {
|
||||
serde_json::from_str(s.content.get())
|
||||
.map(|c: RoomGuestAccessEventContent| {
|
||||
c.guest_access == GuestAccess::CanJoin
|
||||
})
|
||||
.map_err(|_| {
|
||||
Error::bad_database("Invalid room guest access event in database.")
|
||||
})
|
||||
})?,
|
||||
avatar_url: services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.room_state_get(&room_id, &StateEventType::RoomAvatar, "")?
|
||||
.map(|s| {
|
||||
serde_json::from_str(s.content.get())
|
||||
.map(|c: RoomAvatarEventContent| c.url)
|
||||
.map_err(|_| {
|
||||
Error::bad_database("Invalid room avatar event in database.")
|
||||
})
|
||||
})
|
||||
.transpose()?
|
||||
// url is now an Option<String> so we must flatten
|
||||
.flatten(),
|
||||
join_rule: services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.room_state_get(&room_id, &StateEventType::RoomJoinRules, "")?
|
||||
.map(|s| {
|
||||
serde_json::from_str(s.content.get())
|
||||
.map(|c: RoomJoinRulesEventContent| match c.join_rule {
|
||||
JoinRule::Public => Some(PublicRoomJoinRule::Public),
|
||||
JoinRule::Knock => Some(PublicRoomJoinRule::Knock),
|
||||
_ => None,
|
||||
})
|
||||
.map_err(|e| {
|
||||
error!("Invalid room join rule event in database: {}", e);
|
||||
Error::BadDatabase("Invalid room join rule event in database.")
|
||||
})
|
||||
})
|
||||
.transpose()?
|
||||
.flatten()
|
||||
.ok_or_else(|| Error::bad_database("Missing room join rule event for room."))?,
|
||||
room_type: services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.room_state_get(&room_id, &StateEventType::RoomCreate, "")?
|
||||
.map(|s| {
|
||||
serde_json::from_str::<RoomCreateEventContent>(s.content.get()).map_err(
|
||||
|e| {
|
||||
error!("Invalid room create event in database: {}", e);
|
||||
Error::BadDatabase("Invalid room create event in database.")
|
||||
},
|
||||
)
|
||||
})
|
||||
.transpose()?
|
||||
.and_then(|e| e.room_type),
|
||||
room_id,
|
||||
};
|
||||
Ok(chunk)
|
||||
})
|
||||
.filter_map(|r: Result<_>| r.ok()) // Filter out buggy rooms
|
||||
.filter(|chunk| {
|
||||
if let Some(query) = filter
|
||||
.generic_search_term
|
||||
.as_ref()
|
||||
.map(|q| q.to_lowercase())
|
||||
{
|
||||
if let Some(name) = &chunk.name {
|
||||
if name.as_str().to_lowercase().contains(&query) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
let chunk = PublicRoomsChunk {
|
||||
canonical_alias: services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.room_state_get(&room_id, &StateEventType::RoomCanonicalAlias, "")?
|
||||
.map_or(Ok(None), |s| {
|
||||
serde_json::from_str(s.content.get())
|
||||
.map(|c: RoomCanonicalAliasEventContent| c.alias)
|
||||
.map_err(|_| Error::bad_database("Invalid canonical alias event in database."))
|
||||
})?,
|
||||
name: services().rooms.state_accessor.get_name(&room_id)?,
|
||||
num_joined_members: services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.room_joined_count(&room_id)?
|
||||
.unwrap_or_else(|| {
|
||||
warn!("Room {} has no member count", room_id);
|
||||
0
|
||||
})
|
||||
.try_into()
|
||||
.expect("user count should not be that big"),
|
||||
topic: services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.room_state_get(&room_id, &StateEventType::RoomTopic, "")?
|
||||
.map_or(Ok(None), |s| {
|
||||
serde_json::from_str(s.content.get())
|
||||
.map(|c: RoomTopicEventContent| Some(c.topic))
|
||||
.map_err(|_| {
|
||||
error!("Invalid room topic event in database for room {}", room_id);
|
||||
Error::bad_database("Invalid room topic event in database.")
|
||||
})
|
||||
})
|
||||
.unwrap_or(None),
|
||||
world_readable: services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.room_state_get(&room_id, &StateEventType::RoomHistoryVisibility, "")?
|
||||
.map_or(Ok(false), |s| {
|
||||
serde_json::from_str(s.content.get())
|
||||
.map(|c: RoomHistoryVisibilityEventContent| {
|
||||
c.history_visibility == HistoryVisibility::WorldReadable
|
||||
})
|
||||
.map_err(|_| Error::bad_database("Invalid room history visibility event in database."))
|
||||
})?,
|
||||
guest_can_join: services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.room_state_get(&room_id, &StateEventType::RoomGuestAccess, "")?
|
||||
.map_or(Ok(false), |s| {
|
||||
serde_json::from_str(s.content.get())
|
||||
.map(|c: RoomGuestAccessEventContent| c.guest_access == GuestAccess::CanJoin)
|
||||
.map_err(|_| Error::bad_database("Invalid room guest access event in database."))
|
||||
})?,
|
||||
avatar_url: services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.room_state_get(&room_id, &StateEventType::RoomAvatar, "")?
|
||||
.map(|s| {
|
||||
serde_json::from_str(s.content.get())
|
||||
.map(|c: RoomAvatarEventContent| c.url)
|
||||
.map_err(|_| Error::bad_database("Invalid room avatar event in database."))
|
||||
})
|
||||
.transpose()?
|
||||
// url is now an Option<String> so we must flatten
|
||||
.flatten(),
|
||||
join_rule: services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.room_state_get(&room_id, &StateEventType::RoomJoinRules, "")?
|
||||
.map(|s| {
|
||||
serde_json::from_str(s.content.get())
|
||||
.map(|c: RoomJoinRulesEventContent| match c.join_rule {
|
||||
JoinRule::Public => Some(PublicRoomJoinRule::Public),
|
||||
JoinRule::Knock => Some(PublicRoomJoinRule::Knock),
|
||||
_ => None,
|
||||
})
|
||||
.map_err(|e| {
|
||||
error!("Invalid room join rule event in database: {}", e);
|
||||
Error::BadDatabase("Invalid room join rule event in database.")
|
||||
})
|
||||
})
|
||||
.transpose()?
|
||||
.flatten()
|
||||
.ok_or_else(|| Error::bad_database("Missing room join rule event for room."))?,
|
||||
room_type: services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.room_state_get(&room_id, &StateEventType::RoomCreate, "")?
|
||||
.map(|s| {
|
||||
serde_json::from_str::<RoomCreateEventContent>(s.content.get()).map_err(|e| {
|
||||
error!("Invalid room create event in database: {}", e);
|
||||
Error::BadDatabase("Invalid room create event in database.")
|
||||
})
|
||||
})
|
||||
.transpose()?
|
||||
.and_then(|e| e.room_type),
|
||||
room_id,
|
||||
};
|
||||
Ok(chunk)
|
||||
})
|
||||
.filter_map(|r: Result<_>| r.ok()) // Filter out buggy rooms
|
||||
.filter(|chunk| {
|
||||
if let Some(query) = filter.generic_search_term.as_ref().map(|q| q.to_lowercase()) {
|
||||
if let Some(name) = &chunk.name {
|
||||
if name.as_str().to_lowercase().contains(&query) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(topic) = &chunk.topic {
|
||||
if topic.to_lowercase().contains(&query) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
if let Some(topic) = &chunk.topic {
|
||||
if topic.to_lowercase().contains(&query) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(canonical_alias) = &chunk.canonical_alias {
|
||||
if canonical_alias.as_str().to_lowercase().contains(&query) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
if let Some(canonical_alias) = &chunk.canonical_alias {
|
||||
if canonical_alias.as_str().to_lowercase().contains(&query) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
} else {
|
||||
// No search term
|
||||
true
|
||||
}
|
||||
})
|
||||
// We need to collect all, so we can sort by member count
|
||||
.collect();
|
||||
false
|
||||
} else {
|
||||
// No search term
|
||||
true
|
||||
}
|
||||
})
|
||||
// We need to collect all, so we can sort by member count
|
||||
.collect();
|
||||
|
||||
all_rooms.sort_by(|l, r| r.num_joined_members.cmp(&l.num_joined_members));
|
||||
all_rooms.sort_by(|l, r| r.num_joined_members.cmp(&l.num_joined_members));
|
||||
|
||||
let total_room_count_estimate = (all_rooms.len() as u32).into();
|
||||
let total_room_count_estimate = (all_rooms.len() as u32).into();
|
||||
|
||||
let chunk: Vec<_> = all_rooms
|
||||
.into_iter()
|
||||
.skip(num_since as usize)
|
||||
.take(limit as usize)
|
||||
.collect();
|
||||
let chunk: Vec<_> = all_rooms.into_iter().skip(num_since as usize).take(limit as usize).collect();
|
||||
|
||||
let prev_batch = if num_since == 0 {
|
||||
None
|
||||
} else {
|
||||
Some(format!("p{num_since}"))
|
||||
};
|
||||
let prev_batch = if num_since == 0 {
|
||||
None
|
||||
} else {
|
||||
Some(format!("p{num_since}"))
|
||||
};
|
||||
|
||||
let next_batch = if chunk.len() < limit as usize {
|
||||
None
|
||||
} else {
|
||||
Some(format!("n{}", num_since + limit))
|
||||
};
|
||||
let next_batch = if chunk.len() < limit as usize {
|
||||
None
|
||||
} else {
|
||||
Some(format!("n{}", num_since + limit))
|
||||
};
|
||||
|
||||
Ok(get_public_rooms_filtered::v3::Response {
|
||||
chunk,
|
||||
prev_batch,
|
||||
next_batch,
|
||||
total_room_count_estimate: Some(total_room_count_estimate),
|
||||
})
|
||||
Ok(get_public_rooms_filtered::v3::Response {
|
||||
chunk,
|
||||
prev_batch,
|
||||
next_batch,
|
||||
total_room_count_estimate: Some(total_room_count_estimate),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,34 +1,31 @@
|
||||
use crate::{services, Error, Result, Ruma};
|
||||
use ruma::api::client::{
|
||||
error::ErrorKind,
|
||||
filter::{create_filter, get_filter},
|
||||
error::ErrorKind,
|
||||
filter::{create_filter, get_filter},
|
||||
};
|
||||
|
||||
use crate::{services, Error, Result, Ruma};
|
||||
|
||||
/// # `GET /_matrix/client/r0/user/{userId}/filter/{filterId}`
|
||||
///
|
||||
/// Loads a filter that was previously created.
|
||||
///
|
||||
/// - A user can only access their own filters
|
||||
pub async fn get_filter_route(
|
||||
body: Ruma<get_filter::v3::Request>,
|
||||
) -> Result<get_filter::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let filter = match services().users.get_filter(sender_user, &body.filter_id)? {
|
||||
Some(filter) => filter,
|
||||
None => return Err(Error::BadRequest(ErrorKind::NotFound, "Filter not found.")),
|
||||
};
|
||||
pub async fn get_filter_route(body: Ruma<get_filter::v3::Request>) -> Result<get_filter::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let filter = match services().users.get_filter(sender_user, &body.filter_id)? {
|
||||
Some(filter) => filter,
|
||||
None => return Err(Error::BadRequest(ErrorKind::NotFound, "Filter not found.")),
|
||||
};
|
||||
|
||||
Ok(get_filter::v3::Response::new(filter))
|
||||
Ok(get_filter::v3::Response::new(filter))
|
||||
}
|
||||
|
||||
/// # `PUT /_matrix/client/r0/user/{userId}/filter`
|
||||
///
|
||||
/// Creates a new filter to be used by other endpoints.
|
||||
pub async fn create_filter_route(
|
||||
body: Ruma<create_filter::v3::Request>,
|
||||
) -> Result<create_filter::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
Ok(create_filter::v3::Response::new(
|
||||
services().users.create_filter(sender_user, &body.filter)?,
|
||||
))
|
||||
pub async fn create_filter_route(body: Ruma<create_filter::v3::Request>) -> Result<create_filter::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
Ok(create_filter::v3::Response::new(
|
||||
services().users.create_filter(sender_user, &body.filter)?,
|
||||
))
|
||||
}
|
||||
|
||||
+356
-448
@@ -1,65 +1,53 @@
|
||||
use super::SESSION_ID_LENGTH;
|
||||
use crate::{services, utils, Error, Result, Ruma};
|
||||
use std::{
|
||||
collections::{hash_map, BTreeMap, HashMap, HashSet},
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
use futures_util::{stream::FuturesUnordered, StreamExt};
|
||||
use ruma::{
|
||||
api::{
|
||||
client::{
|
||||
error::ErrorKind,
|
||||
keys::{
|
||||
claim_keys, get_key_changes, get_keys, upload_keys, upload_signatures,
|
||||
upload_signing_keys,
|
||||
},
|
||||
uiaa::{AuthFlow, AuthType, UiaaInfo},
|
||||
},
|
||||
federation,
|
||||
},
|
||||
serde::Raw,
|
||||
DeviceKeyAlgorithm, OwnedDeviceId, OwnedUserId, UserId,
|
||||
api::{
|
||||
client::{
|
||||
error::ErrorKind,
|
||||
keys::{claim_keys, get_key_changes, get_keys, upload_keys, upload_signatures, upload_signing_keys},
|
||||
uiaa::{AuthFlow, AuthType, UiaaInfo},
|
||||
},
|
||||
federation,
|
||||
},
|
||||
serde::Raw,
|
||||
DeviceKeyAlgorithm, OwnedDeviceId, OwnedUserId, UserId,
|
||||
};
|
||||
use serde_json::json;
|
||||
use std::{
|
||||
collections::{hash_map, BTreeMap, HashMap, HashSet},
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use tracing::{debug, error};
|
||||
|
||||
use super::SESSION_ID_LENGTH;
|
||||
use crate::{services, utils, Error, Result, Ruma};
|
||||
|
||||
/// # `POST /_matrix/client/r0/keys/upload`
|
||||
///
|
||||
/// Publish end-to-end encryption keys for the sender device.
|
||||
///
|
||||
/// - Adds one time keys
|
||||
/// - If there are no device keys yet: Adds device keys (TODO: merge with existing keys?)
|
||||
pub async fn upload_keys_route(
|
||||
body: Ruma<upload_keys::v3::Request>,
|
||||
) -> Result<upload_keys::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
|
||||
/// - If there are no device keys yet: Adds device keys (TODO: merge with
|
||||
/// existing keys?)
|
||||
pub async fn upload_keys_route(body: Ruma<upload_keys::v3::Request>) -> Result<upload_keys::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
|
||||
|
||||
for (key_key, key_value) in &body.one_time_keys {
|
||||
services()
|
||||
.users
|
||||
.add_one_time_key(sender_user, sender_device, key_key, key_value)?;
|
||||
}
|
||||
for (key_key, key_value) in &body.one_time_keys {
|
||||
services().users.add_one_time_key(sender_user, sender_device, key_key, key_value)?;
|
||||
}
|
||||
|
||||
if let Some(device_keys) = &body.device_keys {
|
||||
// TODO: merge this and the existing event?
|
||||
// This check is needed to assure that signatures are kept
|
||||
if services()
|
||||
.users
|
||||
.get_device_keys(sender_user, sender_device)?
|
||||
.is_none()
|
||||
{
|
||||
services()
|
||||
.users
|
||||
.add_device_keys(sender_user, sender_device, device_keys)?;
|
||||
}
|
||||
}
|
||||
if let Some(device_keys) = &body.device_keys {
|
||||
// TODO: merge this and the existing event?
|
||||
// This check is needed to assure that signatures are kept
|
||||
if services().users.get_device_keys(sender_user, sender_device)?.is_none() {
|
||||
services().users.add_device_keys(sender_user, sender_device, device_keys)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(upload_keys::v3::Response {
|
||||
one_time_key_counts: services()
|
||||
.users
|
||||
.count_one_time_keys(sender_user, sender_device)?,
|
||||
})
|
||||
Ok(upload_keys::v3::Response {
|
||||
one_time_key_counts: services().users.count_one_time_keys(sender_user, sender_device)?,
|
||||
})
|
||||
}
|
||||
|
||||
/// # `POST /_matrix/client/r0/keys/query`
|
||||
@@ -68,30 +56,29 @@ pub async fn upload_keys_route(
|
||||
///
|
||||
/// - Always fetches users from other servers over federation
|
||||
/// - Gets master keys, self-signing keys, user signing keys and device keys.
|
||||
/// - The master and self-signing keys contain signatures that the user is allowed to see
|
||||
/// - The master and self-signing keys contain signatures that the user is
|
||||
/// allowed to see
|
||||
pub async fn get_keys_route(body: Ruma<get_keys::v3::Request>) -> Result<get_keys::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let response = get_keys_helper(
|
||||
Some(sender_user),
|
||||
&body.device_keys,
|
||||
|u| u == sender_user,
|
||||
true, // Always allow local users to see device names of other local users
|
||||
)
|
||||
.await?;
|
||||
let response = get_keys_helper(
|
||||
Some(sender_user),
|
||||
&body.device_keys,
|
||||
|u| u == sender_user,
|
||||
true, // Always allow local users to see device names of other local users
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(response)
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
/// # `POST /_matrix/client/r0/keys/claim`
|
||||
///
|
||||
/// Claims one-time keys
|
||||
pub async fn claim_keys_route(
|
||||
body: Ruma<claim_keys::v3::Request>,
|
||||
) -> Result<claim_keys::v3::Response> {
|
||||
let response = claim_keys_helper(&body.one_time_keys).await?;
|
||||
pub async fn claim_keys_route(body: Ruma<claim_keys::v3::Request>) -> Result<claim_keys::v3::Response> {
|
||||
let response = claim_keys_helper(&body.one_time_keys).await?;
|
||||
|
||||
Ok(response)
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
/// # `POST /_matrix/client/r0/keys/device_signing/upload`
|
||||
@@ -100,452 +87,373 @@ pub async fn claim_keys_route(
|
||||
///
|
||||
/// - Requires UIAA to verify password
|
||||
pub async fn upload_signing_keys_route(
|
||||
body: Ruma<upload_signing_keys::v3::Request>,
|
||||
body: Ruma<upload_signing_keys::v3::Request>,
|
||||
) -> Result<upload_signing_keys::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
|
||||
|
||||
// UIAA
|
||||
let mut uiaainfo = UiaaInfo {
|
||||
flows: vec![AuthFlow {
|
||||
stages: vec![AuthType::Password],
|
||||
}],
|
||||
completed: Vec::new(),
|
||||
params: Box::default(),
|
||||
session: None,
|
||||
auth_error: None,
|
||||
};
|
||||
// UIAA
|
||||
let mut uiaainfo = UiaaInfo {
|
||||
flows: vec![AuthFlow {
|
||||
stages: vec![AuthType::Password],
|
||||
}],
|
||||
completed: Vec::new(),
|
||||
params: Box::default(),
|
||||
session: None,
|
||||
auth_error: None,
|
||||
};
|
||||
|
||||
if let Some(auth) = &body.auth {
|
||||
let (worked, uiaainfo) =
|
||||
services()
|
||||
.uiaa
|
||||
.try_auth(sender_user, sender_device, auth, &uiaainfo)?;
|
||||
if !worked {
|
||||
return Err(Error::Uiaa(uiaainfo));
|
||||
}
|
||||
// Success!
|
||||
} else if let Some(json) = body.json_body {
|
||||
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
|
||||
services()
|
||||
.uiaa
|
||||
.create(sender_user, sender_device, &uiaainfo, &json)?;
|
||||
return Err(Error::Uiaa(uiaainfo));
|
||||
} else {
|
||||
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
|
||||
}
|
||||
if let Some(auth) = &body.auth {
|
||||
let (worked, uiaainfo) = services().uiaa.try_auth(sender_user, sender_device, auth, &uiaainfo)?;
|
||||
if !worked {
|
||||
return Err(Error::Uiaa(uiaainfo));
|
||||
}
|
||||
// Success!
|
||||
} else if let Some(json) = body.json_body {
|
||||
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
|
||||
services().uiaa.create(sender_user, sender_device, &uiaainfo, &json)?;
|
||||
return Err(Error::Uiaa(uiaainfo));
|
||||
} else {
|
||||
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
|
||||
}
|
||||
|
||||
if let Some(master_key) = &body.master_key {
|
||||
services().users.add_cross_signing_keys(
|
||||
sender_user,
|
||||
master_key,
|
||||
&body.self_signing_key,
|
||||
&body.user_signing_key,
|
||||
true, // notify so that other users see the new keys
|
||||
)?;
|
||||
}
|
||||
if let Some(master_key) = &body.master_key {
|
||||
services().users.add_cross_signing_keys(
|
||||
sender_user,
|
||||
master_key,
|
||||
&body.self_signing_key,
|
||||
&body.user_signing_key,
|
||||
true, // notify so that other users see the new keys
|
||||
)?;
|
||||
}
|
||||
|
||||
Ok(upload_signing_keys::v3::Response {})
|
||||
Ok(upload_signing_keys::v3::Response {})
|
||||
}
|
||||
|
||||
/// # `POST /_matrix/client/r0/keys/signatures/upload`
|
||||
///
|
||||
/// Uploads end-to-end key signatures from the sender user.
|
||||
pub async fn upload_signatures_route(
|
||||
body: Ruma<upload_signatures::v3::Request>,
|
||||
body: Ruma<upload_signatures::v3::Request>,
|
||||
) -> Result<upload_signatures::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
for (user_id, keys) in &body.signed_keys {
|
||||
for (key_id, key) in keys {
|
||||
let key = serde_json::to_value(key)
|
||||
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid key JSON"))?;
|
||||
for (user_id, keys) in &body.signed_keys {
|
||||
for (key_id, key) in keys {
|
||||
let key = serde_json::to_value(key)
|
||||
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid key JSON"))?;
|
||||
|
||||
for signature in key
|
||||
.get("signatures")
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Missing signatures field.",
|
||||
))?
|
||||
.get(sender_user.to_string())
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Invalid user in signatures field.",
|
||||
))?
|
||||
.as_object()
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Invalid signature.",
|
||||
))?
|
||||
.clone()
|
||||
.into_iter()
|
||||
{
|
||||
// Signature validation?
|
||||
let signature = (
|
||||
signature.0,
|
||||
signature
|
||||
.1
|
||||
.as_str()
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Invalid signature value.",
|
||||
))?
|
||||
.to_owned(),
|
||||
);
|
||||
services()
|
||||
.users
|
||||
.sign_key(user_id, key_id, signature, sender_user)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
for signature in key
|
||||
.get("signatures")
|
||||
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Missing signatures field."))?
|
||||
.get(sender_user.to_string())
|
||||
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Invalid user in signatures field."))?
|
||||
.as_object()
|
||||
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Invalid signature."))?
|
||||
.clone()
|
||||
.into_iter()
|
||||
{
|
||||
// Signature validation?
|
||||
let signature = (
|
||||
signature.0,
|
||||
signature
|
||||
.1
|
||||
.as_str()
|
||||
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Invalid signature value."))?
|
||||
.to_owned(),
|
||||
);
|
||||
services().users.sign_key(user_id, key_id, signature, sender_user)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(upload_signatures::v3::Response {
|
||||
failures: BTreeMap::new(), // TODO: integrate
|
||||
})
|
||||
Ok(upload_signatures::v3::Response {
|
||||
failures: BTreeMap::new(), // TODO: integrate
|
||||
})
|
||||
}
|
||||
|
||||
/// # `POST /_matrix/client/r0/keys/changes`
|
||||
///
|
||||
/// Gets a list of users who have updated their device identity keys since the previous sync token.
|
||||
/// Gets a list of users who have updated their device identity keys since the
|
||||
/// previous sync token.
|
||||
///
|
||||
/// - TODO: left users
|
||||
pub async fn get_key_changes_route(
|
||||
body: Ruma<get_key_changes::v3::Request>,
|
||||
) -> Result<get_key_changes::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
pub async fn get_key_changes_route(body: Ruma<get_key_changes::v3::Request>) -> Result<get_key_changes::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let mut device_list_updates = HashSet::new();
|
||||
let mut device_list_updates = HashSet::new();
|
||||
|
||||
device_list_updates.extend(
|
||||
services()
|
||||
.users
|
||||
.keys_changed(
|
||||
sender_user.as_str(),
|
||||
body.from
|
||||
.parse()
|
||||
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `from`."))?,
|
||||
Some(
|
||||
body.to
|
||||
.parse()
|
||||
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `to`."))?,
|
||||
),
|
||||
)
|
||||
.filter_map(std::result::Result::ok),
|
||||
);
|
||||
device_list_updates.extend(
|
||||
services()
|
||||
.users
|
||||
.keys_changed(
|
||||
sender_user.as_str(),
|
||||
body.from.parse().map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `from`."))?,
|
||||
Some(body.to.parse().map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `to`."))?),
|
||||
)
|
||||
.filter_map(std::result::Result::ok),
|
||||
);
|
||||
|
||||
for room_id in services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.rooms_joined(sender_user)
|
||||
.filter_map(std::result::Result::ok)
|
||||
{
|
||||
device_list_updates.extend(
|
||||
services()
|
||||
.users
|
||||
.keys_changed(
|
||||
room_id.as_ref(),
|
||||
body.from.parse().map_err(|_| {
|
||||
Error::BadRequest(ErrorKind::InvalidParam, "Invalid `from`.")
|
||||
})?,
|
||||
Some(body.to.parse().map_err(|_| {
|
||||
Error::BadRequest(ErrorKind::InvalidParam, "Invalid `to`.")
|
||||
})?),
|
||||
)
|
||||
.filter_map(std::result::Result::ok),
|
||||
);
|
||||
}
|
||||
Ok(get_key_changes::v3::Response {
|
||||
changed: device_list_updates.into_iter().collect(),
|
||||
left: Vec::new(), // TODO
|
||||
})
|
||||
for room_id in services().rooms.state_cache.rooms_joined(sender_user).filter_map(std::result::Result::ok) {
|
||||
device_list_updates.extend(
|
||||
services()
|
||||
.users
|
||||
.keys_changed(
|
||||
room_id.as_ref(),
|
||||
body.from.parse().map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `from`."))?,
|
||||
Some(body.to.parse().map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `to`."))?),
|
||||
)
|
||||
.filter_map(std::result::Result::ok),
|
||||
);
|
||||
}
|
||||
Ok(get_key_changes::v3::Response {
|
||||
changed: device_list_updates.into_iter().collect(),
|
||||
left: Vec::new(), // TODO
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
|
||||
sender_user: Option<&UserId>,
|
||||
device_keys_input: &BTreeMap<OwnedUserId, Vec<OwnedDeviceId>>,
|
||||
allowed_signatures: F,
|
||||
include_display_names: bool,
|
||||
sender_user: Option<&UserId>, device_keys_input: &BTreeMap<OwnedUserId, Vec<OwnedDeviceId>>, allowed_signatures: F,
|
||||
include_display_names: bool,
|
||||
) -> Result<get_keys::v3::Response> {
|
||||
let mut master_keys = BTreeMap::new();
|
||||
let mut self_signing_keys = BTreeMap::new();
|
||||
let mut user_signing_keys = BTreeMap::new();
|
||||
let mut device_keys = BTreeMap::new();
|
||||
let mut master_keys = BTreeMap::new();
|
||||
let mut self_signing_keys = BTreeMap::new();
|
||||
let mut user_signing_keys = BTreeMap::new();
|
||||
let mut device_keys = BTreeMap::new();
|
||||
|
||||
let mut get_over_federation = HashMap::new();
|
||||
let mut get_over_federation = HashMap::new();
|
||||
|
||||
for (user_id, device_ids) in device_keys_input {
|
||||
let user_id: &UserId = user_id;
|
||||
for (user_id, device_ids) in device_keys_input {
|
||||
let user_id: &UserId = user_id;
|
||||
|
||||
if user_id.server_name() != services().globals.server_name() {
|
||||
get_over_federation
|
||||
.entry(user_id.server_name())
|
||||
.or_insert_with(Vec::new)
|
||||
.push((user_id, device_ids));
|
||||
continue;
|
||||
}
|
||||
if user_id.server_name() != services().globals.server_name() {
|
||||
get_over_federation.entry(user_id.server_name()).or_insert_with(Vec::new).push((user_id, device_ids));
|
||||
continue;
|
||||
}
|
||||
|
||||
if device_ids.is_empty() {
|
||||
let mut container = BTreeMap::new();
|
||||
for device_id in services().users.all_device_ids(user_id) {
|
||||
let device_id = device_id?;
|
||||
if let Some(mut keys) = services().users.get_device_keys(user_id, &device_id)? {
|
||||
let metadata = services()
|
||||
.users
|
||||
.get_device_metadata(user_id, &device_id)?
|
||||
.ok_or_else(|| {
|
||||
Error::bad_database("all_device_keys contained nonexistent device.")
|
||||
})?;
|
||||
if device_ids.is_empty() {
|
||||
let mut container = BTreeMap::new();
|
||||
for device_id in services().users.all_device_ids(user_id) {
|
||||
let device_id = device_id?;
|
||||
if let Some(mut keys) = services().users.get_device_keys(user_id, &device_id)? {
|
||||
let metadata = services()
|
||||
.users
|
||||
.get_device_metadata(user_id, &device_id)?
|
||||
.ok_or_else(|| Error::bad_database("all_device_keys contained nonexistent device."))?;
|
||||
|
||||
add_unsigned_device_display_name(&mut keys, metadata, include_display_names)
|
||||
.map_err(|_| Error::bad_database("invalid device keys in database"))?;
|
||||
add_unsigned_device_display_name(&mut keys, metadata, include_display_names)
|
||||
.map_err(|_| Error::bad_database("invalid device keys in database"))?;
|
||||
|
||||
container.insert(device_id, keys);
|
||||
}
|
||||
}
|
||||
device_keys.insert(user_id.to_owned(), container);
|
||||
} else {
|
||||
for device_id in device_ids {
|
||||
let mut container = BTreeMap::new();
|
||||
if let Some(mut keys) = services().users.get_device_keys(user_id, device_id)? {
|
||||
let metadata = services()
|
||||
.users
|
||||
.get_device_metadata(user_id, device_id)?
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Tried to get keys for nonexistent device.",
|
||||
))?;
|
||||
container.insert(device_id, keys);
|
||||
}
|
||||
}
|
||||
device_keys.insert(user_id.to_owned(), container);
|
||||
} else {
|
||||
for device_id in device_ids {
|
||||
let mut container = BTreeMap::new();
|
||||
if let Some(mut keys) = services().users.get_device_keys(user_id, device_id)? {
|
||||
let metadata = services().users.get_device_metadata(user_id, device_id)?.ok_or(
|
||||
Error::BadRequest(ErrorKind::InvalidParam, "Tried to get keys for nonexistent device."),
|
||||
)?;
|
||||
|
||||
add_unsigned_device_display_name(&mut keys, metadata, include_display_names)
|
||||
.map_err(|_| Error::bad_database("invalid device keys in database"))?;
|
||||
container.insert(device_id.to_owned(), keys);
|
||||
}
|
||||
device_keys.insert(user_id.to_owned(), container);
|
||||
}
|
||||
}
|
||||
add_unsigned_device_display_name(&mut keys, metadata, include_display_names)
|
||||
.map_err(|_| Error::bad_database("invalid device keys in database"))?;
|
||||
container.insert(device_id.to_owned(), keys);
|
||||
}
|
||||
device_keys.insert(user_id.to_owned(), container);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(master_key) =
|
||||
services()
|
||||
.users
|
||||
.get_master_key(sender_user, user_id, &allowed_signatures)?
|
||||
{
|
||||
master_keys.insert(user_id.to_owned(), master_key);
|
||||
}
|
||||
if let Some(self_signing_key) =
|
||||
services()
|
||||
.users
|
||||
.get_self_signing_key(sender_user, user_id, &allowed_signatures)?
|
||||
{
|
||||
self_signing_keys.insert(user_id.to_owned(), self_signing_key);
|
||||
}
|
||||
if Some(user_id) == sender_user {
|
||||
if let Some(user_signing_key) = services().users.get_user_signing_key(user_id)? {
|
||||
user_signing_keys.insert(user_id.to_owned(), user_signing_key);
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some(master_key) = services().users.get_master_key(sender_user, user_id, &allowed_signatures)? {
|
||||
master_keys.insert(user_id.to_owned(), master_key);
|
||||
}
|
||||
if let Some(self_signing_key) =
|
||||
services().users.get_self_signing_key(sender_user, user_id, &allowed_signatures)?
|
||||
{
|
||||
self_signing_keys.insert(user_id.to_owned(), self_signing_key);
|
||||
}
|
||||
if Some(user_id) == sender_user {
|
||||
if let Some(user_signing_key) = services().users.get_user_signing_key(user_id)? {
|
||||
user_signing_keys.insert(user_id.to_owned(), user_signing_key);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut failures = BTreeMap::new();
|
||||
let mut failures = BTreeMap::new();
|
||||
|
||||
let back_off = |id| match services()
|
||||
.globals
|
||||
.bad_query_ratelimiter
|
||||
.write()
|
||||
.unwrap()
|
||||
.entry(id)
|
||||
{
|
||||
hash_map::Entry::Vacant(e) => {
|
||||
e.insert((Instant::now(), 1));
|
||||
}
|
||||
hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1),
|
||||
};
|
||||
let back_off = |id| match services().globals.bad_query_ratelimiter.write().unwrap().entry(id) {
|
||||
hash_map::Entry::Vacant(e) => {
|
||||
e.insert((Instant::now(), 1));
|
||||
},
|
||||
hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1),
|
||||
};
|
||||
|
||||
let mut futures: FuturesUnordered<_> = get_over_federation
|
||||
.into_iter()
|
||||
.map(|(server, vec)| async move {
|
||||
if let Some((time, tries)) = services()
|
||||
.globals
|
||||
.bad_query_ratelimiter
|
||||
.read()
|
||||
.unwrap()
|
||||
.get(server)
|
||||
{
|
||||
// Exponential backoff
|
||||
let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries);
|
||||
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
|
||||
min_elapsed_duration = Duration::from_secs(60 * 60 * 24);
|
||||
}
|
||||
let mut futures: FuturesUnordered<_> = get_over_federation
|
||||
.into_iter()
|
||||
.map(|(server, vec)| async move {
|
||||
if let Some((time, tries)) = services().globals.bad_query_ratelimiter.read().unwrap().get(server) {
|
||||
// Exponential backoff
|
||||
let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries);
|
||||
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
|
||||
min_elapsed_duration = Duration::from_secs(60 * 60 * 24);
|
||||
}
|
||||
|
||||
if time.elapsed() < min_elapsed_duration {
|
||||
debug!("Backing off query from {:?}", server);
|
||||
return (
|
||||
server,
|
||||
Err(Error::BadServerResponse("bad query, still backing off")),
|
||||
);
|
||||
}
|
||||
}
|
||||
if time.elapsed() < min_elapsed_duration {
|
||||
debug!("Backing off query from {:?}", server);
|
||||
return (server, Err(Error::BadServerResponse("bad query, still backing off")));
|
||||
}
|
||||
}
|
||||
|
||||
let mut device_keys_input_fed = BTreeMap::new();
|
||||
for (user_id, keys) in vec {
|
||||
device_keys_input_fed.insert(user_id.to_owned(), keys.clone());
|
||||
}
|
||||
(
|
||||
server,
|
||||
tokio::time::timeout(
|
||||
Duration::from_secs(50),
|
||||
services().sending.send_federation_request(
|
||||
server,
|
||||
federation::keys::get_keys::v1::Request {
|
||||
device_keys: device_keys_input_fed,
|
||||
},
|
||||
),
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!("get_keys_helper query took too long: {}", e);
|
||||
Error::BadServerResponse("get_keys_helper query took too long")
|
||||
}),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
let mut device_keys_input_fed = BTreeMap::new();
|
||||
for (user_id, keys) in vec {
|
||||
device_keys_input_fed.insert(user_id.to_owned(), keys.clone());
|
||||
}
|
||||
(
|
||||
server,
|
||||
tokio::time::timeout(
|
||||
Duration::from_secs(50),
|
||||
services().sending.send_federation_request(
|
||||
server,
|
||||
federation::keys::get_keys::v1::Request {
|
||||
device_keys: device_keys_input_fed,
|
||||
},
|
||||
),
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!("get_keys_helper query took too long: {}", e);
|
||||
Error::BadServerResponse("get_keys_helper query took too long")
|
||||
}),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
while let Some((server, response)) = futures.next().await {
|
||||
match response {
|
||||
Ok(Ok(response)) => {
|
||||
for (user, masterkey) in response.master_keys {
|
||||
let (master_key_id, mut master_key) =
|
||||
services().users.parse_master_key(&user, &masterkey)?;
|
||||
while let Some((server, response)) = futures.next().await {
|
||||
match response {
|
||||
Ok(Ok(response)) => {
|
||||
for (user, masterkey) in response.master_keys {
|
||||
let (master_key_id, mut master_key) = services().users.parse_master_key(&user, &masterkey)?;
|
||||
|
||||
if let Some(our_master_key) = services().users.get_key(
|
||||
&master_key_id,
|
||||
sender_user,
|
||||
&user,
|
||||
&allowed_signatures,
|
||||
)? {
|
||||
let (_, our_master_key) =
|
||||
services().users.parse_master_key(&user, &our_master_key)?;
|
||||
master_key.signatures.extend(our_master_key.signatures);
|
||||
}
|
||||
let json = serde_json::to_value(master_key).expect("to_value always works");
|
||||
let raw = serde_json::from_value(json).expect("Raw::from_value always works");
|
||||
services().users.add_cross_signing_keys(
|
||||
&user, &raw, &None, &None,
|
||||
false, // Dont notify. A notification would trigger another key request resulting in an endless loop
|
||||
)?;
|
||||
master_keys.insert(user, raw);
|
||||
}
|
||||
if let Some(our_master_key) =
|
||||
services().users.get_key(&master_key_id, sender_user, &user, &allowed_signatures)?
|
||||
{
|
||||
let (_, our_master_key) = services().users.parse_master_key(&user, &our_master_key)?;
|
||||
master_key.signatures.extend(our_master_key.signatures);
|
||||
}
|
||||
let json = serde_json::to_value(master_key).expect("to_value always works");
|
||||
let raw = serde_json::from_value(json).expect("Raw::from_value always works");
|
||||
services().users.add_cross_signing_keys(
|
||||
&user, &raw, &None, &None,
|
||||
false, /* Dont notify. A notification would trigger another key request resulting in an
|
||||
* endless loop */
|
||||
)?;
|
||||
master_keys.insert(user, raw);
|
||||
}
|
||||
|
||||
self_signing_keys.extend(response.self_signing_keys);
|
||||
device_keys.extend(response.device_keys);
|
||||
}
|
||||
_ => {
|
||||
back_off(server.to_owned());
|
||||
failures.insert(server.to_string(), json!({}));
|
||||
}
|
||||
}
|
||||
}
|
||||
self_signing_keys.extend(response.self_signing_keys);
|
||||
device_keys.extend(response.device_keys);
|
||||
},
|
||||
_ => {
|
||||
back_off(server.to_owned());
|
||||
failures.insert(server.to_string(), json!({}));
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
Ok(get_keys::v3::Response {
|
||||
master_keys,
|
||||
self_signing_keys,
|
||||
user_signing_keys,
|
||||
device_keys,
|
||||
failures,
|
||||
})
|
||||
Ok(get_keys::v3::Response {
|
||||
master_keys,
|
||||
self_signing_keys,
|
||||
user_signing_keys,
|
||||
device_keys,
|
||||
failures,
|
||||
})
|
||||
}
|
||||
|
||||
fn add_unsigned_device_display_name(
|
||||
keys: &mut Raw<ruma::encryption::DeviceKeys>,
|
||||
metadata: ruma::api::client::device::Device,
|
||||
include_display_names: bool,
|
||||
keys: &mut Raw<ruma::encryption::DeviceKeys>, metadata: ruma::api::client::device::Device,
|
||||
include_display_names: bool,
|
||||
) -> serde_json::Result<()> {
|
||||
if let Some(display_name) = metadata.display_name {
|
||||
let mut object = keys.deserialize_as::<serde_json::Map<String, serde_json::Value>>()?;
|
||||
if let Some(display_name) = metadata.display_name {
|
||||
let mut object = keys.deserialize_as::<serde_json::Map<String, serde_json::Value>>()?;
|
||||
|
||||
let unsigned = object.entry("unsigned").or_insert_with(|| json!({}));
|
||||
if let serde_json::Value::Object(unsigned_object) = unsigned {
|
||||
if include_display_names {
|
||||
unsigned_object.insert("device_display_name".to_owned(), display_name.into());
|
||||
} else {
|
||||
unsigned_object.insert(
|
||||
"device_display_name".to_owned(),
|
||||
Some(metadata.device_id.as_str().to_owned()).into(),
|
||||
);
|
||||
}
|
||||
}
|
||||
let unsigned = object.entry("unsigned").or_insert_with(|| json!({}));
|
||||
if let serde_json::Value::Object(unsigned_object) = unsigned {
|
||||
if include_display_names {
|
||||
unsigned_object.insert("device_display_name".to_owned(), display_name.into());
|
||||
} else {
|
||||
unsigned_object.insert(
|
||||
"device_display_name".to_owned(),
|
||||
Some(metadata.device_id.as_str().to_owned()).into(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
*keys = Raw::from_json(serde_json::value::to_raw_value(&object)?);
|
||||
}
|
||||
*keys = Raw::from_json(serde_json::value::to_raw_value(&object)?);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) async fn claim_keys_helper(
|
||||
one_time_keys_input: &BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, DeviceKeyAlgorithm>>,
|
||||
one_time_keys_input: &BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, DeviceKeyAlgorithm>>,
|
||||
) -> Result<claim_keys::v3::Response> {
|
||||
let mut one_time_keys = BTreeMap::new();
|
||||
let mut one_time_keys = BTreeMap::new();
|
||||
|
||||
let mut get_over_federation = BTreeMap::new();
|
||||
let mut get_over_federation = BTreeMap::new();
|
||||
|
||||
for (user_id, map) in one_time_keys_input {
|
||||
if user_id.server_name() != services().globals.server_name() {
|
||||
get_over_federation
|
||||
.entry(user_id.server_name())
|
||||
.or_insert_with(Vec::new)
|
||||
.push((user_id, map));
|
||||
}
|
||||
for (user_id, map) in one_time_keys_input {
|
||||
if user_id.server_name() != services().globals.server_name() {
|
||||
get_over_federation.entry(user_id.server_name()).or_insert_with(Vec::new).push((user_id, map));
|
||||
}
|
||||
|
||||
let mut container = BTreeMap::new();
|
||||
for (device_id, key_algorithm) in map {
|
||||
if let Some(one_time_keys) =
|
||||
services()
|
||||
.users
|
||||
.take_one_time_key(user_id, device_id, key_algorithm)?
|
||||
{
|
||||
let mut c = BTreeMap::new();
|
||||
c.insert(one_time_keys.0, one_time_keys.1);
|
||||
container.insert(device_id.clone(), c);
|
||||
}
|
||||
}
|
||||
one_time_keys.insert(user_id.clone(), container);
|
||||
}
|
||||
let mut container = BTreeMap::new();
|
||||
for (device_id, key_algorithm) in map {
|
||||
if let Some(one_time_keys) = services().users.take_one_time_key(user_id, device_id, key_algorithm)? {
|
||||
let mut c = BTreeMap::new();
|
||||
c.insert(one_time_keys.0, one_time_keys.1);
|
||||
container.insert(device_id.clone(), c);
|
||||
}
|
||||
}
|
||||
one_time_keys.insert(user_id.clone(), container);
|
||||
}
|
||||
|
||||
let mut failures = BTreeMap::new();
|
||||
let mut failures = BTreeMap::new();
|
||||
|
||||
let mut futures: FuturesUnordered<_> = get_over_federation
|
||||
.into_iter()
|
||||
.map(|(server, vec)| async move {
|
||||
let mut one_time_keys_input_fed = BTreeMap::new();
|
||||
for (user_id, keys) in vec {
|
||||
one_time_keys_input_fed.insert(user_id.clone(), keys.clone());
|
||||
}
|
||||
(
|
||||
server,
|
||||
services()
|
||||
.sending
|
||||
.send_federation_request(
|
||||
server,
|
||||
federation::keys::claim_keys::v1::Request {
|
||||
one_time_keys: one_time_keys_input_fed,
|
||||
},
|
||||
)
|
||||
.await,
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
let mut futures: FuturesUnordered<_> = get_over_federation
|
||||
.into_iter()
|
||||
.map(|(server, vec)| async move {
|
||||
let mut one_time_keys_input_fed = BTreeMap::new();
|
||||
for (user_id, keys) in vec {
|
||||
one_time_keys_input_fed.insert(user_id.clone(), keys.clone());
|
||||
}
|
||||
(
|
||||
server,
|
||||
services()
|
||||
.sending
|
||||
.send_federation_request(
|
||||
server,
|
||||
federation::keys::claim_keys::v1::Request {
|
||||
one_time_keys: one_time_keys_input_fed,
|
||||
},
|
||||
)
|
||||
.await,
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
while let Some((server, response)) = futures.next().await {
|
||||
match response {
|
||||
Ok(keys) => {
|
||||
one_time_keys.extend(keys.one_time_keys);
|
||||
}
|
||||
Err(_e) => {
|
||||
failures.insert(server.to_string(), json!({}));
|
||||
}
|
||||
}
|
||||
}
|
||||
while let Some((server, response)) = futures.next().await {
|
||||
match response {
|
||||
Ok(keys) => {
|
||||
one_time_keys.extend(keys.one_time_keys);
|
||||
},
|
||||
Err(_e) => {
|
||||
failures.insert(server.to_string(), json!({}));
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
Ok(claim_keys::v3::Response {
|
||||
failures,
|
||||
one_time_keys,
|
||||
})
|
||||
Ok(claim_keys::v3::Response {
|
||||
failures,
|
||||
one_time_keys,
|
||||
})
|
||||
}
|
||||
|
||||
+370
-443
@@ -1,22 +1,22 @@
|
||||
use std::{io::Cursor, net::IpAddr, sync::Arc, time::Duration};
|
||||
|
||||
use crate::{
|
||||
service::media::{FileMeta, UrlPreviewData},
|
||||
services, utils, Error, Result, Ruma,
|
||||
};
|
||||
use image::io::Reader as ImgReader;
|
||||
|
||||
use reqwest::Url;
|
||||
use ruma::api::client::{
|
||||
error::ErrorKind,
|
||||
media::{
|
||||
create_content, get_content, get_content_as_filename, get_content_thumbnail,
|
||||
get_media_config, get_media_preview,
|
||||
},
|
||||
error::ErrorKind,
|
||||
media::{
|
||||
create_content, get_content, get_content_as_filename, get_content_thumbnail, get_media_config,
|
||||
get_media_preview,
|
||||
},
|
||||
};
|
||||
use tracing::{debug, error, info, warn};
|
||||
use webpage::HTML;
|
||||
|
||||
use crate::{
|
||||
service::media::{FileMeta, UrlPreviewData},
|
||||
services, utils, Error, Result, Ruma,
|
||||
};
|
||||
|
||||
/// generated MXC ID (`media-id`) length
|
||||
const MXC_LENGTH: usize = 32;
|
||||
|
||||
@@ -24,48 +24,39 @@ const MXC_LENGTH: usize = 32;
|
||||
///
|
||||
/// Returns max upload size.
|
||||
pub async fn get_media_config_route(
|
||||
_body: Ruma<get_media_config::v3::Request>,
|
||||
_body: Ruma<get_media_config::v3::Request>,
|
||||
) -> Result<get_media_config::v3::Response> {
|
||||
Ok(get_media_config::v3::Response {
|
||||
upload_size: services().globals.max_request_size().into(),
|
||||
})
|
||||
Ok(get_media_config::v3::Response {
|
||||
upload_size: services().globals.max_request_size().into(),
|
||||
})
|
||||
}
|
||||
|
||||
/// # `GET /_matrix/media/v3/preview_url`
|
||||
///
|
||||
/// Returns URL preview.
|
||||
pub async fn get_media_preview_route(
|
||||
body: Ruma<get_media_preview::v3::Request>,
|
||||
body: Ruma<get_media_preview::v3::Request>,
|
||||
) -> Result<get_media_preview::v3::Response> {
|
||||
let url = &body.url;
|
||||
if !url_preview_allowed(url) {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"URL is not allowed to be previewed",
|
||||
));
|
||||
}
|
||||
let url = &body.url;
|
||||
if !url_preview_allowed(url) {
|
||||
return Err(Error::BadRequest(ErrorKind::Forbidden, "URL is not allowed to be previewed"));
|
||||
}
|
||||
|
||||
if let Ok(preview) = get_url_preview(url).await {
|
||||
let res = serde_json::value::to_raw_value(&preview).map_err(|e| {
|
||||
error!(
|
||||
"Failed to convert UrlPreviewData into a serde json value: {}",
|
||||
e
|
||||
);
|
||||
Error::BadRequest(
|
||||
ErrorKind::Unknown,
|
||||
"Unknown error occurred parsing URL preview",
|
||||
)
|
||||
})?;
|
||||
if let Ok(preview) = get_url_preview(url).await {
|
||||
let res = serde_json::value::to_raw_value(&preview).map_err(|e| {
|
||||
error!("Failed to convert UrlPreviewData into a serde json value: {}", e);
|
||||
Error::BadRequest(ErrorKind::Unknown, "Unknown error occurred parsing URL preview")
|
||||
})?;
|
||||
|
||||
return Ok(get_media_preview::v3::Response::from_raw_value(res));
|
||||
}
|
||||
return Ok(get_media_preview::v3::Response::from_raw_value(res));
|
||||
}
|
||||
|
||||
Err(Error::BadRequest(
|
||||
ErrorKind::LimitExceeded {
|
||||
retry_after_ms: Some(Duration::from_secs(5)),
|
||||
},
|
||||
"Retry later",
|
||||
))
|
||||
Err(Error::BadRequest(
|
||||
ErrorKind::LimitExceeded {
|
||||
retry_after_ms: Some(Duration::from_secs(5)),
|
||||
},
|
||||
"Retry later",
|
||||
))
|
||||
}
|
||||
|
||||
/// # `POST /_matrix/media/v3/upload`
|
||||
@@ -74,80 +65,70 @@ pub async fn get_media_preview_route(
|
||||
///
|
||||
/// - Some metadata will be saved in the database
|
||||
/// - Media will be saved in the media/ directory
|
||||
pub async fn create_content_route(
|
||||
body: Ruma<create_content::v3::Request>,
|
||||
) -> Result<create_content::v3::Response> {
|
||||
let mxc = format!(
|
||||
"mxc://{}/{}",
|
||||
services().globals.server_name(),
|
||||
utils::random_string(MXC_LENGTH)
|
||||
);
|
||||
pub async fn create_content_route(body: Ruma<create_content::v3::Request>) -> Result<create_content::v3::Response> {
|
||||
let mxc = format!(
|
||||
"mxc://{}/{}",
|
||||
services().globals.server_name(),
|
||||
utils::random_string(MXC_LENGTH)
|
||||
);
|
||||
|
||||
services()
|
||||
.media
|
||||
.create(
|
||||
mxc.clone(),
|
||||
body.filename
|
||||
.as_ref()
|
||||
.map(|filename| "inline; filename=".to_owned() + filename)
|
||||
.as_deref(),
|
||||
body.content_type.as_deref(),
|
||||
&body.file,
|
||||
)
|
||||
.await?;
|
||||
services()
|
||||
.media
|
||||
.create(
|
||||
mxc.clone(),
|
||||
body.filename.as_ref().map(|filename| "inline; filename=".to_owned() + filename).as_deref(),
|
||||
body.content_type.as_deref(),
|
||||
&body.file,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let content_uri = mxc.into();
|
||||
let content_uri = mxc.into();
|
||||
|
||||
Ok(create_content::v3::Response {
|
||||
content_uri,
|
||||
blurhash: None,
|
||||
})
|
||||
Ok(create_content::v3::Response {
|
||||
content_uri,
|
||||
blurhash: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// helper method to fetch remote media from other servers over federation
|
||||
pub async fn get_remote_content(
|
||||
mxc: &str,
|
||||
server_name: &ruma::ServerName,
|
||||
media_id: String,
|
||||
allow_redirect: bool,
|
||||
timeout_ms: Duration,
|
||||
mxc: &str, server_name: &ruma::ServerName, media_id: String, allow_redirect: bool, timeout_ms: Duration,
|
||||
) -> Result<get_content::v3::Response, Error> {
|
||||
// we'll lie to the client and say the blocked server's media was not found and log.
|
||||
// the client has no way of telling anyways so this is a security bonus.
|
||||
if services()
|
||||
.globals
|
||||
.prevent_media_downloads_from()
|
||||
.contains(&server_name.to_owned())
|
||||
{
|
||||
info!("Received request for remote media `{}` but server is in our media server blocklist. Returning 404.", mxc);
|
||||
return Err(Error::BadRequest(ErrorKind::NotFound, "Media not found."));
|
||||
}
|
||||
// we'll lie to the client and say the blocked server's media was not found and
|
||||
// log. the client has no way of telling anyways so this is a security bonus.
|
||||
if services().globals.prevent_media_downloads_from().contains(&server_name.to_owned()) {
|
||||
info!(
|
||||
"Received request for remote media `{}` but server is in our media server blocklist. Returning 404.",
|
||||
mxc
|
||||
);
|
||||
return Err(Error::BadRequest(ErrorKind::NotFound, "Media not found."));
|
||||
}
|
||||
|
||||
let content_response = services()
|
||||
.sending
|
||||
.send_federation_request(
|
||||
server_name,
|
||||
get_content::v3::Request {
|
||||
allow_remote: true,
|
||||
server_name: server_name.to_owned(),
|
||||
media_id,
|
||||
timeout_ms,
|
||||
allow_redirect,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
let content_response = services()
|
||||
.sending
|
||||
.send_federation_request(
|
||||
server_name,
|
||||
get_content::v3::Request {
|
||||
allow_remote: true,
|
||||
server_name: server_name.to_owned(),
|
||||
media_id,
|
||||
timeout_ms,
|
||||
allow_redirect,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
services()
|
||||
.media
|
||||
.create(
|
||||
mxc.to_owned(),
|
||||
content_response.content_disposition.as_deref(),
|
||||
content_response.content_type.as_deref(),
|
||||
&content_response.file,
|
||||
)
|
||||
.await?;
|
||||
services()
|
||||
.media
|
||||
.create(
|
||||
mxc.to_owned(),
|
||||
content_response.content_disposition.as_deref(),
|
||||
content_response.content_type.as_deref(),
|
||||
&content_response.file,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(content_response)
|
||||
Ok(content_response)
|
||||
}
|
||||
|
||||
/// # `GET /_matrix/media/v3/download/{serverName}/{mediaId}`
|
||||
@@ -156,37 +137,36 @@ pub async fn get_remote_content(
|
||||
///
|
||||
/// - Only allows federation if `allow_remote` is true
|
||||
/// - Only redirects if `allow_redirect` is true
|
||||
/// - Uses client-provided `timeout_ms` if available, else defaults to 20 seconds
|
||||
pub async fn get_content_route(
|
||||
body: Ruma<get_content::v3::Request>,
|
||||
) -> Result<get_content::v3::Response> {
|
||||
let mxc = format!("mxc://{}/{}", body.server_name, body.media_id);
|
||||
/// - Uses client-provided `timeout_ms` if available, else defaults to 20
|
||||
/// seconds
|
||||
pub async fn get_content_route(body: Ruma<get_content::v3::Request>) -> Result<get_content::v3::Response> {
|
||||
let mxc = format!("mxc://{}/{}", body.server_name, body.media_id);
|
||||
|
||||
if let Some(FileMeta {
|
||||
content_disposition,
|
||||
content_type,
|
||||
file,
|
||||
}) = services().media.get(mxc.clone()).await?
|
||||
{
|
||||
Ok(get_content::v3::Response {
|
||||
file,
|
||||
content_type,
|
||||
content_disposition,
|
||||
cross_origin_resource_policy: Some("cross-origin".to_owned()),
|
||||
})
|
||||
} else if &*body.server_name != services().globals.server_name() && body.allow_remote {
|
||||
let remote_content_response = get_remote_content(
|
||||
&mxc,
|
||||
&body.server_name,
|
||||
body.media_id.clone(),
|
||||
body.allow_redirect,
|
||||
body.timeout_ms,
|
||||
)
|
||||
.await?;
|
||||
Ok(remote_content_response)
|
||||
} else {
|
||||
Err(Error::BadRequest(ErrorKind::NotFound, "Media not found."))
|
||||
}
|
||||
if let Some(FileMeta {
|
||||
content_disposition,
|
||||
content_type,
|
||||
file,
|
||||
}) = services().media.get(mxc.clone()).await?
|
||||
{
|
||||
Ok(get_content::v3::Response {
|
||||
file,
|
||||
content_type,
|
||||
content_disposition,
|
||||
cross_origin_resource_policy: Some("cross-origin".to_owned()),
|
||||
})
|
||||
} else if &*body.server_name != services().globals.server_name() && body.allow_remote {
|
||||
let remote_content_response = get_remote_content(
|
||||
&mxc,
|
||||
&body.server_name,
|
||||
body.media_id.clone(),
|
||||
body.allow_redirect,
|
||||
body.timeout_ms,
|
||||
)
|
||||
.await?;
|
||||
Ok(remote_content_response)
|
||||
} else {
|
||||
Err(Error::BadRequest(ErrorKind::NotFound, "Media not found."))
|
||||
}
|
||||
}
|
||||
|
||||
/// # `GET /_matrix/media/v3/download/{serverName}/{mediaId}/{fileName}`
|
||||
@@ -195,41 +175,44 @@ pub async fn get_content_route(
|
||||
///
|
||||
/// - Only allows federation if `allow_remote` is true
|
||||
/// - Only redirects if `allow_redirect` is true
|
||||
/// - Uses client-provided `timeout_ms` if available, else defaults to 20 seconds
|
||||
/// - Uses client-provided `timeout_ms` if available, else defaults to 20
|
||||
/// seconds
|
||||
pub async fn get_content_as_filename_route(
|
||||
body: Ruma<get_content_as_filename::v3::Request>,
|
||||
body: Ruma<get_content_as_filename::v3::Request>,
|
||||
) -> Result<get_content_as_filename::v3::Response> {
|
||||
let mxc = format!("mxc://{}/{}", body.server_name, body.media_id);
|
||||
let mxc = format!("mxc://{}/{}", body.server_name, body.media_id);
|
||||
|
||||
if let Some(FileMeta {
|
||||
content_type, file, ..
|
||||
}) = services().media.get(mxc.clone()).await?
|
||||
{
|
||||
Ok(get_content_as_filename::v3::Response {
|
||||
file,
|
||||
content_type,
|
||||
content_disposition: Some(format!("inline; filename={}", body.filename)),
|
||||
cross_origin_resource_policy: Some("cross-origin".to_owned()),
|
||||
})
|
||||
} else if &*body.server_name != services().globals.server_name() && body.allow_remote {
|
||||
let remote_content_response = get_remote_content(
|
||||
&mxc,
|
||||
&body.server_name,
|
||||
body.media_id.clone(),
|
||||
body.allow_redirect,
|
||||
body.timeout_ms,
|
||||
)
|
||||
.await?;
|
||||
if let Some(FileMeta {
|
||||
content_type,
|
||||
file,
|
||||
..
|
||||
}) = services().media.get(mxc.clone()).await?
|
||||
{
|
||||
Ok(get_content_as_filename::v3::Response {
|
||||
file,
|
||||
content_type,
|
||||
content_disposition: Some(format!("inline; filename={}", body.filename)),
|
||||
cross_origin_resource_policy: Some("cross-origin".to_owned()),
|
||||
})
|
||||
} else if &*body.server_name != services().globals.server_name() && body.allow_remote {
|
||||
let remote_content_response = get_remote_content(
|
||||
&mxc,
|
||||
&body.server_name,
|
||||
body.media_id.clone(),
|
||||
body.allow_redirect,
|
||||
body.timeout_ms,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(get_content_as_filename::v3::Response {
|
||||
content_disposition: Some(format!("inline: filename={}", body.filename)),
|
||||
content_type: remote_content_response.content_type,
|
||||
file: remote_content_response.file,
|
||||
cross_origin_resource_policy: Some("cross-origin".to_owned()),
|
||||
})
|
||||
} else {
|
||||
Err(Error::BadRequest(ErrorKind::NotFound, "Media not found."))
|
||||
}
|
||||
Ok(get_content_as_filename::v3::Response {
|
||||
content_disposition: Some(format!("inline: filename={}", body.filename)),
|
||||
content_type: remote_content_response.content_type,
|
||||
file: remote_content_response.file,
|
||||
cross_origin_resource_policy: Some("cross-origin".to_owned()),
|
||||
})
|
||||
} else {
|
||||
Err(Error::BadRequest(ErrorKind::NotFound, "Media not found."))
|
||||
}
|
||||
}
|
||||
|
||||
/// # `GET /_matrix/media/v3/thumbnail/{serverName}/{mediaId}`
|
||||
@@ -238,157 +221,152 @@ pub async fn get_content_as_filename_route(
|
||||
///
|
||||
/// - Only allows federation if `allow_remote` is true
|
||||
/// - Only redirects if `allow_redirect` is true
|
||||
/// - Uses client-provided `timeout_ms` if available, else defaults to 20 seconds
|
||||
/// - Uses client-provided `timeout_ms` if available, else defaults to 20
|
||||
/// seconds
|
||||
pub async fn get_content_thumbnail_route(
|
||||
body: Ruma<get_content_thumbnail::v3::Request>,
|
||||
body: Ruma<get_content_thumbnail::v3::Request>,
|
||||
) -> Result<get_content_thumbnail::v3::Response> {
|
||||
let mxc = format!("mxc://{}/{}", body.server_name, body.media_id);
|
||||
let mxc = format!("mxc://{}/{}", body.server_name, body.media_id);
|
||||
|
||||
if let Some(FileMeta {
|
||||
content_type, file, ..
|
||||
}) = services()
|
||||
.media
|
||||
.get_thumbnail(
|
||||
mxc.clone(),
|
||||
body.width
|
||||
.try_into()
|
||||
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid."))?,
|
||||
body.height
|
||||
.try_into()
|
||||
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Height is invalid."))?,
|
||||
)
|
||||
.await?
|
||||
{
|
||||
Ok(get_content_thumbnail::v3::Response {
|
||||
file,
|
||||
content_type,
|
||||
cross_origin_resource_policy: Some("cross-origin".to_owned()),
|
||||
})
|
||||
} else if &*body.server_name != services().globals.server_name() && body.allow_remote {
|
||||
// we'll lie to the client and say the blocked server's media was not found and log.
|
||||
// the client has no way of telling anyways so this is a security bonus.
|
||||
if services()
|
||||
.globals
|
||||
.prevent_media_downloads_from()
|
||||
.contains(&body.server_name.clone())
|
||||
{
|
||||
info!("Received request for remote media `{}` but server is in our media server blocklist. Returning 404.", mxc);
|
||||
return Err(Error::BadRequest(ErrorKind::NotFound, "Media not found."));
|
||||
}
|
||||
if let Some(FileMeta {
|
||||
content_type,
|
||||
file,
|
||||
..
|
||||
}) = services()
|
||||
.media
|
||||
.get_thumbnail(
|
||||
mxc.clone(),
|
||||
body.width.try_into().map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid."))?,
|
||||
body.height.try_into().map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Height is invalid."))?,
|
||||
)
|
||||
.await?
|
||||
{
|
||||
Ok(get_content_thumbnail::v3::Response {
|
||||
file,
|
||||
content_type,
|
||||
cross_origin_resource_policy: Some("cross-origin".to_owned()),
|
||||
})
|
||||
} else if &*body.server_name != services().globals.server_name() && body.allow_remote {
|
||||
// we'll lie to the client and say the blocked server's media was not found and
|
||||
// log. the client has no way of telling anyways so this is a security bonus.
|
||||
if services().globals.prevent_media_downloads_from().contains(&body.server_name.clone()) {
|
||||
info!(
|
||||
"Received request for remote media `{}` but server is in our media server blocklist. Returning 404.",
|
||||
mxc
|
||||
);
|
||||
return Err(Error::BadRequest(ErrorKind::NotFound, "Media not found."));
|
||||
}
|
||||
|
||||
let get_thumbnail_response = services()
|
||||
.sending
|
||||
.send_federation_request(
|
||||
&body.server_name,
|
||||
get_content_thumbnail::v3::Request {
|
||||
allow_remote: body.allow_remote,
|
||||
height: body.height,
|
||||
width: body.width,
|
||||
method: body.method.clone(),
|
||||
server_name: body.server_name.clone(),
|
||||
media_id: body.media_id.clone(),
|
||||
timeout_ms: body.timeout_ms,
|
||||
allow_redirect: body.allow_redirect,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
let get_thumbnail_response = services()
|
||||
.sending
|
||||
.send_federation_request(
|
||||
&body.server_name,
|
||||
get_content_thumbnail::v3::Request {
|
||||
allow_remote: body.allow_remote,
|
||||
height: body.height,
|
||||
width: body.width,
|
||||
method: body.method.clone(),
|
||||
server_name: body.server_name.clone(),
|
||||
media_id: body.media_id.clone(),
|
||||
timeout_ms: body.timeout_ms,
|
||||
allow_redirect: body.allow_redirect,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
services()
|
||||
.media
|
||||
.upload_thumbnail(
|
||||
mxc,
|
||||
None,
|
||||
get_thumbnail_response.content_type.as_deref(),
|
||||
body.width.try_into().expect("all UInts are valid u32s"),
|
||||
body.height.try_into().expect("all UInts are valid u32s"),
|
||||
&get_thumbnail_response.file,
|
||||
)
|
||||
.await?;
|
||||
services()
|
||||
.media
|
||||
.upload_thumbnail(
|
||||
mxc,
|
||||
None,
|
||||
get_thumbnail_response.content_type.as_deref(),
|
||||
body.width.try_into().expect("all UInts are valid u32s"),
|
||||
body.height.try_into().expect("all UInts are valid u32s"),
|
||||
&get_thumbnail_response.file,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(get_thumbnail_response)
|
||||
} else {
|
||||
Err(Error::BadRequest(ErrorKind::NotFound, "Media not found."))
|
||||
}
|
||||
Ok(get_thumbnail_response)
|
||||
} else {
|
||||
Err(Error::BadRequest(ErrorKind::NotFound, "Media not found."))
|
||||
}
|
||||
}
|
||||
|
||||
async fn download_image(client: &reqwest::Client, url: &str) -> Result<UrlPreviewData> {
|
||||
let image = client.get(url).send().await?.bytes().await?;
|
||||
let mxc = format!(
|
||||
"mxc://{}/{}",
|
||||
services().globals.server_name(),
|
||||
utils::random_string(MXC_LENGTH)
|
||||
);
|
||||
let image = client.get(url).send().await?.bytes().await?;
|
||||
let mxc = format!(
|
||||
"mxc://{}/{}",
|
||||
services().globals.server_name(),
|
||||
utils::random_string(MXC_LENGTH)
|
||||
);
|
||||
|
||||
services()
|
||||
.media
|
||||
.create(mxc.clone(), None, None, &image)
|
||||
.await?;
|
||||
services().media.create(mxc.clone(), None, None, &image).await?;
|
||||
|
||||
let (width, height) = match ImgReader::new(Cursor::new(&image)).with_guessed_format() {
|
||||
Err(_) => (None, None),
|
||||
Ok(reader) => match reader.into_dimensions() {
|
||||
Err(_) => (None, None),
|
||||
Ok((width, height)) => (Some(width), Some(height)),
|
||||
},
|
||||
};
|
||||
let (width, height) = match ImgReader::new(Cursor::new(&image)).with_guessed_format() {
|
||||
Err(_) => (None, None),
|
||||
Ok(reader) => match reader.into_dimensions() {
|
||||
Err(_) => (None, None),
|
||||
Ok((width, height)) => (Some(width), Some(height)),
|
||||
},
|
||||
};
|
||||
|
||||
Ok(UrlPreviewData {
|
||||
image: Some(mxc),
|
||||
image_size: Some(image.len()),
|
||||
image_width: width,
|
||||
image_height: height,
|
||||
..Default::default()
|
||||
})
|
||||
Ok(UrlPreviewData {
|
||||
image: Some(mxc),
|
||||
image_size: Some(image.len()),
|
||||
image_width: width,
|
||||
image_height: height,
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
|
||||
async fn download_html(client: &reqwest::Client, url: &str) -> Result<UrlPreviewData> {
|
||||
let mut response = client.get(url).send().await?;
|
||||
let mut response = client.get(url).send().await?;
|
||||
|
||||
let mut bytes: Vec<u8> = Vec::new();
|
||||
while let Some(chunk) = response.chunk().await? {
|
||||
bytes.extend_from_slice(&chunk);
|
||||
if bytes.len() > services().globals.url_preview_max_spider_size() {
|
||||
debug!("Response body from URL {} exceeds url_preview_max_spider_size ({}), not processing the rest of the response body and assuming our necessary data is in this range.", url, services().globals.url_preview_max_spider_size());
|
||||
break;
|
||||
}
|
||||
}
|
||||
let body = String::from_utf8_lossy(&bytes);
|
||||
let html = match HTML::from_string(body.to_string(), Some(url.to_owned())) {
|
||||
Ok(html) => html,
|
||||
Err(_) => {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Unknown,
|
||||
"Failed to parse HTML",
|
||||
))
|
||||
}
|
||||
};
|
||||
let mut bytes: Vec<u8> = Vec::new();
|
||||
while let Some(chunk) = response.chunk().await? {
|
||||
bytes.extend_from_slice(&chunk);
|
||||
if bytes.len() > services().globals.url_preview_max_spider_size() {
|
||||
debug!(
|
||||
"Response body from URL {} exceeds url_preview_max_spider_size ({}), not processing the rest of the \
|
||||
response body and assuming our necessary data is in this range.",
|
||||
url,
|
||||
services().globals.url_preview_max_spider_size()
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
let body = String::from_utf8_lossy(&bytes);
|
||||
let html = match HTML::from_string(body.to_string(), Some(url.to_owned())) {
|
||||
Ok(html) => html,
|
||||
Err(_) => return Err(Error::BadRequest(ErrorKind::Unknown, "Failed to parse HTML")),
|
||||
};
|
||||
|
||||
let mut data = match html.opengraph.images.first() {
|
||||
None => UrlPreviewData::default(),
|
||||
Some(obj) => download_image(client, &obj.url).await?,
|
||||
};
|
||||
let mut data = match html.opengraph.images.first() {
|
||||
None => UrlPreviewData::default(),
|
||||
Some(obj) => download_image(client, &obj.url).await?,
|
||||
};
|
||||
|
||||
let props = html.opengraph.properties;
|
||||
let props = html.opengraph.properties;
|
||||
|
||||
/* use OpenGraph title/description, but fall back to HTML if not available */
|
||||
data.title = props.get("title").cloned().or(html.title);
|
||||
data.description = props.get("description").cloned().or(html.description);
|
||||
/* use OpenGraph title/description, but fall back to HTML if not available */
|
||||
data.title = props.get("title").cloned().or(html.title);
|
||||
data.description = props.get("description").cloned().or(html.description);
|
||||
|
||||
Ok(data)
|
||||
Ok(data)
|
||||
}
|
||||
|
||||
fn url_request_allowed(addr: &IpAddr) -> bool {
|
||||
// TODO: make this check ip_range_denylist
|
||||
// TODO: make this check ip_range_denylist
|
||||
|
||||
// could be implemented with reqwest when it supports IP filtering:
|
||||
// https://github.com/seanmonstar/reqwest/issues/1515
|
||||
// could be implemented with reqwest when it supports IP filtering:
|
||||
// https://github.com/seanmonstar/reqwest/issues/1515
|
||||
|
||||
// These checks have been taken from the Rust core/net/ipaddr.rs crate,
|
||||
// IpAddr::V4.is_global() and IpAddr::V6.is_global(), as .is_global is not
|
||||
// yet stabilized. TODO: Once this is stable, this match can be simplified.
|
||||
match addr {
|
||||
IpAddr::V4(ip4) => {
|
||||
!(ip4.octets()[0] == 0 // "This network"
|
||||
// These checks have been taken from the Rust core/net/ipaddr.rs crate,
|
||||
// IpAddr::V4.is_global() and IpAddr::V6.is_global(), as .is_global is not
|
||||
// yet stabilized. TODO: Once this is stable, this match can be simplified.
|
||||
match addr {
|
||||
IpAddr::V4(ip4) => {
|
||||
!(ip4.octets()[0] == 0 // "This network"
|
||||
|| ip4.is_private()
|
||||
|| (ip4.octets()[0] == 100 && (ip4.octets()[1] & 0b1100_0000 == 0b0100_0000)) // is_shared()
|
||||
|| ip4.is_loopback()
|
||||
@@ -399,9 +377,9 @@ fn url_request_allowed(addr: &IpAddr) -> bool {
|
||||
|| (ip4.octets()[0] == 198 && (ip4.octets()[1] & 0xfe) == 18) // is_benchmarking()
|
||||
|| (ip4.octets()[0] & 240 == 240 && !ip4.is_broadcast()) // is_reserved()
|
||||
|| ip4.is_broadcast())
|
||||
}
|
||||
IpAddr::V6(ip6) => {
|
||||
!(ip6.is_unspecified()
|
||||
},
|
||||
IpAddr::V6(ip6) => {
|
||||
!(ip6.is_unspecified()
|
||||
|| ip6.is_loopback()
|
||||
// IPv4-mapped Address (`::ffff:0:0/96`)
|
||||
|| matches!(ip6.segments(), [0, 0, 0, 0, 0, 0xffff, _, _])
|
||||
@@ -426,178 +404,127 @@ fn url_request_allowed(addr: &IpAddr) -> bool {
|
||||
|| ((ip6.segments()[0] == 0x2001) && (ip6.segments()[1] == 0xdb8)) // is_documentation()
|
||||
|| ((ip6.segments()[0] & 0xfe00) == 0xfc00) // is_unique_local()
|
||||
|| ((ip6.segments()[0] & 0xffc0) == 0xfe80)) // is_unicast_link_local
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
async fn request_url_preview(url: &str) -> Result<UrlPreviewData> {
|
||||
let client = services().globals.url_preview_client();
|
||||
let response = client.head(url).send().await?;
|
||||
let client = services().globals.url_preview_client();
|
||||
let response = client.head(url).send().await?;
|
||||
|
||||
if !response
|
||||
.remote_addr()
|
||||
.map_or(false, |a| url_request_allowed(&a.ip()))
|
||||
{
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"Requesting from this address is forbidden",
|
||||
));
|
||||
}
|
||||
if !response.remote_addr().map_or(false, |a| url_request_allowed(&a.ip())) {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"Requesting from this address is forbidden",
|
||||
));
|
||||
}
|
||||
|
||||
let content_type = match response
|
||||
.headers()
|
||||
.get(reqwest::header::CONTENT_TYPE)
|
||||
.and_then(|x| x.to_str().ok())
|
||||
{
|
||||
Some(ct) => ct,
|
||||
None => {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Unknown,
|
||||
"Unknown Content-Type",
|
||||
))
|
||||
}
|
||||
};
|
||||
let data = match content_type {
|
||||
html if html.starts_with("text/html") => download_html(&client, url).await?,
|
||||
img if img.starts_with("image/") => download_image(&client, url).await?,
|
||||
_ => {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Unknown,
|
||||
"Unsupported Content-Type",
|
||||
))
|
||||
}
|
||||
};
|
||||
let content_type = match response.headers().get(reqwest::header::CONTENT_TYPE).and_then(|x| x.to_str().ok()) {
|
||||
Some(ct) => ct,
|
||||
None => return Err(Error::BadRequest(ErrorKind::Unknown, "Unknown Content-Type")),
|
||||
};
|
||||
let data = match content_type {
|
||||
html if html.starts_with("text/html") => download_html(&client, url).await?,
|
||||
img if img.starts_with("image/") => download_image(&client, url).await?,
|
||||
_ => return Err(Error::BadRequest(ErrorKind::Unknown, "Unsupported Content-Type")),
|
||||
};
|
||||
|
||||
services().media.set_url_preview(url, &data).await?;
|
||||
services().media.set_url_preview(url, &data).await?;
|
||||
|
||||
Ok(data)
|
||||
Ok(data)
|
||||
}
|
||||
|
||||
async fn get_url_preview(url: &str) -> Result<UrlPreviewData> {
|
||||
if let Some(preview) = services().media.get_url_preview(url).await {
|
||||
return Ok(preview);
|
||||
}
|
||||
if let Some(preview) = services().media.get_url_preview(url).await {
|
||||
return Ok(preview);
|
||||
}
|
||||
|
||||
// ensure that only one request is made per URL
|
||||
let mutex_request = Arc::clone(
|
||||
services()
|
||||
.media
|
||||
.url_preview_mutex
|
||||
.write()
|
||||
.unwrap()
|
||||
.entry(url.to_owned())
|
||||
.or_default(),
|
||||
);
|
||||
let _request_lock = mutex_request.lock().await;
|
||||
// ensure that only one request is made per URL
|
||||
let mutex_request =
|
||||
Arc::clone(services().media.url_preview_mutex.write().unwrap().entry(url.to_owned()).or_default());
|
||||
let _request_lock = mutex_request.lock().await;
|
||||
|
||||
match services().media.get_url_preview(url).await {
|
||||
Some(preview) => Ok(preview),
|
||||
None => request_url_preview(url).await,
|
||||
}
|
||||
match services().media.get_url_preview(url).await {
|
||||
Some(preview) => Ok(preview),
|
||||
None => request_url_preview(url).await,
|
||||
}
|
||||
}
|
||||
|
||||
fn url_preview_allowed(url_str: &str) -> bool {
|
||||
let url: Url = match Url::parse(url_str) {
|
||||
Ok(u) => u,
|
||||
Err(e) => {
|
||||
warn!("Failed to parse URL from a str: {}", e);
|
||||
return false;
|
||||
}
|
||||
};
|
||||
let url: Url = match Url::parse(url_str) {
|
||||
Ok(u) => u,
|
||||
Err(e) => {
|
||||
warn!("Failed to parse URL from a str: {}", e);
|
||||
return false;
|
||||
},
|
||||
};
|
||||
|
||||
if ["http", "https"]
|
||||
.iter()
|
||||
.all(|&scheme| scheme != url.scheme().to_lowercase())
|
||||
{
|
||||
debug!("Ignoring non-HTTP/HTTPS URL to preview: {}", url);
|
||||
return false;
|
||||
}
|
||||
if ["http", "https"].iter().all(|&scheme| scheme != url.scheme().to_lowercase()) {
|
||||
debug!("Ignoring non-HTTP/HTTPS URL to preview: {}", url);
|
||||
return false;
|
||||
}
|
||||
|
||||
let host = match url.host_str() {
|
||||
None => {
|
||||
debug!(
|
||||
"Ignoring URL preview for a URL that does not have a host (?): {}",
|
||||
url
|
||||
);
|
||||
return false;
|
||||
}
|
||||
Some(h) => h.to_owned(),
|
||||
};
|
||||
let host = match url.host_str() {
|
||||
None => {
|
||||
debug!("Ignoring URL preview for a URL that does not have a host (?): {}", url);
|
||||
return false;
|
||||
},
|
||||
Some(h) => h.to_owned(),
|
||||
};
|
||||
|
||||
let allowlist_domain_contains = services().globals.url_preview_domain_contains_allowlist();
|
||||
let allowlist_domain_explicit = services().globals.url_preview_domain_explicit_allowlist();
|
||||
let allowlist_url_contains = services().globals.url_preview_url_contains_allowlist();
|
||||
let allowlist_domain_contains = services().globals.url_preview_domain_contains_allowlist();
|
||||
let allowlist_domain_explicit = services().globals.url_preview_domain_explicit_allowlist();
|
||||
let allowlist_url_contains = services().globals.url_preview_url_contains_allowlist();
|
||||
|
||||
if allowlist_domain_contains.contains(&"*".to_owned())
|
||||
|| allowlist_domain_explicit.contains(&"*".to_owned())
|
||||
|| allowlist_url_contains.contains(&"*".to_owned())
|
||||
{
|
||||
debug!(
|
||||
"Config key contains * which is allowing all URL previews. Allowing URL {}",
|
||||
url
|
||||
);
|
||||
return true;
|
||||
}
|
||||
if allowlist_domain_contains.contains(&"*".to_owned())
|
||||
|| allowlist_domain_explicit.contains(&"*".to_owned())
|
||||
|| allowlist_url_contains.contains(&"*".to_owned())
|
||||
{
|
||||
debug!("Config key contains * which is allowing all URL previews. Allowing URL {}", url);
|
||||
return true;
|
||||
}
|
||||
|
||||
if !host.is_empty() {
|
||||
if allowlist_domain_explicit.contains(&host) {
|
||||
debug!(
|
||||
"Host {} is allowed by url_preview_domain_explicit_allowlist (check 1/3)",
|
||||
&host
|
||||
);
|
||||
return true;
|
||||
}
|
||||
if !host.is_empty() {
|
||||
if allowlist_domain_explicit.contains(&host) {
|
||||
debug!("Host {} is allowed by url_preview_domain_explicit_allowlist (check 1/3)", &host);
|
||||
return true;
|
||||
}
|
||||
|
||||
if allowlist_domain_contains
|
||||
.iter()
|
||||
.any(|domain_s| domain_s.contains(&host.clone()))
|
||||
{
|
||||
debug!(
|
||||
"Host {} is allowed by url_preview_domain_contains_allowlist (check 2/3)",
|
||||
&host
|
||||
);
|
||||
return true;
|
||||
}
|
||||
if allowlist_domain_contains.iter().any(|domain_s| domain_s.contains(&host.clone())) {
|
||||
debug!("Host {} is allowed by url_preview_domain_contains_allowlist (check 2/3)", &host);
|
||||
return true;
|
||||
}
|
||||
|
||||
if allowlist_url_contains
|
||||
.iter()
|
||||
.any(|url_s| url.to_string().contains(&url_s.to_string()))
|
||||
{
|
||||
debug!(
|
||||
"URL {} is allowed by url_preview_url_contains_allowlist (check 3/3)",
|
||||
&host
|
||||
);
|
||||
return true;
|
||||
}
|
||||
if allowlist_url_contains.iter().any(|url_s| url.to_string().contains(&url_s.to_string())) {
|
||||
debug!("URL {} is allowed by url_preview_url_contains_allowlist (check 3/3)", &host);
|
||||
return true;
|
||||
}
|
||||
|
||||
// check root domain if available and if user has root domain checks
|
||||
if services().globals.url_preview_check_root_domain() {
|
||||
debug!("Checking root domain");
|
||||
match host.split_once('.') {
|
||||
None => return false,
|
||||
Some((_, root_domain)) => {
|
||||
if allowlist_domain_explicit.contains(&root_domain.to_owned()) {
|
||||
debug!(
|
||||
"Root domain {} is allowed by url_preview_domain_explicit_allowlist (check 1/3)",
|
||||
&root_domain
|
||||
);
|
||||
return true;
|
||||
}
|
||||
// check root domain if available and if user has root domain checks
|
||||
if services().globals.url_preview_check_root_domain() {
|
||||
debug!("Checking root domain");
|
||||
match host.split_once('.') {
|
||||
None => return false,
|
||||
Some((_, root_domain)) => {
|
||||
if allowlist_domain_explicit.contains(&root_domain.to_owned()) {
|
||||
debug!(
|
||||
"Root domain {} is allowed by url_preview_domain_explicit_allowlist (check 1/3)",
|
||||
&root_domain
|
||||
);
|
||||
return true;
|
||||
}
|
||||
|
||||
if allowlist_domain_contains
|
||||
.iter()
|
||||
.any(|domain_s| domain_s.contains(&root_domain.to_owned()))
|
||||
{
|
||||
debug!(
|
||||
"Root domain {} is allowed by url_preview_domain_contains_allowlist (check 2/3)",
|
||||
&root_domain
|
||||
);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if allowlist_domain_contains.iter().any(|domain_s| domain_s.contains(&root_domain.to_owned())) {
|
||||
debug!(
|
||||
"Root domain {} is allowed by url_preview_domain_contains_allowlist (check 2/3)",
|
||||
&root_domain
|
||||
);
|
||||
return true;
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
false
|
||||
}
|
||||
|
||||
+1216
-1471
File diff suppressed because it is too large
Load Diff
+228
-260
@@ -1,316 +1,284 @@
|
||||
use crate::{
|
||||
service::{pdu::PduBuilder, rooms::timeline::PduCount},
|
||||
services, utils, Error, Result, Ruma,
|
||||
use std::{
|
||||
collections::{BTreeMap, HashSet},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use ruma::{
|
||||
api::client::{
|
||||
error::ErrorKind,
|
||||
message::{get_message_events, send_message_event},
|
||||
},
|
||||
events::{StateEventType, TimelineEventType},
|
||||
api::client::{
|
||||
error::ErrorKind,
|
||||
message::{get_message_events, send_message_event},
|
||||
},
|
||||
events::{StateEventType, TimelineEventType},
|
||||
};
|
||||
use serde_json::from_str;
|
||||
use std::{
|
||||
collections::{BTreeMap, HashSet},
|
||||
sync::Arc,
|
||||
|
||||
use crate::{
|
||||
service::{pdu::PduBuilder, rooms::timeline::PduCount},
|
||||
services, utils, Error, Result, Ruma,
|
||||
};
|
||||
|
||||
/// # `PUT /_matrix/client/v3/rooms/{roomId}/send/{eventType}/{txnId}`
|
||||
///
|
||||
/// Send a message event into the room.
|
||||
///
|
||||
/// - Is a NOOP if the txn id was already used before and returns the same event id again
|
||||
/// - Is a NOOP if the txn id was already used before and returns the same event
|
||||
/// id again
|
||||
/// - The only requirement for the content is that it has to be valid json
|
||||
/// - Tries to send the event into the room, auth rules will determine if it is allowed
|
||||
/// - Tries to send the event into the room, auth rules will determine if it is
|
||||
/// allowed
|
||||
pub async fn send_message_event_route(
|
||||
body: Ruma<send_message_event::v3::Request>,
|
||||
body: Ruma<send_message_event::v3::Request>,
|
||||
) -> Result<send_message_event::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_device = body.sender_device.as_deref();
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_device = body.sender_device.as_deref();
|
||||
|
||||
let mutex_state = Arc::clone(
|
||||
services()
|
||||
.globals
|
||||
.roomid_mutex_state
|
||||
.write()
|
||||
.unwrap()
|
||||
.entry(body.room_id.clone())
|
||||
.or_default(),
|
||||
);
|
||||
let state_lock = mutex_state.lock().await;
|
||||
let mutex_state =
|
||||
Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(body.room_id.clone()).or_default());
|
||||
let state_lock = mutex_state.lock().await;
|
||||
|
||||
// Forbid m.room.encrypted if encryption is disabled
|
||||
if TimelineEventType::RoomEncrypted == body.event_type.to_string().into()
|
||||
&& !services().globals.allow_encryption()
|
||||
{
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"Encryption has been disabled",
|
||||
));
|
||||
}
|
||||
// Forbid m.room.encrypted if encryption is disabled
|
||||
if TimelineEventType::RoomEncrypted == body.event_type.to_string().into() && !services().globals.allow_encryption()
|
||||
{
|
||||
return Err(Error::BadRequest(ErrorKind::Forbidden, "Encryption has been disabled"));
|
||||
}
|
||||
|
||||
// certain event types require certain fields to be valid in request bodies.
|
||||
// this helps prevent attempting to handle events that we can't deserialise later so don't waste resources on it.
|
||||
//
|
||||
// see https://spec.matrix.org/v1.9/client-server-api/#events-2 for what's required per event type.
|
||||
match body.event_type.to_string().into() {
|
||||
TimelineEventType::RoomMessage => {
|
||||
let body_field = body.body.body.get_field::<String>("body");
|
||||
let msgtype_field = body.body.body.get_field::<String>("msgtype");
|
||||
// certain event types require certain fields to be valid in request bodies.
|
||||
// this helps prevent attempting to handle events that we can't deserialise
|
||||
// later so don't waste resources on it.
|
||||
//
|
||||
// see https://spec.matrix.org/v1.9/client-server-api/#events-2 for what's required per event type.
|
||||
match body.event_type.to_string().into() {
|
||||
TimelineEventType::RoomMessage => {
|
||||
let body_field = body.body.body.get_field::<String>("body");
|
||||
let msgtype_field = body.body.body.get_field::<String>("msgtype");
|
||||
|
||||
if body_field.is_err() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"'body' field in JSON request is invalid",
|
||||
));
|
||||
}
|
||||
if body_field.is_err() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"'body' field in JSON request is invalid",
|
||||
));
|
||||
}
|
||||
|
||||
if msgtype_field.is_err() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"'msgtype' field in JSON request is invalid",
|
||||
));
|
||||
}
|
||||
}
|
||||
TimelineEventType::RoomName => {
|
||||
let name_field = body.body.body.get_field::<String>("name");
|
||||
if msgtype_field.is_err() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"'msgtype' field in JSON request is invalid",
|
||||
));
|
||||
}
|
||||
},
|
||||
TimelineEventType::RoomName => {
|
||||
let name_field = body.body.body.get_field::<String>("name");
|
||||
|
||||
if name_field.is_err() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"'name' field in JSON request is invalid",
|
||||
));
|
||||
}
|
||||
}
|
||||
TimelineEventType::RoomTopic => {
|
||||
let topic_field = body.body.body.get_field::<String>("topic");
|
||||
if name_field.is_err() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"'name' field in JSON request is invalid",
|
||||
));
|
||||
}
|
||||
},
|
||||
TimelineEventType::RoomTopic => {
|
||||
let topic_field = body.body.body.get_field::<String>("topic");
|
||||
|
||||
if topic_field.is_err() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"'topic' field in JSON request is invalid",
|
||||
));
|
||||
}
|
||||
}
|
||||
_ => {} // event may be custom/experimental or can be empty don't do anything with it
|
||||
};
|
||||
if topic_field.is_err() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"'topic' field in JSON request is invalid",
|
||||
));
|
||||
}
|
||||
},
|
||||
_ => {}, // event may be custom/experimental or can be empty don't do anything with it
|
||||
};
|
||||
|
||||
// Check if this is a new transaction id
|
||||
if let Some(response) =
|
||||
services()
|
||||
.transaction_ids
|
||||
.existing_txnid(sender_user, sender_device, &body.txn_id)?
|
||||
{
|
||||
// The client might have sent a txnid of the /sendToDevice endpoint
|
||||
// This txnid has no response associated with it
|
||||
if response.is_empty() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Tried to use txn id already used for an incompatible endpoint.",
|
||||
));
|
||||
}
|
||||
// Check if this is a new transaction id
|
||||
if let Some(response) = services().transaction_ids.existing_txnid(sender_user, sender_device, &body.txn_id)? {
|
||||
// The client might have sent a txnid of the /sendToDevice endpoint
|
||||
// This txnid has no response associated with it
|
||||
if response.is_empty() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Tried to use txn id already used for an incompatible endpoint.",
|
||||
));
|
||||
}
|
||||
|
||||
let event_id = utils::string_from_bytes(&response)
|
||||
.map_err(|_| Error::bad_database("Invalid txnid bytes in database."))?
|
||||
.try_into()
|
||||
.map_err(|_| Error::bad_database("Invalid event id in txnid data."))?;
|
||||
return Ok(send_message_event::v3::Response { event_id });
|
||||
}
|
||||
let event_id = utils::string_from_bytes(&response)
|
||||
.map_err(|_| Error::bad_database("Invalid txnid bytes in database."))?
|
||||
.try_into()
|
||||
.map_err(|_| Error::bad_database("Invalid event id in txnid data."))?;
|
||||
return Ok(send_message_event::v3::Response {
|
||||
event_id,
|
||||
});
|
||||
}
|
||||
|
||||
let mut unsigned = BTreeMap::new();
|
||||
unsigned.insert("transaction_id".to_owned(), body.txn_id.to_string().into());
|
||||
let mut unsigned = BTreeMap::new();
|
||||
unsigned.insert("transaction_id".to_owned(), body.txn_id.to_string().into());
|
||||
|
||||
let event_id = services()
|
||||
.rooms
|
||||
.timeline
|
||||
.build_and_append_pdu(
|
||||
PduBuilder {
|
||||
event_type: body.event_type.to_string().into(),
|
||||
content: from_str(body.body.body.json().get())
|
||||
.map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Invalid JSON body."))?,
|
||||
unsigned: Some(unsigned),
|
||||
state_key: None,
|
||||
redacts: None,
|
||||
},
|
||||
sender_user,
|
||||
&body.room_id,
|
||||
&state_lock,
|
||||
)
|
||||
.await?;
|
||||
let event_id = services()
|
||||
.rooms
|
||||
.timeline
|
||||
.build_and_append_pdu(
|
||||
PduBuilder {
|
||||
event_type: body.event_type.to_string().into(),
|
||||
content: from_str(body.body.body.json().get())
|
||||
.map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Invalid JSON body."))?,
|
||||
unsigned: Some(unsigned),
|
||||
state_key: None,
|
||||
redacts: None,
|
||||
},
|
||||
sender_user,
|
||||
&body.room_id,
|
||||
&state_lock,
|
||||
)
|
||||
.await?;
|
||||
|
||||
services().transaction_ids.add_txnid(
|
||||
sender_user,
|
||||
sender_device,
|
||||
&body.txn_id,
|
||||
event_id.as_bytes(),
|
||||
)?;
|
||||
services().transaction_ids.add_txnid(sender_user, sender_device, &body.txn_id, event_id.as_bytes())?;
|
||||
|
||||
drop(state_lock);
|
||||
drop(state_lock);
|
||||
|
||||
Ok(send_message_event::v3::Response::new(
|
||||
(*event_id).to_owned(),
|
||||
))
|
||||
Ok(send_message_event::v3::Response::new((*event_id).to_owned()))
|
||||
}
|
||||
|
||||
/// # `GET /_matrix/client/r0/rooms/{roomId}/messages`
|
||||
///
|
||||
/// Allows paginating through room history.
|
||||
///
|
||||
/// - Only works if the user is joined (TODO: always allow, but only show events where the user was
|
||||
/// - Only works if the user is joined (TODO: always allow, but only show events
|
||||
/// where the user was
|
||||
/// joined, depending on history_visibility)
|
||||
pub async fn get_message_events_route(
|
||||
body: Ruma<get_message_events::v3::Request>,
|
||||
body: Ruma<get_message_events::v3::Request>,
|
||||
) -> Result<get_message_events::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
|
||||
|
||||
let from = match body.from.clone() {
|
||||
Some(from) => PduCount::try_from_string(&from)?,
|
||||
None => match body.dir {
|
||||
ruma::api::Direction::Forward => PduCount::min(),
|
||||
ruma::api::Direction::Backward => PduCount::max(),
|
||||
},
|
||||
};
|
||||
let from = match body.from.clone() {
|
||||
Some(from) => PduCount::try_from_string(&from)?,
|
||||
None => match body.dir {
|
||||
ruma::api::Direction::Forward => PduCount::min(),
|
||||
ruma::api::Direction::Backward => PduCount::max(),
|
||||
},
|
||||
};
|
||||
|
||||
let to = body
|
||||
.to
|
||||
.as_ref()
|
||||
.and_then(|t| PduCount::try_from_string(t).ok());
|
||||
let to = body.to.as_ref().and_then(|t| PduCount::try_from_string(t).ok());
|
||||
|
||||
services().rooms.lazy_loading.lazy_load_confirm_delivery(
|
||||
sender_user,
|
||||
sender_device,
|
||||
&body.room_id,
|
||||
from,
|
||||
)?;
|
||||
services().rooms.lazy_loading.lazy_load_confirm_delivery(sender_user, sender_device, &body.room_id, from)?;
|
||||
|
||||
let limit = u64::from(body.limit).min(100) as usize;
|
||||
let limit = u64::from(body.limit).min(100) as usize;
|
||||
|
||||
let next_token;
|
||||
let next_token;
|
||||
|
||||
let mut resp = get_message_events::v3::Response::new();
|
||||
let mut resp = get_message_events::v3::Response::new();
|
||||
|
||||
let mut lazy_loaded = HashSet::new();
|
||||
let mut lazy_loaded = HashSet::new();
|
||||
|
||||
match body.dir {
|
||||
ruma::api::Direction::Forward => {
|
||||
let events_after: Vec<_> = services()
|
||||
.rooms
|
||||
.timeline
|
||||
.pdus_after(sender_user, &body.room_id, from)?
|
||||
.take(limit)
|
||||
.filter_map(std::result::Result::ok) // Filter out buggy events
|
||||
.filter(|(_, pdu)| {
|
||||
services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.user_can_see_event(sender_user, &body.room_id, &pdu.event_id)
|
||||
.unwrap_or(false)
|
||||
})
|
||||
.take_while(|&(k, _)| Some(k) != to) // Stop at `to`
|
||||
.collect();
|
||||
match body.dir {
|
||||
ruma::api::Direction::Forward => {
|
||||
let events_after: Vec<_> = services()
|
||||
.rooms
|
||||
.timeline
|
||||
.pdus_after(sender_user, &body.room_id, from)?
|
||||
.take(limit)
|
||||
.filter_map(std::result::Result::ok) // Filter out buggy events
|
||||
.filter(|(_, pdu)| {
|
||||
services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.user_can_see_event(sender_user, &body.room_id, &pdu.event_id)
|
||||
.unwrap_or(false)
|
||||
})
|
||||
.take_while(|&(k, _)| Some(k) != to) // Stop at `to`
|
||||
.collect();
|
||||
|
||||
for (_, event) in &events_after {
|
||||
/* TODO: Remove this when these are resolved:
|
||||
* https://github.com/vector-im/element-android/issues/3417
|
||||
* https://github.com/vector-im/element-web/issues/21034
|
||||
if !services().rooms.lazy_loading.lazy_load_was_sent_before(
|
||||
sender_user,
|
||||
sender_device,
|
||||
&body.room_id,
|
||||
&event.sender,
|
||||
)? {
|
||||
lazy_loaded.insert(event.sender.clone());
|
||||
}
|
||||
*/
|
||||
lazy_loaded.insert(event.sender.clone());
|
||||
}
|
||||
for (_, event) in &events_after {
|
||||
/* TODO: Remove this when these are resolved:
|
||||
* https://github.com/vector-im/element-android/issues/3417
|
||||
* https://github.com/vector-im/element-web/issues/21034
|
||||
if !services().rooms.lazy_loading.lazy_load_was_sent_before(
|
||||
sender_user,
|
||||
sender_device,
|
||||
&body.room_id,
|
||||
&event.sender,
|
||||
)? {
|
||||
lazy_loaded.insert(event.sender.clone());
|
||||
}
|
||||
*/
|
||||
lazy_loaded.insert(event.sender.clone());
|
||||
}
|
||||
|
||||
next_token = events_after.last().map(|(count, _)| count).copied();
|
||||
next_token = events_after.last().map(|(count, _)| count).copied();
|
||||
|
||||
let events_after: Vec<_> = events_after
|
||||
.into_iter()
|
||||
.map(|(_, pdu)| pdu.to_room_event())
|
||||
.collect();
|
||||
let events_after: Vec<_> = events_after.into_iter().map(|(_, pdu)| pdu.to_room_event()).collect();
|
||||
|
||||
resp.start = from.stringify();
|
||||
resp.end = next_token.map(|count| count.stringify());
|
||||
resp.chunk = events_after;
|
||||
}
|
||||
ruma::api::Direction::Backward => {
|
||||
services()
|
||||
.rooms
|
||||
.timeline
|
||||
.backfill_if_required(&body.room_id, from)
|
||||
.await?;
|
||||
let events_before: Vec<_> = services()
|
||||
.rooms
|
||||
.timeline
|
||||
.pdus_until(sender_user, &body.room_id, from)?
|
||||
.take(limit)
|
||||
.filter_map(std::result::Result::ok) // Filter out buggy events
|
||||
.filter(|(_, pdu)| {
|
||||
services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.user_can_see_event(sender_user, &body.room_id, &pdu.event_id)
|
||||
.unwrap_or(false)
|
||||
})
|
||||
.take_while(|&(k, _)| Some(k) != to) // Stop at `to`
|
||||
.collect();
|
||||
resp.start = from.stringify();
|
||||
resp.end = next_token.map(|count| count.stringify());
|
||||
resp.chunk = events_after;
|
||||
},
|
||||
ruma::api::Direction::Backward => {
|
||||
services().rooms.timeline.backfill_if_required(&body.room_id, from).await?;
|
||||
let events_before: Vec<_> = services()
|
||||
.rooms
|
||||
.timeline
|
||||
.pdus_until(sender_user, &body.room_id, from)?
|
||||
.take(limit)
|
||||
.filter_map(std::result::Result::ok) // Filter out buggy events
|
||||
.filter(|(_, pdu)| {
|
||||
services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.user_can_see_event(sender_user, &body.room_id, &pdu.event_id)
|
||||
.unwrap_or(false)
|
||||
})
|
||||
.take_while(|&(k, _)| Some(k) != to) // Stop at `to`
|
||||
.collect();
|
||||
|
||||
for (_, event) in &events_before {
|
||||
/* TODO: Remove this when these are resolved:
|
||||
* https://github.com/vector-im/element-android/issues/3417
|
||||
* https://github.com/vector-im/element-web/issues/21034
|
||||
if !services().rooms.lazy_loading.lazy_load_was_sent_before(
|
||||
sender_user,
|
||||
sender_device,
|
||||
&body.room_id,
|
||||
&event.sender,
|
||||
)? {
|
||||
lazy_loaded.insert(event.sender.clone());
|
||||
}
|
||||
*/
|
||||
lazy_loaded.insert(event.sender.clone());
|
||||
}
|
||||
for (_, event) in &events_before {
|
||||
/* TODO: Remove this when these are resolved:
|
||||
* https://github.com/vector-im/element-android/issues/3417
|
||||
* https://github.com/vector-im/element-web/issues/21034
|
||||
if !services().rooms.lazy_loading.lazy_load_was_sent_before(
|
||||
sender_user,
|
||||
sender_device,
|
||||
&body.room_id,
|
||||
&event.sender,
|
||||
)? {
|
||||
lazy_loaded.insert(event.sender.clone());
|
||||
}
|
||||
*/
|
||||
lazy_loaded.insert(event.sender.clone());
|
||||
}
|
||||
|
||||
next_token = events_before.last().map(|(count, _)| count).copied();
|
||||
next_token = events_before.last().map(|(count, _)| count).copied();
|
||||
|
||||
let events_before: Vec<_> = events_before
|
||||
.into_iter()
|
||||
.map(|(_, pdu)| pdu.to_room_event())
|
||||
.collect();
|
||||
let events_before: Vec<_> = events_before.into_iter().map(|(_, pdu)| pdu.to_room_event()).collect();
|
||||
|
||||
resp.start = from.stringify();
|
||||
resp.end = next_token.map(|count| count.stringify());
|
||||
resp.chunk = events_before;
|
||||
}
|
||||
}
|
||||
resp.start = from.stringify();
|
||||
resp.end = next_token.map(|count| count.stringify());
|
||||
resp.chunk = events_before;
|
||||
},
|
||||
}
|
||||
|
||||
resp.state = Vec::new();
|
||||
for ll_id in &lazy_loaded {
|
||||
if let Some(member_event) = services().rooms.state_accessor.room_state_get(
|
||||
&body.room_id,
|
||||
&StateEventType::RoomMember,
|
||||
ll_id.as_str(),
|
||||
)? {
|
||||
resp.state.push(member_event.to_state_event());
|
||||
}
|
||||
}
|
||||
resp.state = Vec::new();
|
||||
for ll_id in &lazy_loaded {
|
||||
if let Some(member_event) = services().rooms.state_accessor.room_state_get(
|
||||
&body.room_id,
|
||||
&StateEventType::RoomMember,
|
||||
ll_id.as_str(),
|
||||
)? {
|
||||
resp.state.push(member_event.to_state_event());
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: enable again when we are sure clients can handle it
|
||||
/*
|
||||
if let Some(next_token) = next_token {
|
||||
services().rooms.lazy_loading.lazy_load_mark_sent(
|
||||
sender_user,
|
||||
sender_device,
|
||||
&body.room_id,
|
||||
lazy_loaded,
|
||||
next_token,
|
||||
);
|
||||
}
|
||||
*/
|
||||
// TODO: enable again when we are sure clients can handle it
|
||||
/*
|
||||
if let Some(next_token) = next_token {
|
||||
services().rooms.lazy_loading.lazy_load_mark_sent(
|
||||
sender_user,
|
||||
sender_device,
|
||||
&body.room_id,
|
||||
lazy_loaded,
|
||||
next_token,
|
||||
);
|
||||
}
|
||||
*/
|
||||
|
||||
Ok(resp)
|
||||
Ok(resp)
|
||||
}
|
||||
|
||||
@@ -1,38 +1,35 @@
|
||||
use crate::{services, Error, Result, Ruma};
|
||||
use ruma::api::client::{
|
||||
error::ErrorKind,
|
||||
presence::{get_presence, set_presence},
|
||||
};
|
||||
use std::time::Duration;
|
||||
|
||||
use ruma::api::client::{
|
||||
error::ErrorKind,
|
||||
presence::{get_presence, set_presence},
|
||||
};
|
||||
|
||||
use crate::{services, Error, Result, Ruma};
|
||||
|
||||
/// # `PUT /_matrix/client/r0/presence/{userId}/status`
|
||||
///
|
||||
/// Sets the presence state of the sender user.
|
||||
pub async fn set_presence_route(
|
||||
body: Ruma<set_presence::v3::Request>,
|
||||
) -> Result<set_presence::v3::Response> {
|
||||
if !services().globals.allow_local_presence() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"Presence is disabled on this server",
|
||||
));
|
||||
}
|
||||
pub async fn set_presence_route(body: Ruma<set_presence::v3::Request>) -> Result<set_presence::v3::Response> {
|
||||
if !services().globals.allow_local_presence() {
|
||||
return Err(Error::BadRequest(ErrorKind::Forbidden, "Presence is disabled on this server"));
|
||||
}
|
||||
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
for room_id in services().rooms.state_cache.rooms_joined(sender_user) {
|
||||
let room_id = room_id?;
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
for room_id in services().rooms.state_cache.rooms_joined(sender_user) {
|
||||
let room_id = room_id?;
|
||||
|
||||
services().rooms.edus.presence.set_presence(
|
||||
&room_id,
|
||||
sender_user,
|
||||
body.presence.clone(),
|
||||
None,
|
||||
None,
|
||||
body.status_msg.clone(),
|
||||
)?;
|
||||
}
|
||||
services().rooms.edus.presence.set_presence(
|
||||
&room_id,
|
||||
sender_user,
|
||||
body.presence.clone(),
|
||||
None,
|
||||
None,
|
||||
body.status_msg.clone(),
|
||||
)?;
|
||||
}
|
||||
|
||||
Ok(set_presence::v3::Response {})
|
||||
Ok(set_presence::v3::Response {})
|
||||
}
|
||||
|
||||
/// # `GET /_matrix/client/r0/presence/{userId}/status`
|
||||
@@ -40,53 +37,36 @@ pub async fn set_presence_route(
|
||||
/// Gets the presence state of the given user.
|
||||
///
|
||||
/// - Only works if you share a room with the user
|
||||
pub async fn get_presence_route(
|
||||
body: Ruma<get_presence::v3::Request>,
|
||||
) -> Result<get_presence::v3::Response> {
|
||||
if !services().globals.allow_local_presence() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"Presence is disabled on this server",
|
||||
));
|
||||
}
|
||||
pub async fn get_presence_route(body: Ruma<get_presence::v3::Request>) -> Result<get_presence::v3::Response> {
|
||||
if !services().globals.allow_local_presence() {
|
||||
return Err(Error::BadRequest(ErrorKind::Forbidden, "Presence is disabled on this server"));
|
||||
}
|
||||
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let mut presence_event = None;
|
||||
let mut presence_event = None;
|
||||
|
||||
for room_id in services()
|
||||
.rooms
|
||||
.user
|
||||
.get_shared_rooms(vec![sender_user.clone(), body.user_id.clone()])?
|
||||
{
|
||||
let room_id = room_id?;
|
||||
for room_id in services().rooms.user.get_shared_rooms(vec![sender_user.clone(), body.user_id.clone()])? {
|
||||
let room_id = room_id?;
|
||||
|
||||
if let Some(presence) = services()
|
||||
.rooms
|
||||
.edus
|
||||
.presence
|
||||
.get_presence(&room_id, sender_user)?
|
||||
{
|
||||
presence_event = Some(presence);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if let Some(presence) = services().rooms.edus.presence.get_presence(&room_id, sender_user)? {
|
||||
presence_event = Some(presence);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(presence) = presence_event {
|
||||
Ok(get_presence::v3::Response {
|
||||
// TODO: Should ruma just use the presenceeventcontent type here?
|
||||
status_msg: presence.content.status_msg,
|
||||
currently_active: presence.content.currently_active,
|
||||
last_active_ago: presence
|
||||
.content
|
||||
.last_active_ago
|
||||
.map(|millis| Duration::from_millis(millis.into())),
|
||||
presence: presence.content.presence,
|
||||
})
|
||||
} else {
|
||||
Err(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"Presence state for this user was not found",
|
||||
))
|
||||
}
|
||||
if let Some(presence) = presence_event {
|
||||
Ok(get_presence::v3::Response {
|
||||
// TODO: Should ruma just use the presenceeventcontent type here?
|
||||
status_msg: presence.content.status_msg,
|
||||
currently_active: presence.content.currently_active,
|
||||
last_active_ago: presence.content.last_active_ago.map(|millis| Duration::from_millis(millis.into())),
|
||||
presence: presence.content.presence,
|
||||
})
|
||||
} else {
|
||||
Err(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"Presence state for this user was not found",
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
+213
-307
@@ -1,17 +1,15 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use ruma::{
|
||||
api::{
|
||||
client::{
|
||||
error::ErrorKind,
|
||||
profile::{
|
||||
get_avatar_url, get_display_name, get_profile, set_avatar_url, set_display_name,
|
||||
},
|
||||
},
|
||||
federation,
|
||||
},
|
||||
events::{room::member::RoomMemberEventContent, StateEventType, TimelineEventType},
|
||||
presence::PresenceState,
|
||||
api::{
|
||||
client::{
|
||||
error::ErrorKind,
|
||||
profile::{get_avatar_url, get_display_name, get_profile, set_avatar_url, set_display_name},
|
||||
},
|
||||
federation,
|
||||
},
|
||||
events::{room::member::RoomMemberEventContent, StateEventType, TimelineEventType},
|
||||
presence::PresenceState,
|
||||
};
|
||||
use serde_json::value::to_raw_value;
|
||||
|
||||
@@ -23,87 +21,62 @@ use crate::{service::pdu::PduBuilder, services, Error, Result, Ruma};
|
||||
///
|
||||
/// - Also makes sure other users receive the update using presence EDUs
|
||||
pub async fn set_displayname_route(
|
||||
body: Ruma<set_display_name::v3::Request>,
|
||||
body: Ruma<set_display_name::v3::Request>,
|
||||
) -> Result<set_display_name::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
services()
|
||||
.users
|
||||
.set_displayname(sender_user, body.displayname.clone())
|
||||
.await?;
|
||||
services().users.set_displayname(sender_user, body.displayname.clone()).await?;
|
||||
|
||||
// Send a new membership event and presence update into all joined rooms
|
||||
let all_rooms_joined: Vec<_> = services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.rooms_joined(sender_user)
|
||||
.filter_map(std::result::Result::ok)
|
||||
.map(|room_id| {
|
||||
Ok::<_, Error>((
|
||||
PduBuilder {
|
||||
event_type: TimelineEventType::RoomMember,
|
||||
content: to_raw_value(&RoomMemberEventContent {
|
||||
displayname: body.displayname.clone(),
|
||||
..serde_json::from_str(
|
||||
services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.room_state_get(
|
||||
&room_id,
|
||||
&StateEventType::RoomMember,
|
||||
sender_user.as_str(),
|
||||
)?
|
||||
.ok_or_else(|| {
|
||||
Error::bad_database(
|
||||
"Tried to send displayname update for user not in the \
|
||||
room.",
|
||||
)
|
||||
})?
|
||||
.content
|
||||
.get(),
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Database contains invalid PDU."))?
|
||||
})
|
||||
.expect("event is valid, we just created it"),
|
||||
unsigned: None,
|
||||
state_key: Some(sender_user.to_string()),
|
||||
redacts: None,
|
||||
},
|
||||
room_id,
|
||||
))
|
||||
})
|
||||
.filter_map(std::result::Result::ok)
|
||||
.collect();
|
||||
// Send a new membership event and presence update into all joined rooms
|
||||
let all_rooms_joined: Vec<_> = services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.rooms_joined(sender_user)
|
||||
.filter_map(std::result::Result::ok)
|
||||
.map(|room_id| {
|
||||
Ok::<_, Error>((
|
||||
PduBuilder {
|
||||
event_type: TimelineEventType::RoomMember,
|
||||
content: to_raw_value(&RoomMemberEventContent {
|
||||
displayname: body.displayname.clone(),
|
||||
..serde_json::from_str(
|
||||
services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.room_state_get(&room_id, &StateEventType::RoomMember, sender_user.as_str())?
|
||||
.ok_or_else(|| {
|
||||
Error::bad_database("Tried to send displayname update for user not in the room.")
|
||||
})?
|
||||
.content
|
||||
.get(),
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Database contains invalid PDU."))?
|
||||
})
|
||||
.expect("event is valid, we just created it"),
|
||||
unsigned: None,
|
||||
state_key: Some(sender_user.to_string()),
|
||||
redacts: None,
|
||||
},
|
||||
room_id,
|
||||
))
|
||||
})
|
||||
.filter_map(std::result::Result::ok)
|
||||
.collect();
|
||||
|
||||
for (pdu_builder, room_id) in all_rooms_joined {
|
||||
let mutex_state = Arc::clone(
|
||||
services()
|
||||
.globals
|
||||
.roomid_mutex_state
|
||||
.write()
|
||||
.unwrap()
|
||||
.entry(room_id.clone())
|
||||
.or_default(),
|
||||
);
|
||||
let state_lock = mutex_state.lock().await;
|
||||
for (pdu_builder, room_id) in all_rooms_joined {
|
||||
let mutex_state =
|
||||
Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(room_id.clone()).or_default());
|
||||
let state_lock = mutex_state.lock().await;
|
||||
|
||||
let _ = services()
|
||||
.rooms
|
||||
.timeline
|
||||
.build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock)
|
||||
.await;
|
||||
}
|
||||
let _ = services().rooms.timeline.build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock).await;
|
||||
}
|
||||
|
||||
if services().globals.allow_local_presence() {
|
||||
// Presence update
|
||||
services()
|
||||
.rooms
|
||||
.edus
|
||||
.presence
|
||||
.ping_presence(sender_user, PresenceState::Online)?;
|
||||
}
|
||||
if services().globals.allow_local_presence() {
|
||||
// Presence update
|
||||
services().rooms.edus.presence.ping_presence(sender_user, PresenceState::Online)?;
|
||||
}
|
||||
|
||||
Ok(set_display_name::v3::Response {})
|
||||
Ok(set_display_name::v3::Response {})
|
||||
}
|
||||
|
||||
/// # `GET /_matrix/client/v3/profile/{userId}/displayname`
|
||||
@@ -113,55 +86,44 @@ pub async fn set_displayname_route(
|
||||
/// - If user is on another server and we do not have a local copy already
|
||||
/// fetch displayname over federation
|
||||
pub async fn get_displayname_route(
|
||||
body: Ruma<get_display_name::v3::Request>,
|
||||
body: Ruma<get_display_name::v3::Request>,
|
||||
) -> Result<get_display_name::v3::Response> {
|
||||
if body.user_id.server_name() != services().globals.server_name() {
|
||||
// Create and update our local copy of the user
|
||||
if let Ok(response) = services()
|
||||
.sending
|
||||
.send_federation_request(
|
||||
body.user_id.server_name(),
|
||||
federation::query::get_profile_information::v1::Request {
|
||||
user_id: body.user_id.clone(),
|
||||
field: None, // we want the full user's profile to update locally too
|
||||
},
|
||||
)
|
||||
.await
|
||||
{
|
||||
if !services().users.exists(&body.user_id)? {
|
||||
services().users.create(&body.user_id, None)?;
|
||||
}
|
||||
if body.user_id.server_name() != services().globals.server_name() {
|
||||
// Create and update our local copy of the user
|
||||
if let Ok(response) = services()
|
||||
.sending
|
||||
.send_federation_request(
|
||||
body.user_id.server_name(),
|
||||
federation::query::get_profile_information::v1::Request {
|
||||
user_id: body.user_id.clone(),
|
||||
field: None, // we want the full user's profile to update locally too
|
||||
},
|
||||
)
|
||||
.await
|
||||
{
|
||||
if !services().users.exists(&body.user_id)? {
|
||||
services().users.create(&body.user_id, None)?;
|
||||
}
|
||||
|
||||
services()
|
||||
.users
|
||||
.set_displayname(&body.user_id, response.displayname.clone())
|
||||
.await?;
|
||||
services()
|
||||
.users
|
||||
.set_avatar_url(&body.user_id, response.avatar_url.clone())
|
||||
.await?;
|
||||
services()
|
||||
.users
|
||||
.set_blurhash(&body.user_id, response.blurhash.clone())
|
||||
.await?;
|
||||
services().users.set_displayname(&body.user_id, response.displayname.clone()).await?;
|
||||
services().users.set_avatar_url(&body.user_id, response.avatar_url.clone()).await?;
|
||||
services().users.set_blurhash(&body.user_id, response.blurhash.clone()).await?;
|
||||
|
||||
return Ok(get_display_name::v3::Response {
|
||||
displayname: response.displayname,
|
||||
});
|
||||
}
|
||||
}
|
||||
return Ok(get_display_name::v3::Response {
|
||||
displayname: response.displayname,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if !services().users.exists(&body.user_id)? {
|
||||
// Return 404 if this user doesn't exist and we couldn't fetch it over federation
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"Profile was not found.",
|
||||
));
|
||||
}
|
||||
if !services().users.exists(&body.user_id)? {
|
||||
// Return 404 if this user doesn't exist and we couldn't fetch it over
|
||||
// federation
|
||||
return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found."));
|
||||
}
|
||||
|
||||
Ok(get_display_name::v3::Response {
|
||||
displayname: services().users.displayname(&body.user_id)?,
|
||||
})
|
||||
Ok(get_display_name::v3::Response {
|
||||
displayname: services().users.displayname(&body.user_id)?,
|
||||
})
|
||||
}
|
||||
|
||||
/// # `PUT /_matrix/client/r0/profile/{userId}/avatar_url`
|
||||
@@ -169,93 +131,63 @@ pub async fn get_displayname_route(
|
||||
/// Updates the avatar_url and blurhash.
|
||||
///
|
||||
/// - Also makes sure other users receive the update using presence EDUs
|
||||
pub async fn set_avatar_url_route(
|
||||
body: Ruma<set_avatar_url::v3::Request>,
|
||||
) -> Result<set_avatar_url::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
pub async fn set_avatar_url_route(body: Ruma<set_avatar_url::v3::Request>) -> Result<set_avatar_url::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
services()
|
||||
.users
|
||||
.set_avatar_url(sender_user, body.avatar_url.clone())
|
||||
.await?;
|
||||
services().users.set_avatar_url(sender_user, body.avatar_url.clone()).await?;
|
||||
|
||||
services()
|
||||
.users
|
||||
.set_blurhash(sender_user, body.blurhash.clone())
|
||||
.await?;
|
||||
services().users.set_blurhash(sender_user, body.blurhash.clone()).await?;
|
||||
|
||||
// Send a new membership event and presence update into all joined rooms
|
||||
let all_joined_rooms: Vec<_> = services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.rooms_joined(sender_user)
|
||||
.filter_map(std::result::Result::ok)
|
||||
.map(|room_id| {
|
||||
Ok::<_, Error>((
|
||||
PduBuilder {
|
||||
event_type: TimelineEventType::RoomMember,
|
||||
content: to_raw_value(&RoomMemberEventContent {
|
||||
avatar_url: body.avatar_url.clone(),
|
||||
..serde_json::from_str(
|
||||
services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.room_state_get(
|
||||
&room_id,
|
||||
&StateEventType::RoomMember,
|
||||
sender_user.as_str(),
|
||||
)?
|
||||
.ok_or_else(|| {
|
||||
Error::bad_database(
|
||||
"Tried to send displayname update for user not in the \
|
||||
room.",
|
||||
)
|
||||
})?
|
||||
.content
|
||||
.get(),
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Database contains invalid PDU."))?
|
||||
})
|
||||
.expect("event is valid, we just created it"),
|
||||
unsigned: None,
|
||||
state_key: Some(sender_user.to_string()),
|
||||
redacts: None,
|
||||
},
|
||||
room_id,
|
||||
))
|
||||
})
|
||||
.filter_map(std::result::Result::ok)
|
||||
.collect();
|
||||
// Send a new membership event and presence update into all joined rooms
|
||||
let all_joined_rooms: Vec<_> = services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.rooms_joined(sender_user)
|
||||
.filter_map(std::result::Result::ok)
|
||||
.map(|room_id| {
|
||||
Ok::<_, Error>((
|
||||
PduBuilder {
|
||||
event_type: TimelineEventType::RoomMember,
|
||||
content: to_raw_value(&RoomMemberEventContent {
|
||||
avatar_url: body.avatar_url.clone(),
|
||||
..serde_json::from_str(
|
||||
services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.room_state_get(&room_id, &StateEventType::RoomMember, sender_user.as_str())?
|
||||
.ok_or_else(|| {
|
||||
Error::bad_database("Tried to send displayname update for user not in the room.")
|
||||
})?
|
||||
.content
|
||||
.get(),
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Database contains invalid PDU."))?
|
||||
})
|
||||
.expect("event is valid, we just created it"),
|
||||
unsigned: None,
|
||||
state_key: Some(sender_user.to_string()),
|
||||
redacts: None,
|
||||
},
|
||||
room_id,
|
||||
))
|
||||
})
|
||||
.filter_map(std::result::Result::ok)
|
||||
.collect();
|
||||
|
||||
for (pdu_builder, room_id) in all_joined_rooms {
|
||||
let mutex_state = Arc::clone(
|
||||
services()
|
||||
.globals
|
||||
.roomid_mutex_state
|
||||
.write()
|
||||
.unwrap()
|
||||
.entry(room_id.clone())
|
||||
.or_default(),
|
||||
);
|
||||
let state_lock = mutex_state.lock().await;
|
||||
for (pdu_builder, room_id) in all_joined_rooms {
|
||||
let mutex_state =
|
||||
Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(room_id.clone()).or_default());
|
||||
let state_lock = mutex_state.lock().await;
|
||||
|
||||
let _ = services()
|
||||
.rooms
|
||||
.timeline
|
||||
.build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock)
|
||||
.await;
|
||||
}
|
||||
let _ = services().rooms.timeline.build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock).await;
|
||||
}
|
||||
|
||||
if services().globals.allow_local_presence() {
|
||||
// Presence update
|
||||
services()
|
||||
.rooms
|
||||
.edus
|
||||
.presence
|
||||
.ping_presence(sender_user, PresenceState::Online)?;
|
||||
}
|
||||
if services().globals.allow_local_presence() {
|
||||
// Presence update
|
||||
services().rooms.edus.presence.ping_presence(sender_user, PresenceState::Online)?;
|
||||
}
|
||||
|
||||
Ok(set_avatar_url::v3::Response {})
|
||||
Ok(set_avatar_url::v3::Response {})
|
||||
}
|
||||
|
||||
/// # `GET /_matrix/client/v3/profile/{userId}/avatar_url`
|
||||
@@ -264,58 +196,45 @@ pub async fn set_avatar_url_route(
|
||||
///
|
||||
/// - If user is on another server and we do not have a local copy already
|
||||
/// fetch avatar_url and blurhash over federation
|
||||
pub async fn get_avatar_url_route(
|
||||
body: Ruma<get_avatar_url::v3::Request>,
|
||||
) -> Result<get_avatar_url::v3::Response> {
|
||||
if body.user_id.server_name() != services().globals.server_name() {
|
||||
// Create and update our local copy of the user
|
||||
if let Ok(response) = services()
|
||||
.sending
|
||||
.send_federation_request(
|
||||
body.user_id.server_name(),
|
||||
federation::query::get_profile_information::v1::Request {
|
||||
user_id: body.user_id.clone(),
|
||||
field: None, // we want the full user's profile to update locally as well
|
||||
},
|
||||
)
|
||||
.await
|
||||
{
|
||||
if !services().users.exists(&body.user_id)? {
|
||||
services().users.create(&body.user_id, None)?;
|
||||
}
|
||||
pub async fn get_avatar_url_route(body: Ruma<get_avatar_url::v3::Request>) -> Result<get_avatar_url::v3::Response> {
|
||||
if body.user_id.server_name() != services().globals.server_name() {
|
||||
// Create and update our local copy of the user
|
||||
if let Ok(response) = services()
|
||||
.sending
|
||||
.send_federation_request(
|
||||
body.user_id.server_name(),
|
||||
federation::query::get_profile_information::v1::Request {
|
||||
user_id: body.user_id.clone(),
|
||||
field: None, // we want the full user's profile to update locally as well
|
||||
},
|
||||
)
|
||||
.await
|
||||
{
|
||||
if !services().users.exists(&body.user_id)? {
|
||||
services().users.create(&body.user_id, None)?;
|
||||
}
|
||||
|
||||
services()
|
||||
.users
|
||||
.set_displayname(&body.user_id, response.displayname.clone())
|
||||
.await?;
|
||||
services()
|
||||
.users
|
||||
.set_avatar_url(&body.user_id, response.avatar_url.clone())
|
||||
.await?;
|
||||
services()
|
||||
.users
|
||||
.set_blurhash(&body.user_id, response.blurhash.clone())
|
||||
.await?;
|
||||
services().users.set_displayname(&body.user_id, response.displayname.clone()).await?;
|
||||
services().users.set_avatar_url(&body.user_id, response.avatar_url.clone()).await?;
|
||||
services().users.set_blurhash(&body.user_id, response.blurhash.clone()).await?;
|
||||
|
||||
return Ok(get_avatar_url::v3::Response {
|
||||
avatar_url: response.avatar_url,
|
||||
blurhash: response.blurhash,
|
||||
});
|
||||
}
|
||||
}
|
||||
return Ok(get_avatar_url::v3::Response {
|
||||
avatar_url: response.avatar_url,
|
||||
blurhash: response.blurhash,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if !services().users.exists(&body.user_id)? {
|
||||
// Return 404 if this user doesn't exist and we couldn't fetch it over federation
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"Profile was not found.",
|
||||
));
|
||||
}
|
||||
if !services().users.exists(&body.user_id)? {
|
||||
// Return 404 if this user doesn't exist and we couldn't fetch it over
|
||||
// federation
|
||||
return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found."));
|
||||
}
|
||||
|
||||
Ok(get_avatar_url::v3::Response {
|
||||
avatar_url: services().users.avatar_url(&body.user_id)?,
|
||||
blurhash: services().users.blurhash(&body.user_id)?,
|
||||
})
|
||||
Ok(get_avatar_url::v3::Response {
|
||||
avatar_url: services().users.avatar_url(&body.user_id)?,
|
||||
blurhash: services().users.blurhash(&body.user_id)?,
|
||||
})
|
||||
}
|
||||
|
||||
/// # `GET /_matrix/client/v3/profile/{userId}`
|
||||
@@ -324,58 +243,45 @@ pub async fn get_avatar_url_route(
|
||||
///
|
||||
/// - If user is on another server and we do not have a local copy already,
|
||||
/// fetch profile over federation.
|
||||
pub async fn get_profile_route(
|
||||
body: Ruma<get_profile::v3::Request>,
|
||||
) -> Result<get_profile::v3::Response> {
|
||||
if body.user_id.server_name() != services().globals.server_name() {
|
||||
// Create and update our local copy of the user
|
||||
if let Ok(response) = services()
|
||||
.sending
|
||||
.send_federation_request(
|
||||
body.user_id.server_name(),
|
||||
federation::query::get_profile_information::v1::Request {
|
||||
user_id: body.user_id.clone(),
|
||||
field: None,
|
||||
},
|
||||
)
|
||||
.await
|
||||
{
|
||||
if !services().users.exists(&body.user_id)? {
|
||||
services().users.create(&body.user_id, None)?;
|
||||
}
|
||||
pub async fn get_profile_route(body: Ruma<get_profile::v3::Request>) -> Result<get_profile::v3::Response> {
|
||||
if body.user_id.server_name() != services().globals.server_name() {
|
||||
// Create and update our local copy of the user
|
||||
if let Ok(response) = services()
|
||||
.sending
|
||||
.send_federation_request(
|
||||
body.user_id.server_name(),
|
||||
federation::query::get_profile_information::v1::Request {
|
||||
user_id: body.user_id.clone(),
|
||||
field: None,
|
||||
},
|
||||
)
|
||||
.await
|
||||
{
|
||||
if !services().users.exists(&body.user_id)? {
|
||||
services().users.create(&body.user_id, None)?;
|
||||
}
|
||||
|
||||
services()
|
||||
.users
|
||||
.set_displayname(&body.user_id, response.displayname.clone())
|
||||
.await?;
|
||||
services()
|
||||
.users
|
||||
.set_avatar_url(&body.user_id, response.avatar_url.clone())
|
||||
.await?;
|
||||
services()
|
||||
.users
|
||||
.set_blurhash(&body.user_id, response.blurhash.clone())
|
||||
.await?;
|
||||
services().users.set_displayname(&body.user_id, response.displayname.clone()).await?;
|
||||
services().users.set_avatar_url(&body.user_id, response.avatar_url.clone()).await?;
|
||||
services().users.set_blurhash(&body.user_id, response.blurhash.clone()).await?;
|
||||
|
||||
return Ok(get_profile::v3::Response {
|
||||
displayname: response.displayname,
|
||||
avatar_url: response.avatar_url,
|
||||
blurhash: response.blurhash,
|
||||
});
|
||||
}
|
||||
}
|
||||
return Ok(get_profile::v3::Response {
|
||||
displayname: response.displayname,
|
||||
avatar_url: response.avatar_url,
|
||||
blurhash: response.blurhash,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if !services().users.exists(&body.user_id)? {
|
||||
// Return 404 if this user doesn't exist and we couldn't fetch it over federation
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"Profile was not found.",
|
||||
));
|
||||
}
|
||||
if !services().users.exists(&body.user_id)? {
|
||||
// Return 404 if this user doesn't exist and we couldn't fetch it over
|
||||
// federation
|
||||
return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found."));
|
||||
}
|
||||
|
||||
Ok(get_profile::v3::Response {
|
||||
avatar_url: services().users.avatar_url(&body.user_id)?,
|
||||
blurhash: services().users.blurhash(&body.user_id)?,
|
||||
displayname: services().users.displayname(&body.user_id)?,
|
||||
})
|
||||
Ok(get_profile::v3::Response {
|
||||
avatar_url: services().users.avatar_url(&body.user_id)?,
|
||||
blurhash: services().users.blurhash(&body.user_id)?,
|
||||
displayname: services().users.displayname(&body.user_id)?,
|
||||
})
|
||||
}
|
||||
|
||||
+221
-322
@@ -1,417 +1,320 @@
|
||||
use crate::{services, Error, Result, Ruma};
|
||||
use ruma::{
|
||||
api::client::{
|
||||
error::ErrorKind,
|
||||
push::{
|
||||
delete_pushrule, get_pushers, get_pushrule, get_pushrule_actions, get_pushrule_enabled,
|
||||
get_pushrules_all, set_pusher, set_pushrule, set_pushrule_actions,
|
||||
set_pushrule_enabled, RuleScope,
|
||||
},
|
||||
},
|
||||
events::{push_rules::PushRulesEvent, GlobalAccountDataEventType},
|
||||
push::{InsertPushRuleError, RemovePushRuleError},
|
||||
api::client::{
|
||||
error::ErrorKind,
|
||||
push::{
|
||||
delete_pushrule, get_pushers, get_pushrule, get_pushrule_actions, get_pushrule_enabled, get_pushrules_all,
|
||||
set_pusher, set_pushrule, set_pushrule_actions, set_pushrule_enabled, RuleScope,
|
||||
},
|
||||
},
|
||||
events::{push_rules::PushRulesEvent, GlobalAccountDataEventType},
|
||||
push::{InsertPushRuleError, RemovePushRuleError},
|
||||
};
|
||||
|
||||
use crate::{services, Error, Result, Ruma};
|
||||
|
||||
/// # `GET /_matrix/client/r0/pushrules`
|
||||
///
|
||||
/// Retrieves the push rules event for this user.
|
||||
pub async fn get_pushrules_all_route(
|
||||
body: Ruma<get_pushrules_all::v3::Request>,
|
||||
body: Ruma<get_pushrules_all::v3::Request>,
|
||||
) -> Result<get_pushrules_all::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let event = services()
|
||||
.account_data
|
||||
.get(
|
||||
None,
|
||||
sender_user,
|
||||
GlobalAccountDataEventType::PushRules.to_string().into(),
|
||||
)?
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"PushRules event not found.",
|
||||
))?;
|
||||
let event = services()
|
||||
.account_data
|
||||
.get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())?
|
||||
.ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
|
||||
|
||||
let account_data = serde_json::from_str::<PushRulesEvent>(event.get())
|
||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))?
|
||||
.content;
|
||||
let account_data = serde_json::from_str::<PushRulesEvent>(event.get())
|
||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))?
|
||||
.content;
|
||||
|
||||
Ok(get_pushrules_all::v3::Response {
|
||||
global: account_data.global,
|
||||
})
|
||||
Ok(get_pushrules_all::v3::Response {
|
||||
global: account_data.global,
|
||||
})
|
||||
}
|
||||
|
||||
/// # `GET /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}`
|
||||
///
|
||||
/// Retrieves a single specified push rule for this user.
|
||||
pub async fn get_pushrule_route(
|
||||
body: Ruma<get_pushrule::v3::Request>,
|
||||
) -> Result<get_pushrule::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
pub async fn get_pushrule_route(body: Ruma<get_pushrule::v3::Request>) -> Result<get_pushrule::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let event = services()
|
||||
.account_data
|
||||
.get(
|
||||
None,
|
||||
sender_user,
|
||||
GlobalAccountDataEventType::PushRules.to_string().into(),
|
||||
)?
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"PushRules event not found.",
|
||||
))?;
|
||||
let event = services()
|
||||
.account_data
|
||||
.get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())?
|
||||
.ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
|
||||
|
||||
let account_data = serde_json::from_str::<PushRulesEvent>(event.get())
|
||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))?
|
||||
.content;
|
||||
let account_data = serde_json::from_str::<PushRulesEvent>(event.get())
|
||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))?
|
||||
.content;
|
||||
|
||||
let rule = account_data
|
||||
.global
|
||||
.get(body.kind.clone(), &body.rule_id)
|
||||
.map(Into::into);
|
||||
let rule = account_data.global.get(body.kind.clone(), &body.rule_id).map(Into::into);
|
||||
|
||||
if let Some(rule) = rule {
|
||||
Ok(get_pushrule::v3::Response { rule })
|
||||
} else {
|
||||
Err(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"Push rule not found.",
|
||||
))
|
||||
}
|
||||
if let Some(rule) = rule {
|
||||
Ok(get_pushrule::v3::Response {
|
||||
rule,
|
||||
})
|
||||
} else {
|
||||
Err(Error::BadRequest(ErrorKind::NotFound, "Push rule not found."))
|
||||
}
|
||||
}
|
||||
|
||||
/// # `PUT /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}`
|
||||
///
|
||||
/// Creates a single specified push rule for this user.
|
||||
pub async fn set_pushrule_route(
|
||||
body: Ruma<set_pushrule::v3::Request>,
|
||||
) -> Result<set_pushrule::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let body = body.body;
|
||||
pub async fn set_pushrule_route(body: Ruma<set_pushrule::v3::Request>) -> Result<set_pushrule::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let body = body.body;
|
||||
|
||||
if body.scope != RuleScope::Global {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Scopes other than 'global' are not supported.",
|
||||
));
|
||||
}
|
||||
if body.scope != RuleScope::Global {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Scopes other than 'global' are not supported.",
|
||||
));
|
||||
}
|
||||
|
||||
let event = services()
|
||||
.account_data
|
||||
.get(
|
||||
None,
|
||||
sender_user,
|
||||
GlobalAccountDataEventType::PushRules.to_string().into(),
|
||||
)?
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"PushRules event not found.",
|
||||
))?;
|
||||
let event = services()
|
||||
.account_data
|
||||
.get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())?
|
||||
.ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
|
||||
|
||||
let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get())
|
||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))?;
|
||||
let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get())
|
||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))?;
|
||||
|
||||
if let Err(error) = account_data.content.global.insert(
|
||||
body.rule.clone(),
|
||||
body.after.as_deref(),
|
||||
body.before.as_deref(),
|
||||
) {
|
||||
let err = match error {
|
||||
InsertPushRuleError::ServerDefaultRuleId => Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Rule IDs starting with a dot are reserved for server-default rules.",
|
||||
),
|
||||
InsertPushRuleError::InvalidRuleId => Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Rule ID containing invalid characters.",
|
||||
),
|
||||
InsertPushRuleError::RelativeToServerDefaultRule => Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Can't place a push rule relatively to a server-default rule.",
|
||||
),
|
||||
InsertPushRuleError::UnknownRuleId => Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"The before or after rule could not be found.",
|
||||
),
|
||||
InsertPushRuleError::BeforeHigherThanAfter => Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"The before rule has a higher priority than the after rule.",
|
||||
),
|
||||
_ => Error::BadRequest(ErrorKind::InvalidParam, "Invalid data."),
|
||||
};
|
||||
if let Err(error) =
|
||||
account_data.content.global.insert(body.rule.clone(), body.after.as_deref(), body.before.as_deref())
|
||||
{
|
||||
let err = match error {
|
||||
InsertPushRuleError::ServerDefaultRuleId => Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Rule IDs starting with a dot are reserved for server-default rules.",
|
||||
),
|
||||
InsertPushRuleError::InvalidRuleId => {
|
||||
Error::BadRequest(ErrorKind::InvalidParam, "Rule ID containing invalid characters.")
|
||||
},
|
||||
InsertPushRuleError::RelativeToServerDefaultRule => Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Can't place a push rule relatively to a server-default rule.",
|
||||
),
|
||||
InsertPushRuleError::UnknownRuleId => {
|
||||
Error::BadRequest(ErrorKind::NotFound, "The before or after rule could not be found.")
|
||||
},
|
||||
InsertPushRuleError::BeforeHigherThanAfter => Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"The before rule has a higher priority than the after rule.",
|
||||
),
|
||||
_ => Error::BadRequest(ErrorKind::InvalidParam, "Invalid data."),
|
||||
};
|
||||
|
||||
return Err(err);
|
||||
}
|
||||
return Err(err);
|
||||
}
|
||||
|
||||
services().account_data.update(
|
||||
None,
|
||||
sender_user,
|
||||
GlobalAccountDataEventType::PushRules.to_string().into(),
|
||||
&serde_json::to_value(account_data).expect("to json value always works"),
|
||||
)?;
|
||||
services().account_data.update(
|
||||
None,
|
||||
sender_user,
|
||||
GlobalAccountDataEventType::PushRules.to_string().into(),
|
||||
&serde_json::to_value(account_data).expect("to json value always works"),
|
||||
)?;
|
||||
|
||||
Ok(set_pushrule::v3::Response {})
|
||||
Ok(set_pushrule::v3::Response {})
|
||||
}
|
||||
|
||||
/// # `GET /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}/actions`
|
||||
///
|
||||
/// Gets the actions of a single specified push rule for this user.
|
||||
pub async fn get_pushrule_actions_route(
|
||||
body: Ruma<get_pushrule_actions::v3::Request>,
|
||||
body: Ruma<get_pushrule_actions::v3::Request>,
|
||||
) -> Result<get_pushrule_actions::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
if body.scope != RuleScope::Global {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Scopes other than 'global' are not supported.",
|
||||
));
|
||||
}
|
||||
if body.scope != RuleScope::Global {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Scopes other than 'global' are not supported.",
|
||||
));
|
||||
}
|
||||
|
||||
let event = services()
|
||||
.account_data
|
||||
.get(
|
||||
None,
|
||||
sender_user,
|
||||
GlobalAccountDataEventType::PushRules.to_string().into(),
|
||||
)?
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"PushRules event not found.",
|
||||
))?;
|
||||
let event = services()
|
||||
.account_data
|
||||
.get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())?
|
||||
.ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
|
||||
|
||||
let account_data = serde_json::from_str::<PushRulesEvent>(event.get())
|
||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))?
|
||||
.content;
|
||||
let account_data = serde_json::from_str::<PushRulesEvent>(event.get())
|
||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))?
|
||||
.content;
|
||||
|
||||
let global = account_data.global;
|
||||
let actions = global
|
||||
.get(body.kind.clone(), &body.rule_id)
|
||||
.map(|rule| rule.actions().to_owned())
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"Push rule not found.",
|
||||
))?;
|
||||
let global = account_data.global;
|
||||
let actions = global
|
||||
.get(body.kind.clone(), &body.rule_id)
|
||||
.map(|rule| rule.actions().to_owned())
|
||||
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Push rule not found."))?;
|
||||
|
||||
Ok(get_pushrule_actions::v3::Response { actions })
|
||||
Ok(get_pushrule_actions::v3::Response {
|
||||
actions,
|
||||
})
|
||||
}
|
||||
|
||||
/// # `PUT /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}/actions`
|
||||
///
|
||||
/// Sets the actions of a single specified push rule for this user.
|
||||
pub async fn set_pushrule_actions_route(
|
||||
body: Ruma<set_pushrule_actions::v3::Request>,
|
||||
body: Ruma<set_pushrule_actions::v3::Request>,
|
||||
) -> Result<set_pushrule_actions::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
if body.scope != RuleScope::Global {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Scopes other than 'global' are not supported.",
|
||||
));
|
||||
}
|
||||
if body.scope != RuleScope::Global {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Scopes other than 'global' are not supported.",
|
||||
));
|
||||
}
|
||||
|
||||
let event = services()
|
||||
.account_data
|
||||
.get(
|
||||
None,
|
||||
sender_user,
|
||||
GlobalAccountDataEventType::PushRules.to_string().into(),
|
||||
)?
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"PushRules event not found.",
|
||||
))?;
|
||||
let event = services()
|
||||
.account_data
|
||||
.get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())?
|
||||
.ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
|
||||
|
||||
let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get())
|
||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))?;
|
||||
let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get())
|
||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))?;
|
||||
|
||||
if account_data
|
||||
.content
|
||||
.global
|
||||
.set_actions(body.kind.clone(), &body.rule_id, body.actions.clone())
|
||||
.is_err()
|
||||
{
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"Push rule not found.",
|
||||
));
|
||||
}
|
||||
if account_data.content.global.set_actions(body.kind.clone(), &body.rule_id, body.actions.clone()).is_err() {
|
||||
return Err(Error::BadRequest(ErrorKind::NotFound, "Push rule not found."));
|
||||
}
|
||||
|
||||
services().account_data.update(
|
||||
None,
|
||||
sender_user,
|
||||
GlobalAccountDataEventType::PushRules.to_string().into(),
|
||||
&serde_json::to_value(account_data).expect("to json value always works"),
|
||||
)?;
|
||||
services().account_data.update(
|
||||
None,
|
||||
sender_user,
|
||||
GlobalAccountDataEventType::PushRules.to_string().into(),
|
||||
&serde_json::to_value(account_data).expect("to json value always works"),
|
||||
)?;
|
||||
|
||||
Ok(set_pushrule_actions::v3::Response {})
|
||||
Ok(set_pushrule_actions::v3::Response {})
|
||||
}
|
||||
|
||||
/// # `GET /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}/enabled`
|
||||
///
|
||||
/// Gets the enabled status of a single specified push rule for this user.
|
||||
pub async fn get_pushrule_enabled_route(
|
||||
body: Ruma<get_pushrule_enabled::v3::Request>,
|
||||
body: Ruma<get_pushrule_enabled::v3::Request>,
|
||||
) -> Result<get_pushrule_enabled::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
if body.scope != RuleScope::Global {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Scopes other than 'global' are not supported.",
|
||||
));
|
||||
}
|
||||
if body.scope != RuleScope::Global {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Scopes other than 'global' are not supported.",
|
||||
));
|
||||
}
|
||||
|
||||
let event = services()
|
||||
.account_data
|
||||
.get(
|
||||
None,
|
||||
sender_user,
|
||||
GlobalAccountDataEventType::PushRules.to_string().into(),
|
||||
)?
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"PushRules event not found.",
|
||||
))?;
|
||||
let event = services()
|
||||
.account_data
|
||||
.get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())?
|
||||
.ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
|
||||
|
||||
let account_data = serde_json::from_str::<PushRulesEvent>(event.get())
|
||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))?;
|
||||
let account_data = serde_json::from_str::<PushRulesEvent>(event.get())
|
||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))?;
|
||||
|
||||
let global = account_data.content.global;
|
||||
let enabled = global
|
||||
.get(body.kind.clone(), &body.rule_id)
|
||||
.map(ruma::push::AnyPushRuleRef::enabled)
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"Push rule not found.",
|
||||
))?;
|
||||
let global = account_data.content.global;
|
||||
let enabled = global
|
||||
.get(body.kind.clone(), &body.rule_id)
|
||||
.map(ruma::push::AnyPushRuleRef::enabled)
|
||||
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Push rule not found."))?;
|
||||
|
||||
Ok(get_pushrule_enabled::v3::Response { enabled })
|
||||
Ok(get_pushrule_enabled::v3::Response {
|
||||
enabled,
|
||||
})
|
||||
}
|
||||
|
||||
/// # `PUT /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}/enabled`
|
||||
///
|
||||
/// Sets the enabled status of a single specified push rule for this user.
|
||||
pub async fn set_pushrule_enabled_route(
|
||||
body: Ruma<set_pushrule_enabled::v3::Request>,
|
||||
body: Ruma<set_pushrule_enabled::v3::Request>,
|
||||
) -> Result<set_pushrule_enabled::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
if body.scope != RuleScope::Global {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Scopes other than 'global' are not supported.",
|
||||
));
|
||||
}
|
||||
if body.scope != RuleScope::Global {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Scopes other than 'global' are not supported.",
|
||||
));
|
||||
}
|
||||
|
||||
let event = services()
|
||||
.account_data
|
||||
.get(
|
||||
None,
|
||||
sender_user,
|
||||
GlobalAccountDataEventType::PushRules.to_string().into(),
|
||||
)?
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"PushRules event not found.",
|
||||
))?;
|
||||
let event = services()
|
||||
.account_data
|
||||
.get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())?
|
||||
.ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
|
||||
|
||||
let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get())
|
||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))?;
|
||||
let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get())
|
||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))?;
|
||||
|
||||
if account_data
|
||||
.content
|
||||
.global
|
||||
.set_enabled(body.kind.clone(), &body.rule_id, body.enabled)
|
||||
.is_err()
|
||||
{
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"Push rule not found.",
|
||||
));
|
||||
}
|
||||
if account_data.content.global.set_enabled(body.kind.clone(), &body.rule_id, body.enabled).is_err() {
|
||||
return Err(Error::BadRequest(ErrorKind::NotFound, "Push rule not found."));
|
||||
}
|
||||
|
||||
services().account_data.update(
|
||||
None,
|
||||
sender_user,
|
||||
GlobalAccountDataEventType::PushRules.to_string().into(),
|
||||
&serde_json::to_value(account_data).expect("to json value always works"),
|
||||
)?;
|
||||
services().account_data.update(
|
||||
None,
|
||||
sender_user,
|
||||
GlobalAccountDataEventType::PushRules.to_string().into(),
|
||||
&serde_json::to_value(account_data).expect("to json value always works"),
|
||||
)?;
|
||||
|
||||
Ok(set_pushrule_enabled::v3::Response {})
|
||||
Ok(set_pushrule_enabled::v3::Response {})
|
||||
}
|
||||
|
||||
/// # `DELETE /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}`
|
||||
///
|
||||
/// Deletes a single specified push rule for this user.
|
||||
pub async fn delete_pushrule_route(
|
||||
body: Ruma<delete_pushrule::v3::Request>,
|
||||
) -> Result<delete_pushrule::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
pub async fn delete_pushrule_route(body: Ruma<delete_pushrule::v3::Request>) -> Result<delete_pushrule::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
if body.scope != RuleScope::Global {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Scopes other than 'global' are not supported.",
|
||||
));
|
||||
}
|
||||
if body.scope != RuleScope::Global {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Scopes other than 'global' are not supported.",
|
||||
));
|
||||
}
|
||||
|
||||
let event = services()
|
||||
.account_data
|
||||
.get(
|
||||
None,
|
||||
sender_user,
|
||||
GlobalAccountDataEventType::PushRules.to_string().into(),
|
||||
)?
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"PushRules event not found.",
|
||||
))?;
|
||||
let event = services()
|
||||
.account_data
|
||||
.get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())?
|
||||
.ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
|
||||
|
||||
let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get())
|
||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))?;
|
||||
let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get())
|
||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))?;
|
||||
|
||||
if let Err(error) = account_data
|
||||
.content
|
||||
.global
|
||||
.remove(body.kind.clone(), &body.rule_id)
|
||||
{
|
||||
let err = match error {
|
||||
RemovePushRuleError::ServerDefault => Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Cannot delete a server-default pushrule.",
|
||||
),
|
||||
RemovePushRuleError::NotFound => {
|
||||
Error::BadRequest(ErrorKind::NotFound, "Push rule not found.")
|
||||
}
|
||||
_ => Error::BadRequest(ErrorKind::InvalidParam, "Invalid data."),
|
||||
};
|
||||
if let Err(error) = account_data.content.global.remove(body.kind.clone(), &body.rule_id) {
|
||||
let err = match error {
|
||||
RemovePushRuleError::ServerDefault => {
|
||||
Error::BadRequest(ErrorKind::InvalidParam, "Cannot delete a server-default pushrule.")
|
||||
},
|
||||
RemovePushRuleError::NotFound => Error::BadRequest(ErrorKind::NotFound, "Push rule not found."),
|
||||
_ => Error::BadRequest(ErrorKind::InvalidParam, "Invalid data."),
|
||||
};
|
||||
|
||||
return Err(err);
|
||||
}
|
||||
return Err(err);
|
||||
}
|
||||
|
||||
services().account_data.update(
|
||||
None,
|
||||
sender_user,
|
||||
GlobalAccountDataEventType::PushRules.to_string().into(),
|
||||
&serde_json::to_value(account_data).expect("to json value always works"),
|
||||
)?;
|
||||
services().account_data.update(
|
||||
None,
|
||||
sender_user,
|
||||
GlobalAccountDataEventType::PushRules.to_string().into(),
|
||||
&serde_json::to_value(account_data).expect("to json value always works"),
|
||||
)?;
|
||||
|
||||
Ok(delete_pushrule::v3::Response {})
|
||||
Ok(delete_pushrule::v3::Response {})
|
||||
}
|
||||
|
||||
/// # `GET /_matrix/client/r0/pushers`
|
||||
///
|
||||
/// Gets all currently active pushers for the sender user.
|
||||
pub async fn get_pushers_route(
|
||||
body: Ruma<get_pushers::v3::Request>,
|
||||
) -> Result<get_pushers::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
pub async fn get_pushers_route(body: Ruma<get_pushers::v3::Request>) -> Result<get_pushers::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
Ok(get_pushers::v3::Response {
|
||||
pushers: services().pusher.get_pushers(sender_user)?,
|
||||
})
|
||||
Ok(get_pushers::v3::Response {
|
||||
pushers: services().pusher.get_pushers(sender_user)?,
|
||||
})
|
||||
}
|
||||
|
||||
/// # `POST /_matrix/client/r0/pushers/set`
|
||||
@@ -419,14 +322,10 @@ pub async fn get_pushers_route(
|
||||
/// Adds a pusher for the sender user.
|
||||
///
|
||||
/// - TODO: Handle `append`
|
||||
pub async fn set_pushers_route(
|
||||
body: Ruma<set_pusher::v3::Request>,
|
||||
) -> Result<set_pusher::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
pub async fn set_pushers_route(body: Ruma<set_pusher::v3::Request>) -> Result<set_pusher::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
services()
|
||||
.pusher
|
||||
.set_pusher(sender_user, body.action.clone())?;
|
||||
services().pusher.set_pusher(sender_user, body.action.clone())?;
|
||||
|
||||
Ok(set_pusher::v3::Response::default())
|
||||
Ok(set_pusher::v3::Response::default())
|
||||
}
|
||||
|
||||
@@ -1,182 +1,161 @@
|
||||
use crate::{service::rooms::timeline::PduCount, services, Error, Result, Ruma};
|
||||
use ruma::{
|
||||
api::client::{error::ErrorKind, read_marker::set_read_marker, receipt::create_receipt},
|
||||
events::{
|
||||
receipt::{ReceiptThread, ReceiptType},
|
||||
RoomAccountDataEventType,
|
||||
},
|
||||
MilliSecondsSinceUnixEpoch,
|
||||
};
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
use ruma::{
|
||||
api::client::{error::ErrorKind, read_marker::set_read_marker, receipt::create_receipt},
|
||||
events::{
|
||||
receipt::{ReceiptThread, ReceiptType},
|
||||
RoomAccountDataEventType,
|
||||
},
|
||||
MilliSecondsSinceUnixEpoch,
|
||||
};
|
||||
|
||||
use crate::{service::rooms::timeline::PduCount, services, Error, Result, Ruma};
|
||||
|
||||
/// # `POST /_matrix/client/r0/rooms/{roomId}/read_markers`
|
||||
///
|
||||
/// Sets different types of read markers.
|
||||
///
|
||||
/// - Updates fully-read account data event to `fully_read`
|
||||
/// - If `read_receipt` is set: Update private marker and public read receipt EDU
|
||||
pub async fn set_read_marker_route(
|
||||
body: Ruma<set_read_marker::v3::Request>,
|
||||
) -> Result<set_read_marker::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
/// - If `read_receipt` is set: Update private marker and public read receipt
|
||||
/// EDU
|
||||
pub async fn set_read_marker_route(body: Ruma<set_read_marker::v3::Request>) -> Result<set_read_marker::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
if let Some(fully_read) = &body.fully_read {
|
||||
let fully_read_event = ruma::events::fully_read::FullyReadEvent {
|
||||
content: ruma::events::fully_read::FullyReadEventContent {
|
||||
event_id: fully_read.clone(),
|
||||
},
|
||||
};
|
||||
services().account_data.update(
|
||||
Some(&body.room_id),
|
||||
sender_user,
|
||||
RoomAccountDataEventType::FullyRead,
|
||||
&serde_json::to_value(fully_read_event).expect("to json value always works"),
|
||||
)?;
|
||||
}
|
||||
if let Some(fully_read) = &body.fully_read {
|
||||
let fully_read_event = ruma::events::fully_read::FullyReadEvent {
|
||||
content: ruma::events::fully_read::FullyReadEventContent {
|
||||
event_id: fully_read.clone(),
|
||||
},
|
||||
};
|
||||
services().account_data.update(
|
||||
Some(&body.room_id),
|
||||
sender_user,
|
||||
RoomAccountDataEventType::FullyRead,
|
||||
&serde_json::to_value(fully_read_event).expect("to json value always works"),
|
||||
)?;
|
||||
}
|
||||
|
||||
if body.private_read_receipt.is_some() || body.read_receipt.is_some() {
|
||||
services()
|
||||
.rooms
|
||||
.user
|
||||
.reset_notification_counts(sender_user, &body.room_id)?;
|
||||
}
|
||||
if body.private_read_receipt.is_some() || body.read_receipt.is_some() {
|
||||
services().rooms.user.reset_notification_counts(sender_user, &body.room_id)?;
|
||||
}
|
||||
|
||||
if let Some(event) = &body.private_read_receipt {
|
||||
let count = services()
|
||||
.rooms
|
||||
.timeline
|
||||
.get_pdu_count(event)?
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Event does not exist.",
|
||||
))?;
|
||||
let count = match count {
|
||||
PduCount::Backfilled(_) => {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Read receipt is in backfilled timeline",
|
||||
))
|
||||
}
|
||||
PduCount::Normal(c) => c,
|
||||
};
|
||||
services()
|
||||
.rooms
|
||||
.edus
|
||||
.read_receipt
|
||||
.private_read_set(&body.room_id, sender_user, count)?;
|
||||
}
|
||||
if let Some(event) = &body.private_read_receipt {
|
||||
let count = services()
|
||||
.rooms
|
||||
.timeline
|
||||
.get_pdu_count(event)?
|
||||
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Event does not exist."))?;
|
||||
let count = match count {
|
||||
PduCount::Backfilled(_) => {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Read receipt is in backfilled timeline",
|
||||
))
|
||||
},
|
||||
PduCount::Normal(c) => c,
|
||||
};
|
||||
services().rooms.edus.read_receipt.private_read_set(&body.room_id, sender_user, count)?;
|
||||
}
|
||||
|
||||
if let Some(event) = &body.read_receipt {
|
||||
let mut user_receipts = BTreeMap::new();
|
||||
user_receipts.insert(
|
||||
sender_user.clone(),
|
||||
ruma::events::receipt::Receipt {
|
||||
ts: Some(MilliSecondsSinceUnixEpoch::now()),
|
||||
thread: ReceiptThread::Unthreaded,
|
||||
},
|
||||
);
|
||||
if let Some(event) = &body.read_receipt {
|
||||
let mut user_receipts = BTreeMap::new();
|
||||
user_receipts.insert(
|
||||
sender_user.clone(),
|
||||
ruma::events::receipt::Receipt {
|
||||
ts: Some(MilliSecondsSinceUnixEpoch::now()),
|
||||
thread: ReceiptThread::Unthreaded,
|
||||
},
|
||||
);
|
||||
|
||||
let mut receipts = BTreeMap::new();
|
||||
receipts.insert(ReceiptType::Read, user_receipts);
|
||||
let mut receipts = BTreeMap::new();
|
||||
receipts.insert(ReceiptType::Read, user_receipts);
|
||||
|
||||
let mut receipt_content = BTreeMap::new();
|
||||
receipt_content.insert(event.to_owned(), receipts);
|
||||
let mut receipt_content = BTreeMap::new();
|
||||
receipt_content.insert(event.to_owned(), receipts);
|
||||
|
||||
services().rooms.edus.read_receipt.readreceipt_update(
|
||||
sender_user,
|
||||
&body.room_id,
|
||||
ruma::events::receipt::ReceiptEvent {
|
||||
content: ruma::events::receipt::ReceiptEventContent(receipt_content),
|
||||
room_id: body.room_id.clone(),
|
||||
},
|
||||
)?;
|
||||
}
|
||||
services().rooms.edus.read_receipt.readreceipt_update(
|
||||
sender_user,
|
||||
&body.room_id,
|
||||
ruma::events::receipt::ReceiptEvent {
|
||||
content: ruma::events::receipt::ReceiptEventContent(receipt_content),
|
||||
room_id: body.room_id.clone(),
|
||||
},
|
||||
)?;
|
||||
}
|
||||
|
||||
Ok(set_read_marker::v3::Response {})
|
||||
Ok(set_read_marker::v3::Response {})
|
||||
}
|
||||
|
||||
/// # `POST /_matrix/client/r0/rooms/{roomId}/receipt/{receiptType}/{eventId}`
|
||||
///
|
||||
/// Sets private read marker and public read receipt EDU.
|
||||
pub async fn create_receipt_route(
|
||||
body: Ruma<create_receipt::v3::Request>,
|
||||
) -> Result<create_receipt::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
pub async fn create_receipt_route(body: Ruma<create_receipt::v3::Request>) -> Result<create_receipt::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
if matches!(
|
||||
&body.receipt_type,
|
||||
create_receipt::v3::ReceiptType::Read | create_receipt::v3::ReceiptType::ReadPrivate
|
||||
) {
|
||||
services()
|
||||
.rooms
|
||||
.user
|
||||
.reset_notification_counts(sender_user, &body.room_id)?;
|
||||
}
|
||||
if matches!(
|
||||
&body.receipt_type,
|
||||
create_receipt::v3::ReceiptType::Read | create_receipt::v3::ReceiptType::ReadPrivate
|
||||
) {
|
||||
services().rooms.user.reset_notification_counts(sender_user, &body.room_id)?;
|
||||
}
|
||||
|
||||
match body.receipt_type {
|
||||
create_receipt::v3::ReceiptType::FullyRead => {
|
||||
let fully_read_event = ruma::events::fully_read::FullyReadEvent {
|
||||
content: ruma::events::fully_read::FullyReadEventContent {
|
||||
event_id: body.event_id.clone(),
|
||||
},
|
||||
};
|
||||
services().account_data.update(
|
||||
Some(&body.room_id),
|
||||
sender_user,
|
||||
RoomAccountDataEventType::FullyRead,
|
||||
&serde_json::to_value(fully_read_event).expect("to json value always works"),
|
||||
)?;
|
||||
}
|
||||
create_receipt::v3::ReceiptType::Read => {
|
||||
let mut user_receipts = BTreeMap::new();
|
||||
user_receipts.insert(
|
||||
sender_user.clone(),
|
||||
ruma::events::receipt::Receipt {
|
||||
ts: Some(MilliSecondsSinceUnixEpoch::now()),
|
||||
thread: ReceiptThread::Unthreaded,
|
||||
},
|
||||
);
|
||||
let mut receipts = BTreeMap::new();
|
||||
receipts.insert(ReceiptType::Read, user_receipts);
|
||||
match body.receipt_type {
|
||||
create_receipt::v3::ReceiptType::FullyRead => {
|
||||
let fully_read_event = ruma::events::fully_read::FullyReadEvent {
|
||||
content: ruma::events::fully_read::FullyReadEventContent {
|
||||
event_id: body.event_id.clone(),
|
||||
},
|
||||
};
|
||||
services().account_data.update(
|
||||
Some(&body.room_id),
|
||||
sender_user,
|
||||
RoomAccountDataEventType::FullyRead,
|
||||
&serde_json::to_value(fully_read_event).expect("to json value always works"),
|
||||
)?;
|
||||
},
|
||||
create_receipt::v3::ReceiptType::Read => {
|
||||
let mut user_receipts = BTreeMap::new();
|
||||
user_receipts.insert(
|
||||
sender_user.clone(),
|
||||
ruma::events::receipt::Receipt {
|
||||
ts: Some(MilliSecondsSinceUnixEpoch::now()),
|
||||
thread: ReceiptThread::Unthreaded,
|
||||
},
|
||||
);
|
||||
let mut receipts = BTreeMap::new();
|
||||
receipts.insert(ReceiptType::Read, user_receipts);
|
||||
|
||||
let mut receipt_content = BTreeMap::new();
|
||||
receipt_content.insert(body.event_id.clone(), receipts);
|
||||
let mut receipt_content = BTreeMap::new();
|
||||
receipt_content.insert(body.event_id.clone(), receipts);
|
||||
|
||||
services().rooms.edus.read_receipt.readreceipt_update(
|
||||
sender_user,
|
||||
&body.room_id,
|
||||
ruma::events::receipt::ReceiptEvent {
|
||||
content: ruma::events::receipt::ReceiptEventContent(receipt_content),
|
||||
room_id: body.room_id.clone(),
|
||||
},
|
||||
)?;
|
||||
}
|
||||
create_receipt::v3::ReceiptType::ReadPrivate => {
|
||||
let count = services()
|
||||
.rooms
|
||||
.timeline
|
||||
.get_pdu_count(&body.event_id)?
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Event does not exist.",
|
||||
))?;
|
||||
let count = match count {
|
||||
PduCount::Backfilled(_) => {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Read receipt is in backfilled timeline",
|
||||
))
|
||||
}
|
||||
PduCount::Normal(c) => c,
|
||||
};
|
||||
services().rooms.edus.read_receipt.private_read_set(
|
||||
&body.room_id,
|
||||
sender_user,
|
||||
count,
|
||||
)?;
|
||||
}
|
||||
_ => return Err(Error::bad_database("Unsupported receipt type")),
|
||||
}
|
||||
services().rooms.edus.read_receipt.readreceipt_update(
|
||||
sender_user,
|
||||
&body.room_id,
|
||||
ruma::events::receipt::ReceiptEvent {
|
||||
content: ruma::events::receipt::ReceiptEventContent(receipt_content),
|
||||
room_id: body.room_id.clone(),
|
||||
},
|
||||
)?;
|
||||
},
|
||||
create_receipt::v3::ReceiptType::ReadPrivate => {
|
||||
let count = services()
|
||||
.rooms
|
||||
.timeline
|
||||
.get_pdu_count(&body.event_id)?
|
||||
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Event does not exist."))?;
|
||||
let count = match count {
|
||||
PduCount::Backfilled(_) => {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Read receipt is in backfilled timeline",
|
||||
))
|
||||
},
|
||||
PduCount::Normal(c) => c,
|
||||
};
|
||||
services().rooms.edus.read_receipt.private_read_set(&body.room_id, sender_user, count)?;
|
||||
},
|
||||
_ => return Err(Error::bad_database("Unsupported receipt type")),
|
||||
}
|
||||
|
||||
Ok(create_receipt::v3::Response {})
|
||||
Ok(create_receipt::v3::Response {})
|
||||
}
|
||||
|
||||
@@ -1,58 +1,51 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{service::pdu::PduBuilder, services, Result, Ruma};
|
||||
use ruma::{
|
||||
api::client::redact::redact_event,
|
||||
events::{room::redaction::RoomRedactionEventContent, TimelineEventType},
|
||||
api::client::redact::redact_event,
|
||||
events::{room::redaction::RoomRedactionEventContent, TimelineEventType},
|
||||
};
|
||||
|
||||
use serde_json::value::to_raw_value;
|
||||
|
||||
use crate::{service::pdu::PduBuilder, services, Result, Ruma};
|
||||
|
||||
/// # `PUT /_matrix/client/r0/rooms/{roomId}/redact/{eventId}/{txnId}`
|
||||
///
|
||||
/// Tries to send a redaction event into the room.
|
||||
///
|
||||
/// - TODO: Handle txn id
|
||||
pub async fn redact_event_route(
|
||||
body: Ruma<redact_event::v3::Request>,
|
||||
) -> Result<redact_event::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let body = body.body;
|
||||
pub async fn redact_event_route(body: Ruma<redact_event::v3::Request>) -> Result<redact_event::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let body = body.body;
|
||||
|
||||
let mutex_state = Arc::clone(
|
||||
services()
|
||||
.globals
|
||||
.roomid_mutex_state
|
||||
.write()
|
||||
.unwrap()
|
||||
.entry(body.room_id.clone())
|
||||
.or_default(),
|
||||
);
|
||||
let state_lock = mutex_state.lock().await;
|
||||
let mutex_state =
|
||||
Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(body.room_id.clone()).or_default());
|
||||
let state_lock = mutex_state.lock().await;
|
||||
|
||||
let event_id = services()
|
||||
.rooms
|
||||
.timeline
|
||||
.build_and_append_pdu(
|
||||
PduBuilder {
|
||||
event_type: TimelineEventType::RoomRedaction,
|
||||
content: to_raw_value(&RoomRedactionEventContent {
|
||||
redacts: Some(body.event_id.clone()),
|
||||
reason: body.reason.clone(),
|
||||
})
|
||||
.expect("event is valid, we just created it"),
|
||||
unsigned: None,
|
||||
state_key: None,
|
||||
redacts: Some(body.event_id.into()),
|
||||
},
|
||||
sender_user,
|
||||
&body.room_id,
|
||||
&state_lock,
|
||||
)
|
||||
.await?;
|
||||
let event_id = services()
|
||||
.rooms
|
||||
.timeline
|
||||
.build_and_append_pdu(
|
||||
PduBuilder {
|
||||
event_type: TimelineEventType::RoomRedaction,
|
||||
content: to_raw_value(&RoomRedactionEventContent {
|
||||
redacts: Some(body.event_id.clone()),
|
||||
reason: body.reason.clone(),
|
||||
})
|
||||
.expect("event is valid, we just created it"),
|
||||
unsigned: None,
|
||||
state_key: None,
|
||||
redacts: Some(body.event_id.into()),
|
||||
},
|
||||
sender_user,
|
||||
&body.room_id,
|
||||
&state_lock,
|
||||
)
|
||||
.await?;
|
||||
|
||||
drop(state_lock);
|
||||
drop(state_lock);
|
||||
|
||||
let event_id = (*event_id).to_owned();
|
||||
Ok(redact_event::v3::Response { event_id })
|
||||
let event_id = (*event_id).to_owned();
|
||||
Ok(redact_event::v3::Response {
|
||||
event_id,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,146 +1,113 @@
|
||||
use ruma::api::client::relations::{
|
||||
get_relating_events, get_relating_events_with_rel_type,
|
||||
get_relating_events_with_rel_type_and_event_type,
|
||||
get_relating_events, get_relating_events_with_rel_type, get_relating_events_with_rel_type_and_event_type,
|
||||
};
|
||||
|
||||
use crate::{service::rooms::timeline::PduCount, services, Result, Ruma};
|
||||
|
||||
/// # `GET /_matrix/client/r0/rooms/{roomId}/relations/{eventId}/{relType}/{eventType}`
|
||||
pub async fn get_relating_events_with_rel_type_and_event_type_route(
|
||||
body: Ruma<get_relating_events_with_rel_type_and_event_type::v1::Request>,
|
||||
body: Ruma<get_relating_events_with_rel_type_and_event_type::v1::Request>,
|
||||
) -> Result<get_relating_events_with_rel_type_and_event_type::v1::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let from = match body.from.clone() {
|
||||
Some(from) => PduCount::try_from_string(&from)?,
|
||||
None => match ruma::api::Direction::Backward {
|
||||
// TODO: fix ruma so `body.dir` exists
|
||||
ruma::api::Direction::Forward => PduCount::min(),
|
||||
ruma::api::Direction::Backward => PduCount::max(),
|
||||
},
|
||||
};
|
||||
let from = match body.from.clone() {
|
||||
Some(from) => PduCount::try_from_string(&from)?,
|
||||
None => match ruma::api::Direction::Backward {
|
||||
// TODO: fix ruma so `body.dir` exists
|
||||
ruma::api::Direction::Forward => PduCount::min(),
|
||||
ruma::api::Direction::Backward => PduCount::max(),
|
||||
},
|
||||
};
|
||||
|
||||
let to = body
|
||||
.to
|
||||
.as_ref()
|
||||
.and_then(|t| PduCount::try_from_string(t).ok());
|
||||
let to = body.to.as_ref().and_then(|t| PduCount::try_from_string(t).ok());
|
||||
|
||||
// Use limit or else 10, with maximum 100
|
||||
let limit = body
|
||||
.limit
|
||||
.and_then(|u| u32::try_from(u).ok())
|
||||
.map_or(10_usize, |u| u as usize)
|
||||
.min(100);
|
||||
// Use limit or else 10, with maximum 100
|
||||
let limit = body.limit.and_then(|u| u32::try_from(u).ok()).map_or(10_usize, |u| u as usize).min(100);
|
||||
|
||||
let res = services()
|
||||
.rooms
|
||||
.pdu_metadata
|
||||
.paginate_relations_with_filter(
|
||||
sender_user,
|
||||
&body.room_id,
|
||||
&body.event_id,
|
||||
Some(body.event_type.clone()),
|
||||
Some(body.rel_type.clone()),
|
||||
from,
|
||||
to,
|
||||
limit,
|
||||
)?;
|
||||
let res = services().rooms.pdu_metadata.paginate_relations_with_filter(
|
||||
sender_user,
|
||||
&body.room_id,
|
||||
&body.event_id,
|
||||
Some(body.event_type.clone()),
|
||||
Some(body.rel_type.clone()),
|
||||
from,
|
||||
to,
|
||||
limit,
|
||||
)?;
|
||||
|
||||
Ok(
|
||||
get_relating_events_with_rel_type_and_event_type::v1::Response {
|
||||
chunk: res.chunk,
|
||||
next_batch: res.next_batch,
|
||||
prev_batch: res.prev_batch,
|
||||
},
|
||||
)
|
||||
Ok(get_relating_events_with_rel_type_and_event_type::v1::Response {
|
||||
chunk: res.chunk,
|
||||
next_batch: res.next_batch,
|
||||
prev_batch: res.prev_batch,
|
||||
})
|
||||
}
|
||||
|
||||
/// # `GET /_matrix/client/r0/rooms/{roomId}/relations/{eventId}/{relType}`
|
||||
pub async fn get_relating_events_with_rel_type_route(
|
||||
body: Ruma<get_relating_events_with_rel_type::v1::Request>,
|
||||
body: Ruma<get_relating_events_with_rel_type::v1::Request>,
|
||||
) -> Result<get_relating_events_with_rel_type::v1::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let from = match body.from.clone() {
|
||||
Some(from) => PduCount::try_from_string(&from)?,
|
||||
None => match ruma::api::Direction::Backward {
|
||||
// TODO: fix ruma so `body.dir` exists
|
||||
ruma::api::Direction::Forward => PduCount::min(),
|
||||
ruma::api::Direction::Backward => PduCount::max(),
|
||||
},
|
||||
};
|
||||
let from = match body.from.clone() {
|
||||
Some(from) => PduCount::try_from_string(&from)?,
|
||||
None => match ruma::api::Direction::Backward {
|
||||
// TODO: fix ruma so `body.dir` exists
|
||||
ruma::api::Direction::Forward => PduCount::min(),
|
||||
ruma::api::Direction::Backward => PduCount::max(),
|
||||
},
|
||||
};
|
||||
|
||||
let to = body
|
||||
.to
|
||||
.as_ref()
|
||||
.and_then(|t| PduCount::try_from_string(t).ok());
|
||||
let to = body.to.as_ref().and_then(|t| PduCount::try_from_string(t).ok());
|
||||
|
||||
// Use limit or else 10, with maximum 100
|
||||
let limit = body
|
||||
.limit
|
||||
.and_then(|u| u32::try_from(u).ok())
|
||||
.map_or(10_usize, |u| u as usize)
|
||||
.min(100);
|
||||
// Use limit or else 10, with maximum 100
|
||||
let limit = body.limit.and_then(|u| u32::try_from(u).ok()).map_or(10_usize, |u| u as usize).min(100);
|
||||
|
||||
let res = services()
|
||||
.rooms
|
||||
.pdu_metadata
|
||||
.paginate_relations_with_filter(
|
||||
sender_user,
|
||||
&body.room_id,
|
||||
&body.event_id,
|
||||
None,
|
||||
Some(body.rel_type.clone()),
|
||||
from,
|
||||
to,
|
||||
limit,
|
||||
)?;
|
||||
let res = services().rooms.pdu_metadata.paginate_relations_with_filter(
|
||||
sender_user,
|
||||
&body.room_id,
|
||||
&body.event_id,
|
||||
None,
|
||||
Some(body.rel_type.clone()),
|
||||
from,
|
||||
to,
|
||||
limit,
|
||||
)?;
|
||||
|
||||
Ok(get_relating_events_with_rel_type::v1::Response {
|
||||
chunk: res.chunk,
|
||||
next_batch: res.next_batch,
|
||||
prev_batch: res.prev_batch,
|
||||
})
|
||||
Ok(get_relating_events_with_rel_type::v1::Response {
|
||||
chunk: res.chunk,
|
||||
next_batch: res.next_batch,
|
||||
prev_batch: res.prev_batch,
|
||||
})
|
||||
}
|
||||
|
||||
/// # `GET /_matrix/client/r0/rooms/{roomId}/relations/{eventId}`
|
||||
pub async fn get_relating_events_route(
|
||||
body: Ruma<get_relating_events::v1::Request>,
|
||||
body: Ruma<get_relating_events::v1::Request>,
|
||||
) -> Result<get_relating_events::v1::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let from = match body.from.clone() {
|
||||
Some(from) => PduCount::try_from_string(&from)?,
|
||||
None => match ruma::api::Direction::Backward {
|
||||
// TODO: fix ruma so `body.dir` exists
|
||||
ruma::api::Direction::Forward => PduCount::min(),
|
||||
ruma::api::Direction::Backward => PduCount::max(),
|
||||
},
|
||||
};
|
||||
let from = match body.from.clone() {
|
||||
Some(from) => PduCount::try_from_string(&from)?,
|
||||
None => match ruma::api::Direction::Backward {
|
||||
// TODO: fix ruma so `body.dir` exists
|
||||
ruma::api::Direction::Forward => PduCount::min(),
|
||||
ruma::api::Direction::Backward => PduCount::max(),
|
||||
},
|
||||
};
|
||||
|
||||
let to = body
|
||||
.to
|
||||
.as_ref()
|
||||
.and_then(|t| PduCount::try_from_string(t).ok());
|
||||
let to = body.to.as_ref().and_then(|t| PduCount::try_from_string(t).ok());
|
||||
|
||||
// Use limit or else 10, with maximum 100
|
||||
let limit = body
|
||||
.limit
|
||||
.and_then(|u| u32::try_from(u).ok())
|
||||
.map_or(10_usize, |u| u as usize)
|
||||
.min(100);
|
||||
// Use limit or else 10, with maximum 100
|
||||
let limit = body.limit.and_then(|u| u32::try_from(u).ok()).map_or(10_usize, |u| u as usize).min(100);
|
||||
|
||||
services()
|
||||
.rooms
|
||||
.pdu_metadata
|
||||
.paginate_relations_with_filter(
|
||||
sender_user,
|
||||
&body.room_id,
|
||||
&body.event_id,
|
||||
None,
|
||||
None,
|
||||
from,
|
||||
to,
|
||||
limit,
|
||||
)
|
||||
services().rooms.pdu_metadata.paginate_relations_with_filter(
|
||||
sender_user,
|
||||
&body.room_id,
|
||||
&body.event_id,
|
||||
None,
|
||||
None,
|
||||
from,
|
||||
to,
|
||||
limit,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1,118 +1,112 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::{services, utils::HtmlEscape, Error, Result, Ruma};
|
||||
use rand::Rng;
|
||||
use ruma::{
|
||||
api::client::{error::ErrorKind, room::report_content},
|
||||
events::room::message,
|
||||
int,
|
||||
api::client::{error::ErrorKind, room::report_content},
|
||||
events::room::message,
|
||||
int,
|
||||
};
|
||||
use tokio::time::sleep;
|
||||
use tracing::{debug, info};
|
||||
|
||||
use crate::{services, utils::HtmlEscape, Error, Result, Ruma};
|
||||
|
||||
/// # `POST /_matrix/client/v3/rooms/{roomId}/report/{eventId}`
|
||||
///
|
||||
/// Reports an inappropriate event to homeserver admins
|
||||
///
|
||||
pub async fn report_event_route(
|
||||
body: Ruma<report_content::v3::Request>,
|
||||
) -> Result<report_content::v3::Response> {
|
||||
// user authentication
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
pub async fn report_event_route(body: Ruma<report_content::v3::Request>) -> Result<report_content::v3::Response> {
|
||||
// user authentication
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
info!("Received /report request by user {}", sender_user);
|
||||
info!("Received /report request by user {}", sender_user);
|
||||
|
||||
// check if we know about the reported event ID or if it's invalid
|
||||
let pdu = match services().rooms.timeline.get_pdu(&body.event_id)? {
|
||||
Some(pdu) => pdu,
|
||||
_ => {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"Event ID is not known to us or Event ID is invalid",
|
||||
))
|
||||
}
|
||||
};
|
||||
// check if we know about the reported event ID or if it's invalid
|
||||
let pdu = match services().rooms.timeline.get_pdu(&body.event_id)? {
|
||||
Some(pdu) => pdu,
|
||||
_ => {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"Event ID is not known to us or Event ID is invalid",
|
||||
))
|
||||
},
|
||||
};
|
||||
|
||||
// check if the room ID from the URI matches the PDU's room ID
|
||||
if body.room_id != pdu.room_id {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"Event ID does not belong to the reported room",
|
||||
));
|
||||
}
|
||||
// check if the room ID from the URI matches the PDU's room ID
|
||||
if body.room_id != pdu.room_id {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"Event ID does not belong to the reported room",
|
||||
));
|
||||
}
|
||||
|
||||
// check if reporting user is in the reporting room
|
||||
if !services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.room_members(&pdu.room_id)
|
||||
.filter_map(std::result::Result::ok)
|
||||
.any(|user_id| user_id == *sender_user)
|
||||
{
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"You are not in the room you are reporting.",
|
||||
));
|
||||
}
|
||||
// check if reporting user is in the reporting room
|
||||
if !services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.room_members(&pdu.room_id)
|
||||
.filter_map(std::result::Result::ok)
|
||||
.any(|user_id| user_id == *sender_user)
|
||||
{
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"You are not in the room you are reporting.",
|
||||
));
|
||||
}
|
||||
|
||||
// check if score is in valid range
|
||||
if let Some(true) = body.score.map(|s| s > int!(0) || s < int!(-100)) {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Invalid score, must be within 0 to -100",
|
||||
));
|
||||
};
|
||||
// check if score is in valid range
|
||||
if let Some(true) = body.score.map(|s| s > int!(0) || s < int!(-100)) {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Invalid score, must be within 0 to -100",
|
||||
));
|
||||
};
|
||||
|
||||
// check if report reasoning is less than or equal to 750 characters
|
||||
if let Some(true) = body.reason.clone().map(|s| s.chars().count() >= 750) {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Reason too long, should be 750 characters or fewer",
|
||||
));
|
||||
};
|
||||
// check if report reasoning is less than or equal to 750 characters
|
||||
if let Some(true) = body.reason.clone().map(|s| s.chars().count() >= 750) {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Reason too long, should be 750 characters or fewer",
|
||||
));
|
||||
};
|
||||
|
||||
// send admin room message that we received the report with an @room ping for urgency
|
||||
services()
|
||||
.admin
|
||||
.send_message(message::RoomMessageEventContent::text_html(
|
||||
format!(
|
||||
"@room Report received from: {}\n\n\
|
||||
Event ID: {}\n\
|
||||
Room ID: {}\n\
|
||||
Sent By: {}\n\n\
|
||||
Report Score: {}\n\
|
||||
Report Reason: {}",
|
||||
sender_user.to_owned(),
|
||||
pdu.event_id,
|
||||
pdu.room_id,
|
||||
pdu.sender.clone(),
|
||||
body.score.unwrap_or_else(|| ruma::Int::from(0)),
|
||||
body.reason.as_deref().unwrap_or("")
|
||||
),
|
||||
format!(
|
||||
"<details><summary>@room Report received from: <a href=\"https://matrix.to/#/{0}\">{0}\
|
||||
// send admin room message that we received the report with an @room ping for
|
||||
// urgency
|
||||
services().admin.send_message(message::RoomMessageEventContent::text_html(
|
||||
format!(
|
||||
"@room Report received from: {}\n\nEvent ID: {}\nRoom ID: {}\nSent By: {}\n\nReport Score: {}\nReport \
|
||||
Reason: {}",
|
||||
sender_user.to_owned(),
|
||||
pdu.event_id,
|
||||
pdu.room_id,
|
||||
pdu.sender.clone(),
|
||||
body.score.unwrap_or_else(|| ruma::Int::from(0)),
|
||||
body.reason.as_deref().unwrap_or("")
|
||||
),
|
||||
format!(
|
||||
"<details><summary>@room Report received from: <a href=\"https://matrix.to/#/{0}\">{0}\
|
||||
</a></summary><ul><li>Event Info<ul><li>Event ID: <code>{1}</code>\
|
||||
<a href=\"https://matrix.to/#/{2}/{1}\">🔗</a></li><li>Room ID: <code>{2}</code>\
|
||||
</li><li>Sent By: <a href=\"https://matrix.to/#/{3}\">{3}</a></li></ul></li><li>\
|
||||
Report Info<ul><li>Report Score: {4}</li><li>Report Reason: {5}</li></ul></li>\
|
||||
</ul></details>",
|
||||
sender_user.to_owned(),
|
||||
pdu.event_id.clone(),
|
||||
pdu.room_id.clone(),
|
||||
pdu.sender.clone(),
|
||||
body.score.unwrap_or_else(|| ruma::Int::from(0)),
|
||||
HtmlEscape(body.reason.as_deref().unwrap_or(""))
|
||||
),
|
||||
));
|
||||
sender_user.to_owned(),
|
||||
pdu.event_id.clone(),
|
||||
pdu.room_id.clone(),
|
||||
pdu.sender.clone(),
|
||||
body.score.unwrap_or_else(|| ruma::Int::from(0)),
|
||||
HtmlEscape(body.reason.as_deref().unwrap_or(""))
|
||||
),
|
||||
));
|
||||
|
||||
// even though this is kinda security by obscurity, let's still make a small random delay sending a successful response
|
||||
// per spec suggestion regarding enumerating for potential events existing in our server.
|
||||
let time_to_wait = rand::thread_rng().gen_range(8..21);
|
||||
debug!(
|
||||
"Got successful /report request, waiting {} seconds before sending successful response.",
|
||||
time_to_wait
|
||||
);
|
||||
sleep(Duration::from_secs(time_to_wait)).await;
|
||||
// even though this is kinda security by obscurity, let's still make a small
|
||||
// random delay sending a successful response per spec suggestion regarding
|
||||
// enumerating for potential events existing in our server.
|
||||
let time_to_wait = rand::thread_rng().gen_range(8..21);
|
||||
debug!(
|
||||
"Got successful /report request, waiting {} seconds before sending successful response.",
|
||||
time_to_wait
|
||||
);
|
||||
sleep(Duration::from_secs(time_to_wait)).await;
|
||||
|
||||
Ok(report_content::v3::Response {})
|
||||
Ok(report_content::v3::Response {})
|
||||
}
|
||||
|
||||
+730
-862
File diff suppressed because it is too large
Load Diff
+100
-118
@@ -1,138 +1,120 @@
|
||||
use crate::{services, Error, Result, Ruma};
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
use ruma::api::client::{
|
||||
error::ErrorKind,
|
||||
search::search_events::{
|
||||
self,
|
||||
v3::{EventContextResult, ResultCategories, ResultRoomEvents, SearchResult},
|
||||
},
|
||||
error::ErrorKind,
|
||||
search::search_events::{
|
||||
self,
|
||||
v3::{EventContextResult, ResultCategories, ResultRoomEvents, SearchResult},
|
||||
},
|
||||
};
|
||||
|
||||
use std::collections::BTreeMap;
|
||||
use crate::{services, Error, Result, Ruma};
|
||||
|
||||
/// # `POST /_matrix/client/r0/search`
|
||||
///
|
||||
/// Searches rooms for messages.
|
||||
///
|
||||
/// - Only works if the user is currently joined to the room (TODO: Respect history visibility)
|
||||
pub async fn search_events_route(
|
||||
body: Ruma<search_events::v3::Request>,
|
||||
) -> Result<search_events::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
/// - Only works if the user is currently joined to the room (TODO: Respect
|
||||
/// history visibility)
|
||||
pub async fn search_events_route(body: Ruma<search_events::v3::Request>) -> Result<search_events::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let search_criteria = body.search_categories.room_events.as_ref().unwrap();
|
||||
let filter = &search_criteria.filter;
|
||||
let search_criteria = body.search_categories.room_events.as_ref().unwrap();
|
||||
let filter = &search_criteria.filter;
|
||||
|
||||
let room_ids = filter.rooms.clone().unwrap_or_else(|| {
|
||||
services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.rooms_joined(sender_user)
|
||||
.filter_map(std::result::Result::ok)
|
||||
.collect()
|
||||
});
|
||||
let room_ids = filter.rooms.clone().unwrap_or_else(|| {
|
||||
services().rooms.state_cache.rooms_joined(sender_user).filter_map(std::result::Result::ok).collect()
|
||||
});
|
||||
|
||||
// Use limit or else 10, with maximum 100
|
||||
let limit = filter.limit.map_or(10, u64::from).min(100) as usize;
|
||||
// Use limit or else 10, with maximum 100
|
||||
let limit = filter.limit.map_or(10, u64::from).min(100) as usize;
|
||||
|
||||
let mut searches = Vec::new();
|
||||
let mut searches = Vec::new();
|
||||
|
||||
for room_id in room_ids {
|
||||
if !services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.is_joined(sender_user, &room_id)?
|
||||
{
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"You don't have permission to view this room.",
|
||||
));
|
||||
}
|
||||
for room_id in room_ids {
|
||||
if !services().rooms.state_cache.is_joined(sender_user, &room_id)? {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"You don't have permission to view this room.",
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(search) = services()
|
||||
.rooms
|
||||
.search
|
||||
.search_pdus(&room_id, &search_criteria.search_term)?
|
||||
{
|
||||
searches.push(search.0.peekable());
|
||||
}
|
||||
}
|
||||
if let Some(search) = services().rooms.search.search_pdus(&room_id, &search_criteria.search_term)? {
|
||||
searches.push(search.0.peekable());
|
||||
}
|
||||
}
|
||||
|
||||
let skip = match body.next_batch.as_ref().map(|s| s.parse()) {
|
||||
Some(Ok(s)) => s,
|
||||
Some(Err(_)) => {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Invalid next_batch token.",
|
||||
))
|
||||
}
|
||||
None => 0, // Default to the start
|
||||
};
|
||||
let skip = match body.next_batch.as_ref().map(|s| s.parse()) {
|
||||
Some(Ok(s)) => s,
|
||||
Some(Err(_)) => return Err(Error::BadRequest(ErrorKind::InvalidParam, "Invalid next_batch token.")),
|
||||
None => 0, // Default to the start
|
||||
};
|
||||
|
||||
let mut results = Vec::new();
|
||||
for _ in 0..skip + limit {
|
||||
if let Some(s) = searches
|
||||
.iter_mut()
|
||||
.map(|s| (s.peek().cloned(), s))
|
||||
.max_by_key(|(peek, _)| peek.clone())
|
||||
.and_then(|(_, i)| i.next())
|
||||
{
|
||||
results.push(s);
|
||||
}
|
||||
}
|
||||
let mut results = Vec::new();
|
||||
for _ in 0..skip + limit {
|
||||
if let Some(s) = searches
|
||||
.iter_mut()
|
||||
.map(|s| (s.peek().cloned(), s))
|
||||
.max_by_key(|(peek, _)| peek.clone())
|
||||
.and_then(|(_, i)| i.next())
|
||||
{
|
||||
results.push(s);
|
||||
}
|
||||
}
|
||||
|
||||
let results: Vec<_> = results
|
||||
.iter()
|
||||
.filter_map(|result| {
|
||||
services()
|
||||
.rooms
|
||||
.timeline
|
||||
.get_pdu_from_id(result)
|
||||
.ok()?
|
||||
.filter(|pdu| {
|
||||
services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.user_can_see_event(sender_user, &pdu.room_id, &pdu.event_id)
|
||||
.unwrap_or(false)
|
||||
})
|
||||
.map(|pdu| pdu.to_room_event())
|
||||
})
|
||||
.map(|result| {
|
||||
Ok::<_, Error>(SearchResult {
|
||||
context: EventContextResult {
|
||||
end: None,
|
||||
events_after: Vec::new(),
|
||||
events_before: Vec::new(),
|
||||
profile_info: BTreeMap::new(),
|
||||
start: None,
|
||||
},
|
||||
rank: None,
|
||||
result: Some(result),
|
||||
})
|
||||
})
|
||||
.filter_map(std::result::Result::ok)
|
||||
.skip(skip)
|
||||
.take(limit)
|
||||
.collect();
|
||||
let results: Vec<_> = results
|
||||
.iter()
|
||||
.filter_map(|result| {
|
||||
services()
|
||||
.rooms
|
||||
.timeline
|
||||
.get_pdu_from_id(result)
|
||||
.ok()?
|
||||
.filter(|pdu| {
|
||||
services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.user_can_see_event(sender_user, &pdu.room_id, &pdu.event_id)
|
||||
.unwrap_or(false)
|
||||
})
|
||||
.map(|pdu| pdu.to_room_event())
|
||||
})
|
||||
.map(|result| {
|
||||
Ok::<_, Error>(SearchResult {
|
||||
context: EventContextResult {
|
||||
end: None,
|
||||
events_after: Vec::new(),
|
||||
events_before: Vec::new(),
|
||||
profile_info: BTreeMap::new(),
|
||||
start: None,
|
||||
},
|
||||
rank: None,
|
||||
result: Some(result),
|
||||
})
|
||||
})
|
||||
.filter_map(std::result::Result::ok)
|
||||
.skip(skip)
|
||||
.take(limit)
|
||||
.collect();
|
||||
|
||||
let next_batch = if results.len() < limit {
|
||||
None
|
||||
} else {
|
||||
Some((skip + limit).to_string())
|
||||
};
|
||||
let next_batch = if results.len() < limit {
|
||||
None
|
||||
} else {
|
||||
Some((skip + limit).to_string())
|
||||
};
|
||||
|
||||
Ok(search_events::v3::Response::new(ResultCategories {
|
||||
room_events: ResultRoomEvents {
|
||||
count: Some((results.len() as u32).into()), // TODO: set this to none. Element shouldn't depend on it
|
||||
groups: BTreeMap::new(), // TODO
|
||||
next_batch,
|
||||
results,
|
||||
state: BTreeMap::new(), // TODO
|
||||
highlights: search_criteria
|
||||
.search_term
|
||||
.split_terminator(|c: char| !c.is_alphanumeric())
|
||||
.map(str::to_lowercase)
|
||||
.collect(),
|
||||
},
|
||||
}))
|
||||
Ok(search_events::v3::Response::new(ResultCategories {
|
||||
room_events: ResultRoomEvents {
|
||||
count: Some((results.len() as u32).into()), // TODO: set this to none. Element shouldn't depend on it
|
||||
groups: BTreeMap::new(), // TODO
|
||||
next_batch,
|
||||
results,
|
||||
state: BTreeMap::new(), // TODO
|
||||
highlights: search_criteria
|
||||
.search_term
|
||||
.split_terminator(|c: char| !c.is_alphanumeric())
|
||||
.map(str::to_lowercase)
|
||||
.collect(),
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
+199
-223
@@ -1,246 +1,221 @@
|
||||
use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH};
|
||||
use crate::{services, utils, Error, Result, Ruma};
|
||||
use argon2::{PasswordHash, PasswordVerifier};
|
||||
use ruma::{
|
||||
api::client::{
|
||||
error::ErrorKind,
|
||||
session::{
|
||||
get_login_types::{
|
||||
self,
|
||||
v3::{ApplicationServiceLoginType, PasswordLoginType},
|
||||
},
|
||||
login::{
|
||||
self,
|
||||
v3::{DiscoveryInfo, HomeserverInfo},
|
||||
},
|
||||
logout, logout_all,
|
||||
},
|
||||
uiaa::UserIdentifier,
|
||||
},
|
||||
UserId,
|
||||
api::client::{
|
||||
error::ErrorKind,
|
||||
session::{
|
||||
get_login_types::{
|
||||
self,
|
||||
v3::{ApplicationServiceLoginType, PasswordLoginType},
|
||||
},
|
||||
login::{
|
||||
self,
|
||||
v3::{DiscoveryInfo, HomeserverInfo},
|
||||
},
|
||||
logout, logout_all,
|
||||
},
|
||||
uiaa::UserIdentifier,
|
||||
},
|
||||
UserId,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH};
|
||||
use crate::{services, utils, Error, Result, Ruma};
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Claims {
|
||||
sub: String,
|
||||
//exp: usize,
|
||||
sub: String,
|
||||
//exp: usize,
|
||||
}
|
||||
|
||||
/// # `GET /_matrix/client/v3/login`
|
||||
///
|
||||
/// Get the supported login types of this server. One of these should be used as the `type` field
|
||||
/// when logging in.
|
||||
pub async fn get_login_types_route(
|
||||
_body: Ruma<get_login_types::v3::Request>,
|
||||
) -> Result<get_login_types::v3::Response> {
|
||||
Ok(get_login_types::v3::Response::new(vec![
|
||||
get_login_types::v3::LoginType::Password(PasswordLoginType::default()),
|
||||
get_login_types::v3::LoginType::ApplicationService(ApplicationServiceLoginType::default()),
|
||||
]))
|
||||
/// Get the supported login types of this server. One of these should be used as
|
||||
/// the `type` field when logging in.
|
||||
pub async fn get_login_types_route(_body: Ruma<get_login_types::v3::Request>) -> Result<get_login_types::v3::Response> {
|
||||
Ok(get_login_types::v3::Response::new(vec![
|
||||
get_login_types::v3::LoginType::Password(PasswordLoginType::default()),
|
||||
get_login_types::v3::LoginType::ApplicationService(ApplicationServiceLoginType::default()),
|
||||
]))
|
||||
}
|
||||
|
||||
/// # `POST /_matrix/client/v3/login`
|
||||
///
|
||||
/// Authenticates the user and returns an access token it can use in subsequent requests.
|
||||
/// Authenticates the user and returns an access token it can use in subsequent
|
||||
/// requests.
|
||||
///
|
||||
/// - The user needs to authenticate using their password (or if enabled using a json web token)
|
||||
/// - The user needs to authenticate using their password (or if enabled using a
|
||||
/// json web token)
|
||||
/// - If `device_id` is known: invalidates old access token of that device
|
||||
/// - If `device_id` is unknown: creates a new device
|
||||
/// - Returns access token that is associated with the user and device
|
||||
///
|
||||
/// Note: You can use [`GET /_matrix/client/r0/login`](fn.get_supported_versions_route.html) to see
|
||||
/// Note: You can use [`GET
|
||||
/// /_matrix/client/r0/login`](fn.get_supported_versions_route.html) to see
|
||||
/// supported login types.
|
||||
pub async fn login_route(body: Ruma<login::v3::Request>) -> Result<login::v3::Response> {
|
||||
// Validate login method
|
||||
// TODO: Other login methods
|
||||
let user_id = match &body.login_info {
|
||||
#[allow(deprecated)]
|
||||
login::v3::LoginInfo::Password(login::v3::Password {
|
||||
identifier,
|
||||
password,
|
||||
user,
|
||||
..
|
||||
}) => {
|
||||
debug!("Got password login type");
|
||||
let username = if let Some(UserIdentifier::UserIdOrLocalpart(user_id)) = identifier {
|
||||
debug!("Using username from identifier field");
|
||||
user_id.to_lowercase()
|
||||
} else if let Some(user_id) = user {
|
||||
warn!("User \"{}\" is attempting to login with the deprecated \"user\" field at \"/_matrix/client/v3/login\". conduwuit implements this deprecated behaviour, but this is destined to be removed in a future Matrix release.", user_id);
|
||||
user_id.to_lowercase()
|
||||
} else {
|
||||
warn!("Bad login type: {:?}", &body.login_info);
|
||||
return Err(Error::BadRequest(ErrorKind::Forbidden, "Bad login type."));
|
||||
};
|
||||
// Validate login method
|
||||
// TODO: Other login methods
|
||||
let user_id = match &body.login_info {
|
||||
#[allow(deprecated)]
|
||||
login::v3::LoginInfo::Password(login::v3::Password {
|
||||
identifier,
|
||||
password,
|
||||
user,
|
||||
..
|
||||
}) => {
|
||||
debug!("Got password login type");
|
||||
let username = if let Some(UserIdentifier::UserIdOrLocalpart(user_id)) = identifier {
|
||||
debug!("Using username from identifier field");
|
||||
user_id.to_lowercase()
|
||||
} else if let Some(user_id) = user {
|
||||
warn!(
|
||||
"User \"{}\" is attempting to login with the deprecated \"user\" field at \
|
||||
\"/_matrix/client/v3/login\". conduwuit implements this deprecated behaviour, but this is \
|
||||
destined to be removed in a future Matrix release.",
|
||||
user_id
|
||||
);
|
||||
user_id.to_lowercase()
|
||||
} else {
|
||||
warn!("Bad login type: {:?}", &body.login_info);
|
||||
return Err(Error::BadRequest(ErrorKind::Forbidden, "Bad login type."));
|
||||
};
|
||||
|
||||
let user_id =
|
||||
UserId::parse_with_server_name(username, services().globals.server_name())
|
||||
.map_err(|e| {
|
||||
warn!("Failed to parse username from user logging in: {}", e);
|
||||
Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.")
|
||||
})?;
|
||||
let user_id = UserId::parse_with_server_name(username, services().globals.server_name()).map_err(|e| {
|
||||
warn!("Failed to parse username from user logging in: {}", e);
|
||||
Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.")
|
||||
})?;
|
||||
|
||||
let hash = services()
|
||||
.users
|
||||
.password_hash(&user_id)?
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"Wrong username or password.",
|
||||
))?;
|
||||
let hash = services()
|
||||
.users
|
||||
.password_hash(&user_id)?
|
||||
.ok_or(Error::BadRequest(ErrorKind::Forbidden, "Wrong username or password."))?;
|
||||
|
||||
if hash.is_empty() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::UserDeactivated,
|
||||
"The user has been deactivated",
|
||||
));
|
||||
}
|
||||
if hash.is_empty() {
|
||||
return Err(Error::BadRequest(ErrorKind::UserDeactivated, "The user has been deactivated"));
|
||||
}
|
||||
|
||||
let Ok(parsed_hash) = PasswordHash::new(&hash) else {
|
||||
error!("error while hashing user {}", user_id);
|
||||
return Err(Error::BadServerResponse("could not hash"));
|
||||
};
|
||||
let Ok(parsed_hash) = PasswordHash::new(&hash) else {
|
||||
error!("error while hashing user {}", user_id);
|
||||
return Err(Error::BadServerResponse("could not hash"));
|
||||
};
|
||||
|
||||
let hash_matches = services()
|
||||
.globals
|
||||
.argon
|
||||
.verify_password(password.as_bytes(), &parsed_hash)
|
||||
.is_ok();
|
||||
let hash_matches = services().globals.argon.verify_password(password.as_bytes(), &parsed_hash).is_ok();
|
||||
|
||||
if !hash_matches {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"Wrong username or password.",
|
||||
));
|
||||
}
|
||||
if !hash_matches {
|
||||
return Err(Error::BadRequest(ErrorKind::Forbidden, "Wrong username or password."));
|
||||
}
|
||||
|
||||
user_id
|
||||
}
|
||||
login::v3::LoginInfo::Token(login::v3::Token { token }) => {
|
||||
debug!("Got token login type");
|
||||
if let Some(jwt_decoding_key) = services().globals.jwt_decoding_key() {
|
||||
let token = jsonwebtoken::decode::<Claims>(
|
||||
token,
|
||||
jwt_decoding_key,
|
||||
&jsonwebtoken::Validation::default(),
|
||||
)
|
||||
.map_err(|e| {
|
||||
warn!("Failed to parse JWT token from user logging in: {}", e);
|
||||
Error::BadRequest(ErrorKind::InvalidUsername, "Token is invalid.")
|
||||
})?;
|
||||
user_id
|
||||
},
|
||||
login::v3::LoginInfo::Token(login::v3::Token {
|
||||
token,
|
||||
}) => {
|
||||
debug!("Got token login type");
|
||||
if let Some(jwt_decoding_key) = services().globals.jwt_decoding_key() {
|
||||
let token =
|
||||
jsonwebtoken::decode::<Claims>(token, jwt_decoding_key, &jsonwebtoken::Validation::default())
|
||||
.map_err(|e| {
|
||||
warn!("Failed to parse JWT token from user logging in: {}", e);
|
||||
Error::BadRequest(ErrorKind::InvalidUsername, "Token is invalid.")
|
||||
})?;
|
||||
|
||||
let username = token.claims.sub.to_lowercase();
|
||||
let username = token.claims.sub.to_lowercase();
|
||||
|
||||
UserId::parse_with_server_name(username, services().globals.server_name()).map_err(
|
||||
|e| {
|
||||
warn!("Failed to parse username from user logging in: {}", e);
|
||||
Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.")
|
||||
},
|
||||
)?
|
||||
} else {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Unknown,
|
||||
"Token login is not supported (server has no jwt decoding key).",
|
||||
));
|
||||
}
|
||||
}
|
||||
#[allow(deprecated)]
|
||||
login::v3::LoginInfo::ApplicationService(login::v3::ApplicationService {
|
||||
identifier,
|
||||
user,
|
||||
}) => {
|
||||
debug!("Got appservice login type");
|
||||
if !body.from_appservice {
|
||||
info!("User tried logging in as an appservice, but request body is not from a known/registered appservice");
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"Forbidden login type.",
|
||||
));
|
||||
};
|
||||
let username = if let Some(UserIdentifier::UserIdOrLocalpart(user_id)) = identifier {
|
||||
user_id.to_lowercase()
|
||||
} else if let Some(user_id) = user {
|
||||
warn!("Appservice \"{}\" is attempting to login with the deprecated \"user\" field at \"/_matrix/client/v3/login\". conduwuit implements this deprecated behaviour, but this is destined to be removed in a future Matrix release.", user_id);
|
||||
user_id.to_lowercase()
|
||||
} else {
|
||||
return Err(Error::BadRequest(ErrorKind::Forbidden, "Bad login type."));
|
||||
};
|
||||
UserId::parse_with_server_name(username, services().globals.server_name()).map_err(|e| {
|
||||
warn!("Failed to parse username from user logging in: {}", e);
|
||||
Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.")
|
||||
})?
|
||||
} else {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Unknown,
|
||||
"Token login is not supported (server has no jwt decoding key).",
|
||||
));
|
||||
}
|
||||
},
|
||||
#[allow(deprecated)]
|
||||
login::v3::LoginInfo::ApplicationService(login::v3::ApplicationService {
|
||||
identifier,
|
||||
user,
|
||||
}) => {
|
||||
debug!("Got appservice login type");
|
||||
if !body.from_appservice {
|
||||
info!(
|
||||
"User tried logging in as an appservice, but request body is not from a known/registered \
|
||||
appservice"
|
||||
);
|
||||
return Err(Error::BadRequest(ErrorKind::Forbidden, "Forbidden login type."));
|
||||
};
|
||||
let username = if let Some(UserIdentifier::UserIdOrLocalpart(user_id)) = identifier {
|
||||
user_id.to_lowercase()
|
||||
} else if let Some(user_id) = user {
|
||||
warn!(
|
||||
"Appservice \"{}\" is attempting to login with the deprecated \"user\" field at \
|
||||
\"/_matrix/client/v3/login\". conduwuit implements this deprecated behaviour, but this is \
|
||||
destined to be removed in a future Matrix release.",
|
||||
user_id
|
||||
);
|
||||
user_id.to_lowercase()
|
||||
} else {
|
||||
return Err(Error::BadRequest(ErrorKind::Forbidden, "Bad login type."));
|
||||
};
|
||||
|
||||
UserId::parse_with_server_name(username, services().globals.server_name()).map_err(
|
||||
|e| {
|
||||
warn!("Failed to parse username from appservice logging in: {}", e);
|
||||
Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.")
|
||||
},
|
||||
)?
|
||||
}
|
||||
_ => {
|
||||
warn!("Unsupported or unknown login type: {:?}", &body.login_info);
|
||||
debug!("JSON body: {:?}", &body.json_body);
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Unknown,
|
||||
"Unsupported or unknown login type.",
|
||||
));
|
||||
}
|
||||
};
|
||||
UserId::parse_with_server_name(username, services().globals.server_name()).map_err(|e| {
|
||||
warn!("Failed to parse username from appservice logging in: {}", e);
|
||||
Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.")
|
||||
})?
|
||||
},
|
||||
_ => {
|
||||
warn!("Unsupported or unknown login type: {:?}", &body.login_info);
|
||||
debug!("JSON body: {:?}", &body.json_body);
|
||||
return Err(Error::BadRequest(ErrorKind::Unknown, "Unsupported or unknown login type."));
|
||||
},
|
||||
};
|
||||
|
||||
// Generate new device id if the user didn't specify one
|
||||
let device_id = body
|
||||
.device_id
|
||||
.clone()
|
||||
.unwrap_or_else(|| utils::random_string(DEVICE_ID_LENGTH).into());
|
||||
// Generate new device id if the user didn't specify one
|
||||
let device_id = body.device_id.clone().unwrap_or_else(|| utils::random_string(DEVICE_ID_LENGTH).into());
|
||||
|
||||
// Generate a new token for the device
|
||||
let token = utils::random_string(TOKEN_LENGTH);
|
||||
// Generate a new token for the device
|
||||
let token = utils::random_string(TOKEN_LENGTH);
|
||||
|
||||
// Determine if device_id was provided and exists in the db for this user
|
||||
let device_exists = body.device_id.as_ref().map_or(false, |device_id| {
|
||||
services()
|
||||
.users
|
||||
.all_device_ids(&user_id)
|
||||
.any(|x| x.as_ref().map_or(false, |v| v == device_id))
|
||||
});
|
||||
// Determine if device_id was provided and exists in the db for this user
|
||||
let device_exists = body.device_id.as_ref().map_or(false, |device_id| {
|
||||
services().users.all_device_ids(&user_id).any(|x| x.as_ref().map_or(false, |v| v == device_id))
|
||||
});
|
||||
|
||||
if device_exists {
|
||||
services().users.set_token(&user_id, &device_id, &token)?;
|
||||
} else {
|
||||
services().users.create_device(
|
||||
&user_id,
|
||||
&device_id,
|
||||
&token,
|
||||
body.initial_device_display_name.clone(),
|
||||
)?;
|
||||
}
|
||||
if device_exists {
|
||||
services().users.set_token(&user_id, &device_id, &token)?;
|
||||
} else {
|
||||
services().users.create_device(&user_id, &device_id, &token, body.initial_device_display_name.clone())?;
|
||||
}
|
||||
|
||||
// send client well-known if specified so the client knows to reconfigure itself
|
||||
let client_discovery_info = DiscoveryInfo::new(HomeserverInfo::new(
|
||||
services()
|
||||
.globals
|
||||
.well_known_client()
|
||||
.to_owned()
|
||||
.unwrap_or_else(|| "".to_owned()),
|
||||
));
|
||||
// send client well-known if specified so the client knows to reconfigure itself
|
||||
let client_discovery_info = DiscoveryInfo::new(HomeserverInfo::new(
|
||||
services().globals.well_known_client().to_owned().unwrap_or_else(|| "".to_owned()),
|
||||
));
|
||||
|
||||
info!("{} logged in", user_id);
|
||||
info!("{} logged in", user_id);
|
||||
|
||||
// home_server is deprecated but apparently must still be sent despite it being deprecated over 6 years ago.
|
||||
// initially i thought this macro was unnecessary, but ruma uses this same macro for the same reason so...
|
||||
#[allow(deprecated)]
|
||||
Ok(login::v3::Response {
|
||||
user_id,
|
||||
access_token: token,
|
||||
device_id,
|
||||
well_known: {
|
||||
if client_discovery_info.homeserver.base_url.as_str() == "" {
|
||||
None
|
||||
} else {
|
||||
Some(client_discovery_info)
|
||||
}
|
||||
},
|
||||
expires_in: None,
|
||||
home_server: Some(services().globals.server_name().to_owned()),
|
||||
refresh_token: None,
|
||||
})
|
||||
// home_server is deprecated but apparently must still be sent despite it being
|
||||
// deprecated over 6 years ago. initially i thought this macro was unnecessary,
|
||||
// but ruma uses this same macro for the same reason so...
|
||||
#[allow(deprecated)]
|
||||
Ok(login::v3::Response {
|
||||
user_id,
|
||||
access_token: token,
|
||||
device_id,
|
||||
well_known: {
|
||||
if client_discovery_info.homeserver.base_url.as_str() == "" {
|
||||
None
|
||||
} else {
|
||||
Some(client_discovery_info)
|
||||
}
|
||||
},
|
||||
expires_in: None,
|
||||
home_server: Some(services().globals.server_name().to_owned()),
|
||||
refresh_token: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// # `POST /_matrix/client/v3/logout`
|
||||
@@ -248,19 +223,20 @@ pub async fn login_route(body: Ruma<login::v3::Request>) -> Result<login::v3::Re
|
||||
/// Log out the current device.
|
||||
///
|
||||
/// - Invalidates access token
|
||||
/// - Deletes device metadata (device id, device display name, last seen ip, last seen ts)
|
||||
/// - Deletes device metadata (device id, device display name, last seen ip,
|
||||
/// last seen ts)
|
||||
/// - Forgets to-device events
|
||||
/// - Triggers device list updates
|
||||
pub async fn logout_route(body: Ruma<logout::v3::Request>) -> Result<logout::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
|
||||
|
||||
services().users.remove_device(sender_user, sender_device)?;
|
||||
services().users.remove_device(sender_user, sender_device)?;
|
||||
|
||||
// send device list update for user after logout
|
||||
services().users.mark_device_key_update(sender_user)?;
|
||||
// send device list update for user after logout
|
||||
services().users.mark_device_key_update(sender_user)?;
|
||||
|
||||
Ok(logout::v3::Response::new())
|
||||
Ok(logout::v3::Response::new())
|
||||
}
|
||||
|
||||
/// # `POST /_matrix/client/r0/logout/all`
|
||||
@@ -268,23 +244,23 @@ pub async fn logout_route(body: Ruma<logout::v3::Request>) -> Result<logout::v3:
|
||||
/// Log out all devices of this user.
|
||||
///
|
||||
/// - Invalidates all access tokens
|
||||
/// - Deletes all device metadata (device id, device display name, last seen ip, last seen ts)
|
||||
/// - Deletes all device metadata (device id, device display name, last seen ip,
|
||||
/// last seen ts)
|
||||
/// - Forgets all to-device events
|
||||
/// - Triggers device list updates
|
||||
///
|
||||
/// Note: This is equivalent to calling [`GET /_matrix/client/r0/logout`](fn.logout_route.html)
|
||||
/// from each device of this user.
|
||||
pub async fn logout_all_route(
|
||||
body: Ruma<logout_all::v3::Request>,
|
||||
) -> Result<logout_all::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
/// Note: This is equivalent to calling [`GET
|
||||
/// /_matrix/client/r0/logout`](fn.logout_route.html) from each device of this
|
||||
/// user.
|
||||
pub async fn logout_all_route(body: Ruma<logout_all::v3::Request>) -> Result<logout_all::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
for device_id in services().users.all_device_ids(sender_user).flatten() {
|
||||
services().users.remove_device(sender_user, &device_id)?;
|
||||
}
|
||||
for device_id in services().users.all_device_ids(sender_user).flatten() {
|
||||
services().users.remove_device(sender_user, &device_id)?;
|
||||
}
|
||||
|
||||
// send device list update for user after logout
|
||||
services().users.mark_device_key_update(sender_user)?;
|
||||
// send device list update for user after logout
|
||||
services().users.mark_device_key_update(sender_user)?;
|
||||
|
||||
Ok(logout_all::v3::Response::new())
|
||||
Ok(logout_all::v3::Response::new())
|
||||
}
|
||||
|
||||
@@ -1,34 +1,19 @@
|
||||
use crate::{services, Result, Ruma};
|
||||
use ruma::api::client::space::get_hierarchy;
|
||||
|
||||
use crate::{services, Result, Ruma};
|
||||
|
||||
/// # `GET /_matrix/client/v1/rooms/{room_id}/hierarchy``
|
||||
///
|
||||
/// Paginates over the space tree in a depth-first manner to locate child rooms of a given space.
|
||||
pub async fn get_hierarchy_route(
|
||||
body: Ruma<get_hierarchy::v1::Request>,
|
||||
) -> Result<get_hierarchy::v1::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
/// Paginates over the space tree in a depth-first manner to locate child rooms
|
||||
/// of a given space.
|
||||
pub async fn get_hierarchy_route(body: Ruma<get_hierarchy::v1::Request>) -> Result<get_hierarchy::v1::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let skip = body
|
||||
.from
|
||||
.as_ref()
|
||||
.and_then(|s| s.parse::<usize>().ok())
|
||||
.unwrap_or(0);
|
||||
let skip = body.from.as_ref().and_then(|s| s.parse::<usize>().ok()).unwrap_or(0);
|
||||
|
||||
let limit = body.limit.map_or(10, u64::from).min(100) as usize;
|
||||
let limit = body.limit.map_or(10, u64::from).min(100) as usize;
|
||||
|
||||
let max_depth = body.max_depth.map_or(3, u64::from).min(10) as usize + 1; // +1 to skip the space room itself
|
||||
let max_depth = body.max_depth.map_or(3, u64::from).min(10) as usize + 1; // +1 to skip the space room itself
|
||||
|
||||
services()
|
||||
.rooms
|
||||
.spaces
|
||||
.get_hierarchy(
|
||||
sender_user,
|
||||
&body.room_id,
|
||||
limit,
|
||||
skip,
|
||||
max_depth,
|
||||
body.suggested_only,
|
||||
)
|
||||
.await
|
||||
services().rooms.spaces.get_hierarchy(sender_user, &body.room_id, limit, skip, max_depth, body.suggested_only).await
|
||||
}
|
||||
|
||||
+183
-221
@@ -1,42 +1,44 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{service::pdu::PduBuilder, services, Error, Result, Ruma, RumaResponse};
|
||||
use ruma::{
|
||||
api::client::{
|
||||
error::ErrorKind,
|
||||
state::{get_state_events, get_state_events_for_key, send_state_event},
|
||||
},
|
||||
events::{
|
||||
room::canonical_alias::RoomCanonicalAliasEventContent, AnyStateEventContent, StateEventType,
|
||||
},
|
||||
serde::Raw,
|
||||
EventId, RoomId, UserId,
|
||||
api::client::{
|
||||
error::ErrorKind,
|
||||
state::{get_state_events, get_state_events_for_key, send_state_event},
|
||||
},
|
||||
events::{room::canonical_alias::RoomCanonicalAliasEventContent, AnyStateEventContent, StateEventType},
|
||||
serde::Raw,
|
||||
EventId, RoomId, UserId,
|
||||
};
|
||||
use tracing::{error, log::warn};
|
||||
|
||||
use crate::{service::pdu::PduBuilder, services, Error, Result, Ruma, RumaResponse};
|
||||
|
||||
/// # `PUT /_matrix/client/r0/rooms/{roomId}/state/{eventType}/{stateKey}`
|
||||
///
|
||||
/// Sends a state event into the room.
|
||||
///
|
||||
/// - The only requirement for the content is that it has to be valid json
|
||||
/// - Tries to send the event into the room, auth rules will determine if it is allowed
|
||||
/// - Tries to send the event into the room, auth rules will determine if it is
|
||||
/// allowed
|
||||
/// - If event is new canonical_alias: Rejects if alias is incorrect
|
||||
pub async fn send_state_event_for_key_route(
|
||||
body: Ruma<send_state_event::v3::Request>,
|
||||
body: Ruma<send_state_event::v3::Request>,
|
||||
) -> Result<send_state_event::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let event_id = send_state_event_for_key_helper(
|
||||
sender_user,
|
||||
&body.room_id,
|
||||
&body.event_type,
|
||||
&body.body.body, // Yes, I hate it too
|
||||
body.state_key.clone(),
|
||||
)
|
||||
.await?;
|
||||
let event_id = send_state_event_for_key_helper(
|
||||
sender_user,
|
||||
&body.room_id,
|
||||
&body.event_type,
|
||||
&body.body.body, // Yes, I hate it too
|
||||
body.state_key.clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let event_id = (*event_id).to_owned();
|
||||
Ok(send_state_event::v3::Response { event_id })
|
||||
let event_id = (*event_id).to_owned();
|
||||
Ok(send_state_event::v3::Response {
|
||||
event_id,
|
||||
})
|
||||
}
|
||||
|
||||
/// # `PUT /_matrix/client/r0/rooms/{roomId}/state/{eventType}`
|
||||
@@ -44,249 +46,209 @@ pub async fn send_state_event_for_key_route(
|
||||
/// Sends a state event into the room.
|
||||
///
|
||||
/// - The only requirement for the content is that it has to be valid json
|
||||
/// - Tries to send the event into the room, auth rules will determine if it is allowed
|
||||
/// - Tries to send the event into the room, auth rules will determine if it is
|
||||
/// allowed
|
||||
/// - If event is new canonical_alias: Rejects if alias is incorrect
|
||||
pub async fn send_state_event_for_empty_key_route(
|
||||
body: Ruma<send_state_event::v3::Request>,
|
||||
body: Ruma<send_state_event::v3::Request>,
|
||||
) -> Result<RumaResponse<send_state_event::v3::Response>> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
// Forbid m.room.encryption if encryption is disabled
|
||||
if body.event_type == StateEventType::RoomEncryption && !services().globals.allow_encryption() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"Encryption has been disabled",
|
||||
));
|
||||
}
|
||||
// Forbid m.room.encryption if encryption is disabled
|
||||
if body.event_type == StateEventType::RoomEncryption && !services().globals.allow_encryption() {
|
||||
return Err(Error::BadRequest(ErrorKind::Forbidden, "Encryption has been disabled"));
|
||||
}
|
||||
|
||||
let event_id = send_state_event_for_key_helper(
|
||||
sender_user,
|
||||
&body.room_id,
|
||||
&body.event_type.to_string().into(),
|
||||
&body.body.body,
|
||||
body.state_key.clone(),
|
||||
)
|
||||
.await?;
|
||||
let event_id = send_state_event_for_key_helper(
|
||||
sender_user,
|
||||
&body.room_id,
|
||||
&body.event_type.to_string().into(),
|
||||
&body.body.body,
|
||||
body.state_key.clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let event_id = (*event_id).to_owned();
|
||||
Ok(send_state_event::v3::Response { event_id }.into())
|
||||
let event_id = (*event_id).to_owned();
|
||||
Ok(send_state_event::v3::Response {
|
||||
event_id,
|
||||
}
|
||||
.into())
|
||||
}
|
||||
|
||||
/// # `GET /_matrix/client/r0/rooms/{roomid}/state`
|
||||
///
|
||||
/// Get all state events for a room.
|
||||
///
|
||||
/// - If not joined: Only works if current room history visibility is world readable
|
||||
/// - If not joined: Only works if current room history visibility is world
|
||||
/// readable
|
||||
pub async fn get_state_events_route(
|
||||
body: Ruma<get_state_events::v3::Request>,
|
||||
body: Ruma<get_state_events::v3::Request>,
|
||||
) -> Result<get_state_events::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
if !services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.user_can_see_state_events(sender_user, &body.room_id)?
|
||||
{
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"You don't have permission to view the room state.",
|
||||
));
|
||||
}
|
||||
if !services().rooms.state_accessor.user_can_see_state_events(sender_user, &body.room_id)? {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"You don't have permission to view the room state.",
|
||||
));
|
||||
}
|
||||
|
||||
Ok(get_state_events::v3::Response {
|
||||
room_state: services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.room_state_full(&body.room_id)
|
||||
.await?
|
||||
.values()
|
||||
.map(|pdu| pdu.to_state_event())
|
||||
.collect(),
|
||||
})
|
||||
Ok(get_state_events::v3::Response {
|
||||
room_state: services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.room_state_full(&body.room_id)
|
||||
.await?
|
||||
.values()
|
||||
.map(|pdu| pdu.to_state_event())
|
||||
.collect(),
|
||||
})
|
||||
}
|
||||
|
||||
/// # `GET /_matrix/client/v3/rooms/{roomid}/state/{eventType}/{stateKey}`
|
||||
///
|
||||
/// Get single state event of a room with the specified state key.
|
||||
/// The optional query parameter `?format=event|content` allows returning the full room state event
|
||||
/// or just the state event's content (default behaviour)
|
||||
/// The optional query parameter `?format=event|content` allows returning the
|
||||
/// full room state event or just the state event's content (default behaviour)
|
||||
///
|
||||
/// - If not joined: Only works if current room history visibility is world readable
|
||||
/// - If not joined: Only works if current room history visibility is world
|
||||
/// readable
|
||||
pub async fn get_state_events_for_key_route(
|
||||
body: Ruma<get_state_events_for_key::v3::Request>,
|
||||
body: Ruma<get_state_events_for_key::v3::Request>,
|
||||
) -> Result<get_state_events_for_key::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
if !services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.user_can_see_state_events(sender_user, &body.room_id)?
|
||||
{
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"You don't have permission to view the room state.",
|
||||
));
|
||||
}
|
||||
if !services().rooms.state_accessor.user_can_see_state_events(sender_user, &body.room_id)? {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"You don't have permission to view the room state.",
|
||||
));
|
||||
}
|
||||
|
||||
let event = services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.room_state_get(&body.room_id, &body.event_type, &body.state_key)?
|
||||
.ok_or_else(|| {
|
||||
warn!(
|
||||
"State event {:?} not found in room {:?}",
|
||||
&body.event_type, &body.room_id
|
||||
);
|
||||
Error::BadRequest(ErrorKind::NotFound, "State event not found.")
|
||||
})?;
|
||||
if body
|
||||
.format
|
||||
.as_ref()
|
||||
.is_some_and(|f| f.to_lowercase().eq("event"))
|
||||
{
|
||||
Ok(get_state_events_for_key::v3::Response {
|
||||
content: None,
|
||||
event: serde_json::from_str(event.to_state_event().json().get()).map_err(|e| {
|
||||
error!("Invalid room state event in database: {}", e);
|
||||
Error::bad_database("Invalid room state event in database")
|
||||
})?,
|
||||
})
|
||||
} else {
|
||||
Ok(get_state_events_for_key::v3::Response {
|
||||
content: Some(serde_json::from_str(event.content.get()).map_err(|e| {
|
||||
error!("Invalid room state event content in database: {}", e);
|
||||
Error::bad_database("Invalid room state event content in database")
|
||||
})?),
|
||||
event: None,
|
||||
})
|
||||
}
|
||||
let event =
|
||||
services().rooms.state_accessor.room_state_get(&body.room_id, &body.event_type, &body.state_key)?.ok_or_else(
|
||||
|| {
|
||||
warn!("State event {:?} not found in room {:?}", &body.event_type, &body.room_id);
|
||||
Error::BadRequest(ErrorKind::NotFound, "State event not found.")
|
||||
},
|
||||
)?;
|
||||
if body.format.as_ref().is_some_and(|f| f.to_lowercase().eq("event")) {
|
||||
Ok(get_state_events_for_key::v3::Response {
|
||||
content: None,
|
||||
event: serde_json::from_str(event.to_state_event().json().get()).map_err(|e| {
|
||||
error!("Invalid room state event in database: {}", e);
|
||||
Error::bad_database("Invalid room state event in database")
|
||||
})?,
|
||||
})
|
||||
} else {
|
||||
Ok(get_state_events_for_key::v3::Response {
|
||||
content: Some(serde_json::from_str(event.content.get()).map_err(|e| {
|
||||
error!("Invalid room state event content in database: {}", e);
|
||||
Error::bad_database("Invalid room state event content in database")
|
||||
})?),
|
||||
event: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// # `GET /_matrix/client/v3/rooms/{roomid}/state/{eventType}`
|
||||
///
|
||||
/// Get single state event of a room.
|
||||
/// The optional query parameter `?format=event|content` allows returning the full room state event
|
||||
/// or just the state event's content (default behaviour)
|
||||
/// The optional query parameter `?format=event|content` allows returning the
|
||||
/// full room state event or just the state event's content (default behaviour)
|
||||
///
|
||||
/// - If not joined: Only works if current room history visibility is world readable
|
||||
/// - If not joined: Only works if current room history visibility is world
|
||||
/// readable
|
||||
pub async fn get_state_events_for_empty_key_route(
|
||||
body: Ruma<get_state_events_for_key::v3::Request>,
|
||||
body: Ruma<get_state_events_for_key::v3::Request>,
|
||||
) -> Result<RumaResponse<get_state_events_for_key::v3::Response>> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
if !services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.user_can_see_state_events(sender_user, &body.room_id)?
|
||||
{
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"You don't have permission to view the room state.",
|
||||
));
|
||||
}
|
||||
if !services().rooms.state_accessor.user_can_see_state_events(sender_user, &body.room_id)? {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"You don't have permission to view the room state.",
|
||||
));
|
||||
}
|
||||
|
||||
let event = services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.room_state_get(&body.room_id, &body.event_type, "")?
|
||||
.ok_or_else(|| {
|
||||
warn!(
|
||||
"State event {:?} not found in room {:?}",
|
||||
&body.event_type, &body.room_id
|
||||
);
|
||||
Error::BadRequest(ErrorKind::NotFound, "State event not found.")
|
||||
})?;
|
||||
let event =
|
||||
services().rooms.state_accessor.room_state_get(&body.room_id, &body.event_type, "")?.ok_or_else(|| {
|
||||
warn!("State event {:?} not found in room {:?}", &body.event_type, &body.room_id);
|
||||
Error::BadRequest(ErrorKind::NotFound, "State event not found.")
|
||||
})?;
|
||||
|
||||
if body
|
||||
.format
|
||||
.as_ref()
|
||||
.is_some_and(|f| f.to_lowercase().eq("event"))
|
||||
{
|
||||
Ok(get_state_events_for_key::v3::Response {
|
||||
content: None,
|
||||
event: serde_json::from_str(event.to_state_event().json().get()).map_err(|e| {
|
||||
error!("Invalid room state event in database: {}", e);
|
||||
Error::bad_database("Invalid room state event in database")
|
||||
})?,
|
||||
}
|
||||
.into())
|
||||
} else {
|
||||
Ok(get_state_events_for_key::v3::Response {
|
||||
content: Some(serde_json::from_str(event.content.get()).map_err(|e| {
|
||||
error!("Invalid room state event content in database: {}", e);
|
||||
Error::bad_database("Invalid room state event content in database")
|
||||
})?),
|
||||
event: None,
|
||||
}
|
||||
.into())
|
||||
}
|
||||
if body.format.as_ref().is_some_and(|f| f.to_lowercase().eq("event")) {
|
||||
Ok(get_state_events_for_key::v3::Response {
|
||||
content: None,
|
||||
event: serde_json::from_str(event.to_state_event().json().get()).map_err(|e| {
|
||||
error!("Invalid room state event in database: {}", e);
|
||||
Error::bad_database("Invalid room state event in database")
|
||||
})?,
|
||||
}
|
||||
.into())
|
||||
} else {
|
||||
Ok(get_state_events_for_key::v3::Response {
|
||||
content: Some(serde_json::from_str(event.content.get()).map_err(|e| {
|
||||
error!("Invalid room state event content in database: {}", e);
|
||||
Error::bad_database("Invalid room state event content in database")
|
||||
})?),
|
||||
event: None,
|
||||
}
|
||||
.into())
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_state_event_for_key_helper(
|
||||
sender: &UserId,
|
||||
room_id: &RoomId,
|
||||
event_type: &StateEventType,
|
||||
json: &Raw<AnyStateEventContent>,
|
||||
state_key: String,
|
||||
sender: &UserId, room_id: &RoomId, event_type: &StateEventType, json: &Raw<AnyStateEventContent>, state_key: String,
|
||||
) -> Result<Arc<EventId>> {
|
||||
let sender_user = sender;
|
||||
let sender_user = sender;
|
||||
|
||||
// TODO: Review this check, error if event is unparsable, use event type, allow alias if it
|
||||
// previously existed
|
||||
if let Ok(canonical_alias) =
|
||||
serde_json::from_str::<RoomCanonicalAliasEventContent>(json.json().get())
|
||||
{
|
||||
let mut aliases = canonical_alias.alt_aliases.clone();
|
||||
// TODO: Review this check, error if event is unparsable, use event type, allow
|
||||
// alias if it previously existed
|
||||
if let Ok(canonical_alias) = serde_json::from_str::<RoomCanonicalAliasEventContent>(json.json().get()) {
|
||||
let mut aliases = canonical_alias.alt_aliases.clone();
|
||||
|
||||
if let Some(alias) = canonical_alias.alias {
|
||||
aliases.push(alias);
|
||||
}
|
||||
if let Some(alias) = canonical_alias.alias {
|
||||
aliases.push(alias);
|
||||
}
|
||||
|
||||
for alias in aliases {
|
||||
if alias.server_name() != services().globals.server_name()
|
||||
|| services()
|
||||
.rooms
|
||||
.alias
|
||||
.resolve_local_alias(&alias)?
|
||||
.filter(|room| room == room_id) // Make sure it's the right room
|
||||
.is_none()
|
||||
{
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"You are only allowed to send canonical_alias \
|
||||
events when it's aliases already exists",
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
for alias in aliases {
|
||||
if alias.server_name() != services().globals.server_name()
|
||||
|| services()
|
||||
.rooms
|
||||
.alias
|
||||
.resolve_local_alias(&alias)?
|
||||
.filter(|room| room == room_id) // Make sure it's the right room
|
||||
.is_none()
|
||||
{
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"You are only allowed to send canonical_alias events when it's aliases already exists",
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mutex_state = Arc::clone(
|
||||
services()
|
||||
.globals
|
||||
.roomid_mutex_state
|
||||
.write()
|
||||
.unwrap()
|
||||
.entry(room_id.to_owned())
|
||||
.or_default(),
|
||||
);
|
||||
let state_lock = mutex_state.lock().await;
|
||||
let mutex_state =
|
||||
Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(room_id.to_owned()).or_default());
|
||||
let state_lock = mutex_state.lock().await;
|
||||
|
||||
let event_id = services()
|
||||
.rooms
|
||||
.timeline
|
||||
.build_and_append_pdu(
|
||||
PduBuilder {
|
||||
event_type: event_type.to_string().into(),
|
||||
content: serde_json::from_str(json.json().get()).expect("content is valid json"),
|
||||
unsigned: None,
|
||||
state_key: Some(state_key),
|
||||
redacts: None,
|
||||
},
|
||||
sender_user,
|
||||
room_id,
|
||||
&state_lock,
|
||||
)
|
||||
.await?;
|
||||
let event_id = services()
|
||||
.rooms
|
||||
.timeline
|
||||
.build_and_append_pdu(
|
||||
PduBuilder {
|
||||
event_type: event_type.to_string().into(),
|
||||
content: serde_json::from_str(json.json().get()).expect("content is valid json"),
|
||||
unsigned: None,
|
||||
state_key: Some(state_key),
|
||||
redacts: None,
|
||||
},
|
||||
sender_user,
|
||||
room_id,
|
||||
&state_lock,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(event_id)
|
||||
Ok(event_id)
|
||||
}
|
||||
|
||||
+1223
-1555
File diff suppressed because it is too large
Load Diff
@@ -1,55 +1,45 @@
|
||||
use crate::{services, Error, Result, Ruma};
|
||||
use ruma::{
|
||||
api::client::tag::{create_tag, delete_tag, get_tags},
|
||||
events::{
|
||||
tag::{TagEvent, TagEventContent},
|
||||
RoomAccountDataEventType,
|
||||
},
|
||||
};
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
use ruma::{
|
||||
api::client::tag::{create_tag, delete_tag, get_tags},
|
||||
events::{
|
||||
tag::{TagEvent, TagEventContent},
|
||||
RoomAccountDataEventType,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::{services, Error, Result, Ruma};
|
||||
|
||||
/// # `PUT /_matrix/client/r0/user/{userId}/rooms/{roomId}/tags/{tag}`
|
||||
///
|
||||
/// Adds a tag to the room.
|
||||
///
|
||||
/// - Inserts the tag into the tag event of the room account data.
|
||||
pub async fn update_tag_route(
|
||||
body: Ruma<create_tag::v3::Request>,
|
||||
) -> Result<create_tag::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
pub async fn update_tag_route(body: Ruma<create_tag::v3::Request>) -> Result<create_tag::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let event = services().account_data.get(
|
||||
Some(&body.room_id),
|
||||
sender_user,
|
||||
RoomAccountDataEventType::Tag,
|
||||
)?;
|
||||
let event = services().account_data.get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)?;
|
||||
|
||||
let mut tags_event = event
|
||||
.map(|e| {
|
||||
serde_json::from_str(e.get())
|
||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))
|
||||
})
|
||||
.unwrap_or_else(|| {
|
||||
Ok(TagEvent {
|
||||
content: TagEventContent {
|
||||
tags: BTreeMap::new(),
|
||||
},
|
||||
})
|
||||
})?;
|
||||
let mut tags_event = event
|
||||
.map(|e| serde_json::from_str(e.get()).map_err(|_| Error::bad_database("Invalid account data event in db.")))
|
||||
.unwrap_or_else(|| {
|
||||
Ok(TagEvent {
|
||||
content: TagEventContent {
|
||||
tags: BTreeMap::new(),
|
||||
},
|
||||
})
|
||||
})?;
|
||||
|
||||
tags_event
|
||||
.content
|
||||
.tags
|
||||
.insert(body.tag.clone().into(), body.tag_info.clone());
|
||||
tags_event.content.tags.insert(body.tag.clone().into(), body.tag_info.clone());
|
||||
|
||||
services().account_data.update(
|
||||
Some(&body.room_id),
|
||||
sender_user,
|
||||
RoomAccountDataEventType::Tag,
|
||||
&serde_json::to_value(tags_event).expect("to json value always works"),
|
||||
)?;
|
||||
services().account_data.update(
|
||||
Some(&body.room_id),
|
||||
sender_user,
|
||||
RoomAccountDataEventType::Tag,
|
||||
&serde_json::to_value(tags_event).expect("to json value always works"),
|
||||
)?;
|
||||
|
||||
Ok(create_tag::v3::Response {})
|
||||
Ok(create_tag::v3::Response {})
|
||||
}
|
||||
|
||||
/// # `DELETE /_matrix/client/r0/user/{userId}/rooms/{roomId}/tags/{tag}`
|
||||
@@ -57,40 +47,31 @@ pub async fn update_tag_route(
|
||||
/// Deletes a tag from the room.
|
||||
///
|
||||
/// - Removes the tag from the tag event of the room account data.
|
||||
pub async fn delete_tag_route(
|
||||
body: Ruma<delete_tag::v3::Request>,
|
||||
) -> Result<delete_tag::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
pub async fn delete_tag_route(body: Ruma<delete_tag::v3::Request>) -> Result<delete_tag::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let event = services().account_data.get(
|
||||
Some(&body.room_id),
|
||||
sender_user,
|
||||
RoomAccountDataEventType::Tag,
|
||||
)?;
|
||||
let event = services().account_data.get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)?;
|
||||
|
||||
let mut tags_event = event
|
||||
.map(|e| {
|
||||
serde_json::from_str(e.get())
|
||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))
|
||||
})
|
||||
.unwrap_or_else(|| {
|
||||
Ok(TagEvent {
|
||||
content: TagEventContent {
|
||||
tags: BTreeMap::new(),
|
||||
},
|
||||
})
|
||||
})?;
|
||||
let mut tags_event = event
|
||||
.map(|e| serde_json::from_str(e.get()).map_err(|_| Error::bad_database("Invalid account data event in db.")))
|
||||
.unwrap_or_else(|| {
|
||||
Ok(TagEvent {
|
||||
content: TagEventContent {
|
||||
tags: BTreeMap::new(),
|
||||
},
|
||||
})
|
||||
})?;
|
||||
|
||||
tags_event.content.tags.remove(&body.tag.clone().into());
|
||||
tags_event.content.tags.remove(&body.tag.clone().into());
|
||||
|
||||
services().account_data.update(
|
||||
Some(&body.room_id),
|
||||
sender_user,
|
||||
RoomAccountDataEventType::Tag,
|
||||
&serde_json::to_value(tags_event).expect("to json value always works"),
|
||||
)?;
|
||||
services().account_data.update(
|
||||
Some(&body.room_id),
|
||||
sender_user,
|
||||
RoomAccountDataEventType::Tag,
|
||||
&serde_json::to_value(tags_event).expect("to json value always works"),
|
||||
)?;
|
||||
|
||||
Ok(delete_tag::v3::Response {})
|
||||
Ok(delete_tag::v3::Response {})
|
||||
}
|
||||
|
||||
/// # `GET /_matrix/client/r0/user/{userId}/rooms/{roomId}/tags`
|
||||
@@ -99,28 +80,21 @@ pub async fn delete_tag_route(
|
||||
///
|
||||
/// - Gets the tag event of the room account data.
|
||||
pub async fn get_tags_route(body: Ruma<get_tags::v3::Request>) -> Result<get_tags::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let event = services().account_data.get(
|
||||
Some(&body.room_id),
|
||||
sender_user,
|
||||
RoomAccountDataEventType::Tag,
|
||||
)?;
|
||||
let event = services().account_data.get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)?;
|
||||
|
||||
let tags_event = event
|
||||
.map(|e| {
|
||||
serde_json::from_str(e.get())
|
||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))
|
||||
})
|
||||
.unwrap_or_else(|| {
|
||||
Ok(TagEvent {
|
||||
content: TagEventContent {
|
||||
tags: BTreeMap::new(),
|
||||
},
|
||||
})
|
||||
})?;
|
||||
let tags_event = event
|
||||
.map(|e| serde_json::from_str(e.get()).map_err(|_| Error::bad_database("Invalid account data event in db.")))
|
||||
.unwrap_or_else(|| {
|
||||
Ok(TagEvent {
|
||||
content: TagEventContent {
|
||||
tags: BTreeMap::new(),
|
||||
},
|
||||
})
|
||||
})?;
|
||||
|
||||
Ok(get_tags::v3::Response {
|
||||
tags: tags_event.content.tags,
|
||||
})
|
||||
Ok(get_tags::v3::Response {
|
||||
tags: tags_event.content.tags,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
use crate::{Result, Ruma};
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
use ruma::api::client::thirdparty::get_protocols;
|
||||
|
||||
use std::collections::BTreeMap;
|
||||
use crate::{Result, Ruma};
|
||||
|
||||
/// # `GET /_matrix/client/r0/thirdparty/protocols`
|
||||
///
|
||||
/// TODO: Fetches all metadata about protocols supported by the homeserver.
|
||||
pub async fn get_protocols_route(
|
||||
_body: Ruma<get_protocols::v3::Request>,
|
||||
) -> Result<get_protocols::v3::Response> {
|
||||
// TODO
|
||||
Ok(get_protocols::v3::Response {
|
||||
protocols: BTreeMap::new(),
|
||||
})
|
||||
pub async fn get_protocols_route(_body: Ruma<get_protocols::v3::Request>) -> Result<get_protocols::v3::Response> {
|
||||
// TODO
|
||||
Ok(get_protocols::v3::Response {
|
||||
protocols: BTreeMap::new(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -3,47 +3,37 @@ use ruma::api::client::{error::ErrorKind, threads::get_threads};
|
||||
use crate::{services, Error, Result, Ruma};
|
||||
|
||||
/// # `GET /_matrix/client/r0/rooms/{roomId}/threads`
|
||||
pub async fn get_threads_route(
|
||||
body: Ruma<get_threads::v1::Request>,
|
||||
) -> Result<get_threads::v1::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
pub async fn get_threads_route(body: Ruma<get_threads::v1::Request>) -> Result<get_threads::v1::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
// Use limit or else 10, with maximum 100
|
||||
let limit = body
|
||||
.limit
|
||||
.and_then(|l| l.try_into().ok())
|
||||
.unwrap_or(10)
|
||||
.min(100);
|
||||
// Use limit or else 10, with maximum 100
|
||||
let limit = body.limit.and_then(|l| l.try_into().ok()).unwrap_or(10).min(100);
|
||||
|
||||
let from = if let Some(from) = &body.from {
|
||||
from.parse()
|
||||
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, ""))?
|
||||
} else {
|
||||
u64::MAX
|
||||
};
|
||||
let from = if let Some(from) = &body.from {
|
||||
from.parse().map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, ""))?
|
||||
} else {
|
||||
u64::MAX
|
||||
};
|
||||
|
||||
let threads = services()
|
||||
.rooms
|
||||
.threads
|
||||
.threads_until(sender_user, &body.room_id, from, &body.include)?
|
||||
.take(limit)
|
||||
.filter_map(std::result::Result::ok)
|
||||
.filter(|(_, pdu)| {
|
||||
services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.user_can_see_event(sender_user, &body.room_id, &pdu.event_id)
|
||||
.unwrap_or(false)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let threads = services()
|
||||
.rooms
|
||||
.threads
|
||||
.threads_until(sender_user, &body.room_id, from, &body.include)?
|
||||
.take(limit)
|
||||
.filter_map(std::result::Result::ok)
|
||||
.filter(|(_, pdu)| {
|
||||
services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.user_can_see_event(sender_user, &body.room_id, &pdu.event_id)
|
||||
.unwrap_or(false)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let next_batch = threads.last().map(|(count, _)| count.to_string());
|
||||
let next_batch = threads.last().map(|(count, _)| count.to_string());
|
||||
|
||||
Ok(get_threads::v1::Response {
|
||||
chunk: threads
|
||||
.into_iter()
|
||||
.map(|(_, pdu)| pdu.to_room_event())
|
||||
.collect(),
|
||||
next_batch,
|
||||
})
|
||||
Ok(get_threads::v1::Response {
|
||||
chunk: threads.into_iter().map(|(_, pdu)| pdu.to_room_event()).collect(),
|
||||
next_batch,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,92 +1,85 @@
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
use crate::{services, Error, Result, Ruma};
|
||||
use ruma::{
|
||||
api::{
|
||||
client::{error::ErrorKind, to_device::send_event_to_device},
|
||||
federation::{self, transactions::edu::DirectDeviceContent},
|
||||
},
|
||||
to_device::DeviceIdOrAllDevices,
|
||||
api::{
|
||||
client::{error::ErrorKind, to_device::send_event_to_device},
|
||||
federation::{self, transactions::edu::DirectDeviceContent},
|
||||
},
|
||||
to_device::DeviceIdOrAllDevices,
|
||||
};
|
||||
|
||||
use crate::{services, Error, Result, Ruma};
|
||||
|
||||
/// # `PUT /_matrix/client/r0/sendToDevice/{eventType}/{txnId}`
|
||||
///
|
||||
/// Send a to-device event to a set of client devices.
|
||||
pub async fn send_event_to_device_route(
|
||||
body: Ruma<send_event_to_device::v3::Request>,
|
||||
body: Ruma<send_event_to_device::v3::Request>,
|
||||
) -> Result<send_event_to_device::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_device = body.sender_device.as_deref();
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_device = body.sender_device.as_deref();
|
||||
|
||||
// Check if this is a new transaction id
|
||||
if services()
|
||||
.transaction_ids
|
||||
.existing_txnid(sender_user, sender_device, &body.txn_id)?
|
||||
.is_some()
|
||||
{
|
||||
return Ok(send_event_to_device::v3::Response {});
|
||||
}
|
||||
// Check if this is a new transaction id
|
||||
if services().transaction_ids.existing_txnid(sender_user, sender_device, &body.txn_id)?.is_some() {
|
||||
return Ok(send_event_to_device::v3::Response {});
|
||||
}
|
||||
|
||||
for (target_user_id, map) in &body.messages {
|
||||
for (target_device_id_maybe, event) in map {
|
||||
if target_user_id.server_name() != services().globals.server_name() {
|
||||
let mut map = BTreeMap::new();
|
||||
map.insert(target_device_id_maybe.clone(), event.clone());
|
||||
let mut messages = BTreeMap::new();
|
||||
messages.insert(target_user_id.clone(), map);
|
||||
let count = services().globals.next_count()?;
|
||||
for (target_user_id, map) in &body.messages {
|
||||
for (target_device_id_maybe, event) in map {
|
||||
if target_user_id.server_name() != services().globals.server_name() {
|
||||
let mut map = BTreeMap::new();
|
||||
map.insert(target_device_id_maybe.clone(), event.clone());
|
||||
let mut messages = BTreeMap::new();
|
||||
messages.insert(target_user_id.clone(), map);
|
||||
let count = services().globals.next_count()?;
|
||||
|
||||
services().sending.send_reliable_edu(
|
||||
target_user_id.server_name(),
|
||||
serde_json::to_vec(&federation::transactions::edu::Edu::DirectToDevice(
|
||||
DirectDeviceContent {
|
||||
sender: sender_user.clone(),
|
||||
ev_type: body.event_type.clone(),
|
||||
message_id: count.to_string().into(),
|
||||
messages,
|
||||
},
|
||||
))
|
||||
.expect("DirectToDevice EDU can be serialized"),
|
||||
count,
|
||||
)?;
|
||||
services().sending.send_reliable_edu(
|
||||
target_user_id.server_name(),
|
||||
serde_json::to_vec(&federation::transactions::edu::Edu::DirectToDevice(DirectDeviceContent {
|
||||
sender: sender_user.clone(),
|
||||
ev_type: body.event_type.clone(),
|
||||
message_id: count.to_string().into(),
|
||||
messages,
|
||||
}))
|
||||
.expect("DirectToDevice EDU can be serialized"),
|
||||
count,
|
||||
)?;
|
||||
|
||||
continue;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
match target_device_id_maybe {
|
||||
DeviceIdOrAllDevices::DeviceId(target_device_id) => {
|
||||
services().users.add_to_device_event(
|
||||
sender_user,
|
||||
target_user_id,
|
||||
target_device_id,
|
||||
&body.event_type.to_string(),
|
||||
event.deserialize_as().map_err(|_| {
|
||||
Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid")
|
||||
})?,
|
||||
)?;
|
||||
}
|
||||
match target_device_id_maybe {
|
||||
DeviceIdOrAllDevices::DeviceId(target_device_id) => {
|
||||
services().users.add_to_device_event(
|
||||
sender_user,
|
||||
target_user_id,
|
||||
target_device_id,
|
||||
&body.event_type.to_string(),
|
||||
event
|
||||
.deserialize_as()
|
||||
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid"))?,
|
||||
)?;
|
||||
},
|
||||
|
||||
DeviceIdOrAllDevices::AllDevices => {
|
||||
for target_device_id in services().users.all_device_ids(target_user_id) {
|
||||
services().users.add_to_device_event(
|
||||
sender_user,
|
||||
target_user_id,
|
||||
&target_device_id?,
|
||||
&body.event_type.to_string(),
|
||||
event.deserialize_as().map_err(|_| {
|
||||
Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid")
|
||||
})?,
|
||||
)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
DeviceIdOrAllDevices::AllDevices => {
|
||||
for target_device_id in services().users.all_device_ids(target_user_id) {
|
||||
services().users.add_to_device_event(
|
||||
sender_user,
|
||||
target_user_id,
|
||||
&target_device_id?,
|
||||
&body.event_type.to_string(),
|
||||
event
|
||||
.deserialize_as()
|
||||
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid"))?,
|
||||
)?;
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Save transaction id with empty data
|
||||
services()
|
||||
.transaction_ids
|
||||
.add_txnid(sender_user, sender_device, &body.txn_id, &[])?;
|
||||
// Save transaction id with empty data
|
||||
services().transaction_ids.add_txnid(sender_user, sender_device, &body.txn_id, &[])?;
|
||||
|
||||
Ok(send_event_to_device::v3::Response {})
|
||||
Ok(send_event_to_device::v3::Response {})
|
||||
}
|
||||
|
||||
@@ -1,40 +1,30 @@
|
||||
use crate::{services, utils, Error, Result, Ruma};
|
||||
use ruma::api::client::{error::ErrorKind, typing::create_typing_event};
|
||||
|
||||
use crate::{services, utils, Error, Result, Ruma};
|
||||
|
||||
/// # `PUT /_matrix/client/r0/rooms/{roomId}/typing/{userId}`
|
||||
///
|
||||
/// Sets the typing state of the sender user.
|
||||
pub async fn create_typing_event_route(
|
||||
body: Ruma<create_typing_event::v3::Request>,
|
||||
body: Ruma<create_typing_event::v3::Request>,
|
||||
) -> Result<create_typing_event::v3::Response> {
|
||||
use create_typing_event::v3::Typing;
|
||||
use create_typing_event::v3::Typing;
|
||||
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
if !services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.is_joined(sender_user, &body.room_id)?
|
||||
{
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"You are not in this room.",
|
||||
));
|
||||
}
|
||||
if !services().rooms.state_cache.is_joined(sender_user, &body.room_id)? {
|
||||
return Err(Error::BadRequest(ErrorKind::Forbidden, "You are not in this room."));
|
||||
}
|
||||
|
||||
if let Typing::Yes(duration) = body.state {
|
||||
services().rooms.edus.typing.typing_add(
|
||||
sender_user,
|
||||
&body.room_id,
|
||||
duration.as_millis() as u64 + utils::millis_since_unix_epoch(),
|
||||
)?;
|
||||
} else {
|
||||
services()
|
||||
.rooms
|
||||
.edus
|
||||
.typing
|
||||
.typing_remove(sender_user, &body.room_id)?;
|
||||
}
|
||||
if let Typing::Yes(duration) = body.state {
|
||||
services().rooms.edus.typing.typing_add(
|
||||
sender_user,
|
||||
&body.room_id,
|
||||
duration.as_millis() as u64 + utils::millis_since_unix_epoch(),
|
||||
)?;
|
||||
} else {
|
||||
services().rooms.edus.typing.typing_remove(sender_user, &body.room_id)?;
|
||||
}
|
||||
|
||||
Ok(create_typing_event::v3::Response {})
|
||||
Ok(create_typing_event::v3::Response {})
|
||||
}
|
||||
|
||||
@@ -7,72 +7,74 @@ use crate::{services, Error, Result, Ruma};
|
||||
|
||||
/// # `GET /_matrix/client/versions`
|
||||
///
|
||||
/// Get the versions of the specification and unstable features supported by this server.
|
||||
/// Get the versions of the specification and unstable features supported by
|
||||
/// this server.
|
||||
///
|
||||
/// - Versions take the form MAJOR.MINOR.PATCH
|
||||
/// - Only the latest PATCH release will be reported for each MAJOR.MINOR value
|
||||
/// - Unstable features are namespaced and may include version information in their name
|
||||
/// - Unstable features are namespaced and may include version information in
|
||||
/// their name
|
||||
///
|
||||
/// Note: Unstable features are used while developing new features. Clients should avoid using
|
||||
/// unstable features in their stable releases
|
||||
/// Note: Unstable features are used while developing new features. Clients
|
||||
/// should avoid using unstable features in their stable releases
|
||||
pub async fn get_supported_versions_route(
|
||||
_body: Ruma<get_supported_versions::Request>,
|
||||
_body: Ruma<get_supported_versions::Request>,
|
||||
) -> Result<get_supported_versions::Response> {
|
||||
let resp = get_supported_versions::Response {
|
||||
versions: vec![
|
||||
"r0.0.1".to_owned(),
|
||||
"r0.1.0".to_owned(),
|
||||
"r0.2.0".to_owned(),
|
||||
"r0.3.0".to_owned(),
|
||||
"r0.4.0".to_owned(),
|
||||
"r0.5.0".to_owned(),
|
||||
"r0.6.0".to_owned(),
|
||||
"r0.6.1".to_owned(),
|
||||
"v1.1".to_owned(),
|
||||
"v1.2".to_owned(),
|
||||
"v1.3".to_owned(),
|
||||
"v1.4".to_owned(),
|
||||
"v1.5".to_owned(),
|
||||
],
|
||||
unstable_features: BTreeMap::from_iter([
|
||||
("org.matrix.e2e_cross_signing".to_owned(), true),
|
||||
("org.matrix.msc2836".to_owned(), true),
|
||||
("org.matrix.msc3827".to_owned(), true),
|
||||
("org.matrix.msc2946".to_owned(), true),
|
||||
]),
|
||||
};
|
||||
let resp = get_supported_versions::Response {
|
||||
versions: vec![
|
||||
"r0.0.1".to_owned(),
|
||||
"r0.1.0".to_owned(),
|
||||
"r0.2.0".to_owned(),
|
||||
"r0.3.0".to_owned(),
|
||||
"r0.4.0".to_owned(),
|
||||
"r0.5.0".to_owned(),
|
||||
"r0.6.0".to_owned(),
|
||||
"r0.6.1".to_owned(),
|
||||
"v1.1".to_owned(),
|
||||
"v1.2".to_owned(),
|
||||
"v1.3".to_owned(),
|
||||
"v1.4".to_owned(),
|
||||
"v1.5".to_owned(),
|
||||
],
|
||||
unstable_features: BTreeMap::from_iter([
|
||||
("org.matrix.e2e_cross_signing".to_owned(), true),
|
||||
("org.matrix.msc2836".to_owned(), true),
|
||||
("org.matrix.msc3827".to_owned(), true),
|
||||
("org.matrix.msc2946".to_owned(), true),
|
||||
]),
|
||||
};
|
||||
|
||||
Ok(resp)
|
||||
Ok(resp)
|
||||
}
|
||||
|
||||
/// # `GET /.well-known/matrix/client`
|
||||
pub async fn well_known_client_route() -> Result<impl IntoResponse> {
|
||||
let client_url = match services().globals.well_known_client() {
|
||||
Some(url) => url.clone(),
|
||||
None => return Err(Error::BadRequest(ErrorKind::NotFound, "Not found.")),
|
||||
};
|
||||
let client_url = match services().globals.well_known_client() {
|
||||
Some(url) => url.clone(),
|
||||
None => return Err(Error::BadRequest(ErrorKind::NotFound, "Not found.")),
|
||||
};
|
||||
|
||||
Ok(Json(serde_json::json!({
|
||||
"m.homeserver": {"base_url": client_url},
|
||||
"org.matrix.msc3575.proxy": {"url": client_url}
|
||||
})))
|
||||
Ok(Json(serde_json::json!({
|
||||
"m.homeserver": {"base_url": client_url},
|
||||
"org.matrix.msc3575.proxy": {"url": client_url}
|
||||
})))
|
||||
}
|
||||
|
||||
/// # `GET /client/server.json`
|
||||
///
|
||||
/// Endpoint provided by sliding sync proxy used by some clients such as Element Web
|
||||
/// as a non-standard health check.
|
||||
/// Endpoint provided by sliding sync proxy used by some clients such as Element
|
||||
/// Web as a non-standard health check.
|
||||
pub async fn syncv3_client_server_json() -> Result<impl IntoResponse> {
|
||||
let server_url = match services().globals.well_known_client() {
|
||||
Some(url) => url.clone(),
|
||||
None => match services().globals.well_known_server() {
|
||||
Some(url) => url.clone(),
|
||||
None => return Err(Error::BadRequest(ErrorKind::NotFound, "Not found.")),
|
||||
},
|
||||
};
|
||||
let server_url = match services().globals.well_known_client() {
|
||||
Some(url) => url.clone(),
|
||||
None => match services().globals.well_known_server() {
|
||||
Some(url) => url.clone(),
|
||||
None => return Err(Error::BadRequest(ErrorKind::NotFound, "Not found.")),
|
||||
},
|
||||
};
|
||||
|
||||
Ok(Json(serde_json::json!({
|
||||
"server": server_url,
|
||||
"version": format!("{} {}", env!("CARGO_PKG_NAME"), env!("CARGO_PKG_VERSION"))
|
||||
})))
|
||||
Ok(Json(serde_json::json!({
|
||||
"server": server_url,
|
||||
"version": format!("{} {}", env!("CARGO_PKG_NAME"), env!("CARGO_PKG_VERSION"))
|
||||
})))
|
||||
}
|
||||
|
||||
@@ -1,94 +1,78 @@
|
||||
use crate::{services, Result, Ruma};
|
||||
use ruma::{
|
||||
api::client::user_directory::search_users,
|
||||
events::{
|
||||
room::join_rules::{JoinRule, RoomJoinRulesEventContent},
|
||||
StateEventType,
|
||||
},
|
||||
api::client::user_directory::search_users,
|
||||
events::{
|
||||
room::join_rules::{JoinRule, RoomJoinRulesEventContent},
|
||||
StateEventType,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::{services, Result, Ruma};
|
||||
|
||||
/// # `POST /_matrix/client/r0/user_directory/search`
|
||||
///
|
||||
/// Searches all known users for a match.
|
||||
///
|
||||
/// - Hides any local users that aren't in any public rooms (i.e. those that have the join rule set to public)
|
||||
/// - Hides any local users that aren't in any public rooms (i.e. those that
|
||||
/// have the join rule set to public)
|
||||
/// and don't share a room with the sender
|
||||
pub async fn search_users_route(
|
||||
body: Ruma<search_users::v3::Request>,
|
||||
) -> Result<search_users::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let limit = u64::from(body.limit) as usize;
|
||||
pub async fn search_users_route(body: Ruma<search_users::v3::Request>) -> Result<search_users::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let limit = u64::from(body.limit) as usize;
|
||||
|
||||
let mut users = services().users.iter().filter_map(|user_id| {
|
||||
// Filter out buggy users (they should not exist, but you never know...)
|
||||
let user_id = user_id.ok()?;
|
||||
let mut users = services().users.iter().filter_map(|user_id| {
|
||||
// Filter out buggy users (they should not exist, but you never know...)
|
||||
let user_id = user_id.ok()?;
|
||||
|
||||
let user = search_users::v3::User {
|
||||
user_id: user_id.clone(),
|
||||
display_name: services().users.displayname(&user_id).ok()?,
|
||||
avatar_url: services().users.avatar_url(&user_id).ok()?,
|
||||
};
|
||||
let user = search_users::v3::User {
|
||||
user_id: user_id.clone(),
|
||||
display_name: services().users.displayname(&user_id).ok()?,
|
||||
avatar_url: services().users.avatar_url(&user_id).ok()?,
|
||||
};
|
||||
|
||||
let user_id_matches = user
|
||||
.user_id
|
||||
.to_string()
|
||||
.to_lowercase()
|
||||
.contains(&body.search_term.to_lowercase());
|
||||
let user_id_matches = user.user_id.to_string().to_lowercase().contains(&body.search_term.to_lowercase());
|
||||
|
||||
let user_displayname_matches = user
|
||||
.display_name
|
||||
.as_ref()
|
||||
.filter(|name| {
|
||||
name.to_lowercase()
|
||||
.contains(&body.search_term.to_lowercase())
|
||||
})
|
||||
.is_some();
|
||||
let user_displayname_matches = user
|
||||
.display_name
|
||||
.as_ref()
|
||||
.filter(|name| name.to_lowercase().contains(&body.search_term.to_lowercase()))
|
||||
.is_some();
|
||||
|
||||
if !user_id_matches && !user_displayname_matches {
|
||||
return None;
|
||||
}
|
||||
if !user_id_matches && !user_displayname_matches {
|
||||
return None;
|
||||
}
|
||||
|
||||
let user_is_in_public_rooms = services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.rooms_joined(&user_id)
|
||||
.filter_map(std::result::Result::ok)
|
||||
.any(|room| {
|
||||
services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.room_state_get(&room, &StateEventType::RoomJoinRules, "")
|
||||
.map_or(false, |event| {
|
||||
event.map_or(false, |event| {
|
||||
serde_json::from_str(event.content.get())
|
||||
.map_or(false, |r: RoomJoinRulesEventContent| {
|
||||
r.join_rule == JoinRule::Public
|
||||
})
|
||||
})
|
||||
})
|
||||
});
|
||||
let user_is_in_public_rooms =
|
||||
services().rooms.state_cache.rooms_joined(&user_id).filter_map(std::result::Result::ok).any(|room| {
|
||||
services().rooms.state_accessor.room_state_get(&room, &StateEventType::RoomJoinRules, "").map_or(
|
||||
false,
|
||||
|event| {
|
||||
event.map_or(false, |event| {
|
||||
serde_json::from_str(event.content.get())
|
||||
.map_or(false, |r: RoomJoinRulesEventContent| r.join_rule == JoinRule::Public)
|
||||
})
|
||||
},
|
||||
)
|
||||
});
|
||||
|
||||
if user_is_in_public_rooms {
|
||||
return Some(user);
|
||||
}
|
||||
if user_is_in_public_rooms {
|
||||
return Some(user);
|
||||
}
|
||||
|
||||
let user_is_in_shared_rooms = services()
|
||||
.rooms
|
||||
.user
|
||||
.get_shared_rooms(vec![sender_user.clone(), user_id])
|
||||
.ok()?
|
||||
.next()
|
||||
.is_some();
|
||||
let user_is_in_shared_rooms =
|
||||
services().rooms.user.get_shared_rooms(vec![sender_user.clone(), user_id]).ok()?.next().is_some();
|
||||
|
||||
if user_is_in_shared_rooms {
|
||||
return Some(user);
|
||||
}
|
||||
if user_is_in_shared_rooms {
|
||||
return Some(user);
|
||||
}
|
||||
|
||||
None
|
||||
});
|
||||
None
|
||||
});
|
||||
|
||||
let results = users.by_ref().take(limit).collect();
|
||||
let limited = users.next().is_some();
|
||||
let results = users.by_ref().take(limit).collect();
|
||||
let limited = users.next().is_some();
|
||||
|
||||
Ok(search_users::v3::Response { results, limited })
|
||||
Ok(search_users::v3::Response {
|
||||
results,
|
||||
limited,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
use crate::{services, Result, Ruma};
|
||||
use std::time::{Duration, SystemTime};
|
||||
|
||||
use base64::{engine::general_purpose, Engine as _};
|
||||
use hmac::{Hmac, Mac};
|
||||
use ruma::{api::client::voip::get_turn_server_info, SecondsSinceUnixEpoch};
|
||||
use sha1::Sha1;
|
||||
use std::time::{Duration, SystemTime};
|
||||
|
||||
use crate::{services, Result, Ruma};
|
||||
|
||||
type HmacSha1 = Hmac<Sha1>;
|
||||
|
||||
@@ -11,38 +13,37 @@ type HmacSha1 = Hmac<Sha1>;
|
||||
///
|
||||
/// TODO: Returns information about the recommended turn server.
|
||||
pub async fn turn_server_route(
|
||||
body: Ruma<get_turn_server_info::v3::Request>,
|
||||
body: Ruma<get_turn_server_info::v3::Request>,
|
||||
) -> Result<get_turn_server_info::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let turn_secret = services().globals.turn_secret().clone();
|
||||
let turn_secret = services().globals.turn_secret().clone();
|
||||
|
||||
let (username, password) = if !turn_secret.is_empty() {
|
||||
let expiry = SecondsSinceUnixEpoch::from_system_time(
|
||||
SystemTime::now() + Duration::from_secs(services().globals.turn_ttl()),
|
||||
)
|
||||
.expect("time is valid");
|
||||
let (username, password) = if !turn_secret.is_empty() {
|
||||
let expiry = SecondsSinceUnixEpoch::from_system_time(
|
||||
SystemTime::now() + Duration::from_secs(services().globals.turn_ttl()),
|
||||
)
|
||||
.expect("time is valid");
|
||||
|
||||
let username: String = format!("{}:{}", expiry.get(), sender_user);
|
||||
let username: String = format!("{}:{}", expiry.get(), sender_user);
|
||||
|
||||
let mut mac = HmacSha1::new_from_slice(turn_secret.as_bytes())
|
||||
.expect("HMAC can take key of any size");
|
||||
mac.update(username.as_bytes());
|
||||
let mut mac = HmacSha1::new_from_slice(turn_secret.as_bytes()).expect("HMAC can take key of any size");
|
||||
mac.update(username.as_bytes());
|
||||
|
||||
let password: String = general_purpose::STANDARD.encode(mac.finalize().into_bytes());
|
||||
let password: String = general_purpose::STANDARD.encode(mac.finalize().into_bytes());
|
||||
|
||||
(username, password)
|
||||
} else {
|
||||
(
|
||||
services().globals.turn_username().clone(),
|
||||
services().globals.turn_password().clone(),
|
||||
)
|
||||
};
|
||||
(username, password)
|
||||
} else {
|
||||
(
|
||||
services().globals.turn_username().clone(),
|
||||
services().globals.turn_password().clone(),
|
||||
)
|
||||
};
|
||||
|
||||
Ok(get_turn_server_info::v3::Response {
|
||||
username,
|
||||
password,
|
||||
uris: services().globals.turn_uris().to_vec(),
|
||||
ttl: Duration::from_secs(services().globals.turn_ttl()),
|
||||
})
|
||||
Ok(get_turn_server_info::v3::Response {
|
||||
username,
|
||||
password,
|
||||
uris: services().globals.turn_uris().to_vec(),
|
||||
ttl: Duration::from_secs(services().globals.turn_ttl()),
|
||||
})
|
||||
}
|
||||
|
||||
+303
-370
@@ -1,21 +1,21 @@
|
||||
use std::{collections::BTreeMap, str};
|
||||
|
||||
use axum::{
|
||||
async_trait,
|
||||
body::{Full, HttpBody},
|
||||
extract::{rejection::TypedHeaderRejectionReason, FromRequest, Path, TypedHeader},
|
||||
headers::{
|
||||
authorization::{Bearer, Credentials},
|
||||
Authorization,
|
||||
},
|
||||
response::{IntoResponse, Response},
|
||||
BoxError, RequestExt, RequestPartsExt,
|
||||
async_trait,
|
||||
body::{Full, HttpBody},
|
||||
extract::{rejection::TypedHeaderRejectionReason, FromRequest, Path, TypedHeader},
|
||||
headers::{
|
||||
authorization::{Bearer, Credentials},
|
||||
Authorization,
|
||||
},
|
||||
response::{IntoResponse, Response},
|
||||
BoxError, RequestExt, RequestPartsExt,
|
||||
};
|
||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||
use http::{Request, StatusCode};
|
||||
use ruma::{
|
||||
api::{client::error::ErrorKind, AuthScheme, IncomingRequest, OutgoingResponse},
|
||||
CanonicalJsonValue, OwnedDeviceId, OwnedServerName, UserId,
|
||||
api::{client::error::ErrorKind, AuthScheme, IncomingRequest, OutgoingResponse},
|
||||
CanonicalJsonValue, OwnedDeviceId, OwnedServerName, UserId,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use tracing::{debug, error, warn};
|
||||
@@ -25,400 +25,333 @@ use crate::{services, Error, Result};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct QueryParams {
|
||||
access_token: Option<String>,
|
||||
user_id: Option<String>,
|
||||
access_token: Option<String>,
|
||||
user_id: Option<String>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<T, S, B> FromRequest<S, B> for Ruma<T>
|
||||
where
|
||||
T: IncomingRequest,
|
||||
B: HttpBody + Send + 'static,
|
||||
B::Data: Send,
|
||||
B::Error: Into<BoxError>,
|
||||
T: IncomingRequest,
|
||||
B: HttpBody + Send + 'static,
|
||||
B::Data: Send,
|
||||
B::Error: Into<BoxError>,
|
||||
{
|
||||
type Rejection = Error;
|
||||
type Rejection = Error;
|
||||
|
||||
async fn from_request(req: Request<B>, _state: &S) -> Result<Self, Self::Rejection> {
|
||||
let (mut parts, mut body) = match req.with_limited_body() {
|
||||
Ok(limited_req) => {
|
||||
let (parts, body) = limited_req.into_parts();
|
||||
let body = to_bytes(body)
|
||||
.await
|
||||
.map_err(|_| Error::BadRequest(ErrorKind::MissingToken, "Missing token."))?;
|
||||
(parts, body)
|
||||
}
|
||||
Err(original_req) => {
|
||||
let (parts, body) = original_req.into_parts();
|
||||
let body = to_bytes(body)
|
||||
.await
|
||||
.map_err(|_| Error::BadRequest(ErrorKind::MissingToken, "Missing token."))?;
|
||||
(parts, body)
|
||||
}
|
||||
};
|
||||
async fn from_request(req: Request<B>, _state: &S) -> Result<Self, Self::Rejection> {
|
||||
let (mut parts, mut body) = match req.with_limited_body() {
|
||||
Ok(limited_req) => {
|
||||
let (parts, body) = limited_req.into_parts();
|
||||
let body =
|
||||
to_bytes(body).await.map_err(|_| Error::BadRequest(ErrorKind::MissingToken, "Missing token."))?;
|
||||
(parts, body)
|
||||
},
|
||||
Err(original_req) => {
|
||||
let (parts, body) = original_req.into_parts();
|
||||
let body =
|
||||
to_bytes(body).await.map_err(|_| Error::BadRequest(ErrorKind::MissingToken, "Missing token."))?;
|
||||
(parts, body)
|
||||
},
|
||||
};
|
||||
|
||||
let metadata = T::METADATA;
|
||||
let auth_header: Option<TypedHeader<Authorization<Bearer>>> = parts.extract().await?;
|
||||
let path_params: Path<Vec<String>> = parts.extract().await?;
|
||||
let metadata = T::METADATA;
|
||||
let auth_header: Option<TypedHeader<Authorization<Bearer>>> = parts.extract().await?;
|
||||
let path_params: Path<Vec<String>> = parts.extract().await?;
|
||||
|
||||
let query = parts.uri.query().unwrap_or_default();
|
||||
let query_params: QueryParams = match serde_html_form::from_str(query) {
|
||||
Ok(params) => params,
|
||||
Err(e) => {
|
||||
error!(%query, "Failed to deserialize query parameters: {}", e);
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Unknown,
|
||||
"Failed to read query parameters",
|
||||
));
|
||||
}
|
||||
};
|
||||
let query = parts.uri.query().unwrap_or_default();
|
||||
let query_params: QueryParams = match serde_html_form::from_str(query) {
|
||||
Ok(params) => params,
|
||||
Err(e) => {
|
||||
error!(%query, "Failed to deserialize query parameters: {}", e);
|
||||
return Err(Error::BadRequest(ErrorKind::Unknown, "Failed to read query parameters"));
|
||||
},
|
||||
};
|
||||
|
||||
let token = match &auth_header {
|
||||
Some(TypedHeader(Authorization(bearer))) => Some(bearer.token()),
|
||||
None => query_params.access_token.as_deref(),
|
||||
};
|
||||
let token = match &auth_header {
|
||||
Some(TypedHeader(Authorization(bearer))) => Some(bearer.token()),
|
||||
None => query_params.access_token.as_deref(),
|
||||
};
|
||||
|
||||
let mut json_body = serde_json::from_slice::<CanonicalJsonValue>(&body).ok();
|
||||
let mut json_body = serde_json::from_slice::<CanonicalJsonValue>(&body).ok();
|
||||
|
||||
let appservices = services().appservice.all().unwrap();
|
||||
let appservice_registration = appservices
|
||||
.iter()
|
||||
.find(|(_id, registration)| Some(registration.as_token.as_str()) == token);
|
||||
let appservices = services().appservice.all().unwrap();
|
||||
let appservice_registration =
|
||||
appservices.iter().find(|(_id, registration)| Some(registration.as_token.as_str()) == token);
|
||||
|
||||
let (sender_user, sender_device, sender_servername, from_appservice) =
|
||||
if let Some((_id, registration)) = appservice_registration {
|
||||
match metadata.authentication {
|
||||
AuthScheme::AccessToken => {
|
||||
let user_id = query_params.user_id.map_or_else(
|
||||
|| {
|
||||
UserId::parse_with_server_name(
|
||||
registration.sender_localpart.as_str(),
|
||||
services().globals.server_name(),
|
||||
)
|
||||
.unwrap()
|
||||
},
|
||||
|s| UserId::parse(s).unwrap(),
|
||||
);
|
||||
let (sender_user, sender_device, sender_servername, from_appservice) = if let Some((_id, registration)) =
|
||||
appservice_registration
|
||||
{
|
||||
match metadata.authentication {
|
||||
AuthScheme::AccessToken => {
|
||||
let user_id = query_params.user_id.map_or_else(
|
||||
|| {
|
||||
UserId::parse_with_server_name(
|
||||
registration.sender_localpart.as_str(),
|
||||
services().globals.server_name(),
|
||||
)
|
||||
.unwrap()
|
||||
},
|
||||
|s| UserId::parse(s).unwrap(),
|
||||
);
|
||||
|
||||
if !services().users.exists(&user_id).unwrap() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"User does not exist.",
|
||||
));
|
||||
}
|
||||
if !services().users.exists(&user_id).unwrap() {
|
||||
return Err(Error::BadRequest(ErrorKind::Forbidden, "User does not exist."));
|
||||
}
|
||||
|
||||
// TODO: Check if appservice is allowed to be that user
|
||||
(Some(user_id), None, None, true)
|
||||
}
|
||||
AuthScheme::ServerSignatures => (None, None, None, true),
|
||||
AuthScheme::None => (None, None, None, true),
|
||||
}
|
||||
} else {
|
||||
match metadata.authentication {
|
||||
AuthScheme::AccessToken => {
|
||||
let token = match token {
|
||||
Some(token) => token,
|
||||
_ => {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::MissingToken,
|
||||
"Missing access token.",
|
||||
))
|
||||
}
|
||||
};
|
||||
// TODO: Check if appservice is allowed to be that user
|
||||
(Some(user_id), None, None, true)
|
||||
},
|
||||
AuthScheme::ServerSignatures => (None, None, None, true),
|
||||
AuthScheme::None => (None, None, None, true),
|
||||
}
|
||||
} else {
|
||||
match metadata.authentication {
|
||||
AuthScheme::AccessToken => {
|
||||
let token = match token {
|
||||
Some(token) => token,
|
||||
_ => return Err(Error::BadRequest(ErrorKind::MissingToken, "Missing access token.")),
|
||||
};
|
||||
|
||||
match services().users.find_from_token(token).unwrap() {
|
||||
None => {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::UnknownToken { soft_logout: false },
|
||||
"Unknown access token.",
|
||||
))
|
||||
}
|
||||
Some((user_id, device_id)) => (
|
||||
Some(user_id),
|
||||
Some(OwnedDeviceId::from(device_id)),
|
||||
None,
|
||||
false,
|
||||
),
|
||||
}
|
||||
}
|
||||
AuthScheme::ServerSignatures => {
|
||||
let TypedHeader(Authorization(x_matrix)) = parts
|
||||
.extract::<TypedHeader<Authorization<XMatrix>>>()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
warn!("Missing or invalid Authorization header: {}", e);
|
||||
match services().users.find_from_token(token).unwrap() {
|
||||
None => {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::UnknownToken {
|
||||
soft_logout: false,
|
||||
},
|
||||
"Unknown access token.",
|
||||
))
|
||||
},
|
||||
Some((user_id, device_id)) => {
|
||||
(Some(user_id), Some(OwnedDeviceId::from(device_id)), None, false)
|
||||
},
|
||||
}
|
||||
},
|
||||
AuthScheme::ServerSignatures => {
|
||||
let TypedHeader(Authorization(x_matrix)) =
|
||||
parts.extract::<TypedHeader<Authorization<XMatrix>>>().await.map_err(|e| {
|
||||
warn!("Missing or invalid Authorization header: {}", e);
|
||||
|
||||
let msg = match e.reason() {
|
||||
TypedHeaderRejectionReason::Missing => {
|
||||
"Missing Authorization header."
|
||||
}
|
||||
TypedHeaderRejectionReason::Error(_) => {
|
||||
"Invalid X-Matrix signatures."
|
||||
}
|
||||
_ => "Unknown header-related error",
|
||||
};
|
||||
let msg = match e.reason() {
|
||||
TypedHeaderRejectionReason::Missing => "Missing Authorization header.",
|
||||
TypedHeaderRejectionReason::Error(_) => "Invalid X-Matrix signatures.",
|
||||
_ => "Unknown header-related error",
|
||||
};
|
||||
|
||||
Error::BadRequest(ErrorKind::Forbidden, msg)
|
||||
})?;
|
||||
Error::BadRequest(ErrorKind::Forbidden, msg)
|
||||
})?;
|
||||
|
||||
let origin_signatures = BTreeMap::from_iter([(
|
||||
x_matrix.key.clone(),
|
||||
CanonicalJsonValue::String(x_matrix.sig),
|
||||
)]);
|
||||
let origin_signatures =
|
||||
BTreeMap::from_iter([(x_matrix.key.clone(), CanonicalJsonValue::String(x_matrix.sig))]);
|
||||
|
||||
let signatures = BTreeMap::from_iter([(
|
||||
x_matrix.origin.as_str().to_owned(),
|
||||
CanonicalJsonValue::Object(origin_signatures),
|
||||
)]);
|
||||
let signatures = BTreeMap::from_iter([(
|
||||
x_matrix.origin.as_str().to_owned(),
|
||||
CanonicalJsonValue::Object(origin_signatures),
|
||||
)]);
|
||||
|
||||
let server_destination =
|
||||
services().globals.server_name().as_str().to_owned();
|
||||
let server_destination = services().globals.server_name().as_str().to_owned();
|
||||
|
||||
if let Some(destination) = x_matrix.destination.as_ref() {
|
||||
if destination != &server_destination {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"Invalid authorization.",
|
||||
));
|
||||
}
|
||||
}
|
||||
if let Some(destination) = x_matrix.destination.as_ref() {
|
||||
if destination != &server_destination {
|
||||
return Err(Error::BadRequest(ErrorKind::Forbidden, "Invalid authorization."));
|
||||
}
|
||||
}
|
||||
|
||||
let mut request_map = BTreeMap::from_iter([
|
||||
(
|
||||
"method".to_owned(),
|
||||
CanonicalJsonValue::String(parts.method.to_string()),
|
||||
),
|
||||
(
|
||||
"uri".to_owned(),
|
||||
CanonicalJsonValue::String(parts.uri.to_string()),
|
||||
),
|
||||
(
|
||||
"origin".to_owned(),
|
||||
CanonicalJsonValue::String(x_matrix.origin.as_str().to_owned()),
|
||||
),
|
||||
(
|
||||
"destination".to_owned(),
|
||||
CanonicalJsonValue::String(server_destination),
|
||||
),
|
||||
(
|
||||
"signatures".to_owned(),
|
||||
CanonicalJsonValue::Object(signatures),
|
||||
),
|
||||
]);
|
||||
let mut request_map = BTreeMap::from_iter([
|
||||
("method".to_owned(), CanonicalJsonValue::String(parts.method.to_string())),
|
||||
("uri".to_owned(), CanonicalJsonValue::String(parts.uri.to_string())),
|
||||
(
|
||||
"origin".to_owned(),
|
||||
CanonicalJsonValue::String(x_matrix.origin.as_str().to_owned()),
|
||||
),
|
||||
("destination".to_owned(), CanonicalJsonValue::String(server_destination)),
|
||||
("signatures".to_owned(), CanonicalJsonValue::Object(signatures)),
|
||||
]);
|
||||
|
||||
if let Some(json_body) = &json_body {
|
||||
request_map.insert("content".to_owned(), json_body.clone());
|
||||
};
|
||||
if let Some(json_body) = &json_body {
|
||||
request_map.insert("content".to_owned(), json_body.clone());
|
||||
};
|
||||
|
||||
let keys_result = services()
|
||||
.rooms
|
||||
.event_handler
|
||||
.fetch_signing_keys_for_server(
|
||||
&x_matrix.origin,
|
||||
vec![x_matrix.key.clone()],
|
||||
)
|
||||
.await;
|
||||
let keys_result = services()
|
||||
.rooms
|
||||
.event_handler
|
||||
.fetch_signing_keys_for_server(&x_matrix.origin, vec![x_matrix.key.clone()])
|
||||
.await;
|
||||
|
||||
let keys = match keys_result {
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
warn!("Failed to fetch signing keys: {}", e);
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"Failed to fetch signing keys.",
|
||||
));
|
||||
}
|
||||
};
|
||||
let keys = match keys_result {
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
warn!("Failed to fetch signing keys: {}", e);
|
||||
return Err(Error::BadRequest(ErrorKind::Forbidden, "Failed to fetch signing keys."));
|
||||
},
|
||||
};
|
||||
|
||||
let pub_key_map =
|
||||
BTreeMap::from_iter([(x_matrix.origin.as_str().to_owned(), keys)]);
|
||||
let pub_key_map = BTreeMap::from_iter([(x_matrix.origin.as_str().to_owned(), keys)]);
|
||||
|
||||
match ruma::signatures::verify_json(&pub_key_map, &request_map) {
|
||||
Ok(()) => (None, None, Some(x_matrix.origin), false),
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"Failed to verify json request from {}: {}\n{:?}",
|
||||
x_matrix.origin, e, request_map
|
||||
);
|
||||
match ruma::signatures::verify_json(&pub_key_map, &request_map) {
|
||||
Ok(()) => (None, None, Some(x_matrix.origin), false),
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"Failed to verify json request from {}: {}\n{:?}",
|
||||
x_matrix.origin, e, request_map
|
||||
);
|
||||
|
||||
if parts.uri.to_string().contains('@') {
|
||||
warn!(
|
||||
"Request uri contained '@' character. Make sure your \
|
||||
reverse proxy gives Conduit the raw uri (apache: use \
|
||||
nocanon)"
|
||||
);
|
||||
}
|
||||
if parts.uri.to_string().contains('@') {
|
||||
warn!(
|
||||
"Request uri contained '@' character. Make sure your reverse proxy gives Conduit \
|
||||
the raw uri (apache: use nocanon)"
|
||||
);
|
||||
}
|
||||
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"Failed to verify X-Matrix signatures.",
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
AuthScheme::None => match parts.uri.path() {
|
||||
// allow_public_room_directory_without_auth
|
||||
"/_matrix/client/v3/publicRooms" | "/_matrix/client/r0/publicRooms" => {
|
||||
if !services()
|
||||
.globals
|
||||
.config
|
||||
.allow_public_room_directory_without_auth
|
||||
{
|
||||
let token = match token {
|
||||
Some(token) => token,
|
||||
_ => {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::MissingToken,
|
||||
"Missing access token.",
|
||||
))
|
||||
}
|
||||
};
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"Failed to verify X-Matrix signatures.",
|
||||
));
|
||||
},
|
||||
}
|
||||
},
|
||||
AuthScheme::None => match parts.uri.path() {
|
||||
// allow_public_room_directory_without_auth
|
||||
"/_matrix/client/v3/publicRooms" | "/_matrix/client/r0/publicRooms" => {
|
||||
if !services().globals.config.allow_public_room_directory_without_auth {
|
||||
let token = match token {
|
||||
Some(token) => token,
|
||||
_ => return Err(Error::BadRequest(ErrorKind::MissingToken, "Missing access token.")),
|
||||
};
|
||||
|
||||
match services().users.find_from_token(token).unwrap() {
|
||||
None => {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::UnknownToken { soft_logout: false },
|
||||
"Unknown access token.",
|
||||
))
|
||||
}
|
||||
Some((user_id, device_id)) => (
|
||||
Some(user_id),
|
||||
Some(OwnedDeviceId::from(device_id)),
|
||||
None,
|
||||
false,
|
||||
),
|
||||
}
|
||||
} else {
|
||||
(None, None, None, false)
|
||||
}
|
||||
}
|
||||
_ => (None, None, None, false),
|
||||
},
|
||||
}
|
||||
};
|
||||
match services().users.find_from_token(token).unwrap() {
|
||||
None => {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::UnknownToken {
|
||||
soft_logout: false,
|
||||
},
|
||||
"Unknown access token.",
|
||||
))
|
||||
},
|
||||
Some((user_id, device_id)) => {
|
||||
(Some(user_id), Some(OwnedDeviceId::from(device_id)), None, false)
|
||||
},
|
||||
}
|
||||
} else {
|
||||
(None, None, None, false)
|
||||
}
|
||||
},
|
||||
_ => (None, None, None, false),
|
||||
},
|
||||
}
|
||||
};
|
||||
|
||||
let mut http_request = http::Request::builder().uri(parts.uri).method(parts.method);
|
||||
*http_request.headers_mut().unwrap() = parts.headers;
|
||||
let mut http_request = http::Request::builder().uri(parts.uri).method(parts.method);
|
||||
*http_request.headers_mut().unwrap() = parts.headers;
|
||||
|
||||
if let Some(CanonicalJsonValue::Object(json_body)) = &mut json_body {
|
||||
let user_id = sender_user.clone().unwrap_or_else(|| {
|
||||
UserId::parse_with_server_name("", services().globals.server_name())
|
||||
.expect("we know this is valid")
|
||||
});
|
||||
if let Some(CanonicalJsonValue::Object(json_body)) = &mut json_body {
|
||||
let user_id = sender_user.clone().unwrap_or_else(|| {
|
||||
UserId::parse_with_server_name("", services().globals.server_name()).expect("we know this is valid")
|
||||
});
|
||||
|
||||
let uiaa_request = json_body
|
||||
.get("auth")
|
||||
.and_then(|auth| auth.as_object())
|
||||
.and_then(|auth| auth.get("session"))
|
||||
.and_then(|session| session.as_str())
|
||||
.and_then(|session| {
|
||||
services().uiaa.get_uiaa_request(
|
||||
&user_id,
|
||||
&sender_device.clone().unwrap_or_else(|| "".into()),
|
||||
session,
|
||||
)
|
||||
});
|
||||
let uiaa_request = json_body
|
||||
.get("auth")
|
||||
.and_then(|auth| auth.as_object())
|
||||
.and_then(|auth| auth.get("session"))
|
||||
.and_then(|session| session.as_str())
|
||||
.and_then(|session| {
|
||||
services().uiaa.get_uiaa_request(
|
||||
&user_id,
|
||||
&sender_device.clone().unwrap_or_else(|| "".into()),
|
||||
session,
|
||||
)
|
||||
});
|
||||
|
||||
if let Some(CanonicalJsonValue::Object(initial_request)) = uiaa_request {
|
||||
for (key, value) in initial_request {
|
||||
json_body.entry(key).or_insert(value);
|
||||
}
|
||||
}
|
||||
if let Some(CanonicalJsonValue::Object(initial_request)) = uiaa_request {
|
||||
for (key, value) in initial_request {
|
||||
json_body.entry(key).or_insert(value);
|
||||
}
|
||||
}
|
||||
|
||||
let mut buf = BytesMut::new().writer();
|
||||
serde_json::to_writer(&mut buf, json_body).expect("value serialization can't fail");
|
||||
body = buf.into_inner().freeze();
|
||||
}
|
||||
let mut buf = BytesMut::new().writer();
|
||||
serde_json::to_writer(&mut buf, json_body).expect("value serialization can't fail");
|
||||
body = buf.into_inner().freeze();
|
||||
}
|
||||
|
||||
let http_request = http_request.body(&*body).unwrap();
|
||||
let http_request = http_request.body(&*body).unwrap();
|
||||
|
||||
debug!("{:?}", http_request);
|
||||
debug!("{:?}", http_request);
|
||||
|
||||
let body = T::try_from_http_request(http_request, &path_params).map_err(|e| {
|
||||
warn!("try_from_http_request failed: {:?}", e);
|
||||
debug!("JSON body: {:?}", json_body);
|
||||
Error::BadRequest(ErrorKind::BadJson, "Failed to deserialize request.")
|
||||
})?;
|
||||
let body = T::try_from_http_request(http_request, &path_params).map_err(|e| {
|
||||
warn!("try_from_http_request failed: {:?}", e);
|
||||
debug!("JSON body: {:?}", json_body);
|
||||
Error::BadRequest(ErrorKind::BadJson, "Failed to deserialize request.")
|
||||
})?;
|
||||
|
||||
Ok(Ruma {
|
||||
body,
|
||||
sender_user,
|
||||
sender_device,
|
||||
sender_servername,
|
||||
from_appservice,
|
||||
json_body,
|
||||
})
|
||||
}
|
||||
Ok(Ruma {
|
||||
body,
|
||||
sender_user,
|
||||
sender_device,
|
||||
sender_servername,
|
||||
from_appservice,
|
||||
json_body,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct XMatrix {
|
||||
origin: OwnedServerName,
|
||||
destination: Option<String>,
|
||||
key: String, // KeyName?
|
||||
sig: String,
|
||||
origin: OwnedServerName,
|
||||
destination: Option<String>,
|
||||
key: String, // KeyName?
|
||||
sig: String,
|
||||
}
|
||||
|
||||
impl Credentials for XMatrix {
|
||||
const SCHEME: &'static str = "X-Matrix";
|
||||
const SCHEME: &'static str = "X-Matrix";
|
||||
|
||||
fn decode(value: &http::HeaderValue) -> Option<Self> {
|
||||
debug_assert!(
|
||||
value.as_bytes().starts_with(b"X-Matrix "),
|
||||
"HeaderValue to decode should start with \"X-Matrix ..\", received = {value:?}",
|
||||
);
|
||||
fn decode(value: &http::HeaderValue) -> Option<Self> {
|
||||
debug_assert!(
|
||||
value.as_bytes().starts_with(b"X-Matrix "),
|
||||
"HeaderValue to decode should start with \"X-Matrix ..\", received = {value:?}",
|
||||
);
|
||||
|
||||
let parameters = str::from_utf8(&value.as_bytes()["X-Matrix ".len()..])
|
||||
.ok()?
|
||||
.trim_start();
|
||||
let parameters = str::from_utf8(&value.as_bytes()["X-Matrix ".len()..]).ok()?.trim_start();
|
||||
|
||||
let mut origin = None;
|
||||
let mut destination = None;
|
||||
let mut key = None;
|
||||
let mut sig = None;
|
||||
let mut origin = None;
|
||||
let mut destination = None;
|
||||
let mut key = None;
|
||||
let mut sig = None;
|
||||
|
||||
for entry in parameters.split_terminator(',') {
|
||||
let (name, value) = entry.split_once('=')?;
|
||||
for entry in parameters.split_terminator(',') {
|
||||
let (name, value) = entry.split_once('=')?;
|
||||
|
||||
// It's not at all clear why some fields are quoted and others not in the spec,
|
||||
// let's simply accept either form for every field.
|
||||
let value = value
|
||||
.strip_prefix('"')
|
||||
.and_then(|rest| rest.strip_suffix('"'))
|
||||
.unwrap_or(value);
|
||||
// It's not at all clear why some fields are quoted and others not in the spec,
|
||||
// let's simply accept either form for every field.
|
||||
let value = value.strip_prefix('"').and_then(|rest| rest.strip_suffix('"')).unwrap_or(value);
|
||||
|
||||
// FIXME: Catch multiple fields of the same name
|
||||
match name {
|
||||
"origin" => origin = Some(value.try_into().ok()?),
|
||||
"key" => key = Some(value.to_owned()),
|
||||
"sig" => sig = Some(value.to_owned()),
|
||||
"destination" => destination = Some(value.to_owned()),
|
||||
_ => debug!(
|
||||
"Unexpected field `{}` in X-Matrix Authorization header",
|
||||
name
|
||||
),
|
||||
}
|
||||
}
|
||||
// FIXME: Catch multiple fields of the same name
|
||||
match name {
|
||||
"origin" => origin = Some(value.try_into().ok()?),
|
||||
"key" => key = Some(value.to_owned()),
|
||||
"sig" => sig = Some(value.to_owned()),
|
||||
"destination" => destination = Some(value.to_owned()),
|
||||
_ => debug!("Unexpected field `{}` in X-Matrix Authorization header", name),
|
||||
}
|
||||
}
|
||||
|
||||
Some(Self {
|
||||
origin: origin?,
|
||||
key: key?,
|
||||
sig: sig?,
|
||||
destination,
|
||||
})
|
||||
}
|
||||
Some(Self {
|
||||
origin: origin?,
|
||||
key: key?,
|
||||
sig: sig?,
|
||||
destination,
|
||||
})
|
||||
}
|
||||
|
||||
fn encode(&self) -> http::HeaderValue {
|
||||
todo!()
|
||||
}
|
||||
fn encode(&self) -> http::HeaderValue { todo!() }
|
||||
}
|
||||
|
||||
impl<T: OutgoingResponse> IntoResponse for RumaResponse<T> {
|
||||
fn into_response(self) -> Response {
|
||||
match self.0.try_into_http_response::<BytesMut>() {
|
||||
Ok(res) => res.map(BytesMut::freeze).map(Full::new).into_response(),
|
||||
Err(_) => StatusCode::INTERNAL_SERVER_ERROR.into_response(),
|
||||
}
|
||||
}
|
||||
fn into_response(self) -> Response {
|
||||
match self.0.try_into_http_response::<BytesMut>() {
|
||||
Ok(res) => res.map(BytesMut::freeze).map(Full::new).into_response(),
|
||||
Err(_) => StatusCode::INTERNAL_SERVER_ERROR.into_response(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// copied from hyper under the following license:
|
||||
@@ -443,32 +376,32 @@ impl<T: OutgoingResponse> IntoResponse for RumaResponse<T> {
|
||||
// THE SOFTWARE.
|
||||
pub(crate) async fn to_bytes<T>(body: T) -> Result<Bytes, T::Error>
|
||||
where
|
||||
T: HttpBody,
|
||||
T: HttpBody,
|
||||
{
|
||||
futures_util::pin_mut!(body);
|
||||
futures_util::pin_mut!(body);
|
||||
|
||||
// If there's only 1 chunk, we can just return Buf::to_bytes()
|
||||
let mut first = if let Some(buf) = body.data().await {
|
||||
buf?
|
||||
} else {
|
||||
return Ok(Bytes::new());
|
||||
};
|
||||
// If there's only 1 chunk, we can just return Buf::to_bytes()
|
||||
let mut first = if let Some(buf) = body.data().await {
|
||||
buf?
|
||||
} else {
|
||||
return Ok(Bytes::new());
|
||||
};
|
||||
|
||||
let second = if let Some(buf) = body.data().await {
|
||||
buf?
|
||||
} else {
|
||||
return Ok(first.copy_to_bytes(first.remaining()));
|
||||
};
|
||||
let second = if let Some(buf) = body.data().await {
|
||||
buf?
|
||||
} else {
|
||||
return Ok(first.copy_to_bytes(first.remaining()));
|
||||
};
|
||||
|
||||
// With more than 1 buf, we gotta flatten into a Vec first.
|
||||
let cap = first.remaining() + second.remaining() + body.size_hint().lower() as usize;
|
||||
let mut vec = Vec::with_capacity(cap);
|
||||
vec.put(first);
|
||||
vec.put(second);
|
||||
// With more than 1 buf, we gotta flatten into a Vec first.
|
||||
let cap = first.remaining() + second.remaining() + body.size_hint().lower() as usize;
|
||||
let mut vec = Vec::with_capacity(cap);
|
||||
vec.put(first);
|
||||
vec.put(second);
|
||||
|
||||
while let Some(buf) = body.data().await {
|
||||
vec.put(buf?);
|
||||
}
|
||||
while let Some(buf) = body.data().await {
|
||||
vec.put(buf?);
|
||||
}
|
||||
|
||||
Ok(vec.into())
|
||||
Ok(vec.into())
|
||||
}
|
||||
|
||||
+15
-22
@@ -1,43 +1,36 @@
|
||||
use crate::Error;
|
||||
use ruma::{
|
||||
api::client::uiaa::UiaaResponse, CanonicalJsonValue, OwnedDeviceId, OwnedServerName,
|
||||
OwnedUserId,
|
||||
};
|
||||
use std::ops::Deref;
|
||||
|
||||
use ruma::{api::client::uiaa::UiaaResponse, CanonicalJsonValue, OwnedDeviceId, OwnedServerName, OwnedUserId};
|
||||
|
||||
use crate::Error;
|
||||
|
||||
#[cfg(feature = "conduit_bin")]
|
||||
mod axum;
|
||||
|
||||
/// Extractor for Ruma request structs
|
||||
pub struct Ruma<T> {
|
||||
pub body: T,
|
||||
pub sender_user: Option<OwnedUserId>,
|
||||
pub sender_device: Option<OwnedDeviceId>,
|
||||
pub sender_servername: Option<OwnedServerName>,
|
||||
// This is None when body is not a valid string
|
||||
pub json_body: Option<CanonicalJsonValue>,
|
||||
pub from_appservice: bool,
|
||||
pub body: T,
|
||||
pub sender_user: Option<OwnedUserId>,
|
||||
pub sender_device: Option<OwnedDeviceId>,
|
||||
pub sender_servername: Option<OwnedServerName>,
|
||||
// This is None when body is not a valid string
|
||||
pub json_body: Option<CanonicalJsonValue>,
|
||||
pub from_appservice: bool,
|
||||
}
|
||||
|
||||
impl<T> Deref for Ruma<T> {
|
||||
type Target = T;
|
||||
type Target = T;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.body
|
||||
}
|
||||
fn deref(&self) -> &Self::Target { &self.body }
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RumaResponse<T>(pub T);
|
||||
|
||||
impl<T> From<T> for RumaResponse<T> {
|
||||
fn from(t: T) -> Self {
|
||||
Self(t)
|
||||
}
|
||||
fn from(t: T) -> Self { Self(t) }
|
||||
}
|
||||
|
||||
impl From<Error> for RumaResponse<UiaaResponse> {
|
||||
fn from(t: Error) -> Self {
|
||||
t.to_response()
|
||||
}
|
||||
fn from(t: Error) -> Self { t.to_response() }
|
||||
}
|
||||
|
||||
+1427
-1810
File diff suppressed because it is too large
Load Diff
+388
-463
@@ -1,9 +1,9 @@
|
||||
use std::{
|
||||
collections::BTreeMap,
|
||||
fmt,
|
||||
fmt::Write as _,
|
||||
net::{IpAddr, Ipv4Addr},
|
||||
path::PathBuf,
|
||||
collections::BTreeMap,
|
||||
fmt,
|
||||
fmt::Write as _,
|
||||
net::{IpAddr, Ipv4Addr},
|
||||
path::PathBuf,
|
||||
};
|
||||
|
||||
use either::Either;
|
||||
@@ -21,539 +21,464 @@ mod proxy;
|
||||
#[derive(Deserialize, Clone, Debug)]
|
||||
#[serde(transparent)]
|
||||
pub struct ListeningPort {
|
||||
#[serde(with = "either::serde_untagged")]
|
||||
pub ports: Either<u16, Vec<u16>>,
|
||||
#[serde(with = "either::serde_untagged")]
|
||||
pub ports: Either<u16, Vec<u16>>,
|
||||
}
|
||||
|
||||
/// all the config options for conduwuit
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub struct Config {
|
||||
/// [`IpAddr`] conduwuit will listen on (can be IPv4 or IPv6)
|
||||
#[serde(default = "default_address")]
|
||||
pub address: IpAddr,
|
||||
/// default TCP port(s) conduwuit will listen on
|
||||
#[serde(default = "default_port")]
|
||||
pub port: ListeningPort,
|
||||
pub tls: Option<TlsConfig>,
|
||||
pub unix_socket_path: Option<PathBuf>,
|
||||
#[serde(default = "default_unix_socket_perms")]
|
||||
pub unix_socket_perms: u32,
|
||||
pub server_name: OwnedServerName,
|
||||
#[serde(default = "default_database_backend")]
|
||||
pub database_backend: String,
|
||||
pub database_path: String,
|
||||
#[serde(default = "default_db_cache_capacity_mb")]
|
||||
pub db_cache_capacity_mb: f64,
|
||||
#[serde(default = "default_new_user_displayname_suffix")]
|
||||
pub new_user_displayname_suffix: String,
|
||||
#[serde(default = "true_fn")]
|
||||
pub allow_check_for_updates: bool,
|
||||
#[serde(default = "default_conduit_cache_capacity_modifier")]
|
||||
pub conduit_cache_capacity_modifier: f64,
|
||||
#[serde(default = "default_pdu_cache_capacity")]
|
||||
pub pdu_cache_capacity: u32,
|
||||
#[serde(default = "default_cleanup_second_interval")]
|
||||
pub cleanup_second_interval: u32,
|
||||
#[serde(default = "default_max_request_size")]
|
||||
pub max_request_size: u32,
|
||||
#[serde(default = "default_max_concurrent_requests")]
|
||||
pub max_concurrent_requests: u16,
|
||||
#[serde(default = "default_max_fetch_prev_events")]
|
||||
pub max_fetch_prev_events: u16,
|
||||
#[serde(default)]
|
||||
pub allow_registration: bool,
|
||||
#[serde(default)]
|
||||
pub yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse: bool,
|
||||
pub registration_token: Option<String>,
|
||||
#[serde(default = "true_fn")]
|
||||
pub allow_encryption: bool,
|
||||
#[serde(default = "true_fn")]
|
||||
pub allow_federation: bool,
|
||||
#[serde(default)]
|
||||
pub allow_public_room_directory_over_federation: bool,
|
||||
#[serde(default)]
|
||||
pub allow_public_room_directory_without_auth: bool,
|
||||
#[serde(default)]
|
||||
pub allow_device_name_federation: bool,
|
||||
#[serde(default = "true_fn")]
|
||||
pub allow_room_creation: bool,
|
||||
#[serde(default = "true_fn")]
|
||||
pub allow_unstable_room_versions: bool,
|
||||
#[serde(default = "default_default_room_version")]
|
||||
pub default_room_version: RoomVersionId,
|
||||
pub well_known_client: Option<String>,
|
||||
pub well_known_server: Option<String>,
|
||||
#[serde(default)]
|
||||
pub allow_jaeger: bool,
|
||||
#[serde(default)]
|
||||
pub tracing_flame: bool,
|
||||
#[serde(default)]
|
||||
pub proxy: ProxyConfig,
|
||||
pub jwt_secret: Option<String>,
|
||||
#[serde(default = "default_trusted_servers")]
|
||||
pub trusted_servers: Vec<OwnedServerName>,
|
||||
#[serde(default = "true_fn")]
|
||||
pub query_trusted_key_servers_first: bool,
|
||||
#[serde(default = "default_log")]
|
||||
pub log: String,
|
||||
#[serde(default)]
|
||||
pub turn_username: String,
|
||||
#[serde(default)]
|
||||
pub turn_password: String,
|
||||
#[serde(default = "Vec::new")]
|
||||
pub turn_uris: Vec<String>,
|
||||
#[serde(default)]
|
||||
pub turn_secret: String,
|
||||
#[serde(default = "default_turn_ttl")]
|
||||
pub turn_ttl: u64,
|
||||
/// [`IpAddr`] conduwuit will listen on (can be IPv4 or IPv6)
|
||||
#[serde(default = "default_address")]
|
||||
pub address: IpAddr,
|
||||
/// default TCP port(s) conduwuit will listen on
|
||||
#[serde(default = "default_port")]
|
||||
pub port: ListeningPort,
|
||||
pub tls: Option<TlsConfig>,
|
||||
pub unix_socket_path: Option<PathBuf>,
|
||||
#[serde(default = "default_unix_socket_perms")]
|
||||
pub unix_socket_perms: u32,
|
||||
pub server_name: OwnedServerName,
|
||||
#[serde(default = "default_database_backend")]
|
||||
pub database_backend: String,
|
||||
pub database_path: String,
|
||||
#[serde(default = "default_db_cache_capacity_mb")]
|
||||
pub db_cache_capacity_mb: f64,
|
||||
#[serde(default = "default_new_user_displayname_suffix")]
|
||||
pub new_user_displayname_suffix: String,
|
||||
#[serde(default = "true_fn")]
|
||||
pub allow_check_for_updates: bool,
|
||||
#[serde(default = "default_conduit_cache_capacity_modifier")]
|
||||
pub conduit_cache_capacity_modifier: f64,
|
||||
#[serde(default = "default_pdu_cache_capacity")]
|
||||
pub pdu_cache_capacity: u32,
|
||||
#[serde(default = "default_cleanup_second_interval")]
|
||||
pub cleanup_second_interval: u32,
|
||||
#[serde(default = "default_max_request_size")]
|
||||
pub max_request_size: u32,
|
||||
#[serde(default = "default_max_concurrent_requests")]
|
||||
pub max_concurrent_requests: u16,
|
||||
#[serde(default = "default_max_fetch_prev_events")]
|
||||
pub max_fetch_prev_events: u16,
|
||||
#[serde(default)]
|
||||
pub allow_registration: bool,
|
||||
#[serde(default)]
|
||||
pub yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse: bool,
|
||||
pub registration_token: Option<String>,
|
||||
#[serde(default = "true_fn")]
|
||||
pub allow_encryption: bool,
|
||||
#[serde(default = "true_fn")]
|
||||
pub allow_federation: bool,
|
||||
#[serde(default)]
|
||||
pub allow_public_room_directory_over_federation: bool,
|
||||
#[serde(default)]
|
||||
pub allow_public_room_directory_without_auth: bool,
|
||||
#[serde(default)]
|
||||
pub allow_device_name_federation: bool,
|
||||
#[serde(default = "true_fn")]
|
||||
pub allow_room_creation: bool,
|
||||
#[serde(default = "true_fn")]
|
||||
pub allow_unstable_room_versions: bool,
|
||||
#[serde(default = "default_default_room_version")]
|
||||
pub default_room_version: RoomVersionId,
|
||||
pub well_known_client: Option<String>,
|
||||
pub well_known_server: Option<String>,
|
||||
#[serde(default)]
|
||||
pub allow_jaeger: bool,
|
||||
#[serde(default)]
|
||||
pub tracing_flame: bool,
|
||||
#[serde(default)]
|
||||
pub proxy: ProxyConfig,
|
||||
pub jwt_secret: Option<String>,
|
||||
#[serde(default = "default_trusted_servers")]
|
||||
pub trusted_servers: Vec<OwnedServerName>,
|
||||
#[serde(default = "true_fn")]
|
||||
pub query_trusted_key_servers_first: bool,
|
||||
#[serde(default = "default_log")]
|
||||
pub log: String,
|
||||
#[serde(default)]
|
||||
pub turn_username: String,
|
||||
#[serde(default)]
|
||||
pub turn_password: String,
|
||||
#[serde(default = "Vec::new")]
|
||||
pub turn_uris: Vec<String>,
|
||||
#[serde(default)]
|
||||
pub turn_secret: String,
|
||||
#[serde(default = "default_turn_ttl")]
|
||||
pub turn_ttl: u64,
|
||||
|
||||
#[serde(default = "default_rocksdb_log_level")]
|
||||
pub rocksdb_log_level: String,
|
||||
#[serde(default = "default_rocksdb_max_log_file_size")]
|
||||
pub rocksdb_max_log_file_size: usize,
|
||||
#[serde(default = "default_rocksdb_log_time_to_roll")]
|
||||
pub rocksdb_log_time_to_roll: usize,
|
||||
#[serde(default)]
|
||||
pub rocksdb_optimize_for_spinning_disks: bool,
|
||||
#[serde(default = "default_rocksdb_parallelism_threads")]
|
||||
pub rocksdb_parallelism_threads: usize,
|
||||
#[serde(default = "default_rocksdb_log_level")]
|
||||
pub rocksdb_log_level: String,
|
||||
#[serde(default = "default_rocksdb_max_log_file_size")]
|
||||
pub rocksdb_max_log_file_size: usize,
|
||||
#[serde(default = "default_rocksdb_log_time_to_roll")]
|
||||
pub rocksdb_log_time_to_roll: usize,
|
||||
#[serde(default)]
|
||||
pub rocksdb_optimize_for_spinning_disks: bool,
|
||||
#[serde(default = "default_rocksdb_parallelism_threads")]
|
||||
pub rocksdb_parallelism_threads: usize,
|
||||
|
||||
pub emergency_password: Option<String>,
|
||||
pub emergency_password: Option<String>,
|
||||
|
||||
#[serde(default = "default_notification_push_path")]
|
||||
pub notification_push_path: String,
|
||||
#[serde(default = "default_notification_push_path")]
|
||||
pub notification_push_path: String,
|
||||
|
||||
#[serde(default)]
|
||||
pub allow_local_presence: bool,
|
||||
#[serde(default)]
|
||||
pub allow_incoming_presence: bool,
|
||||
#[serde(default)]
|
||||
pub allow_outgoing_presence: bool,
|
||||
#[serde(default = "default_presence_idle_timeout_s")]
|
||||
pub presence_idle_timeout_s: u64,
|
||||
#[serde(default = "default_presence_offline_timeout_s")]
|
||||
pub presence_offline_timeout_s: u64,
|
||||
#[serde(default)]
|
||||
pub allow_local_presence: bool,
|
||||
#[serde(default)]
|
||||
pub allow_incoming_presence: bool,
|
||||
#[serde(default)]
|
||||
pub allow_outgoing_presence: bool,
|
||||
#[serde(default = "default_presence_idle_timeout_s")]
|
||||
pub presence_idle_timeout_s: u64,
|
||||
#[serde(default = "default_presence_offline_timeout_s")]
|
||||
pub presence_offline_timeout_s: u64,
|
||||
|
||||
#[serde(default)]
|
||||
pub zstd_compression: bool,
|
||||
#[serde(default)]
|
||||
pub zstd_compression: bool,
|
||||
|
||||
#[serde(default)]
|
||||
pub allow_guest_registration: bool,
|
||||
#[serde(default)]
|
||||
pub allow_guest_registration: bool,
|
||||
|
||||
#[serde(default = "Vec::new")]
|
||||
pub prevent_media_downloads_from: Vec<OwnedServerName>,
|
||||
#[serde(default = "Vec::new")]
|
||||
pub prevent_media_downloads_from: Vec<OwnedServerName>,
|
||||
|
||||
#[serde(default = "default_ip_range_denylist")]
|
||||
pub ip_range_denylist: Vec<String>,
|
||||
#[serde(default = "default_ip_range_denylist")]
|
||||
pub ip_range_denylist: Vec<String>,
|
||||
|
||||
#[serde(default = "Vec::new")]
|
||||
pub url_preview_domain_contains_allowlist: Vec<String>,
|
||||
#[serde(default = "Vec::new")]
|
||||
pub url_preview_domain_explicit_allowlist: Vec<String>,
|
||||
#[serde(default = "Vec::new")]
|
||||
pub url_preview_url_contains_allowlist: Vec<String>,
|
||||
#[serde(default = "default_url_preview_max_spider_size")]
|
||||
pub url_preview_max_spider_size: usize,
|
||||
#[serde(default)]
|
||||
pub url_preview_check_root_domain: bool,
|
||||
#[serde(default = "Vec::new")]
|
||||
pub url_preview_domain_contains_allowlist: Vec<String>,
|
||||
#[serde(default = "Vec::new")]
|
||||
pub url_preview_domain_explicit_allowlist: Vec<String>,
|
||||
#[serde(default = "Vec::new")]
|
||||
pub url_preview_url_contains_allowlist: Vec<String>,
|
||||
#[serde(default = "default_url_preview_max_spider_size")]
|
||||
pub url_preview_max_spider_size: usize,
|
||||
#[serde(default)]
|
||||
pub url_preview_check_root_domain: bool,
|
||||
|
||||
#[serde(default = "RegexSet::empty")]
|
||||
#[serde(with = "serde_regex")]
|
||||
pub forbidden_room_names: RegexSet,
|
||||
#[serde(default = "RegexSet::empty")]
|
||||
#[serde(with = "serde_regex")]
|
||||
pub forbidden_room_names: RegexSet,
|
||||
|
||||
#[serde(default = "RegexSet::empty")]
|
||||
#[serde(with = "serde_regex")]
|
||||
pub forbidden_usernames: RegexSet,
|
||||
#[serde(default = "RegexSet::empty")]
|
||||
#[serde(with = "serde_regex")]
|
||||
pub forbidden_usernames: RegexSet,
|
||||
|
||||
#[serde(default)]
|
||||
pub block_non_admin_invites: bool,
|
||||
#[serde(default)]
|
||||
pub block_non_admin_invites: bool,
|
||||
|
||||
#[serde(flatten)]
|
||||
pub catchall: BTreeMap<String, IgnoredAny>,
|
||||
#[serde(flatten)]
|
||||
pub catchall: BTreeMap<String, IgnoredAny>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub struct TlsConfig {
|
||||
pub certs: String,
|
||||
pub key: String,
|
||||
#[serde(default)]
|
||||
/// Whether to listen and allow for HTTP and HTTPS connections (insecure!)
|
||||
/// Only works / does something if the `axum_dual_protocol` feature flag was built
|
||||
pub dual_protocol: bool,
|
||||
pub certs: String,
|
||||
pub key: String,
|
||||
#[serde(default)]
|
||||
/// Whether to listen and allow for HTTP and HTTPS connections (insecure!)
|
||||
/// Only works / does something if the `axum_dual_protocol` feature flag was
|
||||
/// built
|
||||
pub dual_protocol: bool,
|
||||
}
|
||||
|
||||
const DEPRECATED_KEYS: &[&str] = &["cache_capacity"];
|
||||
|
||||
impl Config {
|
||||
/// Iterates over all the keys in the config file and warns if there is a deprecated key specified
|
||||
pub fn warn_deprecated(&self) {
|
||||
debug!("Checking for deprecated config keys");
|
||||
let mut was_deprecated = false;
|
||||
for key in self
|
||||
.catchall
|
||||
.keys()
|
||||
.filter(|key| DEPRECATED_KEYS.iter().any(|s| s == key))
|
||||
{
|
||||
warn!("Config parameter \"{}\" is deprecated, ignoring.", key);
|
||||
was_deprecated = true;
|
||||
}
|
||||
/// Iterates over all the keys in the config file and warns if there is a
|
||||
/// deprecated key specified
|
||||
pub fn warn_deprecated(&self) {
|
||||
debug!("Checking for deprecated config keys");
|
||||
let mut was_deprecated = false;
|
||||
for key in self.catchall.keys().filter(|key| DEPRECATED_KEYS.iter().any(|s| s == key)) {
|
||||
warn!("Config parameter \"{}\" is deprecated, ignoring.", key);
|
||||
was_deprecated = true;
|
||||
}
|
||||
|
||||
if was_deprecated {
|
||||
warn!("Read conduit documentation and check your configuration if any new configuration parameters should be adjusted");
|
||||
}
|
||||
}
|
||||
if was_deprecated {
|
||||
warn!(
|
||||
"Read conduit documentation and check your configuration if any new configuration parameters should \
|
||||
be adjusted"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// iterates over all the catchall keys (unknown config options) and warns if there are any.
|
||||
pub fn warn_unknown_key(&self) {
|
||||
debug!("Checking for unknown config keys");
|
||||
for key in self.catchall.keys().filter(
|
||||
|key| "config".to_owned().ne(key.to_owned()), /* "config" is expected */
|
||||
) {
|
||||
warn!(
|
||||
"Config parameter \"{}\" is unknown to conduwuit, ignoring.",
|
||||
key
|
||||
);
|
||||
}
|
||||
}
|
||||
/// iterates over all the catchall keys (unknown config options) and warns
|
||||
/// if there are any.
|
||||
pub fn warn_unknown_key(&self) {
|
||||
debug!("Checking for unknown config keys");
|
||||
for key in
|
||||
self.catchall.keys().filter(|key| "config".to_owned().ne(key.to_owned()) /* "config" is expected */)
|
||||
{
|
||||
warn!("Config parameter \"{}\" is unknown to conduwuit, ignoring.", key);
|
||||
}
|
||||
}
|
||||
|
||||
/// Checks the presence of the `address` and `unix_socket_path` keys in the raw_config, exiting the process if both keys were detected.
|
||||
pub fn is_dual_listening(&self, raw_config: Figment) -> bool {
|
||||
let check_address = raw_config.find_value("address");
|
||||
let check_unix_socket = raw_config.find_value("unix_socket_path");
|
||||
/// Checks the presence of the `address` and `unix_socket_path` keys in the
|
||||
/// raw_config, exiting the process if both keys were detected.
|
||||
pub fn is_dual_listening(&self, raw_config: Figment) -> bool {
|
||||
let check_address = raw_config.find_value("address");
|
||||
let check_unix_socket = raw_config.find_value("unix_socket_path");
|
||||
|
||||
// are the check_address and check_unix_socket keys both Ok (specified) at the same time?
|
||||
if check_address.is_ok() && check_unix_socket.is_ok() {
|
||||
error!("TOML keys \"address\" and \"unix_socket_path\" were both defined. Please specify only one option.");
|
||||
return true;
|
||||
}
|
||||
// are the check_address and check_unix_socket keys both Ok (specified) at the
|
||||
// same time?
|
||||
if check_address.is_ok() && check_unix_socket.is_ok() {
|
||||
error!("TOML keys \"address\" and \"unix_socket_path\" were both defined. Please specify only one option.");
|
||||
return true;
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Config {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
// Prepare a list of config values to show
|
||||
let lines = [
|
||||
("Server name", self.server_name.host()),
|
||||
("Database backend", &self.database_backend),
|
||||
("Database path", &self.database_path),
|
||||
(
|
||||
"Database cache capacity (MB)",
|
||||
&self.db_cache_capacity_mb.to_string(),
|
||||
),
|
||||
(
|
||||
"Cache capacity modifier",
|
||||
&self.conduit_cache_capacity_modifier.to_string(),
|
||||
),
|
||||
("PDU cache capacity", &self.pdu_cache_capacity.to_string()),
|
||||
(
|
||||
"Cleanup interval in seconds",
|
||||
&self.cleanup_second_interval.to_string(),
|
||||
),
|
||||
("Maximum request size (bytes)", &self.max_request_size.to_string()),
|
||||
(
|
||||
"Maximum concurrent requests",
|
||||
&self.max_concurrent_requests.to_string(),
|
||||
),
|
||||
(
|
||||
"Allow registration",
|
||||
&self.allow_registration.to_string(),
|
||||
),
|
||||
(
|
||||
"Registration token",
|
||||
match self.registration_token {
|
||||
Some(_) => "set",
|
||||
None => "not set (open registration!)",
|
||||
},
|
||||
),
|
||||
(
|
||||
"Allow guest registration (inherently false if allow registration is false)",
|
||||
&self.allow_guest_registration.to_string(),
|
||||
),
|
||||
(
|
||||
"New user display name suffix",
|
||||
&self.new_user_displayname_suffix,
|
||||
),
|
||||
("Allow encryption", &self.allow_encryption.to_string()),
|
||||
("Allow federation", &self.allow_federation.to_string()),
|
||||
(
|
||||
"Allow incoming federated presence requests (updates)",
|
||||
&self.allow_incoming_presence.to_string(),
|
||||
),
|
||||
(
|
||||
"Allow outgoing federated presence requests (updates)",
|
||||
&self.allow_outgoing_presence.to_string(),
|
||||
),
|
||||
(
|
||||
"Allow local presence requests (updates)",
|
||||
&self.allow_local_presence.to_string(),
|
||||
),
|
||||
(
|
||||
"Block non-admin room invites (local and remote, admins can still send and receive invites)",
|
||||
&self.block_non_admin_invites.to_string(),
|
||||
),
|
||||
(
|
||||
"Allow device name federation",
|
||||
&self.allow_device_name_federation.to_string(),
|
||||
),
|
||||
("Notification push path", &self.notification_push_path),
|
||||
("Allow room creation", &self.allow_room_creation.to_string()),
|
||||
(
|
||||
"Allow public room directory over federation",
|
||||
&self.allow_public_room_directory_over_federation.to_string(),
|
||||
),
|
||||
(
|
||||
"Allow public room directory without authentication",
|
||||
&self.allow_public_room_directory_without_auth.to_string(),
|
||||
),
|
||||
(
|
||||
"JWT secret",
|
||||
match self.jwt_secret {
|
||||
Some(_) => "set",
|
||||
None => "not set",
|
||||
},
|
||||
),
|
||||
("Trusted servers", {
|
||||
let mut lst = vec![];
|
||||
for server in &self.trusted_servers {
|
||||
lst.push(server.host());
|
||||
}
|
||||
&lst.join(", ")
|
||||
}),
|
||||
(
|
||||
"Query Trusted Key Servers First",
|
||||
&self.query_trusted_key_servers_first.to_string(),
|
||||
),
|
||||
(
|
||||
"TURN username",
|
||||
if self.turn_username.is_empty() {
|
||||
"not set"
|
||||
} else {
|
||||
&self.turn_username
|
||||
},
|
||||
),
|
||||
("TURN password", {
|
||||
if self.turn_password.is_empty() {
|
||||
"not set"
|
||||
} else {
|
||||
"set"
|
||||
}
|
||||
}),
|
||||
("TURN secret", {
|
||||
if self.turn_secret.is_empty() {
|
||||
"not set"
|
||||
} else {
|
||||
"set"
|
||||
}
|
||||
}),
|
||||
("Turn TTL", &self.turn_ttl.to_string()),
|
||||
("Turn URIs", {
|
||||
let mut lst = vec![];
|
||||
for item in self.turn_uris.iter().cloned().enumerate() {
|
||||
let (_, uri): (usize, String) = item;
|
||||
lst.push(uri);
|
||||
}
|
||||
&lst.join(", ")
|
||||
}),
|
||||
(
|
||||
"zstd Response Body Compression",
|
||||
&self.zstd_compression.to_string(),
|
||||
),
|
||||
("RocksDB database log level", &self.rocksdb_log_level),
|
||||
(
|
||||
"RocksDB database log time-to-roll",
|
||||
&self.rocksdb_log_time_to_roll.to_string(),
|
||||
),
|
||||
(
|
||||
"RocksDB database max log file size",
|
||||
&self.rocksdb_max_log_file_size.to_string(),
|
||||
),
|
||||
(
|
||||
"RocksDB database optimize for spinning disks",
|
||||
&self.rocksdb_optimize_for_spinning_disks.to_string(),
|
||||
),
|
||||
(
|
||||
"RocksDB Parallelism Threads",
|
||||
&self.rocksdb_parallelism_threads.to_string(),
|
||||
),
|
||||
("Prevent Media Downloads From", {
|
||||
let mut lst = vec![];
|
||||
for domain in &self.prevent_media_downloads_from {
|
||||
lst.push(domain.host());
|
||||
}
|
||||
&lst.join(", ")
|
||||
}),
|
||||
("Outbound Request IP Range Denylist", {
|
||||
let mut lst = vec![];
|
||||
for item in self.ip_range_denylist.iter().cloned().enumerate() {
|
||||
let (_, ip): (usize, String) = item;
|
||||
lst.push(ip);
|
||||
}
|
||||
&lst.join(", ")
|
||||
}),
|
||||
("Forbidden usernames", {
|
||||
&self.forbidden_usernames.patterns().iter().join(", ")
|
||||
}),
|
||||
("Forbidden room names", {
|
||||
&self.forbidden_room_names.patterns().iter().join(", ")
|
||||
}),
|
||||
(
|
||||
"URL preview domain contains allowlist",
|
||||
&self.url_preview_domain_contains_allowlist.join(", "),
|
||||
),
|
||||
(
|
||||
"URL preview domain explicit allowlist",
|
||||
&self.url_preview_domain_explicit_allowlist.join(", "),
|
||||
),
|
||||
(
|
||||
"URL preview URL contains allowlist",
|
||||
&self.url_preview_url_contains_allowlist.join(", "),
|
||||
),
|
||||
(
|
||||
"URL preview maximum spider size",
|
||||
&self.url_preview_max_spider_size.to_string(),
|
||||
),
|
||||
(
|
||||
"URL preview check root domain",
|
||||
&self.url_preview_check_root_domain.to_string(),
|
||||
),
|
||||
];
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
// Prepare a list of config values to show
|
||||
let lines = [
|
||||
("Server name", self.server_name.host()),
|
||||
("Database backend", &self.database_backend),
|
||||
("Database path", &self.database_path),
|
||||
("Database cache capacity (MB)", &self.db_cache_capacity_mb.to_string()),
|
||||
("Cache capacity modifier", &self.conduit_cache_capacity_modifier.to_string()),
|
||||
("PDU cache capacity", &self.pdu_cache_capacity.to_string()),
|
||||
("Cleanup interval in seconds", &self.cleanup_second_interval.to_string()),
|
||||
("Maximum request size (bytes)", &self.max_request_size.to_string()),
|
||||
("Maximum concurrent requests", &self.max_concurrent_requests.to_string()),
|
||||
("Allow registration", &self.allow_registration.to_string()),
|
||||
(
|
||||
"Registration token",
|
||||
match self.registration_token {
|
||||
Some(_) => "set",
|
||||
None => "not set (open registration!)",
|
||||
},
|
||||
),
|
||||
(
|
||||
"Allow guest registration (inherently false if allow registration is false)",
|
||||
&self.allow_guest_registration.to_string(),
|
||||
),
|
||||
("New user display name suffix", &self.new_user_displayname_suffix),
|
||||
("Allow encryption", &self.allow_encryption.to_string()),
|
||||
("Allow federation", &self.allow_federation.to_string()),
|
||||
(
|
||||
"Allow incoming federated presence requests (updates)",
|
||||
&self.allow_incoming_presence.to_string(),
|
||||
),
|
||||
(
|
||||
"Allow outgoing federated presence requests (updates)",
|
||||
&self.allow_outgoing_presence.to_string(),
|
||||
),
|
||||
(
|
||||
"Allow local presence requests (updates)",
|
||||
&self.allow_local_presence.to_string(),
|
||||
),
|
||||
(
|
||||
"Block non-admin room invites (local and remote, admins can still send and receive invites)",
|
||||
&self.block_non_admin_invites.to_string(),
|
||||
),
|
||||
("Allow device name federation", &self.allow_device_name_federation.to_string()),
|
||||
("Notification push path", &self.notification_push_path),
|
||||
("Allow room creation", &self.allow_room_creation.to_string()),
|
||||
(
|
||||
"Allow public room directory over federation",
|
||||
&self.allow_public_room_directory_over_federation.to_string(),
|
||||
),
|
||||
(
|
||||
"Allow public room directory without authentication",
|
||||
&self.allow_public_room_directory_without_auth.to_string(),
|
||||
),
|
||||
(
|
||||
"JWT secret",
|
||||
match self.jwt_secret {
|
||||
Some(_) => "set",
|
||||
None => "not set",
|
||||
},
|
||||
),
|
||||
("Trusted servers", {
|
||||
let mut lst = vec![];
|
||||
for server in &self.trusted_servers {
|
||||
lst.push(server.host());
|
||||
}
|
||||
&lst.join(", ")
|
||||
}),
|
||||
(
|
||||
"Query Trusted Key Servers First",
|
||||
&self.query_trusted_key_servers_first.to_string(),
|
||||
),
|
||||
(
|
||||
"TURN username",
|
||||
if self.turn_username.is_empty() {
|
||||
"not set"
|
||||
} else {
|
||||
&self.turn_username
|
||||
},
|
||||
),
|
||||
("TURN password", {
|
||||
if self.turn_password.is_empty() {
|
||||
"not set"
|
||||
} else {
|
||||
"set"
|
||||
}
|
||||
}),
|
||||
("TURN secret", {
|
||||
if self.turn_secret.is_empty() {
|
||||
"not set"
|
||||
} else {
|
||||
"set"
|
||||
}
|
||||
}),
|
||||
("Turn TTL", &self.turn_ttl.to_string()),
|
||||
("Turn URIs", {
|
||||
let mut lst = vec![];
|
||||
for item in self.turn_uris.iter().cloned().enumerate() {
|
||||
let (_, uri): (usize, String) = item;
|
||||
lst.push(uri);
|
||||
}
|
||||
&lst.join(", ")
|
||||
}),
|
||||
("zstd Response Body Compression", &self.zstd_compression.to_string()),
|
||||
("RocksDB database log level", &self.rocksdb_log_level),
|
||||
("RocksDB database log time-to-roll", &self.rocksdb_log_time_to_roll.to_string()),
|
||||
(
|
||||
"RocksDB database max log file size",
|
||||
&self.rocksdb_max_log_file_size.to_string(),
|
||||
),
|
||||
(
|
||||
"RocksDB database optimize for spinning disks",
|
||||
&self.rocksdb_optimize_for_spinning_disks.to_string(),
|
||||
),
|
||||
("RocksDB Parallelism Threads", &self.rocksdb_parallelism_threads.to_string()),
|
||||
("Prevent Media Downloads From", {
|
||||
let mut lst = vec![];
|
||||
for domain in &self.prevent_media_downloads_from {
|
||||
lst.push(domain.host());
|
||||
}
|
||||
&lst.join(", ")
|
||||
}),
|
||||
("Outbound Request IP Range Denylist", {
|
||||
let mut lst = vec![];
|
||||
for item in self.ip_range_denylist.iter().cloned().enumerate() {
|
||||
let (_, ip): (usize, String) = item;
|
||||
lst.push(ip);
|
||||
}
|
||||
&lst.join(", ")
|
||||
}),
|
||||
("Forbidden usernames", {
|
||||
&self.forbidden_usernames.patterns().iter().join(", ")
|
||||
}),
|
||||
("Forbidden room names", {
|
||||
&self.forbidden_room_names.patterns().iter().join(", ")
|
||||
}),
|
||||
(
|
||||
"URL preview domain contains allowlist",
|
||||
&self.url_preview_domain_contains_allowlist.join(", "),
|
||||
),
|
||||
(
|
||||
"URL preview domain explicit allowlist",
|
||||
&self.url_preview_domain_explicit_allowlist.join(", "),
|
||||
),
|
||||
(
|
||||
"URL preview URL contains allowlist",
|
||||
&self.url_preview_url_contains_allowlist.join(", "),
|
||||
),
|
||||
("URL preview maximum spider size", &self.url_preview_max_spider_size.to_string()),
|
||||
("URL preview check root domain", &self.url_preview_check_root_domain.to_string()),
|
||||
];
|
||||
|
||||
let mut msg: String = "Active config values:\n\n".to_owned();
|
||||
let mut msg: String = "Active config values:\n\n".to_owned();
|
||||
|
||||
for line in lines.into_iter().enumerate() {
|
||||
let _ = writeln!(msg, "{}: {}", line.1 .0, line.1 .1);
|
||||
}
|
||||
for line in lines.into_iter().enumerate() {
|
||||
let _ = writeln!(msg, "{}: {}", line.1 .0, line.1 .1);
|
||||
}
|
||||
|
||||
write!(f, "{msg}")
|
||||
}
|
||||
write!(f, "{msg}")
|
||||
}
|
||||
}
|
||||
|
||||
fn true_fn() -> bool {
|
||||
true
|
||||
}
|
||||
fn true_fn() -> bool { true }
|
||||
|
||||
fn default_address() -> IpAddr {
|
||||
Ipv4Addr::LOCALHOST.into()
|
||||
}
|
||||
fn default_address() -> IpAddr { Ipv4Addr::LOCALHOST.into() }
|
||||
|
||||
fn default_port() -> ListeningPort {
|
||||
ListeningPort {
|
||||
ports: Either::Left(8008),
|
||||
}
|
||||
ListeningPort {
|
||||
ports: Either::Left(8008),
|
||||
}
|
||||
}
|
||||
|
||||
fn default_unix_socket_perms() -> u32 {
|
||||
660
|
||||
}
|
||||
fn default_unix_socket_perms() -> u32 { 660 }
|
||||
|
||||
fn default_database_backend() -> String {
|
||||
"rocksdb".to_owned()
|
||||
}
|
||||
fn default_database_backend() -> String { "rocksdb".to_owned() }
|
||||
|
||||
fn default_db_cache_capacity_mb() -> f64 {
|
||||
300.0
|
||||
}
|
||||
fn default_db_cache_capacity_mb() -> f64 { 300.0 }
|
||||
|
||||
fn default_conduit_cache_capacity_modifier() -> f64 {
|
||||
1.0
|
||||
}
|
||||
fn default_conduit_cache_capacity_modifier() -> f64 { 1.0 }
|
||||
|
||||
fn default_pdu_cache_capacity() -> u32 {
|
||||
150_000
|
||||
}
|
||||
fn default_pdu_cache_capacity() -> u32 { 150_000 }
|
||||
|
||||
fn default_cleanup_second_interval() -> u32 {
|
||||
60 // every minute
|
||||
60 // every minute
|
||||
}
|
||||
|
||||
fn default_max_request_size() -> u32 {
|
||||
20 * 1024 * 1024 // Default to 20 MB
|
||||
20 * 1024 * 1024 // Default to 20 MB
|
||||
}
|
||||
|
||||
fn default_max_concurrent_requests() -> u16 {
|
||||
500
|
||||
}
|
||||
fn default_max_concurrent_requests() -> u16 { 500 }
|
||||
|
||||
fn default_max_fetch_prev_events() -> u16 {
|
||||
100_u16
|
||||
}
|
||||
fn default_max_fetch_prev_events() -> u16 { 100_u16 }
|
||||
|
||||
fn default_trusted_servers() -> Vec<OwnedServerName> {
|
||||
vec![OwnedServerName::try_from("matrix.org").unwrap()]
|
||||
}
|
||||
fn default_trusted_servers() -> Vec<OwnedServerName> { vec![OwnedServerName::try_from("matrix.org").unwrap()] }
|
||||
|
||||
fn default_log() -> String {
|
||||
"warn,state_res=warn".to_owned()
|
||||
}
|
||||
fn default_log() -> String { "warn,state_res=warn".to_owned() }
|
||||
|
||||
fn default_notification_push_path() -> String {
|
||||
"/_matrix/push/v1/notify".to_owned()
|
||||
}
|
||||
fn default_notification_push_path() -> String { "/_matrix/push/v1/notify".to_owned() }
|
||||
|
||||
fn default_turn_ttl() -> u64 {
|
||||
60 * 60 * 24
|
||||
}
|
||||
fn default_turn_ttl() -> u64 { 60 * 60 * 24 }
|
||||
|
||||
fn default_presence_idle_timeout_s() -> u64 {
|
||||
5 * 60
|
||||
}
|
||||
fn default_presence_idle_timeout_s() -> u64 { 5 * 60 }
|
||||
|
||||
fn default_presence_offline_timeout_s() -> u64 {
|
||||
30 * 60
|
||||
}
|
||||
fn default_presence_offline_timeout_s() -> u64 { 30 * 60 }
|
||||
|
||||
fn default_rocksdb_log_level() -> String {
|
||||
"warn".to_owned()
|
||||
}
|
||||
fn default_rocksdb_log_level() -> String { "warn".to_owned() }
|
||||
|
||||
fn default_rocksdb_log_time_to_roll() -> usize {
|
||||
0
|
||||
}
|
||||
fn default_rocksdb_log_time_to_roll() -> usize { 0 }
|
||||
|
||||
fn default_rocksdb_parallelism_threads() -> usize {
|
||||
num_cpus::get_physical() / 2
|
||||
}
|
||||
fn default_rocksdb_parallelism_threads() -> usize { num_cpus::get_physical() / 2 }
|
||||
|
||||
// I know, it's a great name
|
||||
pub(crate) fn default_default_room_version() -> RoomVersionId {
|
||||
RoomVersionId::V10
|
||||
}
|
||||
pub(crate) fn default_default_room_version() -> RoomVersionId { RoomVersionId::V10 }
|
||||
|
||||
fn default_rocksdb_max_log_file_size() -> usize {
|
||||
// 4 megabytes
|
||||
4 * 1024 * 1024
|
||||
// 4 megabytes
|
||||
4 * 1024 * 1024
|
||||
}
|
||||
|
||||
fn default_ip_range_denylist() -> Vec<String> {
|
||||
vec![
|
||||
"127.0.0.0/8".to_owned(),
|
||||
"10.0.0.0/8".to_owned(),
|
||||
"172.16.0.0/12".to_owned(),
|
||||
"192.168.0.0/16".to_owned(),
|
||||
"100.64.0.0/10".to_owned(),
|
||||
"192.0.0.0/24".to_owned(),
|
||||
"169.254.0.0/16".to_owned(),
|
||||
"192.88.99.0/24".to_owned(),
|
||||
"198.18.0.0/15".to_owned(),
|
||||
"192.0.2.0/24".to_owned(),
|
||||
"198.51.100.0/24".to_owned(),
|
||||
"203.0.113.0/24".to_owned(),
|
||||
"224.0.0.0/4".to_owned(),
|
||||
"::1/128".to_owned(),
|
||||
"fe80::/10".to_owned(),
|
||||
"fc00::/7".to_owned(),
|
||||
"2001:db8::/32".to_owned(),
|
||||
"ff00::/8".to_owned(),
|
||||
"fec0::/10".to_owned(),
|
||||
]
|
||||
vec![
|
||||
"127.0.0.0/8".to_owned(),
|
||||
"10.0.0.0/8".to_owned(),
|
||||
"172.16.0.0/12".to_owned(),
|
||||
"192.168.0.0/16".to_owned(),
|
||||
"100.64.0.0/10".to_owned(),
|
||||
"192.0.0.0/24".to_owned(),
|
||||
"169.254.0.0/16".to_owned(),
|
||||
"192.88.99.0/24".to_owned(),
|
||||
"198.18.0.0/15".to_owned(),
|
||||
"192.0.2.0/24".to_owned(),
|
||||
"198.51.100.0/24".to_owned(),
|
||||
"203.0.113.0/24".to_owned(),
|
||||
"224.0.0.0/4".to_owned(),
|
||||
"::1/128".to_owned(),
|
||||
"fe80::/10".to_owned(),
|
||||
"fc00::/7".to_owned(),
|
||||
"2001:db8::/32".to_owned(),
|
||||
"ff00::/8".to_owned(),
|
||||
"fec0::/10".to_owned(),
|
||||
]
|
||||
}
|
||||
|
||||
fn default_url_preview_max_spider_size() -> usize {
|
||||
1_000_000 // 1MB
|
||||
1_000_000 // 1MB
|
||||
}
|
||||
|
||||
fn default_new_user_displayname_suffix() -> String {
|
||||
"🏳️⚧️".to_owned()
|
||||
}
|
||||
fn default_new_user_displayname_suffix() -> String { "🏳️⚧️".to_owned() }
|
||||
|
||||
+98
-93
@@ -24,119 +24,124 @@ use crate::Result;
|
||||
/// ## Include vs. Exclude
|
||||
/// If include is an empty list, it is assumed to be `["*"]`.
|
||||
///
|
||||
/// If a domain matches both the exclude and include list, the proxy will only be used if it was
|
||||
/// included because of a more specific rule than it was excluded. In the above example, the proxy
|
||||
/// would be used for `ordinary.onion`, `matrix.myspecial.onion`, but not `hello.myspecial.onion`.
|
||||
/// If a domain matches both the exclude and include list, the proxy will only
|
||||
/// be used if it was included because of a more specific rule than it was
|
||||
/// excluded. In the above example, the proxy would be used for
|
||||
/// `ordinary.onion`, `matrix.myspecial.onion`, but not `hello.myspecial.onion`.
|
||||
#[derive(Clone, Default, Debug, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ProxyConfig {
|
||||
#[default]
|
||||
None,
|
||||
Global {
|
||||
#[serde(deserialize_with = "crate::utils::deserialize_from_str")]
|
||||
url: Url,
|
||||
},
|
||||
ByDomain(Vec<PartialProxyConfig>),
|
||||
#[default]
|
||||
None,
|
||||
Global {
|
||||
#[serde(deserialize_with = "crate::utils::deserialize_from_str")]
|
||||
url: Url,
|
||||
},
|
||||
ByDomain(Vec<PartialProxyConfig>),
|
||||
}
|
||||
impl ProxyConfig {
|
||||
pub fn to_proxy(&self) -> Result<Option<Proxy>> {
|
||||
Ok(match self.clone() {
|
||||
ProxyConfig::None => None,
|
||||
ProxyConfig::Global { url } => Some(Proxy::all(url)?),
|
||||
ProxyConfig::ByDomain(proxies) => Some(Proxy::custom(move |url| {
|
||||
proxies.iter().find_map(|proxy| proxy.for_url(url)).cloned() // first matching proxy
|
||||
})),
|
||||
})
|
||||
}
|
||||
pub fn to_proxy(&self) -> Result<Option<Proxy>> {
|
||||
Ok(match self.clone() {
|
||||
ProxyConfig::None => None,
|
||||
ProxyConfig::Global {
|
||||
url,
|
||||
} => Some(Proxy::all(url)?),
|
||||
ProxyConfig::ByDomain(proxies) => Some(Proxy::custom(move |url| {
|
||||
proxies.iter().find_map(|proxy| proxy.for_url(url)).cloned() // first matching
|
||||
// proxy
|
||||
})),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub struct PartialProxyConfig {
|
||||
#[serde(deserialize_with = "crate::utils::deserialize_from_str")]
|
||||
url: Url,
|
||||
#[serde(default)]
|
||||
include: Vec<WildCardedDomain>,
|
||||
#[serde(default)]
|
||||
exclude: Vec<WildCardedDomain>,
|
||||
#[serde(deserialize_with = "crate::utils::deserialize_from_str")]
|
||||
url: Url,
|
||||
#[serde(default)]
|
||||
include: Vec<WildCardedDomain>,
|
||||
#[serde(default)]
|
||||
exclude: Vec<WildCardedDomain>,
|
||||
}
|
||||
impl PartialProxyConfig {
|
||||
pub fn for_url(&self, url: &Url) -> Option<&Url> {
|
||||
let domain = url.domain()?;
|
||||
let mut included_because = None; // most specific reason it was included
|
||||
let mut excluded_because = None; // most specific reason it was excluded
|
||||
if self.include.is_empty() {
|
||||
// treat empty include list as `*`
|
||||
included_because = Some(&WildCardedDomain::WildCard);
|
||||
}
|
||||
for wc_domain in &self.include {
|
||||
if wc_domain.matches(domain) {
|
||||
match included_because {
|
||||
Some(prev) if !wc_domain.more_specific_than(prev) => (),
|
||||
_ => included_because = Some(wc_domain),
|
||||
}
|
||||
}
|
||||
}
|
||||
for wc_domain in &self.exclude {
|
||||
if wc_domain.matches(domain) {
|
||||
match excluded_because {
|
||||
Some(prev) if !wc_domain.more_specific_than(prev) => (),
|
||||
_ => excluded_because = Some(wc_domain),
|
||||
}
|
||||
}
|
||||
}
|
||||
match (included_because, excluded_because) {
|
||||
(Some(a), Some(b)) if a.more_specific_than(b) => Some(&self.url), // included for a more specific reason than excluded
|
||||
(Some(_), None) => Some(&self.url),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
pub fn for_url(&self, url: &Url) -> Option<&Url> {
|
||||
let domain = url.domain()?;
|
||||
let mut included_because = None; // most specific reason it was included
|
||||
let mut excluded_because = None; // most specific reason it was excluded
|
||||
if self.include.is_empty() {
|
||||
// treat empty include list as `*`
|
||||
included_because = Some(&WildCardedDomain::WildCard);
|
||||
}
|
||||
for wc_domain in &self.include {
|
||||
if wc_domain.matches(domain) {
|
||||
match included_because {
|
||||
Some(prev) if !wc_domain.more_specific_than(prev) => (),
|
||||
_ => included_because = Some(wc_domain),
|
||||
}
|
||||
}
|
||||
}
|
||||
for wc_domain in &self.exclude {
|
||||
if wc_domain.matches(domain) {
|
||||
match excluded_because {
|
||||
Some(prev) if !wc_domain.more_specific_than(prev) => (),
|
||||
_ => excluded_because = Some(wc_domain),
|
||||
}
|
||||
}
|
||||
}
|
||||
match (included_because, excluded_because) {
|
||||
(Some(a), Some(b)) if a.more_specific_than(b) => Some(&self.url), /* included for a more specific reason */
|
||||
// than excluded
|
||||
(Some(_), None) => Some(&self.url),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A domain name, that optionally allows a * as its first subdomain.
|
||||
#[derive(Clone, Debug)]
|
||||
enum WildCardedDomain {
|
||||
WildCard,
|
||||
WildCarded(String),
|
||||
Exact(String),
|
||||
WildCard,
|
||||
WildCarded(String),
|
||||
Exact(String),
|
||||
}
|
||||
impl WildCardedDomain {
|
||||
fn matches(&self, domain: &str) -> bool {
|
||||
match self {
|
||||
WildCardedDomain::WildCard => true,
|
||||
WildCardedDomain::WildCarded(d) => domain.ends_with(d),
|
||||
WildCardedDomain::Exact(d) => domain == d,
|
||||
}
|
||||
}
|
||||
fn more_specific_than(&self, other: &Self) -> bool {
|
||||
match (self, other) {
|
||||
(WildCardedDomain::WildCard, WildCardedDomain::WildCard) => false,
|
||||
(_, WildCardedDomain::WildCard) => true,
|
||||
(WildCardedDomain::Exact(a), WildCardedDomain::WildCarded(_)) => other.matches(a),
|
||||
(WildCardedDomain::WildCarded(a), WildCardedDomain::WildCarded(b)) => {
|
||||
a != b && a.ends_with(b)
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
fn matches(&self, domain: &str) -> bool {
|
||||
match self {
|
||||
WildCardedDomain::WildCard => true,
|
||||
WildCardedDomain::WildCarded(d) => domain.ends_with(d),
|
||||
WildCardedDomain::Exact(d) => domain == d,
|
||||
}
|
||||
}
|
||||
|
||||
fn more_specific_than(&self, other: &Self) -> bool {
|
||||
match (self, other) {
|
||||
(WildCardedDomain::WildCard, WildCardedDomain::WildCard) => false,
|
||||
(_, WildCardedDomain::WildCard) => true,
|
||||
(WildCardedDomain::Exact(a), WildCardedDomain::WildCarded(_)) => other.matches(a),
|
||||
(WildCardedDomain::WildCarded(a), WildCardedDomain::WildCarded(b)) => a != b && a.ends_with(b),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
impl std::str::FromStr for WildCardedDomain {
|
||||
type Err = std::convert::Infallible;
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
// maybe do some domain validation?
|
||||
Ok(if s.starts_with("*.") {
|
||||
WildCardedDomain::WildCarded(s[1..].to_owned())
|
||||
} else if s == "*" {
|
||||
WildCardedDomain::WildCarded("".to_owned())
|
||||
} else {
|
||||
WildCardedDomain::Exact(s.to_owned())
|
||||
})
|
||||
}
|
||||
type Err = std::convert::Infallible;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
// maybe do some domain validation?
|
||||
Ok(if s.starts_with("*.") {
|
||||
WildCardedDomain::WildCarded(s[1..].to_owned())
|
||||
} else if s == "*" {
|
||||
WildCardedDomain::WildCarded("".to_owned())
|
||||
} else {
|
||||
WildCardedDomain::Exact(s.to_owned())
|
||||
})
|
||||
}
|
||||
}
|
||||
impl<'de> Deserialize<'de> for WildCardedDomain {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::de::Deserializer<'de>,
|
||||
{
|
||||
crate::utils::deserialize_from_str(deserializer)
|
||||
}
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::de::Deserializer<'de>,
|
||||
{
|
||||
crate::utils::deserialize_from_str(deserializer)
|
||||
}
|
||||
}
|
||||
|
||||
+29
-38
@@ -1,8 +1,8 @@
|
||||
use std::{future::Future, pin::Pin, sync::Arc};
|
||||
|
||||
use super::Config;
|
||||
use crate::Result;
|
||||
|
||||
use std::{future::Future, pin::Pin, sync::Arc};
|
||||
|
||||
#[cfg(feature = "sqlite")]
|
||||
pub mod sqlite;
|
||||
|
||||
@@ -13,53 +13,44 @@ pub(crate) mod rocksdb;
|
||||
pub(crate) mod watchers;
|
||||
|
||||
pub(crate) trait KeyValueDatabaseEngine: Send + Sync {
|
||||
fn open(config: &Config) -> Result<Self>
|
||||
where
|
||||
Self: Sized;
|
||||
fn open_tree(&self, name: &'static str) -> Result<Arc<dyn KvTree>>;
|
||||
fn flush(&self) -> Result<()>;
|
||||
fn cleanup(&self) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
fn memory_usage(&self) -> Result<String> {
|
||||
Ok("Current database engine does not support memory usage reporting.".to_owned())
|
||||
}
|
||||
fn open(config: &Config) -> Result<Self>
|
||||
where
|
||||
Self: Sized;
|
||||
fn open_tree(&self, name: &'static str) -> Result<Arc<dyn KvTree>>;
|
||||
fn flush(&self) -> Result<()>;
|
||||
fn cleanup(&self) -> Result<()> { Ok(()) }
|
||||
fn memory_usage(&self) -> Result<String> {
|
||||
Ok("Current database engine does not support memory usage reporting.".to_owned())
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn clear_caches(&self) {}
|
||||
#[allow(dead_code)]
|
||||
fn clear_caches(&self) {}
|
||||
}
|
||||
|
||||
pub(crate) trait KvTree: Send + Sync {
|
||||
fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>>;
|
||||
fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>>;
|
||||
|
||||
fn insert(&self, key: &[u8], value: &[u8]) -> Result<()>;
|
||||
fn insert_batch(&self, iter: &mut dyn Iterator<Item = (Vec<u8>, Vec<u8>)>) -> Result<()>;
|
||||
fn insert(&self, key: &[u8], value: &[u8]) -> Result<()>;
|
||||
fn insert_batch(&self, iter: &mut dyn Iterator<Item = (Vec<u8>, Vec<u8>)>) -> Result<()>;
|
||||
|
||||
fn remove(&self, key: &[u8]) -> Result<()>;
|
||||
fn remove(&self, key: &[u8]) -> Result<()>;
|
||||
|
||||
fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a>;
|
||||
fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a>;
|
||||
|
||||
fn iter_from<'a>(
|
||||
&'a self,
|
||||
from: &[u8],
|
||||
backwards: bool,
|
||||
) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a>;
|
||||
fn iter_from<'a>(&'a self, from: &[u8], backwards: bool) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a>;
|
||||
|
||||
fn increment(&self, key: &[u8]) -> Result<Vec<u8>>;
|
||||
fn increment_batch(&self, iter: &mut dyn Iterator<Item = Vec<u8>>) -> Result<()>;
|
||||
fn increment(&self, key: &[u8]) -> Result<Vec<u8>>;
|
||||
fn increment_batch(&self, iter: &mut dyn Iterator<Item = Vec<u8>>) -> Result<()>;
|
||||
|
||||
fn scan_prefix<'a>(
|
||||
&'a self,
|
||||
prefix: Vec<u8>,
|
||||
) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a>;
|
||||
fn scan_prefix<'a>(&'a self, prefix: Vec<u8>) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a>;
|
||||
|
||||
fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>>;
|
||||
fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>>;
|
||||
|
||||
fn clear(&self) -> Result<()> {
|
||||
for (key, _) in self.iter() {
|
||||
self.remove(&key)?;
|
||||
}
|
||||
fn clear(&self) -> Result<()> {
|
||||
for (key, _) in self.iter() {
|
||||
self.remove(&key)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
+202
-230
@@ -1,293 +1,265 @@
|
||||
use std::{
|
||||
future::Future,
|
||||
pin::Pin,
|
||||
sync::{Arc, RwLock},
|
||||
future::Future,
|
||||
pin::Pin,
|
||||
sync::{Arc, RwLock},
|
||||
};
|
||||
|
||||
use rocksdb::LogLevel::{Debug, Error, Fatal, Info, Warn};
|
||||
use tracing::{debug, info};
|
||||
|
||||
use super::{super::Config, watchers::Watchers, KeyValueDatabaseEngine, KvTree};
|
||||
use crate::{utils, Result};
|
||||
|
||||
use super::{super::Config, watchers::Watchers, KeyValueDatabaseEngine, KvTree};
|
||||
|
||||
pub(crate) struct Engine {
|
||||
rocks: rocksdb::DBWithThreadMode<rocksdb::MultiThreaded>,
|
||||
cache: rocksdb::Cache,
|
||||
old_cfs: Vec<String>,
|
||||
config: Config,
|
||||
rocks: rocksdb::DBWithThreadMode<rocksdb::MultiThreaded>,
|
||||
cache: rocksdb::Cache,
|
||||
old_cfs: Vec<String>,
|
||||
config: Config,
|
||||
}
|
||||
|
||||
struct RocksDbEngineTree<'a> {
|
||||
db: Arc<Engine>,
|
||||
name: &'a str,
|
||||
watchers: Watchers,
|
||||
write_lock: RwLock<()>,
|
||||
db: Arc<Engine>,
|
||||
name: &'a str,
|
||||
watchers: Watchers,
|
||||
write_lock: RwLock<()>,
|
||||
}
|
||||
|
||||
fn db_options(rocksdb_cache: &rocksdb::Cache, config: &Config) -> rocksdb::Options {
|
||||
// block-based options: https://docs.rs/rocksdb/latest/rocksdb/struct.BlockBasedOptions.html#
|
||||
let mut block_based_options = rocksdb::BlockBasedOptions::default();
|
||||
// block-based options: https://docs.rs/rocksdb/latest/rocksdb/struct.BlockBasedOptions.html#
|
||||
let mut block_based_options = rocksdb::BlockBasedOptions::default();
|
||||
|
||||
block_based_options.set_block_cache(rocksdb_cache);
|
||||
block_based_options.set_block_cache(rocksdb_cache);
|
||||
|
||||
// "Difference of spinning disk"
|
||||
// https://zhangyuchi.gitbooks.io/rocksdbbook/content/RocksDB-Tuning-Guide.html
|
||||
block_based_options.set_block_size(64 * 1024);
|
||||
block_based_options.set_cache_index_and_filter_blocks(true);
|
||||
// "Difference of spinning disk"
|
||||
// https://zhangyuchi.gitbooks.io/rocksdbbook/content/RocksDB-Tuning-Guide.html
|
||||
block_based_options.set_block_size(64 * 1024);
|
||||
block_based_options.set_cache_index_and_filter_blocks(true);
|
||||
|
||||
// database options: https://docs.rs/rocksdb/latest/rocksdb/struct.Options.html#
|
||||
let mut db_opts = rocksdb::Options::default();
|
||||
// database options: https://docs.rs/rocksdb/latest/rocksdb/struct.Options.html#
|
||||
let mut db_opts = rocksdb::Options::default();
|
||||
|
||||
let rocksdb_log_level = match config.rocksdb_log_level.as_ref() {
|
||||
"debug" => Debug,
|
||||
"info" => Info,
|
||||
"error" => Error,
|
||||
"fatal" => Fatal,
|
||||
_ => Warn,
|
||||
};
|
||||
let rocksdb_log_level = match config.rocksdb_log_level.as_ref() {
|
||||
"debug" => Debug,
|
||||
"info" => Info,
|
||||
"error" => Error,
|
||||
"fatal" => Fatal,
|
||||
_ => Warn,
|
||||
};
|
||||
|
||||
let threads = if config.rocksdb_parallelism_threads == 0 {
|
||||
num_cpus::get_physical() // max cores if user specified 0
|
||||
} else {
|
||||
config.rocksdb_parallelism_threads
|
||||
};
|
||||
let threads = if config.rocksdb_parallelism_threads == 0 {
|
||||
num_cpus::get_physical() // max cores if user specified 0
|
||||
} else {
|
||||
config.rocksdb_parallelism_threads
|
||||
};
|
||||
|
||||
db_opts.set_log_level(rocksdb_log_level);
|
||||
db_opts.set_max_log_file_size(config.rocksdb_max_log_file_size);
|
||||
db_opts.set_log_file_time_to_roll(config.rocksdb_log_time_to_roll);
|
||||
db_opts.set_log_level(rocksdb_log_level);
|
||||
db_opts.set_max_log_file_size(config.rocksdb_max_log_file_size);
|
||||
db_opts.set_log_file_time_to_roll(config.rocksdb_log_time_to_roll);
|
||||
|
||||
if config.rocksdb_optimize_for_spinning_disks {
|
||||
db_opts.set_skip_stats_update_on_db_open(true);
|
||||
db_opts.set_compaction_readahead_size(2 * 1024 * 1024); // default compaction_readahead_size is 0 which is good for SSDs
|
||||
db_opts.set_target_file_size_base(256 * 1024 * 1024); // default target_file_size is 64MB which is good for SSDs
|
||||
db_opts.set_optimize_filters_for_hits(true); // doesn't really seem useful for fast storage
|
||||
db_opts.set_keep_log_file_num(3); // keep as few LOG files as possible for spinning hard drives. these are not really important
|
||||
} else {
|
||||
db_opts.set_skip_stats_update_on_db_open(false);
|
||||
db_opts.set_max_bytes_for_level_base(512 * 1024 * 1024);
|
||||
db_opts.set_use_direct_reads(true);
|
||||
db_opts.set_use_direct_io_for_flush_and_compaction(true);
|
||||
db_opts.set_keep_log_file_num(20);
|
||||
}
|
||||
if config.rocksdb_optimize_for_spinning_disks {
|
||||
db_opts.set_skip_stats_update_on_db_open(true);
|
||||
db_opts.set_compaction_readahead_size(2 * 1024 * 1024); // default compaction_readahead_size is 0 which is good for SSDs
|
||||
db_opts.set_target_file_size_base(256 * 1024 * 1024); // default target_file_size is 64MB which is good for SSDs
|
||||
db_opts.set_optimize_filters_for_hits(true); // doesn't really seem useful for fast storage
|
||||
db_opts.set_keep_log_file_num(3); // keep as few LOG files as possible for
|
||||
// spinning hard drives. these are not really
|
||||
// important
|
||||
} else {
|
||||
db_opts.set_skip_stats_update_on_db_open(false);
|
||||
db_opts.set_max_bytes_for_level_base(512 * 1024 * 1024);
|
||||
db_opts.set_use_direct_reads(true);
|
||||
db_opts.set_use_direct_io_for_flush_and_compaction(true);
|
||||
db_opts.set_keep_log_file_num(20);
|
||||
}
|
||||
|
||||
db_opts.set_block_based_table_factory(&block_based_options);
|
||||
db_opts.set_level_compaction_dynamic_level_bytes(true);
|
||||
db_opts.create_if_missing(true);
|
||||
db_opts.increase_parallelism(
|
||||
threads
|
||||
.try_into()
|
||||
.expect("Failed to convert \"rocksdb_parallelism_threads\" usize into i32"),
|
||||
);
|
||||
//db_opts.set_max_open_files(config.rocksdb_max_open_files);
|
||||
db_opts.set_compression_type(rocksdb::DBCompressionType::Zstd);
|
||||
db_opts.set_compaction_style(rocksdb::DBCompactionStyle::Level);
|
||||
db_opts.optimize_level_style_compaction(10 * 1024 * 1024);
|
||||
db_opts.set_block_based_table_factory(&block_based_options);
|
||||
db_opts.set_level_compaction_dynamic_level_bytes(true);
|
||||
db_opts.create_if_missing(true);
|
||||
db_opts.increase_parallelism(
|
||||
threads.try_into().expect("Failed to convert \"rocksdb_parallelism_threads\" usize into i32"),
|
||||
);
|
||||
//db_opts.set_max_open_files(config.rocksdb_max_open_files);
|
||||
db_opts.set_compression_type(rocksdb::DBCompressionType::Zstd);
|
||||
db_opts.set_compaction_style(rocksdb::DBCompactionStyle::Level);
|
||||
db_opts.optimize_level_style_compaction(10 * 1024 * 1024);
|
||||
|
||||
// https://github.com/facebook/rocksdb/wiki/Setup-Options-and-Basic-Tuning
|
||||
db_opts.set_max_background_jobs(6);
|
||||
db_opts.set_bytes_per_sync(1_048_576);
|
||||
// https://github.com/facebook/rocksdb/wiki/Setup-Options-and-Basic-Tuning
|
||||
db_opts.set_max_background_jobs(6);
|
||||
db_opts.set_bytes_per_sync(1_048_576);
|
||||
|
||||
// https://github.com/facebook/rocksdb/wiki/WAL-Recovery-Modes#ktoleratecorruptedtailrecords
|
||||
//
|
||||
// Unclean shutdowns of a Matrix homeserver are likely to be fine when
|
||||
// recovered in this manner as it's likely any lost information will be
|
||||
// restored via federation.
|
||||
db_opts.set_wal_recovery_mode(rocksdb::DBRecoveryMode::TolerateCorruptedTailRecords);
|
||||
// https://github.com/facebook/rocksdb/wiki/WAL-Recovery-Modes#ktoleratecorruptedtailrecords
|
||||
//
|
||||
// Unclean shutdowns of a Matrix homeserver are likely to be fine when
|
||||
// recovered in this manner as it's likely any lost information will be
|
||||
// restored via federation.
|
||||
db_opts.set_wal_recovery_mode(rocksdb::DBRecoveryMode::TolerateCorruptedTailRecords);
|
||||
|
||||
let prefix_extractor = rocksdb::SliceTransform::create_fixed_prefix(1);
|
||||
db_opts.set_prefix_extractor(prefix_extractor);
|
||||
let prefix_extractor = rocksdb::SliceTransform::create_fixed_prefix(1);
|
||||
db_opts.set_prefix_extractor(prefix_extractor);
|
||||
|
||||
db_opts
|
||||
db_opts
|
||||
}
|
||||
|
||||
impl KeyValueDatabaseEngine for Arc<Engine> {
|
||||
fn open(config: &Config) -> Result<Self> {
|
||||
let cache_capacity_bytes = (config.db_cache_capacity_mb * 1024.0 * 1024.0) as usize;
|
||||
let rocksdb_cache = rocksdb::Cache::new_lru_cache(cache_capacity_bytes);
|
||||
fn open(config: &Config) -> Result<Self> {
|
||||
let cache_capacity_bytes = (config.db_cache_capacity_mb * 1024.0 * 1024.0) as usize;
|
||||
let rocksdb_cache = rocksdb::Cache::new_lru_cache(cache_capacity_bytes);
|
||||
|
||||
let db_opts = db_options(&rocksdb_cache, config);
|
||||
let db_opts = db_options(&rocksdb_cache, config);
|
||||
|
||||
debug!("Listing column families in database");
|
||||
let cfs = rocksdb::DBWithThreadMode::<rocksdb::MultiThreaded>::list_cf(
|
||||
&db_opts,
|
||||
&config.database_path,
|
||||
)
|
||||
.unwrap_or_default();
|
||||
debug!("Listing column families in database");
|
||||
let cfs = rocksdb::DBWithThreadMode::<rocksdb::MultiThreaded>::list_cf(&db_opts, &config.database_path)
|
||||
.unwrap_or_default();
|
||||
|
||||
debug!("Opening column family descriptors in database");
|
||||
info!("RocksDB database compaction will take place now, a delay in startup is expected");
|
||||
let db = rocksdb::DBWithThreadMode::<rocksdb::MultiThreaded>::open_cf_descriptors(
|
||||
&db_opts,
|
||||
&config.database_path,
|
||||
cfs.iter().map(|name| {
|
||||
rocksdb::ColumnFamilyDescriptor::new(name, db_options(&rocksdb_cache, config))
|
||||
}),
|
||||
)?;
|
||||
debug!("Opening column family descriptors in database");
|
||||
info!("RocksDB database compaction will take place now, a delay in startup is expected");
|
||||
let db = rocksdb::DBWithThreadMode::<rocksdb::MultiThreaded>::open_cf_descriptors(
|
||||
&db_opts,
|
||||
&config.database_path,
|
||||
cfs.iter().map(|name| rocksdb::ColumnFamilyDescriptor::new(name, db_options(&rocksdb_cache, config))),
|
||||
)?;
|
||||
|
||||
Ok(Arc::new(Engine {
|
||||
rocks: db,
|
||||
cache: rocksdb_cache,
|
||||
old_cfs: cfs,
|
||||
config: config.clone(),
|
||||
}))
|
||||
}
|
||||
Ok(Arc::new(Engine {
|
||||
rocks: db,
|
||||
cache: rocksdb_cache,
|
||||
old_cfs: cfs,
|
||||
config: config.clone(),
|
||||
}))
|
||||
}
|
||||
|
||||
fn open_tree(&self, name: &'static str) -> Result<Arc<dyn KvTree>> {
|
||||
if !self.old_cfs.contains(&name.to_owned()) {
|
||||
// Create if it didn't exist
|
||||
debug!("Creating new column family in database: {}", name);
|
||||
let _ = self
|
||||
.rocks
|
||||
.create_cf(name, &db_options(&self.cache, &self.config));
|
||||
}
|
||||
fn open_tree(&self, name: &'static str) -> Result<Arc<dyn KvTree>> {
|
||||
if !self.old_cfs.contains(&name.to_owned()) {
|
||||
// Create if it didn't exist
|
||||
debug!("Creating new column family in database: {}", name);
|
||||
let _ = self.rocks.create_cf(name, &db_options(&self.cache, &self.config));
|
||||
}
|
||||
|
||||
Ok(Arc::new(RocksDbEngineTree {
|
||||
name,
|
||||
db: Arc::clone(self),
|
||||
watchers: Watchers::default(),
|
||||
write_lock: RwLock::new(()),
|
||||
}))
|
||||
}
|
||||
Ok(Arc::new(RocksDbEngineTree {
|
||||
name,
|
||||
db: Arc::clone(self),
|
||||
watchers: Watchers::default(),
|
||||
write_lock: RwLock::new(()),
|
||||
}))
|
||||
}
|
||||
|
||||
fn flush(&self) -> Result<()> {
|
||||
// TODO?
|
||||
Ok(())
|
||||
}
|
||||
fn flush(&self) -> Result<()> {
|
||||
// TODO?
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn memory_usage(&self) -> Result<String> {
|
||||
let stats =
|
||||
rocksdb::perf::get_memory_usage_stats(Some(&[&self.rocks]), Some(&[&self.cache]))?;
|
||||
Ok(format!(
|
||||
"Approximate memory usage of all the mem-tables: {:.3} MB\n\
|
||||
Approximate memory usage of un-flushed mem-tables: {:.3} MB\n\
|
||||
Approximate memory usage of all the table readers: {:.3} MB\n\
|
||||
Approximate memory usage by cache: {:.3} MB\n\
|
||||
Approximate memory usage by cache pinned: {:.3} MB\n\
|
||||
",
|
||||
stats.mem_table_total as f64 / 1024.0 / 1024.0,
|
||||
stats.mem_table_unflushed as f64 / 1024.0 / 1024.0,
|
||||
stats.mem_table_readers_total as f64 / 1024.0 / 1024.0,
|
||||
stats.cache_total as f64 / 1024.0 / 1024.0,
|
||||
self.cache.get_pinned_usage() as f64 / 1024.0 / 1024.0,
|
||||
))
|
||||
}
|
||||
fn memory_usage(&self) -> Result<String> {
|
||||
let stats = rocksdb::perf::get_memory_usage_stats(Some(&[&self.rocks]), Some(&[&self.cache]))?;
|
||||
Ok(format!(
|
||||
"Approximate memory usage of all the mem-tables: {:.3} MB\nApproximate memory usage of un-flushed \
|
||||
mem-tables: {:.3} MB\nApproximate memory usage of all the table readers: {:.3} MB\nApproximate memory \
|
||||
usage by cache: {:.3} MB\nApproximate memory usage by cache pinned: {:.3} MB\n",
|
||||
stats.mem_table_total as f64 / 1024.0 / 1024.0,
|
||||
stats.mem_table_unflushed as f64 / 1024.0 / 1024.0,
|
||||
stats.mem_table_readers_total as f64 / 1024.0 / 1024.0,
|
||||
stats.cache_total as f64 / 1024.0 / 1024.0,
|
||||
self.cache.get_pinned_usage() as f64 / 1024.0 / 1024.0,
|
||||
))
|
||||
}
|
||||
|
||||
// TODO: figure out if this is needed for rocksdb
|
||||
#[allow(dead_code)]
|
||||
fn clear_caches(&self) {}
|
||||
// TODO: figure out if this is needed for rocksdb
|
||||
#[allow(dead_code)]
|
||||
fn clear_caches(&self) {}
|
||||
}
|
||||
|
||||
impl RocksDbEngineTree<'_> {
|
||||
fn cf(&self) -> Arc<rocksdb::BoundColumnFamily<'_>> {
|
||||
self.db.rocks.cf_handle(self.name).unwrap()
|
||||
}
|
||||
fn cf(&self) -> Arc<rocksdb::BoundColumnFamily<'_>> { self.db.rocks.cf_handle(self.name).unwrap() }
|
||||
}
|
||||
|
||||
impl KvTree for RocksDbEngineTree<'_> {
|
||||
fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
|
||||
Ok(self.db.rocks.get_cf(&self.cf(), key)?)
|
||||
}
|
||||
fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> { Ok(self.db.rocks.get_cf(&self.cf(), key)?) }
|
||||
|
||||
fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> {
|
||||
let lock = self.write_lock.read().unwrap();
|
||||
self.db.rocks.put_cf(&self.cf(), key, value)?;
|
||||
drop(lock);
|
||||
fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> {
|
||||
let lock = self.write_lock.read().unwrap();
|
||||
self.db.rocks.put_cf(&self.cf(), key, value)?;
|
||||
drop(lock);
|
||||
|
||||
self.watchers.wake(key);
|
||||
self.watchers.wake(key);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn insert_batch(&self, iter: &mut dyn Iterator<Item = (Vec<u8>, Vec<u8>)>) -> Result<()> {
|
||||
for (key, value) in iter {
|
||||
self.db.rocks.put_cf(&self.cf(), key, value)?;
|
||||
}
|
||||
fn insert_batch(&self, iter: &mut dyn Iterator<Item = (Vec<u8>, Vec<u8>)>) -> Result<()> {
|
||||
for (key, value) in iter {
|
||||
self.db.rocks.put_cf(&self.cf(), key, value)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn remove(&self, key: &[u8]) -> Result<()> {
|
||||
Ok(self.db.rocks.delete_cf(&self.cf(), key)?)
|
||||
}
|
||||
fn remove(&self, key: &[u8]) -> Result<()> { Ok(self.db.rocks.delete_cf(&self.cf(), key)?) }
|
||||
|
||||
fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a> {
|
||||
Box::new(
|
||||
self.db
|
||||
.rocks
|
||||
.iterator_cf(&self.cf(), rocksdb::IteratorMode::Start)
|
||||
.map(std::result::Result::unwrap)
|
||||
.map(|(k, v)| (Vec::from(k), Vec::from(v))),
|
||||
)
|
||||
}
|
||||
fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a> {
|
||||
Box::new(
|
||||
self.db
|
||||
.rocks
|
||||
.iterator_cf(&self.cf(), rocksdb::IteratorMode::Start)
|
||||
.map(std::result::Result::unwrap)
|
||||
.map(|(k, v)| (Vec::from(k), Vec::from(v))),
|
||||
)
|
||||
}
|
||||
|
||||
fn iter_from<'a>(
|
||||
&'a self,
|
||||
from: &[u8],
|
||||
backwards: bool,
|
||||
) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a> {
|
||||
Box::new(
|
||||
self.db
|
||||
.rocks
|
||||
.iterator_cf(
|
||||
&self.cf(),
|
||||
rocksdb::IteratorMode::From(
|
||||
from,
|
||||
if backwards {
|
||||
rocksdb::Direction::Reverse
|
||||
} else {
|
||||
rocksdb::Direction::Forward
|
||||
},
|
||||
),
|
||||
)
|
||||
.map(std::result::Result::unwrap)
|
||||
.map(|(k, v)| (Vec::from(k), Vec::from(v))),
|
||||
)
|
||||
}
|
||||
fn iter_from<'a>(&'a self, from: &[u8], backwards: bool) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a> {
|
||||
Box::new(
|
||||
self.db
|
||||
.rocks
|
||||
.iterator_cf(
|
||||
&self.cf(),
|
||||
rocksdb::IteratorMode::From(
|
||||
from,
|
||||
if backwards {
|
||||
rocksdb::Direction::Reverse
|
||||
} else {
|
||||
rocksdb::Direction::Forward
|
||||
},
|
||||
),
|
||||
)
|
||||
.map(std::result::Result::unwrap)
|
||||
.map(|(k, v)| (Vec::from(k), Vec::from(v))),
|
||||
)
|
||||
}
|
||||
|
||||
fn increment(&self, key: &[u8]) -> Result<Vec<u8>> {
|
||||
let lock = self.write_lock.write().unwrap();
|
||||
fn increment(&self, key: &[u8]) -> Result<Vec<u8>> {
|
||||
let lock = self.write_lock.write().unwrap();
|
||||
|
||||
let old = self.db.rocks.get_cf(&self.cf(), key)?;
|
||||
let new = utils::increment(old.as_deref()).unwrap();
|
||||
self.db.rocks.put_cf(&self.cf(), key, &new)?;
|
||||
let old = self.db.rocks.get_cf(&self.cf(), key)?;
|
||||
let new = utils::increment(old.as_deref()).unwrap();
|
||||
self.db.rocks.put_cf(&self.cf(), key, &new)?;
|
||||
|
||||
drop(lock);
|
||||
Ok(new)
|
||||
}
|
||||
drop(lock);
|
||||
Ok(new)
|
||||
}
|
||||
|
||||
fn increment_batch(&self, iter: &mut dyn Iterator<Item = Vec<u8>>) -> Result<()> {
|
||||
let lock = self.write_lock.write().unwrap();
|
||||
fn increment_batch(&self, iter: &mut dyn Iterator<Item = Vec<u8>>) -> Result<()> {
|
||||
let lock = self.write_lock.write().unwrap();
|
||||
|
||||
for key in iter {
|
||||
let old = self.db.rocks.get_cf(&self.cf(), &key)?;
|
||||
let new = utils::increment(old.as_deref()).unwrap();
|
||||
self.db.rocks.put_cf(&self.cf(), key, new)?;
|
||||
}
|
||||
for key in iter {
|
||||
let old = self.db.rocks.get_cf(&self.cf(), &key)?;
|
||||
let new = utils::increment(old.as_deref()).unwrap();
|
||||
self.db.rocks.put_cf(&self.cf(), key, new)?;
|
||||
}
|
||||
|
||||
drop(lock);
|
||||
drop(lock);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn scan_prefix<'a>(
|
||||
&'a self,
|
||||
prefix: Vec<u8>,
|
||||
) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a> {
|
||||
Box::new(
|
||||
self.db
|
||||
.rocks
|
||||
.iterator_cf(
|
||||
&self.cf(),
|
||||
rocksdb::IteratorMode::From(&prefix, rocksdb::Direction::Forward),
|
||||
)
|
||||
.map(std::result::Result::unwrap)
|
||||
.map(|(k, v)| (Vec::from(k), Vec::from(v)))
|
||||
.take_while(move |(k, _)| k.starts_with(&prefix)),
|
||||
)
|
||||
}
|
||||
fn scan_prefix<'a>(&'a self, prefix: Vec<u8>) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a> {
|
||||
Box::new(
|
||||
self.db
|
||||
.rocks
|
||||
.iterator_cf(&self.cf(), rocksdb::IteratorMode::From(&prefix, rocksdb::Direction::Forward))
|
||||
.map(std::result::Result::unwrap)
|
||||
.map(|(k, v)| (Vec::from(k), Vec::from(v)))
|
||||
.take_while(move |(k, _)| k.starts_with(&prefix)),
|
||||
)
|
||||
}
|
||||
|
||||
fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
|
||||
self.watchers.watch(prefix)
|
||||
}
|
||||
fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
|
||||
self.watchers.watch(prefix)
|
||||
}
|
||||
}
|
||||
|
||||
+216
-251
@@ -1,340 +1,305 @@
|
||||
use super::{watchers::Watchers, KeyValueDatabaseEngine, KvTree};
|
||||
use crate::{database::Config, Result};
|
||||
use std::{
|
||||
cell::RefCell,
|
||||
future::Future,
|
||||
path::{Path, PathBuf},
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use parking_lot::{Mutex, MutexGuard};
|
||||
use rusqlite::{Connection, DatabaseName::Main, OptionalExtension};
|
||||
use std::{
|
||||
cell::RefCell,
|
||||
future::Future,
|
||||
path::{Path, PathBuf},
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
};
|
||||
use thread_local::ThreadLocal;
|
||||
use tracing::debug;
|
||||
|
||||
use super::{watchers::Watchers, KeyValueDatabaseEngine, KvTree};
|
||||
use crate::{database::Config, Result};
|
||||
|
||||
thread_local! {
|
||||
static READ_CONNECTION: RefCell<Option<&'static Connection>> = RefCell::new(None);
|
||||
static READ_CONNECTION_ITERATOR: RefCell<Option<&'static Connection>> = RefCell::new(None);
|
||||
static READ_CONNECTION: RefCell<Option<&'static Connection>> = RefCell::new(None);
|
||||
static READ_CONNECTION_ITERATOR: RefCell<Option<&'static Connection>> = RefCell::new(None);
|
||||
}
|
||||
|
||||
struct PreparedStatementIterator<'a> {
|
||||
pub iterator: Box<dyn Iterator<Item = TupleOfBytes> + 'a>,
|
||||
pub _statement_ref: NonAliasingBox<rusqlite::Statement<'a>>,
|
||||
pub iterator: Box<dyn Iterator<Item = TupleOfBytes> + 'a>,
|
||||
pub _statement_ref: NonAliasingBox<rusqlite::Statement<'a>>,
|
||||
}
|
||||
|
||||
impl Iterator for PreparedStatementIterator<'_> {
|
||||
type Item = TupleOfBytes;
|
||||
type Item = TupleOfBytes;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
self.iterator.next()
|
||||
}
|
||||
fn next(&mut self) -> Option<Self::Item> { self.iterator.next() }
|
||||
}
|
||||
|
||||
struct NonAliasingBox<T>(*mut T);
|
||||
impl<T> Drop for NonAliasingBox<T> {
|
||||
fn drop(&mut self) {
|
||||
unsafe {
|
||||
let _ = Box::from_raw(self.0);
|
||||
};
|
||||
}
|
||||
fn drop(&mut self) {
|
||||
unsafe {
|
||||
let _ = Box::from_raw(self.0);
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Engine {
|
||||
writer: Mutex<Connection>,
|
||||
read_conn_tls: ThreadLocal<Connection>,
|
||||
read_iterator_conn_tls: ThreadLocal<Connection>,
|
||||
writer: Mutex<Connection>,
|
||||
read_conn_tls: ThreadLocal<Connection>,
|
||||
read_iterator_conn_tls: ThreadLocal<Connection>,
|
||||
|
||||
path: PathBuf,
|
||||
cache_size_per_thread: u32,
|
||||
path: PathBuf,
|
||||
cache_size_per_thread: u32,
|
||||
}
|
||||
|
||||
impl Engine {
|
||||
fn prepare_conn(path: &Path, cache_size_kb: u32) -> Result<Connection> {
|
||||
let conn = Connection::open(path)?;
|
||||
fn prepare_conn(path: &Path, cache_size_kb: u32) -> Result<Connection> {
|
||||
let conn = Connection::open(path)?;
|
||||
|
||||
conn.pragma_update(Some(Main), "page_size", 2048)?;
|
||||
conn.pragma_update(Some(Main), "journal_mode", "WAL")?;
|
||||
conn.pragma_update(Some(Main), "synchronous", "NORMAL")?;
|
||||
conn.pragma_update(Some(Main), "cache_size", -i64::from(cache_size_kb))?;
|
||||
conn.pragma_update(Some(Main), "wal_autocheckpoint", 0)?;
|
||||
conn.pragma_update(Some(Main), "page_size", 2048)?;
|
||||
conn.pragma_update(Some(Main), "journal_mode", "WAL")?;
|
||||
conn.pragma_update(Some(Main), "synchronous", "NORMAL")?;
|
||||
conn.pragma_update(Some(Main), "cache_size", -i64::from(cache_size_kb))?;
|
||||
conn.pragma_update(Some(Main), "wal_autocheckpoint", 0)?;
|
||||
|
||||
Ok(conn)
|
||||
}
|
||||
Ok(conn)
|
||||
}
|
||||
|
||||
fn write_lock(&self) -> MutexGuard<'_, Connection> {
|
||||
self.writer.lock()
|
||||
}
|
||||
fn write_lock(&self) -> MutexGuard<'_, Connection> { self.writer.lock() }
|
||||
|
||||
fn read_lock(&self) -> &Connection {
|
||||
self.read_conn_tls
|
||||
.get_or(|| Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap())
|
||||
}
|
||||
fn read_lock(&self) -> &Connection {
|
||||
self.read_conn_tls.get_or(|| Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap())
|
||||
}
|
||||
|
||||
fn read_lock_iterator(&self) -> &Connection {
|
||||
self.read_iterator_conn_tls
|
||||
.get_or(|| Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap())
|
||||
}
|
||||
fn read_lock_iterator(&self) -> &Connection {
|
||||
self.read_iterator_conn_tls.get_or(|| Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap())
|
||||
}
|
||||
|
||||
pub fn flush_wal(self: &Arc<Self>) -> Result<()> {
|
||||
self.write_lock()
|
||||
.pragma_update(Some(Main), "wal_checkpoint", "RESTART")?;
|
||||
Ok(())
|
||||
}
|
||||
pub fn flush_wal(self: &Arc<Self>) -> Result<()> {
|
||||
self.write_lock().pragma_update(Some(Main), "wal_checkpoint", "RESTART")?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl KeyValueDatabaseEngine for Arc<Engine> {
|
||||
fn open(config: &Config) -> Result<Self> {
|
||||
let path = Path::new(&config.database_path).join("conduit.db");
|
||||
fn open(config: &Config) -> Result<Self> {
|
||||
let path = Path::new(&config.database_path).join("conduit.db");
|
||||
|
||||
// calculates cache-size per permanent connection
|
||||
// 1. convert MB to KiB
|
||||
// 2. divide by permanent connections + permanent iter connections + write connection
|
||||
// 3. round down to nearest integer
|
||||
let cache_size_per_thread: u32 = ((config.db_cache_capacity_mb * 1024.0)
|
||||
/ ((num_cpus::get().max(1) * 2) + 1) as f64)
|
||||
as u32;
|
||||
// calculates cache-size per permanent connection
|
||||
// 1. convert MB to KiB
|
||||
// 2. divide by permanent connections + permanent iter connections + write
|
||||
// connection
|
||||
// 3. round down to nearest integer
|
||||
let cache_size_per_thread: u32 =
|
||||
((config.db_cache_capacity_mb * 1024.0) / ((num_cpus::get().max(1) * 2) + 1) as f64) as u32;
|
||||
|
||||
let writer = Mutex::new(Engine::prepare_conn(&path, cache_size_per_thread)?);
|
||||
let writer = Mutex::new(Engine::prepare_conn(&path, cache_size_per_thread)?);
|
||||
|
||||
let arc = Arc::new(Engine {
|
||||
writer,
|
||||
read_conn_tls: ThreadLocal::new(),
|
||||
read_iterator_conn_tls: ThreadLocal::new(),
|
||||
path,
|
||||
cache_size_per_thread,
|
||||
});
|
||||
let arc = Arc::new(Engine {
|
||||
writer,
|
||||
read_conn_tls: ThreadLocal::new(),
|
||||
read_iterator_conn_tls: ThreadLocal::new(),
|
||||
path,
|
||||
cache_size_per_thread,
|
||||
});
|
||||
|
||||
Ok(arc)
|
||||
}
|
||||
Ok(arc)
|
||||
}
|
||||
|
||||
fn open_tree(&self, name: &str) -> Result<Arc<dyn KvTree>> {
|
||||
self.write_lock().execute(&format!("CREATE TABLE IF NOT EXISTS {name} ( \"key\" BLOB PRIMARY KEY, \"value\" BLOB NOT NULL )"), [])?;
|
||||
fn open_tree(&self, name: &str) -> Result<Arc<dyn KvTree>> {
|
||||
self.write_lock().execute(
|
||||
&format!("CREATE TABLE IF NOT EXISTS {name} ( \"key\" BLOB PRIMARY KEY, \"value\" BLOB NOT NULL )"),
|
||||
[],
|
||||
)?;
|
||||
|
||||
Ok(Arc::new(SqliteTable {
|
||||
engine: Arc::clone(self),
|
||||
name: name.to_owned(),
|
||||
watchers: Watchers::default(),
|
||||
}))
|
||||
}
|
||||
Ok(Arc::new(SqliteTable {
|
||||
engine: Arc::clone(self),
|
||||
name: name.to_owned(),
|
||||
watchers: Watchers::default(),
|
||||
}))
|
||||
}
|
||||
|
||||
fn flush(&self) -> Result<()> {
|
||||
// we enabled PRAGMA synchronous=normal, so this should not be necessary
|
||||
Ok(())
|
||||
}
|
||||
fn flush(&self) -> Result<()> {
|
||||
// we enabled PRAGMA synchronous=normal, so this should not be necessary
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> Result<()> {
|
||||
self.flush_wal()
|
||||
}
|
||||
fn cleanup(&self) -> Result<()> { self.flush_wal() }
|
||||
}
|
||||
|
||||
pub struct SqliteTable {
|
||||
engine: Arc<Engine>,
|
||||
name: String,
|
||||
watchers: Watchers,
|
||||
engine: Arc<Engine>,
|
||||
name: String,
|
||||
watchers: Watchers,
|
||||
}
|
||||
|
||||
type TupleOfBytes = (Vec<u8>, Vec<u8>);
|
||||
|
||||
impl SqliteTable {
|
||||
fn get_with_guard(&self, guard: &Connection, key: &[u8]) -> Result<Option<Vec<u8>>> {
|
||||
Ok(guard
|
||||
.prepare(format!("SELECT value FROM {} WHERE key = ?", self.name).as_str())?
|
||||
.query_row([key], |row| row.get(0))
|
||||
.optional()?)
|
||||
}
|
||||
fn get_with_guard(&self, guard: &Connection, key: &[u8]) -> Result<Option<Vec<u8>>> {
|
||||
Ok(guard
|
||||
.prepare(format!("SELECT value FROM {} WHERE key = ?", self.name).as_str())?
|
||||
.query_row([key], |row| row.get(0))
|
||||
.optional()?)
|
||||
}
|
||||
|
||||
fn insert_with_guard(&self, guard: &Connection, key: &[u8], value: &[u8]) -> Result<()> {
|
||||
guard.execute(
|
||||
format!(
|
||||
"INSERT OR REPLACE INTO {} (key, value) VALUES (?, ?)",
|
||||
self.name
|
||||
)
|
||||
.as_str(),
|
||||
[key, value],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
fn insert_with_guard(&self, guard: &Connection, key: &[u8], value: &[u8]) -> Result<()> {
|
||||
guard.execute(
|
||||
format!("INSERT OR REPLACE INTO {} (key, value) VALUES (?, ?)", self.name).as_str(),
|
||||
[key, value],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn iter_with_guard<'a>(
|
||||
&'a self,
|
||||
guard: &'a Connection,
|
||||
) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> {
|
||||
let statement = Box::leak(Box::new(
|
||||
guard
|
||||
.prepare(&format!(
|
||||
"SELECT key, value FROM {} ORDER BY key ASC",
|
||||
&self.name
|
||||
))
|
||||
.unwrap(),
|
||||
));
|
||||
pub fn iter_with_guard<'a>(&'a self, guard: &'a Connection) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> {
|
||||
let statement = Box::leak(Box::new(
|
||||
guard.prepare(&format!("SELECT key, value FROM {} ORDER BY key ASC", &self.name)).unwrap(),
|
||||
));
|
||||
|
||||
let statement_ref = NonAliasingBox(statement);
|
||||
let statement_ref = NonAliasingBox(statement);
|
||||
|
||||
//let name = self.name.clone();
|
||||
//let name = self.name.clone();
|
||||
|
||||
let iterator = Box::new(
|
||||
statement
|
||||
.query_map([], |row| Ok((row.get_unwrap(0), row.get_unwrap(1))))
|
||||
.unwrap()
|
||||
.map(move |r| r.unwrap()),
|
||||
);
|
||||
let iterator = Box::new(
|
||||
statement.query_map([], |row| Ok((row.get_unwrap(0), row.get_unwrap(1)))).unwrap().map(move |r| r.unwrap()),
|
||||
);
|
||||
|
||||
Box::new(PreparedStatementIterator {
|
||||
iterator,
|
||||
_statement_ref: statement_ref,
|
||||
})
|
||||
}
|
||||
Box::new(PreparedStatementIterator {
|
||||
iterator,
|
||||
_statement_ref: statement_ref,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl KvTree for SqliteTable {
|
||||
fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
|
||||
self.get_with_guard(self.engine.read_lock(), key)
|
||||
}
|
||||
fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> { self.get_with_guard(self.engine.read_lock(), key) }
|
||||
|
||||
fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> {
|
||||
let guard = self.engine.write_lock();
|
||||
self.insert_with_guard(&guard, key, value)?;
|
||||
drop(guard);
|
||||
self.watchers.wake(key);
|
||||
Ok(())
|
||||
}
|
||||
fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> {
|
||||
let guard = self.engine.write_lock();
|
||||
self.insert_with_guard(&guard, key, value)?;
|
||||
drop(guard);
|
||||
self.watchers.wake(key);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn insert_batch<'a>(&self, iter: &mut dyn Iterator<Item = (Vec<u8>, Vec<u8>)>) -> Result<()> {
|
||||
let guard = self.engine.write_lock();
|
||||
fn insert_batch<'a>(&self, iter: &mut dyn Iterator<Item = (Vec<u8>, Vec<u8>)>) -> Result<()> {
|
||||
let guard = self.engine.write_lock();
|
||||
|
||||
guard.execute("BEGIN", [])?;
|
||||
for (key, value) in iter {
|
||||
self.insert_with_guard(&guard, &key, &value)?;
|
||||
}
|
||||
guard.execute("COMMIT", [])?;
|
||||
guard.execute("BEGIN", [])?;
|
||||
for (key, value) in iter {
|
||||
self.insert_with_guard(&guard, &key, &value)?;
|
||||
}
|
||||
guard.execute("COMMIT", [])?;
|
||||
|
||||
drop(guard);
|
||||
drop(guard);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn increment_batch<'a>(&self, iter: &mut dyn Iterator<Item = Vec<u8>>) -> Result<()> {
|
||||
let guard = self.engine.write_lock();
|
||||
fn increment_batch<'a>(&self, iter: &mut dyn Iterator<Item = Vec<u8>>) -> Result<()> {
|
||||
let guard = self.engine.write_lock();
|
||||
|
||||
guard.execute("BEGIN", [])?;
|
||||
for key in iter {
|
||||
let old = self.get_with_guard(&guard, &key)?;
|
||||
let new = crate::utils::increment(old.as_deref())
|
||||
.expect("utils::increment always returns Some");
|
||||
self.insert_with_guard(&guard, &key, &new)?;
|
||||
}
|
||||
guard.execute("COMMIT", [])?;
|
||||
guard.execute("BEGIN", [])?;
|
||||
for key in iter {
|
||||
let old = self.get_with_guard(&guard, &key)?;
|
||||
let new = crate::utils::increment(old.as_deref()).expect("utils::increment always returns Some");
|
||||
self.insert_with_guard(&guard, &key, &new)?;
|
||||
}
|
||||
guard.execute("COMMIT", [])?;
|
||||
|
||||
drop(guard);
|
||||
drop(guard);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn remove(&self, key: &[u8]) -> Result<()> {
|
||||
let guard = self.engine.write_lock();
|
||||
fn remove(&self, key: &[u8]) -> Result<()> {
|
||||
let guard = self.engine.write_lock();
|
||||
|
||||
guard.execute(
|
||||
format!("DELETE FROM {} WHERE key = ?", self.name).as_str(),
|
||||
[key],
|
||||
)?;
|
||||
guard.execute(format!("DELETE FROM {} WHERE key = ?", self.name).as_str(), [key])?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> {
|
||||
let guard = self.engine.read_lock_iterator();
|
||||
fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> {
|
||||
let guard = self.engine.read_lock_iterator();
|
||||
|
||||
self.iter_with_guard(guard)
|
||||
}
|
||||
self.iter_with_guard(guard)
|
||||
}
|
||||
|
||||
fn iter_from<'a>(
|
||||
&'a self,
|
||||
from: &[u8],
|
||||
backwards: bool,
|
||||
) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> {
|
||||
let guard = self.engine.read_lock_iterator();
|
||||
let from = from.to_vec(); // TODO change interface?
|
||||
fn iter_from<'a>(&'a self, from: &[u8], backwards: bool) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> {
|
||||
let guard = self.engine.read_lock_iterator();
|
||||
let from = from.to_vec(); // TODO change interface?
|
||||
|
||||
//let name = self.name.clone();
|
||||
//let name = self.name.clone();
|
||||
|
||||
if backwards {
|
||||
let statement = Box::leak(Box::new(
|
||||
guard
|
||||
.prepare(&format!(
|
||||
"SELECT key, value FROM {} WHERE key <= ? ORDER BY key DESC",
|
||||
&self.name
|
||||
))
|
||||
.unwrap(),
|
||||
));
|
||||
if backwards {
|
||||
let statement = Box::leak(Box::new(
|
||||
guard
|
||||
.prepare(&format!(
|
||||
"SELECT key, value FROM {} WHERE key <= ? ORDER BY key DESC",
|
||||
&self.name
|
||||
))
|
||||
.unwrap(),
|
||||
));
|
||||
|
||||
let statement_ref = NonAliasingBox(statement);
|
||||
let statement_ref = NonAliasingBox(statement);
|
||||
|
||||
let iterator = Box::new(
|
||||
statement
|
||||
.query_map([from], |row| Ok((row.get_unwrap(0), row.get_unwrap(1))))
|
||||
.unwrap()
|
||||
.map(move |r| r.unwrap()),
|
||||
);
|
||||
Box::new(PreparedStatementIterator {
|
||||
iterator,
|
||||
_statement_ref: statement_ref,
|
||||
})
|
||||
} else {
|
||||
let statement = Box::leak(Box::new(
|
||||
guard
|
||||
.prepare(&format!(
|
||||
"SELECT key, value FROM {} WHERE key >= ? ORDER BY key ASC",
|
||||
&self.name
|
||||
))
|
||||
.unwrap(),
|
||||
));
|
||||
let iterator = Box::new(
|
||||
statement
|
||||
.query_map([from], |row| Ok((row.get_unwrap(0), row.get_unwrap(1))))
|
||||
.unwrap()
|
||||
.map(move |r| r.unwrap()),
|
||||
);
|
||||
Box::new(PreparedStatementIterator {
|
||||
iterator,
|
||||
_statement_ref: statement_ref,
|
||||
})
|
||||
} else {
|
||||
let statement = Box::leak(Box::new(
|
||||
guard
|
||||
.prepare(&format!(
|
||||
"SELECT key, value FROM {} WHERE key >= ? ORDER BY key ASC",
|
||||
&self.name
|
||||
))
|
||||
.unwrap(),
|
||||
));
|
||||
|
||||
let statement_ref = NonAliasingBox(statement);
|
||||
let statement_ref = NonAliasingBox(statement);
|
||||
|
||||
let iterator = Box::new(
|
||||
statement
|
||||
.query_map([from], |row| Ok((row.get_unwrap(0), row.get_unwrap(1))))
|
||||
.unwrap()
|
||||
.map(move |r| r.unwrap()),
|
||||
);
|
||||
let iterator = Box::new(
|
||||
statement
|
||||
.query_map([from], |row| Ok((row.get_unwrap(0), row.get_unwrap(1))))
|
||||
.unwrap()
|
||||
.map(move |r| r.unwrap()),
|
||||
);
|
||||
|
||||
Box::new(PreparedStatementIterator {
|
||||
iterator,
|
||||
_statement_ref: statement_ref,
|
||||
})
|
||||
}
|
||||
}
|
||||
Box::new(PreparedStatementIterator {
|
||||
iterator,
|
||||
_statement_ref: statement_ref,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn increment(&self, key: &[u8]) -> Result<Vec<u8>> {
|
||||
let guard = self.engine.write_lock();
|
||||
fn increment(&self, key: &[u8]) -> Result<Vec<u8>> {
|
||||
let guard = self.engine.write_lock();
|
||||
|
||||
let old = self.get_with_guard(&guard, key)?;
|
||||
let old = self.get_with_guard(&guard, key)?;
|
||||
|
||||
let new =
|
||||
crate::utils::increment(old.as_deref()).expect("utils::increment always returns Some");
|
||||
let new = crate::utils::increment(old.as_deref()).expect("utils::increment always returns Some");
|
||||
|
||||
self.insert_with_guard(&guard, key, &new)?;
|
||||
self.insert_with_guard(&guard, key, &new)?;
|
||||
|
||||
Ok(new)
|
||||
}
|
||||
Ok(new)
|
||||
}
|
||||
|
||||
fn scan_prefix<'a>(&'a self, prefix: Vec<u8>) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> {
|
||||
Box::new(
|
||||
self.iter_from(&prefix, false)
|
||||
.take_while(move |(key, _)| key.starts_with(&prefix)),
|
||||
)
|
||||
}
|
||||
fn scan_prefix<'a>(&'a self, prefix: Vec<u8>) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> {
|
||||
Box::new(self.iter_from(&prefix, false).take_while(move |(key, _)| key.starts_with(&prefix)))
|
||||
}
|
||||
|
||||
fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
|
||||
self.watchers.watch(prefix)
|
||||
}
|
||||
fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
|
||||
self.watchers.watch(prefix)
|
||||
}
|
||||
|
||||
fn clear(&self) -> Result<()> {
|
||||
debug!("clear: running");
|
||||
self.engine
|
||||
.write_lock()
|
||||
.execute(format!("DELETE FROM {}", self.name).as_str(), [])?;
|
||||
debug!("clear: ran");
|
||||
Ok(())
|
||||
}
|
||||
fn clear(&self) -> Result<()> {
|
||||
debug!("clear: running");
|
||||
self.engine.write_lock().execute(format!("DELETE FROM {}", self.name).as_str(), [])?;
|
||||
debug!("clear: ran");
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,56 +1,55 @@
|
||||
use std::{
|
||||
collections::{hash_map, HashMap},
|
||||
future::Future,
|
||||
pin::Pin,
|
||||
sync::RwLock,
|
||||
collections::{hash_map, HashMap},
|
||||
future::Future,
|
||||
pin::Pin,
|
||||
sync::RwLock,
|
||||
};
|
||||
|
||||
use tokio::sync::watch;
|
||||
|
||||
type Watcher = RwLock<HashMap<Vec<u8>, (watch::Sender<()>, watch::Receiver<()>)>>;
|
||||
|
||||
#[derive(Default)]
|
||||
pub(super) struct Watchers {
|
||||
watchers: Watcher,
|
||||
watchers: Watcher,
|
||||
}
|
||||
|
||||
impl Watchers {
|
||||
pub(super) fn watch<'a>(
|
||||
&'a self,
|
||||
prefix: &[u8],
|
||||
) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
|
||||
let mut rx = match self.watchers.write().unwrap().entry(prefix.to_vec()) {
|
||||
hash_map::Entry::Occupied(o) => o.get().1.clone(),
|
||||
hash_map::Entry::Vacant(v) => {
|
||||
let (tx, rx) = tokio::sync::watch::channel(());
|
||||
v.insert((tx, rx.clone()));
|
||||
rx
|
||||
}
|
||||
};
|
||||
pub(super) fn watch<'a>(&'a self, prefix: &[u8]) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
|
||||
let mut rx = match self.watchers.write().unwrap().entry(prefix.to_vec()) {
|
||||
hash_map::Entry::Occupied(o) => o.get().1.clone(),
|
||||
hash_map::Entry::Vacant(v) => {
|
||||
let (tx, rx) = tokio::sync::watch::channel(());
|
||||
v.insert((tx, rx.clone()));
|
||||
rx
|
||||
},
|
||||
};
|
||||
|
||||
Box::pin(async move {
|
||||
// Tx is never destroyed
|
||||
rx.changed().await.unwrap();
|
||||
})
|
||||
}
|
||||
pub(super) fn wake(&self, key: &[u8]) {
|
||||
let watchers = self.watchers.read().unwrap();
|
||||
let mut triggered = Vec::new();
|
||||
Box::pin(async move {
|
||||
// Tx is never destroyed
|
||||
rx.changed().await.unwrap();
|
||||
})
|
||||
}
|
||||
|
||||
for length in 0..=key.len() {
|
||||
if watchers.contains_key(&key[..length]) {
|
||||
triggered.push(&key[..length]);
|
||||
}
|
||||
}
|
||||
pub(super) fn wake(&self, key: &[u8]) {
|
||||
let watchers = self.watchers.read().unwrap();
|
||||
let mut triggered = Vec::new();
|
||||
|
||||
drop(watchers);
|
||||
for length in 0..=key.len() {
|
||||
if watchers.contains_key(&key[..length]) {
|
||||
triggered.push(&key[..length]);
|
||||
}
|
||||
}
|
||||
|
||||
if !triggered.is_empty() {
|
||||
let mut watchers = self.watchers.write().unwrap();
|
||||
for prefix in triggered {
|
||||
if let Some(tx) = watchers.remove(prefix) {
|
||||
let _ = tx.0.send(());
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
drop(watchers);
|
||||
|
||||
if !triggered.is_empty() {
|
||||
let mut watchers = self.watchers.write().unwrap();
|
||||
for prefix in triggered {
|
||||
if let Some(tx) = watchers.remove(prefix) {
|
||||
let _ = tx.0.send(());
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,148 +1,120 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use ruma::{
|
||||
api::client::error::ErrorKind,
|
||||
events::{AnyEphemeralRoomEvent, RoomAccountDataEventType},
|
||||
serde::Raw,
|
||||
RoomId, UserId,
|
||||
api::client::error::ErrorKind,
|
||||
events::{AnyEphemeralRoomEvent, RoomAccountDataEventType},
|
||||
serde::Raw,
|
||||
RoomId, UserId,
|
||||
};
|
||||
use tracing::warn;
|
||||
|
||||
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
|
||||
|
||||
impl service::account_data::Data for KeyValueDatabase {
|
||||
/// Places one event in the account data of the user and removes the previous entry.
|
||||
#[tracing::instrument(skip(self, room_id, user_id, event_type, data))]
|
||||
fn update(
|
||||
&self,
|
||||
room_id: Option<&RoomId>,
|
||||
user_id: &UserId,
|
||||
event_type: RoomAccountDataEventType,
|
||||
data: &serde_json::Value,
|
||||
) -> Result<()> {
|
||||
let mut prefix = room_id
|
||||
.map(std::string::ToString::to_string)
|
||||
.unwrap_or_default()
|
||||
.as_bytes()
|
||||
.to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.extend_from_slice(user_id.as_bytes());
|
||||
prefix.push(0xff);
|
||||
/// Places one event in the account data of the user and removes the
|
||||
/// previous entry.
|
||||
#[tracing::instrument(skip(self, room_id, user_id, event_type, data))]
|
||||
fn update(
|
||||
&self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType,
|
||||
data: &serde_json::Value,
|
||||
) -> Result<()> {
|
||||
let mut prefix = room_id.map(std::string::ToString::to_string).unwrap_or_default().as_bytes().to_vec();
|
||||
prefix.push(0xFF);
|
||||
prefix.extend_from_slice(user_id.as_bytes());
|
||||
prefix.push(0xFF);
|
||||
|
||||
let mut roomuserdataid = prefix.clone();
|
||||
roomuserdataid.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
|
||||
roomuserdataid.push(0xff);
|
||||
roomuserdataid.extend_from_slice(event_type.to_string().as_bytes());
|
||||
let mut roomuserdataid = prefix.clone();
|
||||
roomuserdataid.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
|
||||
roomuserdataid.push(0xFF);
|
||||
roomuserdataid.extend_from_slice(event_type.to_string().as_bytes());
|
||||
|
||||
let mut key = prefix;
|
||||
key.extend_from_slice(event_type.to_string().as_bytes());
|
||||
let mut key = prefix;
|
||||
key.extend_from_slice(event_type.to_string().as_bytes());
|
||||
|
||||
if data.get("type").is_none() || data.get("content").is_none() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Account data doesn't have all required fields.",
|
||||
));
|
||||
}
|
||||
if data.get("type").is_none() || data.get("content").is_none() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Account data doesn't have all required fields.",
|
||||
));
|
||||
}
|
||||
|
||||
self.roomuserdataid_accountdata.insert(
|
||||
&roomuserdataid,
|
||||
&serde_json::to_vec(&data).expect("to_vec always works on json values"),
|
||||
)?;
|
||||
self.roomuserdataid_accountdata.insert(
|
||||
&roomuserdataid,
|
||||
&serde_json::to_vec(&data).expect("to_vec always works on json values"),
|
||||
)?;
|
||||
|
||||
let prev = self.roomusertype_roomuserdataid.get(&key)?;
|
||||
let prev = self.roomusertype_roomuserdataid.get(&key)?;
|
||||
|
||||
self.roomusertype_roomuserdataid
|
||||
.insert(&key, &roomuserdataid)?;
|
||||
self.roomusertype_roomuserdataid.insert(&key, &roomuserdataid)?;
|
||||
|
||||
// Remove old entry
|
||||
if let Some(prev) = prev {
|
||||
self.roomuserdataid_accountdata.remove(&prev)?;
|
||||
}
|
||||
// Remove old entry
|
||||
if let Some(prev) = prev {
|
||||
self.roomuserdataid_accountdata.remove(&prev)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Searches the account data for a specific kind.
|
||||
#[tracing::instrument(skip(self, room_id, user_id, kind))]
|
||||
fn get(
|
||||
&self,
|
||||
room_id: Option<&RoomId>,
|
||||
user_id: &UserId,
|
||||
kind: RoomAccountDataEventType,
|
||||
) -> Result<Option<Box<serde_json::value::RawValue>>> {
|
||||
let mut key = room_id
|
||||
.map(std::string::ToString::to_string)
|
||||
.unwrap_or_default()
|
||||
.as_bytes()
|
||||
.to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(user_id.as_bytes());
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(kind.to_string().as_bytes());
|
||||
/// Searches the account data for a specific kind.
|
||||
#[tracing::instrument(skip(self, room_id, user_id, kind))]
|
||||
fn get(
|
||||
&self, room_id: Option<&RoomId>, user_id: &UserId, kind: RoomAccountDataEventType,
|
||||
) -> Result<Option<Box<serde_json::value::RawValue>>> {
|
||||
let mut key = room_id.map(std::string::ToString::to_string).unwrap_or_default().as_bytes().to_vec();
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(user_id.as_bytes());
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(kind.to_string().as_bytes());
|
||||
|
||||
self.roomusertype_roomuserdataid
|
||||
.get(&key)?
|
||||
.and_then(|roomuserdataid| {
|
||||
self.roomuserdataid_accountdata
|
||||
.get(&roomuserdataid)
|
||||
.transpose()
|
||||
})
|
||||
.transpose()?
|
||||
.map(|data| {
|
||||
serde_json::from_slice(&data)
|
||||
.map_err(|_| Error::bad_database("could not deserialize"))
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
self.roomusertype_roomuserdataid
|
||||
.get(&key)?
|
||||
.and_then(|roomuserdataid| self.roomuserdataid_accountdata.get(&roomuserdataid).transpose())
|
||||
.transpose()?
|
||||
.map(|data| serde_json::from_slice(&data).map_err(|_| Error::bad_database("could not deserialize")))
|
||||
.transpose()
|
||||
}
|
||||
|
||||
/// Returns all changes to the account data that happened after `since`.
|
||||
#[tracing::instrument(skip(self, room_id, user_id, since))]
|
||||
fn changes_since(
|
||||
&self,
|
||||
room_id: Option<&RoomId>,
|
||||
user_id: &UserId,
|
||||
since: u64,
|
||||
) -> Result<HashMap<RoomAccountDataEventType, Raw<AnyEphemeralRoomEvent>>> {
|
||||
let mut userdata = HashMap::new();
|
||||
/// Returns all changes to the account data that happened after `since`.
|
||||
#[tracing::instrument(skip(self, room_id, user_id, since))]
|
||||
fn changes_since(
|
||||
&self, room_id: Option<&RoomId>, user_id: &UserId, since: u64,
|
||||
) -> Result<HashMap<RoomAccountDataEventType, Raw<AnyEphemeralRoomEvent>>> {
|
||||
let mut userdata = HashMap::new();
|
||||
|
||||
let mut prefix = room_id
|
||||
.map(std::string::ToString::to_string)
|
||||
.unwrap_or_default()
|
||||
.as_bytes()
|
||||
.to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.extend_from_slice(user_id.as_bytes());
|
||||
prefix.push(0xff);
|
||||
let mut prefix = room_id.map(std::string::ToString::to_string).unwrap_or_default().as_bytes().to_vec();
|
||||
prefix.push(0xFF);
|
||||
prefix.extend_from_slice(user_id.as_bytes());
|
||||
prefix.push(0xFF);
|
||||
|
||||
// Skip the data that's exactly at since, because we sent that last time
|
||||
let mut first_possible = prefix.clone();
|
||||
first_possible.extend_from_slice(&(since + 1).to_be_bytes());
|
||||
// Skip the data that's exactly at since, because we sent that last time
|
||||
let mut first_possible = prefix.clone();
|
||||
first_possible.extend_from_slice(&(since + 1).to_be_bytes());
|
||||
|
||||
for r in self
|
||||
.roomuserdataid_accountdata
|
||||
.iter_from(&first_possible, false)
|
||||
.take_while(move |(k, _)| k.starts_with(&prefix))
|
||||
.map(|(k, v)| {
|
||||
Ok::<_, Error>((
|
||||
RoomAccountDataEventType::from(
|
||||
utils::string_from_bytes(k.rsplit(|&b| b == 0xff).next().ok_or_else(
|
||||
|| Error::bad_database("RoomUserData ID in db is invalid."),
|
||||
)?)
|
||||
.map_err(|e| {
|
||||
warn!("RoomUserData ID in database is invalid: {}", e);
|
||||
Error::bad_database("RoomUserData ID in db is invalid.")
|
||||
})?,
|
||||
),
|
||||
serde_json::from_slice::<Raw<AnyEphemeralRoomEvent>>(&v).map_err(|_| {
|
||||
Error::bad_database("Database contains invalid account data.")
|
||||
})?,
|
||||
))
|
||||
})
|
||||
{
|
||||
let (kind, data) = r?;
|
||||
userdata.insert(kind, data);
|
||||
}
|
||||
for r in self
|
||||
.roomuserdataid_accountdata
|
||||
.iter_from(&first_possible, false)
|
||||
.take_while(move |(k, _)| k.starts_with(&prefix))
|
||||
.map(|(k, v)| {
|
||||
Ok::<_, Error>((
|
||||
RoomAccountDataEventType::from(
|
||||
utils::string_from_bytes(
|
||||
k.rsplit(|&b| b == 0xFF)
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("RoomUserData ID in db is invalid."))?,
|
||||
)
|
||||
.map_err(|e| {
|
||||
warn!("RoomUserData ID in database is invalid: {}", e);
|
||||
Error::bad_database("RoomUserData ID in db is invalid.")
|
||||
})?,
|
||||
),
|
||||
serde_json::from_slice::<Raw<AnyEphemeralRoomEvent>>(&v)
|
||||
.map_err(|_| Error::bad_database("Database contains invalid account data."))?,
|
||||
))
|
||||
}) {
|
||||
let (kind, data) = r?;
|
||||
userdata.insert(kind, data);
|
||||
}
|
||||
|
||||
Ok(userdata)
|
||||
}
|
||||
Ok(userdata)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,78 +3,58 @@ use ruma::api::appservice::Registration;
|
||||
use crate::{database::KeyValueDatabase, service, utils, Error, Result};
|
||||
|
||||
impl service::appservice::Data for KeyValueDatabase {
|
||||
/// Registers an appservice and returns the ID to the caller
|
||||
fn register_appservice(&self, yaml: Registration) -> Result<String> {
|
||||
let id = yaml.id.as_str();
|
||||
self.id_appserviceregistrations.insert(
|
||||
id.as_bytes(),
|
||||
serde_yaml::to_string(&yaml).unwrap().as_bytes(),
|
||||
)?;
|
||||
self.cached_registrations
|
||||
.write()
|
||||
.unwrap()
|
||||
.insert(id.to_owned(), yaml.clone());
|
||||
/// Registers an appservice and returns the ID to the caller
|
||||
fn register_appservice(&self, yaml: Registration) -> Result<String> {
|
||||
let id = yaml.id.as_str();
|
||||
self.id_appserviceregistrations.insert(id.as_bytes(), serde_yaml::to_string(&yaml).unwrap().as_bytes())?;
|
||||
self.cached_registrations.write().unwrap().insert(id.to_owned(), yaml.clone());
|
||||
|
||||
Ok(id.to_owned())
|
||||
}
|
||||
Ok(id.to_owned())
|
||||
}
|
||||
|
||||
/// Remove an appservice registration
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `service_name` - the name you send to register the service previously
|
||||
fn unregister_appservice(&self, service_name: &str) -> Result<()> {
|
||||
self.id_appserviceregistrations
|
||||
.remove(service_name.as_bytes())?;
|
||||
self.cached_registrations
|
||||
.write()
|
||||
.unwrap()
|
||||
.remove(service_name);
|
||||
Ok(())
|
||||
}
|
||||
/// Remove an appservice registration
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `service_name` - the name you send to register the service previously
|
||||
fn unregister_appservice(&self, service_name: &str) -> Result<()> {
|
||||
self.id_appserviceregistrations.remove(service_name.as_bytes())?;
|
||||
self.cached_registrations.write().unwrap().remove(service_name);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_registration(&self, id: &str) -> Result<Option<Registration>> {
|
||||
self.cached_registrations
|
||||
.read()
|
||||
.unwrap()
|
||||
.get(id)
|
||||
.map_or_else(
|
||||
|| {
|
||||
self.id_appserviceregistrations
|
||||
.get(id.as_bytes())?
|
||||
.map(|bytes| {
|
||||
serde_yaml::from_slice(&bytes).map_err(|_| {
|
||||
Error::bad_database(
|
||||
"Invalid registration bytes in id_appserviceregistrations.",
|
||||
)
|
||||
})
|
||||
})
|
||||
.transpose()
|
||||
},
|
||||
|r| Ok(Some(r.clone())),
|
||||
)
|
||||
}
|
||||
fn get_registration(&self, id: &str) -> Result<Option<Registration>> {
|
||||
self.cached_registrations.read().unwrap().get(id).map_or_else(
|
||||
|| {
|
||||
self.id_appserviceregistrations
|
||||
.get(id.as_bytes())?
|
||||
.map(|bytes| {
|
||||
serde_yaml::from_slice(&bytes).map_err(|_| {
|
||||
Error::bad_database("Invalid registration bytes in id_appserviceregistrations.")
|
||||
})
|
||||
})
|
||||
.transpose()
|
||||
},
|
||||
|r| Ok(Some(r.clone())),
|
||||
)
|
||||
}
|
||||
|
||||
fn iter_ids<'a>(&'a self) -> Result<Box<dyn Iterator<Item = Result<String>> + 'a>> {
|
||||
Ok(Box::new(self.id_appserviceregistrations.iter().map(
|
||||
|(id, _)| {
|
||||
utils::string_from_bytes(&id).map_err(|_| {
|
||||
Error::bad_database("Invalid id bytes in id_appserviceregistrations.")
|
||||
})
|
||||
},
|
||||
)))
|
||||
}
|
||||
fn iter_ids<'a>(&'a self) -> Result<Box<dyn Iterator<Item = Result<String>> + 'a>> {
|
||||
Ok(Box::new(self.id_appserviceregistrations.iter().map(|(id, _)| {
|
||||
utils::string_from_bytes(&id)
|
||||
.map_err(|_| Error::bad_database("Invalid id bytes in id_appserviceregistrations."))
|
||||
})))
|
||||
}
|
||||
|
||||
fn all(&self) -> Result<Vec<(String, Registration)>> {
|
||||
self.iter_ids()?
|
||||
.filter_map(std::result::Result::ok)
|
||||
.map(move |id| {
|
||||
Ok((
|
||||
id.clone(),
|
||||
self.get_registration(&id)?
|
||||
.expect("iter_ids only returns appservices that exist"),
|
||||
))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
fn all(&self) -> Result<Vec<(String, Registration)>> {
|
||||
self.iter_ids()?
|
||||
.filter_map(std::result::Result::ok)
|
||||
.map(move |id| {
|
||||
Ok((
|
||||
id.clone(),
|
||||
self.get_registration(&id)?.expect("iter_ids only returns appservices that exist"),
|
||||
))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
+205
-244
@@ -4,9 +4,9 @@ use async_trait::async_trait;
|
||||
use futures_util::{stream::FuturesUnordered, StreamExt};
|
||||
use lru_cache::LruCache;
|
||||
use ruma::{
|
||||
api::federation::discovery::{ServerSigningKeys, VerifyKey},
|
||||
signatures::Ed25519KeyPair,
|
||||
DeviceId, MilliSecondsSinceUnixEpoch, OwnedServerSigningKeyId, ServerName, UserId,
|
||||
api::federation::discovery::{ServerSigningKeys, VerifyKey},
|
||||
signatures::Ed25519KeyPair,
|
||||
DeviceId, MilliSecondsSinceUnixEpoch, OwnedServerSigningKeyId, ServerName, UserId,
|
||||
};
|
||||
|
||||
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
|
||||
@@ -16,139 +16,118 @@ const LAST_CHECK_FOR_UPDATES_COUNT: &[u8] = b"u";
|
||||
|
||||
#[async_trait]
|
||||
impl service::globals::Data for KeyValueDatabase {
|
||||
fn next_count(&self) -> Result<u64> {
|
||||
utils::u64_from_bytes(&self.global.increment(COUNTER)?)
|
||||
.map_err(|_| Error::bad_database("Count has invalid bytes."))
|
||||
}
|
||||
fn next_count(&self) -> Result<u64> {
|
||||
utils::u64_from_bytes(&self.global.increment(COUNTER)?)
|
||||
.map_err(|_| Error::bad_database("Count has invalid bytes."))
|
||||
}
|
||||
|
||||
fn current_count(&self) -> Result<u64> {
|
||||
self.global.get(COUNTER)?.map_or(Ok(0_u64), |bytes| {
|
||||
utils::u64_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Count has invalid bytes."))
|
||||
})
|
||||
}
|
||||
fn current_count(&self) -> Result<u64> {
|
||||
self.global.get(COUNTER)?.map_or(Ok(0_u64), |bytes| {
|
||||
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Count has invalid bytes."))
|
||||
})
|
||||
}
|
||||
|
||||
fn last_check_for_updates_id(&self) -> Result<u64> {
|
||||
self.global
|
||||
.get(LAST_CHECK_FOR_UPDATES_COUNT)?
|
||||
.map_or(Ok(0_u64), |bytes| {
|
||||
utils::u64_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("last check for updates count has invalid bytes.")
|
||||
})
|
||||
})
|
||||
}
|
||||
fn last_check_for_updates_id(&self) -> Result<u64> {
|
||||
self.global.get(LAST_CHECK_FOR_UPDATES_COUNT)?.map_or(Ok(0_u64), |bytes| {
|
||||
utils::u64_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("last check for updates count has invalid bytes."))
|
||||
})
|
||||
}
|
||||
|
||||
fn update_check_for_updates_id(&self, id: u64) -> Result<()> {
|
||||
self.global
|
||||
.insert(LAST_CHECK_FOR_UPDATES_COUNT, &id.to_be_bytes())?;
|
||||
fn update_check_for_updates_id(&self, id: u64) -> Result<()> {
|
||||
self.global.insert(LAST_CHECK_FOR_UPDATES_COUNT, &id.to_be_bytes())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> {
|
||||
let userid_bytes = user_id.as_bytes().to_vec();
|
||||
let mut userid_prefix = userid_bytes.clone();
|
||||
userid_prefix.push(0xff);
|
||||
async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> {
|
||||
let userid_bytes = user_id.as_bytes().to_vec();
|
||||
let mut userid_prefix = userid_bytes.clone();
|
||||
userid_prefix.push(0xFF);
|
||||
|
||||
let mut userdeviceid_prefix = userid_prefix.clone();
|
||||
userdeviceid_prefix.extend_from_slice(device_id.as_bytes());
|
||||
userdeviceid_prefix.push(0xff);
|
||||
let mut userdeviceid_prefix = userid_prefix.clone();
|
||||
userdeviceid_prefix.extend_from_slice(device_id.as_bytes());
|
||||
userdeviceid_prefix.push(0xFF);
|
||||
|
||||
let mut futures = FuturesUnordered::new();
|
||||
let mut futures = FuturesUnordered::new();
|
||||
|
||||
// Return when *any* user changed his key
|
||||
// TODO: only send for user they share a room with
|
||||
futures.push(self.todeviceid_events.watch_prefix(&userdeviceid_prefix));
|
||||
// Return when *any* user changed his key
|
||||
// TODO: only send for user they share a room with
|
||||
futures.push(self.todeviceid_events.watch_prefix(&userdeviceid_prefix));
|
||||
|
||||
futures.push(self.userroomid_joined.watch_prefix(&userid_prefix));
|
||||
futures.push(self.userroomid_invitestate.watch_prefix(&userid_prefix));
|
||||
futures.push(self.userroomid_leftstate.watch_prefix(&userid_prefix));
|
||||
futures.push(
|
||||
self.userroomid_notificationcount
|
||||
.watch_prefix(&userid_prefix),
|
||||
);
|
||||
futures.push(self.userroomid_highlightcount.watch_prefix(&userid_prefix));
|
||||
futures.push(self.userroomid_joined.watch_prefix(&userid_prefix));
|
||||
futures.push(self.userroomid_invitestate.watch_prefix(&userid_prefix));
|
||||
futures.push(self.userroomid_leftstate.watch_prefix(&userid_prefix));
|
||||
futures.push(self.userroomid_notificationcount.watch_prefix(&userid_prefix));
|
||||
futures.push(self.userroomid_highlightcount.watch_prefix(&userid_prefix));
|
||||
|
||||
// Events for rooms we are in
|
||||
for room_id in services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.rooms_joined(user_id)
|
||||
.filter_map(std::result::Result::ok)
|
||||
{
|
||||
let short_roomid = services()
|
||||
.rooms
|
||||
.short
|
||||
.get_shortroomid(&room_id)
|
||||
.ok()
|
||||
.flatten()
|
||||
.expect("room exists")
|
||||
.to_be_bytes()
|
||||
.to_vec();
|
||||
// Events for rooms we are in
|
||||
for room_id in services().rooms.state_cache.rooms_joined(user_id).filter_map(std::result::Result::ok) {
|
||||
let short_roomid = services()
|
||||
.rooms
|
||||
.short
|
||||
.get_shortroomid(&room_id)
|
||||
.ok()
|
||||
.flatten()
|
||||
.expect("room exists")
|
||||
.to_be_bytes()
|
||||
.to_vec();
|
||||
|
||||
let roomid_bytes = room_id.as_bytes().to_vec();
|
||||
let mut roomid_prefix = roomid_bytes.clone();
|
||||
roomid_prefix.push(0xff);
|
||||
let roomid_bytes = room_id.as_bytes().to_vec();
|
||||
let mut roomid_prefix = roomid_bytes.clone();
|
||||
roomid_prefix.push(0xFF);
|
||||
|
||||
// PDUs
|
||||
futures.push(self.pduid_pdu.watch_prefix(&short_roomid));
|
||||
// PDUs
|
||||
futures.push(self.pduid_pdu.watch_prefix(&short_roomid));
|
||||
|
||||
// EDUs
|
||||
futures.push(self.roomid_lasttypingupdate.watch_prefix(&roomid_bytes));
|
||||
// EDUs
|
||||
futures.push(self.roomid_lasttypingupdate.watch_prefix(&roomid_bytes));
|
||||
|
||||
futures.push(self.readreceiptid_readreceipt.watch_prefix(&roomid_prefix));
|
||||
futures.push(self.readreceiptid_readreceipt.watch_prefix(&roomid_prefix));
|
||||
|
||||
// Key changes
|
||||
futures.push(self.keychangeid_userid.watch_prefix(&roomid_prefix));
|
||||
// Key changes
|
||||
futures.push(self.keychangeid_userid.watch_prefix(&roomid_prefix));
|
||||
|
||||
// Room account data
|
||||
let mut roomuser_prefix = roomid_prefix.clone();
|
||||
roomuser_prefix.extend_from_slice(&userid_prefix);
|
||||
// Room account data
|
||||
let mut roomuser_prefix = roomid_prefix.clone();
|
||||
roomuser_prefix.extend_from_slice(&userid_prefix);
|
||||
|
||||
futures.push(
|
||||
self.roomusertype_roomuserdataid
|
||||
.watch_prefix(&roomuser_prefix),
|
||||
);
|
||||
}
|
||||
futures.push(self.roomusertype_roomuserdataid.watch_prefix(&roomuser_prefix));
|
||||
}
|
||||
|
||||
let mut globaluserdata_prefix = vec![0xff];
|
||||
globaluserdata_prefix.extend_from_slice(&userid_prefix);
|
||||
let mut globaluserdata_prefix = vec![0xFF];
|
||||
globaluserdata_prefix.extend_from_slice(&userid_prefix);
|
||||
|
||||
futures.push(
|
||||
self.roomusertype_roomuserdataid
|
||||
.watch_prefix(&globaluserdata_prefix),
|
||||
);
|
||||
futures.push(self.roomusertype_roomuserdataid.watch_prefix(&globaluserdata_prefix));
|
||||
|
||||
// More key changes (used when user is not joined to any rooms)
|
||||
futures.push(self.keychangeid_userid.watch_prefix(&userid_prefix));
|
||||
// More key changes (used when user is not joined to any rooms)
|
||||
futures.push(self.keychangeid_userid.watch_prefix(&userid_prefix));
|
||||
|
||||
// One time keys
|
||||
futures.push(self.userid_lastonetimekeyupdate.watch_prefix(&userid_bytes));
|
||||
// One time keys
|
||||
futures.push(self.userid_lastonetimekeyupdate.watch_prefix(&userid_bytes));
|
||||
|
||||
futures.push(Box::pin(services().globals.rotate.watch()));
|
||||
futures.push(Box::pin(services().globals.rotate.watch()));
|
||||
|
||||
// Wait until one of them finds something
|
||||
futures.next().await;
|
||||
// Wait until one of them finds something
|
||||
futures.next().await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> Result<()> {
|
||||
self.db.cleanup()
|
||||
}
|
||||
fn cleanup(&self) -> Result<()> { self.db.cleanup() }
|
||||
|
||||
fn memory_usage(&self) -> String {
|
||||
let pdu_cache = self.pdu_cache.lock().unwrap().len();
|
||||
let shorteventid_cache = self.shorteventid_cache.lock().unwrap().len();
|
||||
let auth_chain_cache = self.auth_chain_cache.lock().unwrap().len();
|
||||
let eventidshort_cache = self.eventidshort_cache.lock().unwrap().len();
|
||||
let statekeyshort_cache = self.statekeyshort_cache.lock().unwrap().len();
|
||||
let our_real_users_cache = self.our_real_users_cache.read().unwrap().len();
|
||||
let appservice_in_room_cache = self.appservice_in_room_cache.read().unwrap().len();
|
||||
let lasttimelinecount_cache = self.lasttimelinecount_cache.lock().unwrap().len();
|
||||
fn memory_usage(&self) -> String {
|
||||
let pdu_cache = self.pdu_cache.lock().unwrap().len();
|
||||
let shorteventid_cache = self.shorteventid_cache.lock().unwrap().len();
|
||||
let auth_chain_cache = self.auth_chain_cache.lock().unwrap().len();
|
||||
let eventidshort_cache = self.eventidshort_cache.lock().unwrap().len();
|
||||
let statekeyshort_cache = self.statekeyshort_cache.lock().unwrap().len();
|
||||
let our_real_users_cache = self.our_real_users_cache.read().unwrap().len();
|
||||
let appservice_in_room_cache = self.appservice_in_room_cache.read().unwrap().len();
|
||||
let lasttimelinecount_cache = self.lasttimelinecount_cache.lock().unwrap().len();
|
||||
|
||||
let mut response = format!(
|
||||
"\
|
||||
let mut response = format!(
|
||||
"\
|
||||
pdu_cache: {pdu_cache}
|
||||
shorteventid_cache: {shorteventid_cache}
|
||||
auth_chain_cache: {auth_chain_cache}
|
||||
@@ -157,155 +136,137 @@ statekeyshort_cache: {statekeyshort_cache}
|
||||
our_real_users_cache: {our_real_users_cache}
|
||||
appservice_in_room_cache: {appservice_in_room_cache}
|
||||
lasttimelinecount_cache: {lasttimelinecount_cache}\n"
|
||||
);
|
||||
if let Ok(db_stats) = self.db.memory_usage() {
|
||||
response += &db_stats;
|
||||
}
|
||||
);
|
||||
if let Ok(db_stats) = self.db.memory_usage() {
|
||||
response += &db_stats;
|
||||
}
|
||||
|
||||
response
|
||||
}
|
||||
response
|
||||
}
|
||||
|
||||
fn clear_caches(&self, amount: u32) {
|
||||
if amount > 0 {
|
||||
let c = &mut *self.pdu_cache.lock().unwrap();
|
||||
*c = LruCache::new(c.capacity());
|
||||
}
|
||||
if amount > 1 {
|
||||
let c = &mut *self.shorteventid_cache.lock().unwrap();
|
||||
*c = LruCache::new(c.capacity());
|
||||
}
|
||||
if amount > 2 {
|
||||
let c = &mut *self.auth_chain_cache.lock().unwrap();
|
||||
*c = LruCache::new(c.capacity());
|
||||
}
|
||||
if amount > 3 {
|
||||
let c = &mut *self.eventidshort_cache.lock().unwrap();
|
||||
*c = LruCache::new(c.capacity());
|
||||
}
|
||||
if amount > 4 {
|
||||
let c = &mut *self.statekeyshort_cache.lock().unwrap();
|
||||
*c = LruCache::new(c.capacity());
|
||||
}
|
||||
if amount > 5 {
|
||||
let c = &mut *self.our_real_users_cache.write().unwrap();
|
||||
*c = HashMap::new();
|
||||
}
|
||||
if amount > 6 {
|
||||
let c = &mut *self.appservice_in_room_cache.write().unwrap();
|
||||
*c = HashMap::new();
|
||||
}
|
||||
if amount > 7 {
|
||||
let c = &mut *self.lasttimelinecount_cache.lock().unwrap();
|
||||
*c = HashMap::new();
|
||||
}
|
||||
}
|
||||
fn clear_caches(&self, amount: u32) {
|
||||
if amount > 0 {
|
||||
let c = &mut *self.pdu_cache.lock().unwrap();
|
||||
*c = LruCache::new(c.capacity());
|
||||
}
|
||||
if amount > 1 {
|
||||
let c = &mut *self.shorteventid_cache.lock().unwrap();
|
||||
*c = LruCache::new(c.capacity());
|
||||
}
|
||||
if amount > 2 {
|
||||
let c = &mut *self.auth_chain_cache.lock().unwrap();
|
||||
*c = LruCache::new(c.capacity());
|
||||
}
|
||||
if amount > 3 {
|
||||
let c = &mut *self.eventidshort_cache.lock().unwrap();
|
||||
*c = LruCache::new(c.capacity());
|
||||
}
|
||||
if amount > 4 {
|
||||
let c = &mut *self.statekeyshort_cache.lock().unwrap();
|
||||
*c = LruCache::new(c.capacity());
|
||||
}
|
||||
if amount > 5 {
|
||||
let c = &mut *self.our_real_users_cache.write().unwrap();
|
||||
*c = HashMap::new();
|
||||
}
|
||||
if amount > 6 {
|
||||
let c = &mut *self.appservice_in_room_cache.write().unwrap();
|
||||
*c = HashMap::new();
|
||||
}
|
||||
if amount > 7 {
|
||||
let c = &mut *self.lasttimelinecount_cache.lock().unwrap();
|
||||
*c = HashMap::new();
|
||||
}
|
||||
}
|
||||
|
||||
fn load_keypair(&self) -> Result<Ed25519KeyPair> {
|
||||
let keypair_bytes = self.global.get(b"keypair")?.map_or_else(
|
||||
|| {
|
||||
let keypair = utils::generate_keypair();
|
||||
self.global.insert(b"keypair", &keypair)?;
|
||||
Ok::<_, Error>(keypair)
|
||||
},
|
||||
Ok,
|
||||
)?;
|
||||
fn load_keypair(&self) -> Result<Ed25519KeyPair> {
|
||||
let keypair_bytes = self.global.get(b"keypair")?.map_or_else(
|
||||
|| {
|
||||
let keypair = utils::generate_keypair();
|
||||
self.global.insert(b"keypair", &keypair)?;
|
||||
Ok::<_, Error>(keypair)
|
||||
},
|
||||
Ok,
|
||||
)?;
|
||||
|
||||
let mut parts = keypair_bytes.splitn(2, |&b| b == 0xff);
|
||||
let mut parts = keypair_bytes.splitn(2, |&b| b == 0xFF);
|
||||
|
||||
utils::string_from_bytes(
|
||||
// 1. version
|
||||
parts
|
||||
.next()
|
||||
.expect("splitn always returns at least one element"),
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Invalid version bytes in keypair."))
|
||||
.and_then(|version| {
|
||||
// 2. key
|
||||
parts
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("Invalid keypair format in database."))
|
||||
.map(|key| (version, key))
|
||||
})
|
||||
.and_then(|(version, key)| {
|
||||
Ed25519KeyPair::from_der(key, version)
|
||||
.map_err(|_| Error::bad_database("Private or public keys are invalid."))
|
||||
})
|
||||
}
|
||||
fn remove_keypair(&self) -> Result<()> {
|
||||
self.global.remove(b"keypair")
|
||||
}
|
||||
utils::string_from_bytes(
|
||||
// 1. version
|
||||
parts.next().expect("splitn always returns at least one element"),
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Invalid version bytes in keypair."))
|
||||
.and_then(|version| {
|
||||
// 2. key
|
||||
parts
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("Invalid keypair format in database."))
|
||||
.map(|key| (version, key))
|
||||
})
|
||||
.and_then(|(version, key)| {
|
||||
Ed25519KeyPair::from_der(key, version)
|
||||
.map_err(|_| Error::bad_database("Private or public keys are invalid."))
|
||||
})
|
||||
}
|
||||
|
||||
fn add_signing_key(
|
||||
&self,
|
||||
origin: &ServerName,
|
||||
new_keys: ServerSigningKeys,
|
||||
) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> {
|
||||
// Not atomic, but this is not critical
|
||||
let signingkeys = self.server_signingkeys.get(origin.as_bytes())?;
|
||||
fn remove_keypair(&self) -> Result<()> { self.global.remove(b"keypair") }
|
||||
|
||||
let mut keys = signingkeys
|
||||
.and_then(|keys| serde_json::from_slice(&keys).ok())
|
||||
.unwrap_or_else(|| {
|
||||
// Just insert "now", it doesn't matter
|
||||
ServerSigningKeys::new(origin.to_owned(), MilliSecondsSinceUnixEpoch::now())
|
||||
});
|
||||
fn add_signing_key(
|
||||
&self, origin: &ServerName, new_keys: ServerSigningKeys,
|
||||
) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> {
|
||||
// Not atomic, but this is not critical
|
||||
let signingkeys = self.server_signingkeys.get(origin.as_bytes())?;
|
||||
|
||||
let ServerSigningKeys {
|
||||
verify_keys,
|
||||
old_verify_keys,
|
||||
..
|
||||
} = new_keys;
|
||||
let mut keys = signingkeys.and_then(|keys| serde_json::from_slice(&keys).ok()).unwrap_or_else(|| {
|
||||
// Just insert "now", it doesn't matter
|
||||
ServerSigningKeys::new(origin.to_owned(), MilliSecondsSinceUnixEpoch::now())
|
||||
});
|
||||
|
||||
keys.verify_keys.extend(verify_keys);
|
||||
keys.old_verify_keys.extend(old_verify_keys);
|
||||
let ServerSigningKeys {
|
||||
verify_keys,
|
||||
old_verify_keys,
|
||||
..
|
||||
} = new_keys;
|
||||
|
||||
self.server_signingkeys.insert(
|
||||
origin.as_bytes(),
|
||||
&serde_json::to_vec(&keys).expect("serversigningkeys can be serialized"),
|
||||
)?;
|
||||
keys.verify_keys.extend(verify_keys);
|
||||
keys.old_verify_keys.extend(old_verify_keys);
|
||||
|
||||
let mut tree = keys.verify_keys;
|
||||
tree.extend(
|
||||
keys.old_verify_keys
|
||||
.into_iter()
|
||||
.map(|old| (old.0, VerifyKey::new(old.1.key))),
|
||||
);
|
||||
self.server_signingkeys.insert(
|
||||
origin.as_bytes(),
|
||||
&serde_json::to_vec(&keys).expect("serversigningkeys can be serialized"),
|
||||
)?;
|
||||
|
||||
Ok(tree)
|
||||
}
|
||||
let mut tree = keys.verify_keys;
|
||||
tree.extend(keys.old_verify_keys.into_iter().map(|old| (old.0, VerifyKey::new(old.1.key))));
|
||||
|
||||
/// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found for the server.
|
||||
fn signing_keys_for(
|
||||
&self,
|
||||
origin: &ServerName,
|
||||
) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> {
|
||||
let signingkeys = self
|
||||
.server_signingkeys
|
||||
.get(origin.as_bytes())?
|
||||
.and_then(|bytes| serde_json::from_slice(&bytes).ok())
|
||||
.map(|keys: ServerSigningKeys| {
|
||||
let mut tree = keys.verify_keys;
|
||||
tree.extend(
|
||||
keys.old_verify_keys
|
||||
.into_iter()
|
||||
.map(|old| (old.0, VerifyKey::new(old.1.key))),
|
||||
);
|
||||
tree
|
||||
})
|
||||
.unwrap_or_else(BTreeMap::new);
|
||||
Ok(tree)
|
||||
}
|
||||
|
||||
Ok(signingkeys)
|
||||
}
|
||||
/// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found
|
||||
/// for the server.
|
||||
fn signing_keys_for(&self, origin: &ServerName) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> {
|
||||
let signingkeys = self
|
||||
.server_signingkeys
|
||||
.get(origin.as_bytes())?
|
||||
.and_then(|bytes| serde_json::from_slice(&bytes).ok())
|
||||
.map(|keys: ServerSigningKeys| {
|
||||
let mut tree = keys.verify_keys;
|
||||
tree.extend(keys.old_verify_keys.into_iter().map(|old| (old.0, VerifyKey::new(old.1.key))));
|
||||
tree
|
||||
})
|
||||
.unwrap_or_else(BTreeMap::new);
|
||||
|
||||
fn database_version(&self) -> Result<u64> {
|
||||
self.global.get(b"version")?.map_or(Ok(0), |version| {
|
||||
utils::u64_from_bytes(&version)
|
||||
.map_err(|_| Error::bad_database("Database version id is invalid."))
|
||||
})
|
||||
}
|
||||
Ok(signingkeys)
|
||||
}
|
||||
|
||||
fn bump_database_version(&self, new_version: u64) -> Result<()> {
|
||||
self.global.insert(b"version", &new_version.to_be_bytes())?;
|
||||
Ok(())
|
||||
}
|
||||
fn database_version(&self) -> Result<u64> {
|
||||
self.global.get(b"version")?.map_or(Ok(0), |version| {
|
||||
utils::u64_from_bytes(&version).map_err(|_| Error::bad_database("Database version id is invalid."))
|
||||
})
|
||||
}
|
||||
|
||||
fn bump_database_version(&self, new_version: u64) -> Result<()> {
|
||||
self.global.insert(b"version", &new_version.to_be_bytes())?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,364 +1,292 @@
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
use ruma::{
|
||||
api::client::{
|
||||
backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup},
|
||||
error::ErrorKind,
|
||||
},
|
||||
serde::Raw,
|
||||
OwnedRoomId, RoomId, UserId,
|
||||
api::client::{
|
||||
backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup},
|
||||
error::ErrorKind,
|
||||
},
|
||||
serde::Raw,
|
||||
OwnedRoomId, RoomId, UserId,
|
||||
};
|
||||
|
||||
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
|
||||
|
||||
impl service::key_backups::Data for KeyValueDatabase {
|
||||
fn create_backup(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
backup_metadata: &Raw<BackupAlgorithm>,
|
||||
) -> Result<String> {
|
||||
let version = services().globals.next_count()?.to_string();
|
||||
fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw<BackupAlgorithm>) -> Result<String> {
|
||||
let version = services().globals.next_count()?.to_string();
|
||||
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
|
||||
self.backupid_algorithm.insert(
|
||||
&key,
|
||||
&serde_json::to_vec(backup_metadata).expect("BackupAlgorithm::to_vec always works"),
|
||||
)?;
|
||||
self.backupid_etag
|
||||
.insert(&key, &services().globals.next_count()?.to_be_bytes())?;
|
||||
Ok(version)
|
||||
}
|
||||
self.backupid_algorithm.insert(
|
||||
&key,
|
||||
&serde_json::to_vec(backup_metadata).expect("BackupAlgorithm::to_vec always works"),
|
||||
)?;
|
||||
self.backupid_etag.insert(&key, &services().globals.next_count()?.to_be_bytes())?;
|
||||
Ok(version)
|
||||
}
|
||||
|
||||
fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
|
||||
self.backupid_algorithm.remove(&key)?;
|
||||
self.backupid_etag.remove(&key)?;
|
||||
self.backupid_algorithm.remove(&key)?;
|
||||
self.backupid_etag.remove(&key)?;
|
||||
|
||||
key.push(0xff);
|
||||
key.push(0xFF);
|
||||
|
||||
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
|
||||
self.backupkeyid_backup.remove(&outdated_key)?;
|
||||
}
|
||||
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
|
||||
self.backupkeyid_backup.remove(&outdated_key)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn update_backup(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
backup_metadata: &Raw<BackupAlgorithm>,
|
||||
) -> Result<String> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
fn update_backup(&self, user_id: &UserId, version: &str, backup_metadata: &Raw<BackupAlgorithm>) -> Result<String> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
|
||||
if self.backupid_algorithm.get(&key)?.is_none() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"Tried to update nonexistent backup.",
|
||||
));
|
||||
}
|
||||
if self.backupid_algorithm.get(&key)?.is_none() {
|
||||
return Err(Error::BadRequest(ErrorKind::NotFound, "Tried to update nonexistent backup."));
|
||||
}
|
||||
|
||||
self.backupid_algorithm
|
||||
.insert(&key, backup_metadata.json().get().as_bytes())?;
|
||||
self.backupid_etag
|
||||
.insert(&key, &services().globals.next_count()?.to_be_bytes())?;
|
||||
Ok(version.to_owned())
|
||||
}
|
||||
self.backupid_algorithm.insert(&key, backup_metadata.json().get().as_bytes())?;
|
||||
self.backupid_etag.insert(&key, &services().globals.next_count()?.to_be_bytes())?;
|
||||
Ok(version.to_owned())
|
||||
}
|
||||
|
||||
fn get_latest_backup_version(&self, user_id: &UserId) -> Result<Option<String>> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
let mut last_possible_key = prefix.clone();
|
||||
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
|
||||
fn get_latest_backup_version(&self, user_id: &UserId) -> Result<Option<String>> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xFF);
|
||||
let mut last_possible_key = prefix.clone();
|
||||
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
|
||||
|
||||
self.backupid_algorithm
|
||||
.iter_from(&last_possible_key, true)
|
||||
.take_while(move |(k, _)| k.starts_with(&prefix))
|
||||
.next()
|
||||
.map(|(key, _)| {
|
||||
utils::string_from_bytes(
|
||||
key.rsplit(|&b| b == 0xff)
|
||||
.next()
|
||||
.expect("rsplit always returns an element"),
|
||||
)
|
||||
.map_err(|_| Error::bad_database("backupid_algorithm key is invalid."))
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
self.backupid_algorithm
|
||||
.iter_from(&last_possible_key, true)
|
||||
.take_while(move |(k, _)| k.starts_with(&prefix))
|
||||
.next()
|
||||
.map(|(key, _)| {
|
||||
utils::string_from_bytes(key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element"))
|
||||
.map_err(|_| Error::bad_database("backupid_algorithm key is invalid."))
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
|
||||
fn get_latest_backup(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
) -> Result<Option<(String, Raw<BackupAlgorithm>)>> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
let mut last_possible_key = prefix.clone();
|
||||
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
|
||||
fn get_latest_backup(&self, user_id: &UserId) -> Result<Option<(String, Raw<BackupAlgorithm>)>> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xFF);
|
||||
let mut last_possible_key = prefix.clone();
|
||||
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
|
||||
|
||||
self.backupid_algorithm
|
||||
.iter_from(&last_possible_key, true)
|
||||
.take_while(move |(k, _)| k.starts_with(&prefix))
|
||||
.next()
|
||||
.map(|(key, value)| {
|
||||
let version = utils::string_from_bytes(
|
||||
key.rsplit(|&b| b == 0xff)
|
||||
.next()
|
||||
.expect("rsplit always returns an element"),
|
||||
)
|
||||
.map_err(|_| Error::bad_database("backupid_algorithm key is invalid."))?;
|
||||
self.backupid_algorithm
|
||||
.iter_from(&last_possible_key, true)
|
||||
.take_while(move |(k, _)| k.starts_with(&prefix))
|
||||
.next()
|
||||
.map(|(key, value)| {
|
||||
let version = utils::string_from_bytes(
|
||||
key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element"),
|
||||
)
|
||||
.map_err(|_| Error::bad_database("backupid_algorithm key is invalid."))?;
|
||||
|
||||
Ok((
|
||||
version,
|
||||
serde_json::from_slice(&value).map_err(|_| {
|
||||
Error::bad_database("Algorithm in backupid_algorithm is invalid.")
|
||||
})?,
|
||||
))
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
Ok((
|
||||
version,
|
||||
serde_json::from_slice(&value)
|
||||
.map_err(|_| Error::bad_database("Algorithm in backupid_algorithm is invalid."))?,
|
||||
))
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
|
||||
fn get_backup(&self, user_id: &UserId, version: &str) -> Result<Option<Raw<BackupAlgorithm>>> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
fn get_backup(&self, user_id: &UserId, version: &str) -> Result<Option<Raw<BackupAlgorithm>>> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
|
||||
self.backupid_algorithm
|
||||
.get(&key)?
|
||||
.map_or(Ok(None), |bytes| {
|
||||
serde_json::from_slice(&bytes)
|
||||
.map_err(|_| Error::bad_database("Algorithm in backupid_algorithm is invalid."))
|
||||
})
|
||||
}
|
||||
self.backupid_algorithm.get(&key)?.map_or(Ok(None), |bytes| {
|
||||
serde_json::from_slice(&bytes)
|
||||
.map_err(|_| Error::bad_database("Algorithm in backupid_algorithm is invalid."))
|
||||
})
|
||||
}
|
||||
|
||||
fn add_key(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
room_id: &RoomId,
|
||||
session_id: &str,
|
||||
key_data: &Raw<KeyBackupData>,
|
||||
) -> Result<()> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
fn add_key(
|
||||
&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, key_data: &Raw<KeyBackupData>,
|
||||
) -> Result<()> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
|
||||
if self.backupid_algorithm.get(&key)?.is_none() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"Tried to update nonexistent backup.",
|
||||
));
|
||||
}
|
||||
if self.backupid_algorithm.get(&key)?.is_none() {
|
||||
return Err(Error::BadRequest(ErrorKind::NotFound, "Tried to update nonexistent backup."));
|
||||
}
|
||||
|
||||
self.backupid_etag
|
||||
.insert(&key, &services().globals.next_count()?.to_be_bytes())?;
|
||||
self.backupid_etag.insert(&key, &services().globals.next_count()?.to_be_bytes())?;
|
||||
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(room_id.as_bytes());
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(session_id.as_bytes());
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(room_id.as_bytes());
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(session_id.as_bytes());
|
||||
|
||||
self.backupkeyid_backup
|
||||
.insert(&key, key_data.json().get().as_bytes())?;
|
||||
self.backupkeyid_backup.insert(&key, key_data.json().get().as_bytes())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn count_keys(&self, user_id: &UserId, version: &str) -> Result<usize> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.extend_from_slice(version.as_bytes());
|
||||
fn count_keys(&self, user_id: &UserId, version: &str) -> Result<usize> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xFF);
|
||||
prefix.extend_from_slice(version.as_bytes());
|
||||
|
||||
Ok(self.backupkeyid_backup.scan_prefix(prefix).count())
|
||||
}
|
||||
Ok(self.backupkeyid_backup.scan_prefix(prefix).count())
|
||||
}
|
||||
|
||||
fn get_etag(&self, user_id: &UserId, version: &str) -> Result<String> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
fn get_etag(&self, user_id: &UserId, version: &str) -> Result<String> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
|
||||
Ok(utils::u64_from_bytes(
|
||||
&self
|
||||
.backupid_etag
|
||||
.get(&key)?
|
||||
.ok_or_else(|| Error::bad_database("Backup has no etag."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("etag in backupid_etag invalid."))?
|
||||
.to_string())
|
||||
}
|
||||
Ok(utils::u64_from_bytes(
|
||||
&self.backupid_etag.get(&key)?.ok_or_else(|| Error::bad_database("Backup has no etag."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("etag in backupid_etag invalid."))?
|
||||
.to_string())
|
||||
}
|
||||
|
||||
fn get_all(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
) -> Result<BTreeMap<OwnedRoomId, RoomKeyBackup>> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.extend_from_slice(version.as_bytes());
|
||||
prefix.push(0xff);
|
||||
fn get_all(&self, user_id: &UserId, version: &str) -> Result<BTreeMap<OwnedRoomId, RoomKeyBackup>> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xFF);
|
||||
prefix.extend_from_slice(version.as_bytes());
|
||||
prefix.push(0xFF);
|
||||
|
||||
let mut rooms = BTreeMap::<OwnedRoomId, RoomKeyBackup>::new();
|
||||
let mut rooms = BTreeMap::<OwnedRoomId, RoomKeyBackup>::new();
|
||||
|
||||
for result in self
|
||||
.backupkeyid_backup
|
||||
.scan_prefix(prefix)
|
||||
.map(|(key, value)| {
|
||||
let mut parts = key.rsplit(|&b| b == 0xff);
|
||||
for result in self.backupkeyid_backup.scan_prefix(prefix).map(|(key, value)| {
|
||||
let mut parts = key.rsplit(|&b| b == 0xFF);
|
||||
|
||||
let session_id =
|
||||
utils::string_from_bytes(parts.next().ok_or_else(|| {
|
||||
Error::bad_database("backupkeyid_backup key is invalid.")
|
||||
})?)
|
||||
.map_err(|_| {
|
||||
Error::bad_database("backupkeyid_backup session_id is invalid.")
|
||||
})?;
|
||||
let session_id = utils::string_from_bytes(
|
||||
parts.next().ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("backupkeyid_backup session_id is invalid."))?;
|
||||
|
||||
let room_id = RoomId::parse(
|
||||
utils::string_from_bytes(parts.next().ok_or_else(|| {
|
||||
Error::bad_database("backupkeyid_backup key is invalid.")
|
||||
})?)
|
||||
.map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid."))?,
|
||||
)
|
||||
.map_err(|_| {
|
||||
Error::bad_database("backupkeyid_backup room_id is invalid room id.")
|
||||
})?;
|
||||
let room_id = RoomId::parse(
|
||||
utils::string_from_bytes(
|
||||
parts.next().ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid room id."))?;
|
||||
|
||||
let key_data = serde_json::from_slice(&value).map_err(|_| {
|
||||
Error::bad_database("KeyBackupData in backupkeyid_backup is invalid.")
|
||||
})?;
|
||||
let key_data = serde_json::from_slice(&value)
|
||||
.map_err(|_| Error::bad_database("KeyBackupData in backupkeyid_backup is invalid."))?;
|
||||
|
||||
Ok::<_, Error>((room_id, session_id, key_data))
|
||||
})
|
||||
{
|
||||
let (room_id, session_id, key_data) = result?;
|
||||
rooms
|
||||
.entry(room_id)
|
||||
.or_insert_with(|| RoomKeyBackup {
|
||||
sessions: BTreeMap::new(),
|
||||
})
|
||||
.sessions
|
||||
.insert(session_id, key_data);
|
||||
}
|
||||
Ok::<_, Error>((room_id, session_id, key_data))
|
||||
}) {
|
||||
let (room_id, session_id, key_data) = result?;
|
||||
rooms
|
||||
.entry(room_id)
|
||||
.or_insert_with(|| RoomKeyBackup {
|
||||
sessions: BTreeMap::new(),
|
||||
})
|
||||
.sessions
|
||||
.insert(session_id, key_data);
|
||||
}
|
||||
|
||||
Ok(rooms)
|
||||
}
|
||||
Ok(rooms)
|
||||
}
|
||||
|
||||
fn get_room(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
room_id: &RoomId,
|
||||
) -> Result<BTreeMap<String, Raw<KeyBackupData>>> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.extend_from_slice(version.as_bytes());
|
||||
prefix.push(0xff);
|
||||
prefix.extend_from_slice(room_id.as_bytes());
|
||||
prefix.push(0xff);
|
||||
fn get_room(
|
||||
&self, user_id: &UserId, version: &str, room_id: &RoomId,
|
||||
) -> Result<BTreeMap<String, Raw<KeyBackupData>>> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xFF);
|
||||
prefix.extend_from_slice(version.as_bytes());
|
||||
prefix.push(0xFF);
|
||||
prefix.extend_from_slice(room_id.as_bytes());
|
||||
prefix.push(0xFF);
|
||||
|
||||
Ok(self
|
||||
.backupkeyid_backup
|
||||
.scan_prefix(prefix)
|
||||
.map(|(key, value)| {
|
||||
let mut parts = key.rsplit(|&b| b == 0xff);
|
||||
Ok(self
|
||||
.backupkeyid_backup
|
||||
.scan_prefix(prefix)
|
||||
.map(|(key, value)| {
|
||||
let mut parts = key.rsplit(|&b| b == 0xFF);
|
||||
|
||||
let session_id =
|
||||
utils::string_from_bytes(parts.next().ok_or_else(|| {
|
||||
Error::bad_database("backupkeyid_backup key is invalid.")
|
||||
})?)
|
||||
.map_err(|_| {
|
||||
Error::bad_database("backupkeyid_backup session_id is invalid.")
|
||||
})?;
|
||||
let session_id = utils::string_from_bytes(
|
||||
parts.next().ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("backupkeyid_backup session_id is invalid."))?;
|
||||
|
||||
let key_data = serde_json::from_slice(&value).map_err(|_| {
|
||||
Error::bad_database("KeyBackupData in backupkeyid_backup is invalid.")
|
||||
})?;
|
||||
let key_data = serde_json::from_slice(&value)
|
||||
.map_err(|_| Error::bad_database("KeyBackupData in backupkeyid_backup is invalid."))?;
|
||||
|
||||
Ok::<_, Error>((session_id, key_data))
|
||||
})
|
||||
.filter_map(std::result::Result::ok)
|
||||
.collect())
|
||||
}
|
||||
Ok::<_, Error>((session_id, key_data))
|
||||
})
|
||||
.filter_map(std::result::Result::ok)
|
||||
.collect())
|
||||
}
|
||||
|
||||
fn get_session(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
room_id: &RoomId,
|
||||
session_id: &str,
|
||||
) -> Result<Option<Raw<KeyBackupData>>> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(room_id.as_bytes());
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(session_id.as_bytes());
|
||||
fn get_session(
|
||||
&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str,
|
||||
) -> Result<Option<Raw<KeyBackupData>>> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(room_id.as_bytes());
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(session_id.as_bytes());
|
||||
|
||||
self.backupkeyid_backup
|
||||
.get(&key)?
|
||||
.map(|value| {
|
||||
serde_json::from_slice(&value).map_err(|_| {
|
||||
Error::bad_database("KeyBackupData in backupkeyid_backup is invalid.")
|
||||
})
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
self.backupkeyid_backup
|
||||
.get(&key)?
|
||||
.map(|value| {
|
||||
serde_json::from_slice(&value)
|
||||
.map_err(|_| Error::bad_database("KeyBackupData in backupkeyid_backup is invalid."))
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
|
||||
fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
key.push(0xff);
|
||||
fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
key.push(0xFF);
|
||||
|
||||
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
|
||||
self.backupkeyid_backup.remove(&outdated_key)?;
|
||||
}
|
||||
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
|
||||
self.backupkeyid_backup.remove(&outdated_key)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(room_id.as_bytes());
|
||||
key.push(0xff);
|
||||
fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(room_id.as_bytes());
|
||||
key.push(0xFF);
|
||||
|
||||
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
|
||||
self.backupkeyid_backup.remove(&outdated_key)?;
|
||||
}
|
||||
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
|
||||
self.backupkeyid_backup.remove(&outdated_key)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn delete_room_key(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
room_id: &RoomId,
|
||||
session_id: &str,
|
||||
) -> Result<()> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(room_id.as_bytes());
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(session_id.as_bytes());
|
||||
fn delete_room_key(&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str) -> Result<()> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(room_id.as_bytes());
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(session_id.as_bytes());
|
||||
|
||||
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
|
||||
self.backupkeyid_backup.remove(&outdated_key)?;
|
||||
}
|
||||
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
|
||||
self.backupkeyid_backup.remove(&outdated_key)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
+145
-208
@@ -2,245 +2,182 @@ use ruma::api::client::error::ErrorKind;
|
||||
use tracing::debug;
|
||||
|
||||
use crate::{
|
||||
database::KeyValueDatabase,
|
||||
service::{self, media::UrlPreviewData},
|
||||
utils, Error, Result,
|
||||
database::KeyValueDatabase,
|
||||
service::{self, media::UrlPreviewData},
|
||||
utils, Error, Result,
|
||||
};
|
||||
|
||||
impl service::media::Data for KeyValueDatabase {
|
||||
fn create_file_metadata(
|
||||
&self,
|
||||
mxc: String,
|
||||
width: u32,
|
||||
height: u32,
|
||||
content_disposition: Option<&str>,
|
||||
content_type: Option<&str>,
|
||||
) -> Result<Vec<u8>> {
|
||||
let mut key = mxc.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(&width.to_be_bytes());
|
||||
key.extend_from_slice(&height.to_be_bytes());
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(
|
||||
content_disposition
|
||||
.as_ref()
|
||||
.map(|f| f.as_bytes())
|
||||
.unwrap_or_default(),
|
||||
);
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(
|
||||
content_type
|
||||
.as_ref()
|
||||
.map(|c| c.as_bytes())
|
||||
.unwrap_or_default(),
|
||||
);
|
||||
fn create_file_metadata(
|
||||
&self, mxc: String, width: u32, height: u32, content_disposition: Option<&str>, content_type: Option<&str>,
|
||||
) -> Result<Vec<u8>> {
|
||||
let mut key = mxc.as_bytes().to_vec();
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(&width.to_be_bytes());
|
||||
key.extend_from_slice(&height.to_be_bytes());
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(content_disposition.as_ref().map(|f| f.as_bytes()).unwrap_or_default());
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(content_type.as_ref().map(|c| c.as_bytes()).unwrap_or_default());
|
||||
|
||||
self.mediaid_file.insert(&key, &[])?;
|
||||
self.mediaid_file.insert(&key, &[])?;
|
||||
|
||||
Ok(key)
|
||||
}
|
||||
Ok(key)
|
||||
}
|
||||
|
||||
fn delete_file_mxc(&self, mxc: String) -> Result<()> {
|
||||
debug!("MXC URI: {:?}", mxc);
|
||||
fn delete_file_mxc(&self, mxc: String) -> Result<()> {
|
||||
debug!("MXC URI: {:?}", mxc);
|
||||
|
||||
let mut prefix = mxc.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
let mut prefix = mxc.as_bytes().to_vec();
|
||||
prefix.push(0xFF);
|
||||
|
||||
debug!("MXC db prefix: {:?}", prefix);
|
||||
debug!("MXC db prefix: {:?}", prefix);
|
||||
|
||||
for (key, _) in self.mediaid_file.scan_prefix(prefix) {
|
||||
debug!("Deleting key: {:?}", key);
|
||||
self.mediaid_file.remove(&key)?;
|
||||
}
|
||||
for (key, _) in self.mediaid_file.scan_prefix(prefix) {
|
||||
debug!("Deleting key: {:?}", key);
|
||||
self.mediaid_file.remove(&key)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Searches for all files with the given MXC
|
||||
fn search_mxc_metadata_prefix(&self, mxc: String) -> Result<Vec<Vec<u8>>> {
|
||||
debug!("MXC URI: {:?}", mxc);
|
||||
/// Searches for all files with the given MXC
|
||||
fn search_mxc_metadata_prefix(&self, mxc: String) -> Result<Vec<Vec<u8>>> {
|
||||
debug!("MXC URI: {:?}", mxc);
|
||||
|
||||
let mut prefix = mxc.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
let mut prefix = mxc.as_bytes().to_vec();
|
||||
prefix.push(0xFF);
|
||||
|
||||
let mut keys: Vec<Vec<u8>> = vec![];
|
||||
let mut keys: Vec<Vec<u8>> = vec![];
|
||||
|
||||
for (key, _) in self.mediaid_file.scan_prefix(prefix) {
|
||||
keys.push(key);
|
||||
}
|
||||
for (key, _) in self.mediaid_file.scan_prefix(prefix) {
|
||||
keys.push(key);
|
||||
}
|
||||
|
||||
if keys.is_empty() {
|
||||
return Err(Error::bad_database(
|
||||
"Failed to find any keys in database with the provided MXC.",
|
||||
));
|
||||
}
|
||||
if keys.is_empty() {
|
||||
return Err(Error::bad_database(
|
||||
"Failed to find any keys in database with the provided MXC.",
|
||||
));
|
||||
}
|
||||
|
||||
debug!("Got the following keys: {:?}", keys);
|
||||
debug!("Got the following keys: {:?}", keys);
|
||||
|
||||
Ok(keys)
|
||||
}
|
||||
Ok(keys)
|
||||
}
|
||||
|
||||
fn search_file_metadata(
|
||||
&self,
|
||||
mxc: String,
|
||||
width: u32,
|
||||
height: u32,
|
||||
) -> Result<(Option<String>, Option<String>, Vec<u8>)> {
|
||||
let mut prefix = mxc.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.extend_from_slice(&width.to_be_bytes());
|
||||
prefix.extend_from_slice(&height.to_be_bytes());
|
||||
prefix.push(0xff);
|
||||
fn search_file_metadata(
|
||||
&self, mxc: String, width: u32, height: u32,
|
||||
) -> Result<(Option<String>, Option<String>, Vec<u8>)> {
|
||||
let mut prefix = mxc.as_bytes().to_vec();
|
||||
prefix.push(0xFF);
|
||||
prefix.extend_from_slice(&width.to_be_bytes());
|
||||
prefix.extend_from_slice(&height.to_be_bytes());
|
||||
prefix.push(0xFF);
|
||||
|
||||
let (key, _) = self
|
||||
.mediaid_file
|
||||
.scan_prefix(prefix)
|
||||
.next()
|
||||
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Media not found"))?;
|
||||
let (key, _) = self
|
||||
.mediaid_file
|
||||
.scan_prefix(prefix)
|
||||
.next()
|
||||
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Media not found"))?;
|
||||
|
||||
let mut parts = key.rsplit(|&b| b == 0xff);
|
||||
let mut parts = key.rsplit(|&b| b == 0xFF);
|
||||
|
||||
let content_type = parts
|
||||
.next()
|
||||
.map(|bytes| {
|
||||
utils::string_from_bytes(bytes).map_err(|_| {
|
||||
Error::bad_database("Content type in mediaid_file is invalid unicode.")
|
||||
})
|
||||
})
|
||||
.transpose()?;
|
||||
let content_type = parts
|
||||
.next()
|
||||
.map(|bytes| {
|
||||
utils::string_from_bytes(bytes)
|
||||
.map_err(|_| Error::bad_database("Content type in mediaid_file is invalid unicode."))
|
||||
})
|
||||
.transpose()?;
|
||||
|
||||
let content_disposition_bytes = parts
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("Media ID in db is invalid."))?;
|
||||
let content_disposition_bytes =
|
||||
parts.next().ok_or_else(|| Error::bad_database("Media ID in db is invalid."))?;
|
||||
|
||||
let content_disposition = if content_disposition_bytes.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(
|
||||
utils::string_from_bytes(content_disposition_bytes).map_err(|_| {
|
||||
Error::bad_database("Content Disposition in mediaid_file is invalid unicode.")
|
||||
})?,
|
||||
)
|
||||
};
|
||||
Ok((content_disposition, content_type, key))
|
||||
}
|
||||
let content_disposition = if content_disposition_bytes.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(
|
||||
utils::string_from_bytes(content_disposition_bytes)
|
||||
.map_err(|_| Error::bad_database("Content Disposition in mediaid_file is invalid unicode."))?,
|
||||
)
|
||||
};
|
||||
Ok((content_disposition, content_type, key))
|
||||
}
|
||||
|
||||
/// Gets all the media keys in our database (this includes all the metadata associated with it such as width, height, content-type, etc)
|
||||
fn get_all_media_keys(&self) -> Result<Vec<Vec<u8>>> {
|
||||
let mut keys: Vec<Vec<u8>> = vec![];
|
||||
/// Gets all the media keys in our database (this includes all the metadata
|
||||
/// associated with it such as width, height, content-type, etc)
|
||||
fn get_all_media_keys(&self) -> Result<Vec<Vec<u8>>> {
|
||||
let mut keys: Vec<Vec<u8>> = vec![];
|
||||
|
||||
for (key, _) in self.mediaid_file.iter() {
|
||||
keys.push(key);
|
||||
}
|
||||
for (key, _) in self.mediaid_file.iter() {
|
||||
keys.push(key);
|
||||
}
|
||||
|
||||
Ok(keys)
|
||||
}
|
||||
Ok(keys)
|
||||
}
|
||||
|
||||
fn remove_url_preview(&self, url: &str) -> Result<()> {
|
||||
self.url_previews.remove(url.as_bytes())
|
||||
}
|
||||
fn remove_url_preview(&self, url: &str) -> Result<()> { self.url_previews.remove(url.as_bytes()) }
|
||||
|
||||
fn set_url_preview(
|
||||
&self,
|
||||
url: &str,
|
||||
data: &UrlPreviewData,
|
||||
timestamp: std::time::Duration,
|
||||
) -> Result<()> {
|
||||
let mut value = Vec::<u8>::new();
|
||||
value.extend_from_slice(×tamp.as_secs().to_be_bytes());
|
||||
value.push(0xff);
|
||||
value.extend_from_slice(
|
||||
data.title
|
||||
.as_ref()
|
||||
.map(std::string::String::as_bytes)
|
||||
.unwrap_or_default(),
|
||||
);
|
||||
value.push(0xff);
|
||||
value.extend_from_slice(
|
||||
data.description
|
||||
.as_ref()
|
||||
.map(std::string::String::as_bytes)
|
||||
.unwrap_or_default(),
|
||||
);
|
||||
value.push(0xff);
|
||||
value.extend_from_slice(
|
||||
data.image
|
||||
.as_ref()
|
||||
.map(std::string::String::as_bytes)
|
||||
.unwrap_or_default(),
|
||||
);
|
||||
value.push(0xff);
|
||||
value.extend_from_slice(&data.image_size.unwrap_or(0).to_be_bytes());
|
||||
value.push(0xff);
|
||||
value.extend_from_slice(&data.image_width.unwrap_or(0).to_be_bytes());
|
||||
value.push(0xff);
|
||||
value.extend_from_slice(&data.image_height.unwrap_or(0).to_be_bytes());
|
||||
fn set_url_preview(&self, url: &str, data: &UrlPreviewData, timestamp: std::time::Duration) -> Result<()> {
|
||||
let mut value = Vec::<u8>::new();
|
||||
value.extend_from_slice(×tamp.as_secs().to_be_bytes());
|
||||
value.push(0xFF);
|
||||
value.extend_from_slice(data.title.as_ref().map(std::string::String::as_bytes).unwrap_or_default());
|
||||
value.push(0xFF);
|
||||
value.extend_from_slice(data.description.as_ref().map(std::string::String::as_bytes).unwrap_or_default());
|
||||
value.push(0xFF);
|
||||
value.extend_from_slice(data.image.as_ref().map(std::string::String::as_bytes).unwrap_or_default());
|
||||
value.push(0xFF);
|
||||
value.extend_from_slice(&data.image_size.unwrap_or(0).to_be_bytes());
|
||||
value.push(0xFF);
|
||||
value.extend_from_slice(&data.image_width.unwrap_or(0).to_be_bytes());
|
||||
value.push(0xFF);
|
||||
value.extend_from_slice(&data.image_height.unwrap_or(0).to_be_bytes());
|
||||
|
||||
self.url_previews.insert(url.as_bytes(), &value)
|
||||
}
|
||||
self.url_previews.insert(url.as_bytes(), &value)
|
||||
}
|
||||
|
||||
fn get_url_preview(&self, url: &str) -> Option<UrlPreviewData> {
|
||||
let values = self.url_previews.get(url.as_bytes()).ok()??;
|
||||
fn get_url_preview(&self, url: &str) -> Option<UrlPreviewData> {
|
||||
let values = self.url_previews.get(url.as_bytes()).ok()??;
|
||||
|
||||
let mut values = values.split(|&b| b == 0xff);
|
||||
let mut values = values.split(|&b| b == 0xFF);
|
||||
|
||||
let _ts = match values
|
||||
.next()
|
||||
.map(|b| u64::from_be_bytes(b.try_into().expect("valid BE array")))
|
||||
{
|
||||
Some(0) => None,
|
||||
x => x,
|
||||
};
|
||||
let title = match values
|
||||
.next()
|
||||
.and_then(|b| String::from_utf8(b.to_vec()).ok())
|
||||
{
|
||||
Some(s) if s.is_empty() => None,
|
||||
x => x,
|
||||
};
|
||||
let description = match values
|
||||
.next()
|
||||
.and_then(|b| String::from_utf8(b.to_vec()).ok())
|
||||
{
|
||||
Some(s) if s.is_empty() => None,
|
||||
x => x,
|
||||
};
|
||||
let image = match values
|
||||
.next()
|
||||
.and_then(|b| String::from_utf8(b.to_vec()).ok())
|
||||
{
|
||||
Some(s) if s.is_empty() => None,
|
||||
x => x,
|
||||
};
|
||||
let image_size = match values
|
||||
.next()
|
||||
.map(|b| usize::from_be_bytes(b.try_into().expect("valid BE array")))
|
||||
{
|
||||
Some(0) => None,
|
||||
x => x,
|
||||
};
|
||||
let image_width = match values
|
||||
.next()
|
||||
.map(|b| u32::from_be_bytes(b.try_into().expect("valid BE array")))
|
||||
{
|
||||
Some(0) => None,
|
||||
x => x,
|
||||
};
|
||||
let image_height = match values
|
||||
.next()
|
||||
.map(|b| u32::from_be_bytes(b.try_into().expect("valid BE array")))
|
||||
{
|
||||
Some(0) => None,
|
||||
x => x,
|
||||
};
|
||||
let _ts = match values.next().map(|b| u64::from_be_bytes(b.try_into().expect("valid BE array"))) {
|
||||
Some(0) => None,
|
||||
x => x,
|
||||
};
|
||||
let title = match values.next().and_then(|b| String::from_utf8(b.to_vec()).ok()) {
|
||||
Some(s) if s.is_empty() => None,
|
||||
x => x,
|
||||
};
|
||||
let description = match values.next().and_then(|b| String::from_utf8(b.to_vec()).ok()) {
|
||||
Some(s) if s.is_empty() => None,
|
||||
x => x,
|
||||
};
|
||||
let image = match values.next().and_then(|b| String::from_utf8(b.to_vec()).ok()) {
|
||||
Some(s) if s.is_empty() => None,
|
||||
x => x,
|
||||
};
|
||||
let image_size = match values.next().map(|b| usize::from_be_bytes(b.try_into().expect("valid BE array"))) {
|
||||
Some(0) => None,
|
||||
x => x,
|
||||
};
|
||||
let image_width = match values.next().map(|b| u32::from_be_bytes(b.try_into().expect("valid BE array"))) {
|
||||
Some(0) => None,
|
||||
x => x,
|
||||
};
|
||||
let image_height = match values.next().map(|b| u32::from_be_bytes(b.try_into().expect("valid BE array"))) {
|
||||
Some(0) => None,
|
||||
x => x,
|
||||
};
|
||||
|
||||
Some(UrlPreviewData {
|
||||
title,
|
||||
description,
|
||||
image,
|
||||
image_size,
|
||||
image_width,
|
||||
image_height,
|
||||
})
|
||||
}
|
||||
Some(UrlPreviewData {
|
||||
title,
|
||||
description,
|
||||
image,
|
||||
image_size,
|
||||
image_width,
|
||||
image_height,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,79 +1,63 @@
|
||||
use ruma::{
|
||||
api::client::push::{set_pusher, Pusher},
|
||||
UserId,
|
||||
api::client::push::{set_pusher, Pusher},
|
||||
UserId,
|
||||
};
|
||||
|
||||
use crate::{database::KeyValueDatabase, service, utils, Error, Result};
|
||||
|
||||
impl service::pusher::Data for KeyValueDatabase {
|
||||
fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::PusherAction) -> Result<()> {
|
||||
match &pusher {
|
||||
set_pusher::v3::PusherAction::Post(data) => {
|
||||
let mut key = sender.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(data.pusher.ids.pushkey.as_bytes());
|
||||
self.senderkey_pusher.insert(
|
||||
&key,
|
||||
&serde_json::to_vec(&pusher).expect("Pusher is valid JSON value"),
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
set_pusher::v3::PusherAction::Delete(ids) => {
|
||||
let mut key = sender.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(ids.pushkey.as_bytes());
|
||||
self.senderkey_pusher
|
||||
.remove(&key)
|
||||
.map(|_| ())
|
||||
.map_err(Into::into)
|
||||
}
|
||||
}
|
||||
}
|
||||
fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::PusherAction) -> Result<()> {
|
||||
match &pusher {
|
||||
set_pusher::v3::PusherAction::Post(data) => {
|
||||
let mut key = sender.as_bytes().to_vec();
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(data.pusher.ids.pushkey.as_bytes());
|
||||
self.senderkey_pusher
|
||||
.insert(&key, &serde_json::to_vec(&pusher).expect("Pusher is valid JSON value"))?;
|
||||
Ok(())
|
||||
},
|
||||
set_pusher::v3::PusherAction::Delete(ids) => {
|
||||
let mut key = sender.as_bytes().to_vec();
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(ids.pushkey.as_bytes());
|
||||
self.senderkey_pusher.remove(&key).map(|_| ()).map_err(Into::into)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result<Option<Pusher>> {
|
||||
let mut senderkey = sender.as_bytes().to_vec();
|
||||
senderkey.push(0xff);
|
||||
senderkey.extend_from_slice(pushkey.as_bytes());
|
||||
fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result<Option<Pusher>> {
|
||||
let mut senderkey = sender.as_bytes().to_vec();
|
||||
senderkey.push(0xFF);
|
||||
senderkey.extend_from_slice(pushkey.as_bytes());
|
||||
|
||||
self.senderkey_pusher
|
||||
.get(&senderkey)?
|
||||
.map(|push| {
|
||||
serde_json::from_slice(&push)
|
||||
.map_err(|_| Error::bad_database("Invalid Pusher in db."))
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
self.senderkey_pusher
|
||||
.get(&senderkey)?
|
||||
.map(|push| serde_json::from_slice(&push).map_err(|_| Error::bad_database("Invalid Pusher in db.")))
|
||||
.transpose()
|
||||
}
|
||||
|
||||
fn get_pushers(&self, sender: &UserId) -> Result<Vec<Pusher>> {
|
||||
let mut prefix = sender.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
fn get_pushers(&self, sender: &UserId) -> Result<Vec<Pusher>> {
|
||||
let mut prefix = sender.as_bytes().to_vec();
|
||||
prefix.push(0xFF);
|
||||
|
||||
self.senderkey_pusher
|
||||
.scan_prefix(prefix)
|
||||
.map(|(_, push)| {
|
||||
serde_json::from_slice(&push)
|
||||
.map_err(|_| Error::bad_database("Invalid Pusher in db."))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
self.senderkey_pusher
|
||||
.scan_prefix(prefix)
|
||||
.map(|(_, push)| serde_json::from_slice(&push).map_err(|_| Error::bad_database("Invalid Pusher in db.")))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn get_pushkeys<'a>(
|
||||
&'a self,
|
||||
sender: &UserId,
|
||||
) -> Box<dyn Iterator<Item = Result<String>> + 'a> {
|
||||
let mut prefix = sender.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
fn get_pushkeys<'a>(&'a self, sender: &UserId) -> Box<dyn Iterator<Item = Result<String>> + 'a> {
|
||||
let mut prefix = sender.as_bytes().to_vec();
|
||||
prefix.push(0xFF);
|
||||
|
||||
Box::new(self.senderkey_pusher.scan_prefix(prefix).map(|(k, _)| {
|
||||
let mut parts = k.splitn(2, |&b| b == 0xff);
|
||||
let _senderkey = parts.next();
|
||||
let push_key = parts
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("Invalid senderkey_pusher in db"))?;
|
||||
let push_key_string = utils::string_from_bytes(push_key)
|
||||
.map_err(|_| Error::bad_database("Invalid pusher bytes in senderkey_pusher"))?;
|
||||
Box::new(self.senderkey_pusher.scan_prefix(prefix).map(|(k, _)| {
|
||||
let mut parts = k.splitn(2, |&b| b == 0xFF);
|
||||
let _senderkey = parts.next();
|
||||
let push_key = parts.next().ok_or_else(|| Error::bad_database("Invalid senderkey_pusher in db"))?;
|
||||
let push_key_string = utils::string_from_bytes(push_key)
|
||||
.map_err(|_| Error::bad_database("Invalid pusher bytes in senderkey_pusher"))?;
|
||||
|
||||
Ok(push_key_string)
|
||||
}))
|
||||
}
|
||||
Ok(push_key_string)
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,82 +3,68 @@ use ruma::{api::client::error::ErrorKind, OwnedRoomAliasId, OwnedRoomId, RoomAli
|
||||
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
|
||||
|
||||
impl service::rooms::alias::Data for KeyValueDatabase {
|
||||
fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId) -> Result<()> {
|
||||
self.alias_roomid
|
||||
.insert(alias.alias().as_bytes(), room_id.as_bytes())?;
|
||||
let mut aliasid = room_id.as_bytes().to_vec();
|
||||
aliasid.push(0xff);
|
||||
aliasid.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
|
||||
self.aliasid_alias.insert(&aliasid, alias.as_bytes())?;
|
||||
Ok(())
|
||||
}
|
||||
fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId) -> Result<()> {
|
||||
self.alias_roomid.insert(alias.alias().as_bytes(), room_id.as_bytes())?;
|
||||
let mut aliasid = room_id.as_bytes().to_vec();
|
||||
aliasid.push(0xFF);
|
||||
aliasid.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
|
||||
self.aliasid_alias.insert(&aliasid, alias.as_bytes())?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn remove_alias(&self, alias: &RoomAliasId) -> Result<()> {
|
||||
if let Some(room_id) = self.alias_roomid.get(alias.alias().as_bytes())? {
|
||||
let mut prefix = room_id;
|
||||
prefix.push(0xff);
|
||||
fn remove_alias(&self, alias: &RoomAliasId) -> Result<()> {
|
||||
if let Some(room_id) = self.alias_roomid.get(alias.alias().as_bytes())? {
|
||||
let mut prefix = room_id;
|
||||
prefix.push(0xFF);
|
||||
|
||||
for (key, _) in self.aliasid_alias.scan_prefix(prefix) {
|
||||
self.aliasid_alias.remove(&key)?;
|
||||
}
|
||||
self.alias_roomid.remove(alias.alias().as_bytes())?;
|
||||
} else {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"Alias does not exist.",
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
for (key, _) in self.aliasid_alias.scan_prefix(prefix) {
|
||||
self.aliasid_alias.remove(&key)?;
|
||||
}
|
||||
self.alias_roomid.remove(alias.alias().as_bytes())?;
|
||||
} else {
|
||||
return Err(Error::BadRequest(ErrorKind::NotFound, "Alias does not exist."));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result<Option<OwnedRoomId>> {
|
||||
self.alias_roomid
|
||||
.get(alias.alias().as_bytes())?
|
||||
.map(|bytes| {
|
||||
RoomId::parse(utils::string_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("Room ID in alias_roomid is invalid unicode.")
|
||||
})?)
|
||||
.map_err(|_| Error::bad_database("Room ID in alias_roomid is invalid."))
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result<Option<OwnedRoomId>> {
|
||||
self.alias_roomid
|
||||
.get(alias.alias().as_bytes())?
|
||||
.map(|bytes| {
|
||||
RoomId::parse(
|
||||
utils::string_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Room ID in alias_roomid is invalid unicode."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Room ID in alias_roomid is invalid."))
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
|
||||
fn local_aliases_for_room<'a>(
|
||||
&'a self,
|
||||
room_id: &RoomId,
|
||||
) -> Box<dyn Iterator<Item = Result<OwnedRoomAliasId>> + 'a> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
fn local_aliases_for_room<'a>(
|
||||
&'a self, room_id: &RoomId,
|
||||
) -> Box<dyn Iterator<Item = Result<OwnedRoomAliasId>> + 'a> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xFF);
|
||||
|
||||
Box::new(self.aliasid_alias.scan_prefix(prefix).map(|(_, bytes)| {
|
||||
utils::string_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Invalid alias bytes in aliasid_alias."))?
|
||||
.try_into()
|
||||
.map_err(|_| Error::bad_database("Invalid alias in aliasid_alias."))
|
||||
}))
|
||||
}
|
||||
Box::new(self.aliasid_alias.scan_prefix(prefix).map(|(_, bytes)| {
|
||||
utils::string_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Invalid alias bytes in aliasid_alias."))?
|
||||
.try_into()
|
||||
.map_err(|_| Error::bad_database("Invalid alias in aliasid_alias."))
|
||||
}))
|
||||
}
|
||||
|
||||
fn all_local_aliases<'a>(
|
||||
&'a self,
|
||||
) -> Box<dyn Iterator<Item = Result<(OwnedRoomId, String)>> + 'a> {
|
||||
Box::new(
|
||||
self.alias_roomid
|
||||
.iter()
|
||||
.map(|(room_alias_bytes, room_id_bytes)| {
|
||||
let room_alias_localpart = utils::string_from_bytes(&room_alias_bytes)
|
||||
.map_err(|_| {
|
||||
Error::bad_database("Invalid alias bytes in aliasid_alias.")
|
||||
})?;
|
||||
fn all_local_aliases<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(OwnedRoomId, String)>> + 'a> {
|
||||
Box::new(self.alias_roomid.iter().map(|(room_alias_bytes, room_id_bytes)| {
|
||||
let room_alias_localpart = utils::string_from_bytes(&room_alias_bytes)
|
||||
.map_err(|_| Error::bad_database("Invalid alias bytes in aliasid_alias."))?;
|
||||
|
||||
let room_id = utils::string_from_bytes(&room_id_bytes)
|
||||
.map_err(|_| {
|
||||
Error::bad_database("Invalid room_id bytes in aliasid_alias.")
|
||||
})?
|
||||
.try_into()
|
||||
.map_err(|_| Error::bad_database("Invalid room_id in aliasid_alias."))?;
|
||||
let room_id = utils::string_from_bytes(&room_id_bytes)
|
||||
.map_err(|_| Error::bad_database("Invalid room_id bytes in aliasid_alias."))?
|
||||
.try_into()
|
||||
.map_err(|_| Error::bad_database("Invalid room_id in aliasid_alias."))?;
|
||||
|
||||
Ok((room_id, room_alias_localpart))
|
||||
}),
|
||||
)
|
||||
}
|
||||
Ok((room_id, room_alias_localpart))
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,59 +3,47 @@ use std::{collections::HashSet, mem::size_of, sync::Arc};
|
||||
use crate::{database::KeyValueDatabase, service, utils, Result};
|
||||
|
||||
impl service::rooms::auth_chain::Data for KeyValueDatabase {
|
||||
fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Option<Arc<HashSet<u64>>>> {
|
||||
// Check RAM cache
|
||||
if let Some(result) = self.auth_chain_cache.lock().unwrap().get_mut(key) {
|
||||
return Ok(Some(Arc::clone(result)));
|
||||
}
|
||||
fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Option<Arc<HashSet<u64>>>> {
|
||||
// Check RAM cache
|
||||
if let Some(result) = self.auth_chain_cache.lock().unwrap().get_mut(key) {
|
||||
return Ok(Some(Arc::clone(result)));
|
||||
}
|
||||
|
||||
// We only save auth chains for single events in the db
|
||||
if key.len() == 1 {
|
||||
// Check DB cache
|
||||
let chain = self
|
||||
.shorteventid_authchain
|
||||
.get(&key[0].to_be_bytes())?
|
||||
.map(|chain| {
|
||||
chain
|
||||
.chunks_exact(size_of::<u64>())
|
||||
.map(|chunk| utils::u64_from_bytes(chunk).expect("byte length is correct"))
|
||||
.collect()
|
||||
});
|
||||
// We only save auth chains for single events in the db
|
||||
if key.len() == 1 {
|
||||
// Check DB cache
|
||||
let chain = self.shorteventid_authchain.get(&key[0].to_be_bytes())?.map(|chain| {
|
||||
chain
|
||||
.chunks_exact(size_of::<u64>())
|
||||
.map(|chunk| utils::u64_from_bytes(chunk).expect("byte length is correct"))
|
||||
.collect()
|
||||
});
|
||||
|
||||
if let Some(chain) = chain {
|
||||
let chain = Arc::new(chain);
|
||||
if let Some(chain) = chain {
|
||||
let chain = Arc::new(chain);
|
||||
|
||||
// Cache in RAM
|
||||
self.auth_chain_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.insert(vec![key[0]], Arc::clone(&chain));
|
||||
// Cache in RAM
|
||||
self.auth_chain_cache.lock().unwrap().insert(vec![key[0]], Arc::clone(&chain));
|
||||
|
||||
return Ok(Some(chain));
|
||||
}
|
||||
}
|
||||
return Ok(Some(chain));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn cache_auth_chain(&self, key: Vec<u64>, auth_chain: Arc<HashSet<u64>>) -> Result<()> {
|
||||
// Only persist single events in db
|
||||
if key.len() == 1 {
|
||||
self.shorteventid_authchain.insert(
|
||||
&key[0].to_be_bytes(),
|
||||
&auth_chain
|
||||
.iter()
|
||||
.flat_map(|s| s.to_be_bytes().to_vec())
|
||||
.collect::<Vec<u8>>(),
|
||||
)?;
|
||||
}
|
||||
fn cache_auth_chain(&self, key: Vec<u64>, auth_chain: Arc<HashSet<u64>>) -> Result<()> {
|
||||
// Only persist single events in db
|
||||
if key.len() == 1 {
|
||||
self.shorteventid_authchain.insert(
|
||||
&key[0].to_be_bytes(),
|
||||
&auth_chain.iter().flat_map(|s| s.to_be_bytes().to_vec()).collect::<Vec<u8>>(),
|
||||
)?;
|
||||
}
|
||||
|
||||
// Cache in RAM
|
||||
self.auth_chain_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.insert(key, auth_chain);
|
||||
// Cache in RAM
|
||||
self.auth_chain_cache.lock().unwrap().insert(key, auth_chain);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,26 +3,21 @@ use ruma::{OwnedRoomId, RoomId};
|
||||
use crate::{database::KeyValueDatabase, service, utils, Error, Result};
|
||||
|
||||
impl service::rooms::directory::Data for KeyValueDatabase {
|
||||
fn set_public(&self, room_id: &RoomId) -> Result<()> {
|
||||
self.publicroomids.insert(room_id.as_bytes(), &[])
|
||||
}
|
||||
fn set_public(&self, room_id: &RoomId) -> Result<()> { self.publicroomids.insert(room_id.as_bytes(), &[]) }
|
||||
|
||||
fn set_not_public(&self, room_id: &RoomId) -> Result<()> {
|
||||
self.publicroomids.remove(room_id.as_bytes())
|
||||
}
|
||||
fn set_not_public(&self, room_id: &RoomId) -> Result<()> { self.publicroomids.remove(room_id.as_bytes()) }
|
||||
|
||||
fn is_public_room(&self, room_id: &RoomId) -> Result<bool> {
|
||||
Ok(self.publicroomids.get(room_id.as_bytes())?.is_some())
|
||||
}
|
||||
fn is_public_room(&self, room_id: &RoomId) -> Result<bool> {
|
||||
Ok(self.publicroomids.get(room_id.as_bytes())?.is_some())
|
||||
}
|
||||
|
||||
fn public_rooms<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
|
||||
Box::new(self.publicroomids.iter().map(|(bytes, _)| {
|
||||
RoomId::parse(
|
||||
utils::string_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("Room ID in publicroomids is invalid unicode.")
|
||||
})?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Room ID in publicroomids is invalid."))
|
||||
}))
|
||||
}
|
||||
fn public_rooms<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
|
||||
Box::new(self.publicroomids.iter().map(|(bytes, _)| {
|
||||
RoomId::parse(
|
||||
utils::string_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Room ID in publicroomids is invalid unicode."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Room ID in publicroomids is invalid."))
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,178 +1,155 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use ruma::{
|
||||
events::presence::PresenceEvent, presence::PresenceState, OwnedUserId, RoomId, UInt, UserId,
|
||||
};
|
||||
use ruma::{events::presence::PresenceEvent, presence::PresenceState, OwnedUserId, RoomId, UInt, UserId};
|
||||
use tracing::error;
|
||||
|
||||
use crate::{
|
||||
database::KeyValueDatabase,
|
||||
service::{self, rooms::edus::presence::Presence},
|
||||
services,
|
||||
utils::{self, user_id_from_bytes},
|
||||
Error, Result,
|
||||
database::KeyValueDatabase,
|
||||
service::{self, rooms::edus::presence::Presence},
|
||||
services,
|
||||
utils::{self, user_id_from_bytes},
|
||||
Error, Result,
|
||||
};
|
||||
|
||||
impl service::rooms::edus::presence::Data for KeyValueDatabase {
|
||||
fn get_presence(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<PresenceEvent>> {
|
||||
let key = presence_key(room_id, user_id);
|
||||
fn get_presence(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<PresenceEvent>> {
|
||||
let key = presence_key(room_id, user_id);
|
||||
|
||||
self.roomuserid_presence
|
||||
.get(&key)?
|
||||
.map(|presence_bytes| -> Result<PresenceEvent> {
|
||||
Presence::from_json_bytes(&presence_bytes)?.to_presence_event(user_id)
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
self.roomuserid_presence
|
||||
.get(&key)?
|
||||
.map(|presence_bytes| -> Result<PresenceEvent> {
|
||||
Presence::from_json_bytes(&presence_bytes)?.to_presence_event(user_id)
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
|
||||
fn ping_presence(&self, user_id: &UserId, new_state: PresenceState) -> Result<()> {
|
||||
let now = utils::millis_since_unix_epoch();
|
||||
let mut state_changed = false;
|
||||
fn ping_presence(&self, user_id: &UserId, new_state: PresenceState) -> Result<()> {
|
||||
let now = utils::millis_since_unix_epoch();
|
||||
let mut state_changed = false;
|
||||
|
||||
for room_id in services().rooms.state_cache.rooms_joined(user_id) {
|
||||
let key = presence_key(&room_id?, user_id);
|
||||
for room_id in services().rooms.state_cache.rooms_joined(user_id) {
|
||||
let key = presence_key(&room_id?, user_id);
|
||||
|
||||
let presence_bytes = self.roomuserid_presence.get(&key)?;
|
||||
let presence_bytes = self.roomuserid_presence.get(&key)?;
|
||||
|
||||
if let Some(presence_bytes) = presence_bytes {
|
||||
let presence = Presence::from_json_bytes(&presence_bytes)?;
|
||||
if presence.state != new_state {
|
||||
state_changed = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some(presence_bytes) = presence_bytes {
|
||||
let presence = Presence::from_json_bytes(&presence_bytes)?;
|
||||
if presence.state != new_state {
|
||||
state_changed = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let count = if state_changed {
|
||||
services().globals.next_count()?
|
||||
} else {
|
||||
services().globals.current_count()?
|
||||
};
|
||||
let count = if state_changed {
|
||||
services().globals.next_count()?
|
||||
} else {
|
||||
services().globals.current_count()?
|
||||
};
|
||||
|
||||
for room_id in services().rooms.state_cache.rooms_joined(user_id) {
|
||||
let key = presence_key(&room_id?, user_id);
|
||||
for room_id in services().rooms.state_cache.rooms_joined(user_id) {
|
||||
let key = presence_key(&room_id?, user_id);
|
||||
|
||||
let presence_bytes = self.roomuserid_presence.get(&key)?;
|
||||
let presence_bytes = self.roomuserid_presence.get(&key)?;
|
||||
|
||||
let new_presence = match presence_bytes {
|
||||
Some(presence_bytes) => {
|
||||
let mut presence = Presence::from_json_bytes(&presence_bytes)?;
|
||||
presence.state = new_state.clone();
|
||||
presence.currently_active = presence.state == PresenceState::Online;
|
||||
presence.last_active_ts = now;
|
||||
presence.last_count = count;
|
||||
let new_presence = match presence_bytes {
|
||||
Some(presence_bytes) => {
|
||||
let mut presence = Presence::from_json_bytes(&presence_bytes)?;
|
||||
presence.state = new_state.clone();
|
||||
presence.currently_active = presence.state == PresenceState::Online;
|
||||
presence.last_active_ts = now;
|
||||
presence.last_count = count;
|
||||
|
||||
presence
|
||||
}
|
||||
None => Presence::new(
|
||||
new_state.clone(),
|
||||
new_state == PresenceState::Online,
|
||||
now,
|
||||
count,
|
||||
None,
|
||||
),
|
||||
};
|
||||
presence
|
||||
},
|
||||
None => Presence::new(new_state.clone(), new_state == PresenceState::Online, now, count, None),
|
||||
};
|
||||
|
||||
self.roomuserid_presence
|
||||
.insert(&key, &new_presence.to_json_bytes()?)?;
|
||||
}
|
||||
self.roomuserid_presence.insert(&key, &new_presence.to_json_bytes()?)?;
|
||||
}
|
||||
|
||||
let timeout = match new_state {
|
||||
PresenceState::Online => services().globals.config.presence_idle_timeout_s,
|
||||
_ => services().globals.config.presence_offline_timeout_s,
|
||||
};
|
||||
let timeout = match new_state {
|
||||
PresenceState::Online => services().globals.config.presence_idle_timeout_s,
|
||||
_ => services().globals.config.presence_offline_timeout_s,
|
||||
};
|
||||
|
||||
self.presence_timer_sender
|
||||
.send((user_id.to_owned(), Duration::from_secs(timeout)))
|
||||
.map_err(|e| {
|
||||
error!("Failed to add presence timer: {}", e);
|
||||
Error::bad_database("Failed to add presence timer")
|
||||
})
|
||||
}
|
||||
self.presence_timer_sender.send((user_id.to_owned(), Duration::from_secs(timeout))).map_err(|e| {
|
||||
error!("Failed to add presence timer: {}", e);
|
||||
Error::bad_database("Failed to add presence timer")
|
||||
})
|
||||
}
|
||||
|
||||
fn set_presence(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
user_id: &UserId,
|
||||
presence_state: PresenceState,
|
||||
currently_active: Option<bool>,
|
||||
last_active_ago: Option<UInt>,
|
||||
status_msg: Option<String>,
|
||||
) -> Result<()> {
|
||||
let now = utils::millis_since_unix_epoch();
|
||||
let last_active_ts = match last_active_ago {
|
||||
Some(last_active_ago) => now.saturating_sub(last_active_ago.into()),
|
||||
None => now,
|
||||
};
|
||||
fn set_presence(
|
||||
&self, room_id: &RoomId, user_id: &UserId, presence_state: PresenceState, currently_active: Option<bool>,
|
||||
last_active_ago: Option<UInt>, status_msg: Option<String>,
|
||||
) -> Result<()> {
|
||||
let now = utils::millis_since_unix_epoch();
|
||||
let last_active_ts = match last_active_ago {
|
||||
Some(last_active_ago) => now.saturating_sub(last_active_ago.into()),
|
||||
None => now,
|
||||
};
|
||||
|
||||
let key = presence_key(room_id, user_id);
|
||||
let key = presence_key(room_id, user_id);
|
||||
|
||||
let presence = Presence::new(
|
||||
presence_state,
|
||||
currently_active.unwrap_or(false),
|
||||
last_active_ts,
|
||||
services().globals.next_count()?,
|
||||
status_msg,
|
||||
);
|
||||
let presence = Presence::new(
|
||||
presence_state,
|
||||
currently_active.unwrap_or(false),
|
||||
last_active_ts,
|
||||
services().globals.next_count()?,
|
||||
status_msg,
|
||||
);
|
||||
|
||||
let timeout = match presence.state {
|
||||
PresenceState::Online => services().globals.config.presence_idle_timeout_s,
|
||||
_ => services().globals.config.presence_offline_timeout_s,
|
||||
};
|
||||
let timeout = match presence.state {
|
||||
PresenceState::Online => services().globals.config.presence_idle_timeout_s,
|
||||
_ => services().globals.config.presence_offline_timeout_s,
|
||||
};
|
||||
|
||||
self.presence_timer_sender
|
||||
.send((user_id.to_owned(), Duration::from_secs(timeout)))
|
||||
.map_err(|e| {
|
||||
error!("Failed to add presence timer: {}", e);
|
||||
Error::bad_database("Failed to add presence timer")
|
||||
})?;
|
||||
self.presence_timer_sender.send((user_id.to_owned(), Duration::from_secs(timeout))).map_err(|e| {
|
||||
error!("Failed to add presence timer: {}", e);
|
||||
Error::bad_database("Failed to add presence timer")
|
||||
})?;
|
||||
|
||||
self.roomuserid_presence
|
||||
.insert(&key, &presence.to_json_bytes()?)?;
|
||||
self.roomuserid_presence.insert(&key, &presence.to_json_bytes()?)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn remove_presence(&self, user_id: &UserId) -> Result<()> {
|
||||
for room_id in services().rooms.state_cache.rooms_joined(user_id) {
|
||||
let key = presence_key(&room_id?, user_id);
|
||||
fn remove_presence(&self, user_id: &UserId) -> Result<()> {
|
||||
for room_id in services().rooms.state_cache.rooms_joined(user_id) {
|
||||
let key = presence_key(&room_id?, user_id);
|
||||
|
||||
self.roomuserid_presence.remove(&key)?;
|
||||
}
|
||||
self.roomuserid_presence.remove(&key)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn presence_since<'a>(
|
||||
&'a self,
|
||||
room_id: &RoomId,
|
||||
since: u64,
|
||||
) -> Box<dyn Iterator<Item = (OwnedUserId, u64, PresenceEvent)> + 'a> {
|
||||
let prefix = [room_id.as_bytes(), &[0xff]].concat();
|
||||
fn presence_since<'a>(
|
||||
&'a self, room_id: &RoomId, since: u64,
|
||||
) -> Box<dyn Iterator<Item = (OwnedUserId, u64, PresenceEvent)> + 'a> {
|
||||
let prefix = [room_id.as_bytes(), &[0xFF]].concat();
|
||||
|
||||
Box::new(
|
||||
self.roomuserid_presence
|
||||
.scan_prefix(prefix)
|
||||
.flat_map(
|
||||
|(key, presence_bytes)| -> Result<(OwnedUserId, u64, PresenceEvent)> {
|
||||
let user_id = user_id_from_bytes(
|
||||
key.rsplit(|byte| *byte == 0xff).next().ok_or_else(|| {
|
||||
Error::bad_database("No UserID bytes in presence key")
|
||||
})?,
|
||||
)?;
|
||||
Box::new(
|
||||
self.roomuserid_presence
|
||||
.scan_prefix(prefix)
|
||||
.flat_map(|(key, presence_bytes)| -> Result<(OwnedUserId, u64, PresenceEvent)> {
|
||||
let user_id = user_id_from_bytes(
|
||||
key.rsplit(|byte| *byte == 0xFF)
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("No UserID bytes in presence key"))?,
|
||||
)?;
|
||||
|
||||
let presence = Presence::from_json_bytes(&presence_bytes)?;
|
||||
let presence_event = presence.to_presence_event(&user_id)?;
|
||||
let presence = Presence::from_json_bytes(&presence_bytes)?;
|
||||
let presence_event = presence.to_presence_event(&user_id)?;
|
||||
|
||||
Ok((user_id, presence.last_count, presence_event))
|
||||
},
|
||||
)
|
||||
.filter(move |(_, count, _)| *count > since),
|
||||
)
|
||||
}
|
||||
Ok((user_id, presence.last_count, presence_event))
|
||||
})
|
||||
.filter(move |(_, count, _)| *count > since),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn presence_key(room_id: &RoomId, user_id: &UserId) -> Vec<u8> {
|
||||
[room_id.as_bytes(), &[0xff], user_id.as_bytes()].concat()
|
||||
[room_id.as_bytes(), &[0xFF], user_id.as_bytes()].concat()
|
||||
}
|
||||
|
||||
@@ -1,150 +1,113 @@
|
||||
use std::mem;
|
||||
|
||||
use ruma::{
|
||||
events::receipt::ReceiptEvent, serde::Raw, CanonicalJsonObject, OwnedUserId, RoomId, UserId,
|
||||
};
|
||||
use ruma::{events::receipt::ReceiptEvent, serde::Raw, CanonicalJsonObject, OwnedUserId, RoomId, UserId};
|
||||
|
||||
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
|
||||
|
||||
impl service::rooms::edus::read_receipt::Data for KeyValueDatabase {
|
||||
fn readreceipt_update(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
room_id: &RoomId,
|
||||
event: ReceiptEvent,
|
||||
) -> Result<()> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: ReceiptEvent) -> Result<()> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xFF);
|
||||
|
||||
let mut last_possible_key = prefix.clone();
|
||||
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
|
||||
let mut last_possible_key = prefix.clone();
|
||||
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
|
||||
|
||||
// Remove old entry
|
||||
if let Some((old, _)) = self
|
||||
.readreceiptid_readreceipt
|
||||
.iter_from(&last_possible_key, true)
|
||||
.take_while(|(key, _)| key.starts_with(&prefix))
|
||||
.find(|(key, _)| {
|
||||
key.rsplit(|&b| b == 0xff)
|
||||
.next()
|
||||
.expect("rsplit always returns an element")
|
||||
== user_id.as_bytes()
|
||||
})
|
||||
{
|
||||
// This is the old room_latest
|
||||
self.readreceiptid_readreceipt.remove(&old)?;
|
||||
}
|
||||
// Remove old entry
|
||||
if let Some((old, _)) = self
|
||||
.readreceiptid_readreceipt
|
||||
.iter_from(&last_possible_key, true)
|
||||
.take_while(|(key, _)| key.starts_with(&prefix))
|
||||
.find(|(key, _)| {
|
||||
key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element") == user_id.as_bytes()
|
||||
}) {
|
||||
// This is the old room_latest
|
||||
self.readreceiptid_readreceipt.remove(&old)?;
|
||||
}
|
||||
|
||||
let mut room_latest_id = prefix;
|
||||
room_latest_id.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
|
||||
room_latest_id.push(0xff);
|
||||
room_latest_id.extend_from_slice(user_id.as_bytes());
|
||||
let mut room_latest_id = prefix;
|
||||
room_latest_id.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
|
||||
room_latest_id.push(0xFF);
|
||||
room_latest_id.extend_from_slice(user_id.as_bytes());
|
||||
|
||||
self.readreceiptid_readreceipt.insert(
|
||||
&room_latest_id,
|
||||
&serde_json::to_vec(&event).expect("EduEvent::to_string always works"),
|
||||
)?;
|
||||
self.readreceiptid_readreceipt.insert(
|
||||
&room_latest_id,
|
||||
&serde_json::to_vec(&event).expect("EduEvent::to_string always works"),
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn readreceipts_since<'a>(
|
||||
&'a self,
|
||||
room_id: &RoomId,
|
||||
since: u64,
|
||||
) -> Box<
|
||||
dyn Iterator<
|
||||
Item = Result<(
|
||||
OwnedUserId,
|
||||
u64,
|
||||
Raw<ruma::events::AnySyncEphemeralRoomEvent>,
|
||||
)>,
|
||||
> + 'a,
|
||||
> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
let prefix2 = prefix.clone();
|
||||
fn readreceipts_since<'a>(
|
||||
&'a self, room_id: &RoomId, since: u64,
|
||||
) -> Box<dyn Iterator<Item = Result<(OwnedUserId, u64, Raw<ruma::events::AnySyncEphemeralRoomEvent>)>> + 'a> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xFF);
|
||||
let prefix2 = prefix.clone();
|
||||
|
||||
let mut first_possible_edu = prefix.clone();
|
||||
first_possible_edu.extend_from_slice(&(since + 1).to_be_bytes()); // +1 so we don't send the event at since
|
||||
let mut first_possible_edu = prefix.clone();
|
||||
first_possible_edu.extend_from_slice(&(since + 1).to_be_bytes()); // +1 so we don't send the event at since
|
||||
|
||||
Box::new(
|
||||
self.readreceiptid_readreceipt
|
||||
.iter_from(&first_possible_edu, false)
|
||||
.take_while(move |(k, _)| k.starts_with(&prefix2))
|
||||
.map(move |(k, v)| {
|
||||
let count = utils::u64_from_bytes(
|
||||
&k[prefix.len()..prefix.len() + mem::size_of::<u64>()],
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Invalid readreceiptid count in db."))?;
|
||||
let user_id = UserId::parse(
|
||||
utils::string_from_bytes(&k[prefix.len() + mem::size_of::<u64>() + 1..])
|
||||
.map_err(|_| {
|
||||
Error::bad_database("Invalid readreceiptid userid bytes in db.")
|
||||
})?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Invalid readreceiptid userid in db."))?;
|
||||
Box::new(
|
||||
self.readreceiptid_readreceipt
|
||||
.iter_from(&first_possible_edu, false)
|
||||
.take_while(move |(k, _)| k.starts_with(&prefix2))
|
||||
.map(move |(k, v)| {
|
||||
let count = utils::u64_from_bytes(&k[prefix.len()..prefix.len() + mem::size_of::<u64>()])
|
||||
.map_err(|_| Error::bad_database("Invalid readreceiptid count in db."))?;
|
||||
let user_id = UserId::parse(
|
||||
utils::string_from_bytes(&k[prefix.len() + mem::size_of::<u64>() + 1..])
|
||||
.map_err(|_| Error::bad_database("Invalid readreceiptid userid bytes in db."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Invalid readreceiptid userid in db."))?;
|
||||
|
||||
let mut json =
|
||||
serde_json::from_slice::<CanonicalJsonObject>(&v).map_err(|_| {
|
||||
Error::bad_database(
|
||||
"Read receipt in roomlatestid_roomlatest is invalid json.",
|
||||
)
|
||||
})?;
|
||||
json.remove("room_id");
|
||||
let mut json = serde_json::from_slice::<CanonicalJsonObject>(&v)
|
||||
.map_err(|_| Error::bad_database("Read receipt in roomlatestid_roomlatest is invalid json."))?;
|
||||
json.remove("room_id");
|
||||
|
||||
Ok((
|
||||
user_id,
|
||||
count,
|
||||
Raw::from_json(
|
||||
serde_json::value::to_raw_value(&json)
|
||||
.expect("json is valid raw value"),
|
||||
),
|
||||
))
|
||||
}),
|
||||
)
|
||||
}
|
||||
Ok((
|
||||
user_id,
|
||||
count,
|
||||
Raw::from_json(serde_json::value::to_raw_value(&json).expect("json is valid raw value")),
|
||||
))
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> {
|
||||
let mut key = room_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(user_id.as_bytes());
|
||||
fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> {
|
||||
let mut key = room_id.as_bytes().to_vec();
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(user_id.as_bytes());
|
||||
|
||||
self.roomuserid_privateread
|
||||
.insert(&key, &count.to_be_bytes())?;
|
||||
self.roomuserid_privateread.insert(&key, &count.to_be_bytes())?;
|
||||
|
||||
self.roomuserid_lastprivatereadupdate
|
||||
.insert(&key, &services().globals.next_count()?.to_be_bytes())
|
||||
}
|
||||
self.roomuserid_lastprivatereadupdate.insert(&key, &services().globals.next_count()?.to_be_bytes())
|
||||
}
|
||||
|
||||
fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {
|
||||
let mut key = room_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(user_id.as_bytes());
|
||||
fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {
|
||||
let mut key = room_id.as_bytes().to_vec();
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(user_id.as_bytes());
|
||||
|
||||
self.roomuserid_privateread
|
||||
.get(&key)?
|
||||
.map_or(Ok(None), |v| {
|
||||
Ok(Some(utils::u64_from_bytes(&v).map_err(|_| {
|
||||
Error::bad_database("Invalid private read marker bytes")
|
||||
})?))
|
||||
})
|
||||
}
|
||||
self.roomuserid_privateread.get(&key)?.map_or(Ok(None), |v| {
|
||||
Ok(Some(
|
||||
utils::u64_from_bytes(&v).map_err(|_| Error::bad_database("Invalid private read marker bytes"))?,
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
|
||||
let mut key = room_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(user_id.as_bytes());
|
||||
fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
|
||||
let mut key = room_id.as_bytes().to_vec();
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(user_id.as_bytes());
|
||||
|
||||
Ok(self
|
||||
.roomuserid_lastprivatereadupdate
|
||||
.get(&key)?
|
||||
.map(|bytes| {
|
||||
utils::u64_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid.")
|
||||
})
|
||||
})
|
||||
.transpose()?
|
||||
.unwrap_or(0))
|
||||
}
|
||||
Ok(self
|
||||
.roomuserid_lastprivatereadupdate
|
||||
.get(&key)?
|
||||
.map(|bytes| {
|
||||
utils::u64_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid."))
|
||||
})
|
||||
.transpose()?
|
||||
.unwrap_or(0))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,123 +5,111 @@ use ruma::{OwnedUserId, RoomId, UserId};
|
||||
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
|
||||
|
||||
impl service::rooms::edus::typing::Data for KeyValueDatabase {
|
||||
fn typing_add(&self, user_id: &UserId, room_id: &RoomId, timeout: u64) -> Result<()> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
fn typing_add(&self, user_id: &UserId, room_id: &RoomId, timeout: u64) -> Result<()> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xFF);
|
||||
|
||||
let count = services().globals.next_count()?.to_be_bytes();
|
||||
let count = services().globals.next_count()?.to_be_bytes();
|
||||
|
||||
let mut room_typing_id = prefix;
|
||||
room_typing_id.extend_from_slice(&timeout.to_be_bytes());
|
||||
room_typing_id.push(0xff);
|
||||
room_typing_id.extend_from_slice(&count);
|
||||
let mut room_typing_id = prefix;
|
||||
room_typing_id.extend_from_slice(&timeout.to_be_bytes());
|
||||
room_typing_id.push(0xFF);
|
||||
room_typing_id.extend_from_slice(&count);
|
||||
|
||||
self.typingid_userid
|
||||
.insert(&room_typing_id, user_id.as_bytes())?;
|
||||
self.typingid_userid.insert(&room_typing_id, user_id.as_bytes())?;
|
||||
|
||||
self.roomid_lasttypingupdate
|
||||
.insert(room_id.as_bytes(), &count)?;
|
||||
self.roomid_lasttypingupdate.insert(room_id.as_bytes(), &count)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn typing_remove(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
fn typing_remove(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xFF);
|
||||
|
||||
let user_id = user_id.to_string();
|
||||
let user_id = user_id.to_string();
|
||||
|
||||
let mut found_outdated = false;
|
||||
let mut found_outdated = false;
|
||||
|
||||
// Maybe there are multiple ones from calling roomtyping_add multiple times
|
||||
for outdated_edu in self
|
||||
.typingid_userid
|
||||
.scan_prefix(prefix)
|
||||
.filter(|(_, v)| &**v == user_id.as_bytes())
|
||||
{
|
||||
self.typingid_userid.remove(&outdated_edu.0)?;
|
||||
found_outdated = true;
|
||||
}
|
||||
// Maybe there are multiple ones from calling roomtyping_add multiple times
|
||||
for outdated_edu in self.typingid_userid.scan_prefix(prefix).filter(|(_, v)| &**v == user_id.as_bytes()) {
|
||||
self.typingid_userid.remove(&outdated_edu.0)?;
|
||||
found_outdated = true;
|
||||
}
|
||||
|
||||
if found_outdated {
|
||||
self.roomid_lasttypingupdate.insert(
|
||||
room_id.as_bytes(),
|
||||
&services().globals.next_count()?.to_be_bytes(),
|
||||
)?;
|
||||
}
|
||||
if found_outdated {
|
||||
self.roomid_lasttypingupdate.insert(room_id.as_bytes(), &services().globals.next_count()?.to_be_bytes())?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn typings_maintain(&self, room_id: &RoomId) -> Result<()> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
fn typings_maintain(&self, room_id: &RoomId) -> Result<()> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xFF);
|
||||
|
||||
let current_timestamp = utils::millis_since_unix_epoch();
|
||||
let current_timestamp = utils::millis_since_unix_epoch();
|
||||
|
||||
let mut found_outdated = false;
|
||||
let mut found_outdated = false;
|
||||
|
||||
// Find all outdated edus before inserting a new one
|
||||
for outdated_edu in self
|
||||
.typingid_userid
|
||||
.scan_prefix(prefix)
|
||||
.map(|(key, _)| {
|
||||
Ok::<_, Error>((
|
||||
key.clone(),
|
||||
utils::u64_from_bytes(
|
||||
&key.splitn(2, |&b| b == 0xff).nth(1).ok_or_else(|| {
|
||||
Error::bad_database("RoomTyping has invalid timestamp or delimiters.")
|
||||
})?[0..mem::size_of::<u64>()],
|
||||
)
|
||||
.map_err(|_| Error::bad_database("RoomTyping has invalid timestamp bytes."))?,
|
||||
))
|
||||
})
|
||||
.filter_map(std::result::Result::ok)
|
||||
.take_while(|&(_, timestamp)| timestamp < current_timestamp)
|
||||
{
|
||||
// This is an outdated edu (time > timestamp)
|
||||
self.typingid_userid.remove(&outdated_edu.0)?;
|
||||
found_outdated = true;
|
||||
}
|
||||
// Find all outdated edus before inserting a new one
|
||||
for outdated_edu in self
|
||||
.typingid_userid
|
||||
.scan_prefix(prefix)
|
||||
.map(|(key, _)| {
|
||||
Ok::<_, Error>((
|
||||
key.clone(),
|
||||
utils::u64_from_bytes(
|
||||
&key.splitn(2, |&b| b == 0xFF)
|
||||
.nth(1)
|
||||
.ok_or_else(|| Error::bad_database("RoomTyping has invalid timestamp or delimiters."))?[0..mem::size_of::<u64>()],
|
||||
)
|
||||
.map_err(|_| Error::bad_database("RoomTyping has invalid timestamp bytes."))?,
|
||||
))
|
||||
})
|
||||
.filter_map(std::result::Result::ok)
|
||||
.take_while(|&(_, timestamp)| timestamp < current_timestamp)
|
||||
{
|
||||
// This is an outdated edu (time > timestamp)
|
||||
self.typingid_userid.remove(&outdated_edu.0)?;
|
||||
found_outdated = true;
|
||||
}
|
||||
|
||||
if found_outdated {
|
||||
self.roomid_lasttypingupdate.insert(
|
||||
room_id.as_bytes(),
|
||||
&services().globals.next_count()?.to_be_bytes(),
|
||||
)?;
|
||||
}
|
||||
if found_outdated {
|
||||
self.roomid_lasttypingupdate.insert(room_id.as_bytes(), &services().globals.next_count()?.to_be_bytes())?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn last_typing_update(&self, room_id: &RoomId) -> Result<u64> {
|
||||
Ok(self
|
||||
.roomid_lasttypingupdate
|
||||
.get(room_id.as_bytes())?
|
||||
.map(|bytes| {
|
||||
utils::u64_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("Count in roomid_lastroomactiveupdate is invalid.")
|
||||
})
|
||||
})
|
||||
.transpose()?
|
||||
.unwrap_or(0))
|
||||
}
|
||||
fn last_typing_update(&self, room_id: &RoomId) -> Result<u64> {
|
||||
Ok(self
|
||||
.roomid_lasttypingupdate
|
||||
.get(room_id.as_bytes())?
|
||||
.map(|bytes| {
|
||||
utils::u64_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Count in roomid_lastroomactiveupdate is invalid."))
|
||||
})
|
||||
.transpose()?
|
||||
.unwrap_or(0))
|
||||
}
|
||||
|
||||
fn typings_all(&self, room_id: &RoomId) -> Result<HashSet<OwnedUserId>> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
fn typings_all(&self, room_id: &RoomId) -> Result<HashSet<OwnedUserId>> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xFF);
|
||||
|
||||
let mut user_ids = HashSet::new();
|
||||
let mut user_ids = HashSet::new();
|
||||
|
||||
for (_, user_id) in self.typingid_userid.scan_prefix(prefix) {
|
||||
let user_id = UserId::parse(utils::string_from_bytes(&user_id).map_err(|_| {
|
||||
Error::bad_database("User ID in typingid_userid is invalid unicode.")
|
||||
})?)
|
||||
.map_err(|_| Error::bad_database("User ID in typingid_userid is invalid."))?;
|
||||
for (_, user_id) in self.typingid_userid.scan_prefix(prefix) {
|
||||
let user_id = UserId::parse(
|
||||
utils::string_from_bytes(&user_id)
|
||||
.map_err(|_| Error::bad_database("User ID in typingid_userid is invalid unicode."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("User ID in typingid_userid is invalid."))?;
|
||||
|
||||
user_ids.insert(user_id);
|
||||
}
|
||||
user_ids.insert(user_id);
|
||||
}
|
||||
|
||||
Ok(user_ids)
|
||||
}
|
||||
Ok(user_ids)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,63 +3,51 @@ use ruma::{DeviceId, RoomId, UserId};
|
||||
use crate::{database::KeyValueDatabase, service, Result};
|
||||
|
||||
impl service::rooms::lazy_loading::Data for KeyValueDatabase {
|
||||
fn lazy_load_was_sent_before(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
room_id: &RoomId,
|
||||
ll_user: &UserId,
|
||||
) -> Result<bool> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(device_id.as_bytes());
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(room_id.as_bytes());
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(ll_user.as_bytes());
|
||||
Ok(self.lazyloadedids.get(&key)?.is_some())
|
||||
}
|
||||
fn lazy_load_was_sent_before(
|
||||
&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, ll_user: &UserId,
|
||||
) -> Result<bool> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(device_id.as_bytes());
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(room_id.as_bytes());
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(ll_user.as_bytes());
|
||||
Ok(self.lazyloadedids.get(&key)?.is_some())
|
||||
}
|
||||
|
||||
fn lazy_load_confirm_delivery(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
room_id: &RoomId,
|
||||
confirmed_user_ids: &mut dyn Iterator<Item = &UserId>,
|
||||
) -> Result<()> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.extend_from_slice(device_id.as_bytes());
|
||||
prefix.push(0xff);
|
||||
prefix.extend_from_slice(room_id.as_bytes());
|
||||
prefix.push(0xff);
|
||||
fn lazy_load_confirm_delivery(
|
||||
&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId,
|
||||
confirmed_user_ids: &mut dyn Iterator<Item = &UserId>,
|
||||
) -> Result<()> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xFF);
|
||||
prefix.extend_from_slice(device_id.as_bytes());
|
||||
prefix.push(0xFF);
|
||||
prefix.extend_from_slice(room_id.as_bytes());
|
||||
prefix.push(0xFF);
|
||||
|
||||
for ll_id in confirmed_user_ids {
|
||||
let mut key = prefix.clone();
|
||||
key.extend_from_slice(ll_id.as_bytes());
|
||||
self.lazyloadedids.insert(&key, &[])?;
|
||||
}
|
||||
for ll_id in confirmed_user_ids {
|
||||
let mut key = prefix.clone();
|
||||
key.extend_from_slice(ll_id.as_bytes());
|
||||
self.lazyloadedids.insert(&key, &[])?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn lazy_load_reset(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
room_id: &RoomId,
|
||||
) -> Result<()> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.extend_from_slice(device_id.as_bytes());
|
||||
prefix.push(0xff);
|
||||
prefix.extend_from_slice(room_id.as_bytes());
|
||||
prefix.push(0xff);
|
||||
fn lazy_load_reset(&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId) -> Result<()> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xFF);
|
||||
prefix.extend_from_slice(device_id.as_bytes());
|
||||
prefix.push(0xFF);
|
||||
prefix.extend_from_slice(room_id.as_bytes());
|
||||
prefix.push(0xFF);
|
||||
|
||||
for (key, _) in self.lazyloadedids.scan_prefix(prefix) {
|
||||
self.lazyloadedids.remove(&key)?;
|
||||
}
|
||||
for (key, _) in self.lazyloadedids.scan_prefix(prefix) {
|
||||
self.lazyloadedids.remove(&key)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,76 +4,68 @@ use tracing::error;
|
||||
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
|
||||
|
||||
impl service::rooms::metadata::Data for KeyValueDatabase {
|
||||
fn exists(&self, room_id: &RoomId) -> Result<bool> {
|
||||
let prefix = match services().rooms.short.get_shortroomid(room_id)? {
|
||||
Some(b) => b.to_be_bytes().to_vec(),
|
||||
None => return Ok(false),
|
||||
};
|
||||
fn exists(&self, room_id: &RoomId) -> Result<bool> {
|
||||
let prefix = match services().rooms.short.get_shortroomid(room_id)? {
|
||||
Some(b) => b.to_be_bytes().to_vec(),
|
||||
None => return Ok(false),
|
||||
};
|
||||
|
||||
// Look for PDUs in that room.
|
||||
Ok(self
|
||||
.pduid_pdu
|
||||
.iter_from(&prefix, false)
|
||||
.next()
|
||||
.filter(|(k, _)| k.starts_with(&prefix))
|
||||
.is_some())
|
||||
}
|
||||
// Look for PDUs in that room.
|
||||
Ok(self.pduid_pdu.iter_from(&prefix, false).next().filter(|(k, _)| k.starts_with(&prefix)).is_some())
|
||||
}
|
||||
|
||||
fn iter_ids<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
|
||||
Box::new(self.roomid_shortroomid.iter().map(|(bytes, _)| {
|
||||
RoomId::parse(
|
||||
utils::string_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("Room ID in publicroomids is invalid unicode.")
|
||||
})?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Room ID in roomid_shortroomid is invalid."))
|
||||
}))
|
||||
}
|
||||
fn iter_ids<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
|
||||
Box::new(self.roomid_shortroomid.iter().map(|(bytes, _)| {
|
||||
RoomId::parse(
|
||||
utils::string_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Room ID in publicroomids is invalid unicode."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Room ID in roomid_shortroomid is invalid."))
|
||||
}))
|
||||
}
|
||||
|
||||
fn is_disabled(&self, room_id: &RoomId) -> Result<bool> {
|
||||
Ok(self.disabledroomids.get(room_id.as_bytes())?.is_some())
|
||||
}
|
||||
fn is_disabled(&self, room_id: &RoomId) -> Result<bool> {
|
||||
Ok(self.disabledroomids.get(room_id.as_bytes())?.is_some())
|
||||
}
|
||||
|
||||
fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()> {
|
||||
if disabled {
|
||||
self.disabledroomids.insert(room_id.as_bytes(), &[])?;
|
||||
} else {
|
||||
self.disabledroomids.remove(room_id.as_bytes())?;
|
||||
}
|
||||
fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()> {
|
||||
if disabled {
|
||||
self.disabledroomids.insert(room_id.as_bytes(), &[])?;
|
||||
} else {
|
||||
self.disabledroomids.remove(room_id.as_bytes())?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn is_banned(&self, room_id: &RoomId) -> Result<bool> {
|
||||
Ok(self.bannedroomids.get(room_id.as_bytes())?.is_some())
|
||||
}
|
||||
fn is_banned(&self, room_id: &RoomId) -> Result<bool> { Ok(self.bannedroomids.get(room_id.as_bytes())?.is_some()) }
|
||||
|
||||
fn ban_room(&self, room_id: &RoomId, banned: bool) -> Result<()> {
|
||||
if banned {
|
||||
self.bannedroomids.insert(room_id.as_bytes(), &[])?;
|
||||
} else {
|
||||
self.bannedroomids.remove(room_id.as_bytes())?;
|
||||
}
|
||||
fn ban_room(&self, room_id: &RoomId, banned: bool) -> Result<()> {
|
||||
if banned {
|
||||
self.bannedroomids.insert(room_id.as_bytes(), &[])?;
|
||||
} else {
|
||||
self.bannedroomids.remove(room_id.as_bytes())?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn list_banned_rooms<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
|
||||
Box::new(self.bannedroomids.iter().map(
|
||||
|(room_id_bytes, _ /* non-banned rooms should not be in this table */)| {
|
||||
let room_id = utils::string_from_bytes(&room_id_bytes)
|
||||
.map_err(|e| {
|
||||
error!("Invalid room_id bytes in bannedroomids: {e}");
|
||||
Error::bad_database("Invalid room_id in bannedroomids.")
|
||||
})?
|
||||
.try_into()
|
||||
.map_err(|e| {
|
||||
error!("Invalid room_id in bannedroomids: {e}");
|
||||
Error::bad_database("Invalid room_id in bannedroomids")
|
||||
})?;
|
||||
fn list_banned_rooms<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
|
||||
Box::new(self.bannedroomids.iter().map(
|
||||
|(room_id_bytes, _ /* non-banned rooms should not be in this table */)| {
|
||||
let room_id = utils::string_from_bytes(&room_id_bytes)
|
||||
.map_err(|e| {
|
||||
error!("Invalid room_id bytes in bannedroomids: {e}");
|
||||
Error::bad_database("Invalid room_id in bannedroomids.")
|
||||
})?
|
||||
.try_into()
|
||||
.map_err(|e| {
|
||||
error!("Invalid room_id in bannedroomids: {e}");
|
||||
Error::bad_database("Invalid room_id in bannedroomids")
|
||||
})?;
|
||||
|
||||
Ok(room_id)
|
||||
},
|
||||
))
|
||||
}
|
||||
Ok(room_id)
|
||||
},
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,26 +3,22 @@ use ruma::{CanonicalJsonObject, EventId};
|
||||
use crate::{database::KeyValueDatabase, service, Error, PduEvent, Result};
|
||||
|
||||
impl service::rooms::outlier::Data for KeyValueDatabase {
|
||||
fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
|
||||
self.eventid_outlierpdu
|
||||
.get(event_id.as_bytes())?
|
||||
.map_or(Ok(None), |pdu| {
|
||||
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))
|
||||
})
|
||||
}
|
||||
fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
|
||||
self.eventid_outlierpdu.get(event_id.as_bytes())?.map_or(Ok(None), |pdu| {
|
||||
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))
|
||||
})
|
||||
}
|
||||
|
||||
fn get_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>> {
|
||||
self.eventid_outlierpdu
|
||||
.get(event_id.as_bytes())?
|
||||
.map_or(Ok(None), |pdu| {
|
||||
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))
|
||||
})
|
||||
}
|
||||
fn get_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>> {
|
||||
self.eventid_outlierpdu.get(event_id.as_bytes())?.map_or(Ok(None), |pdu| {
|
||||
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))
|
||||
})
|
||||
}
|
||||
|
||||
fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()> {
|
||||
self.eventid_outlierpdu.insert(
|
||||
event_id.as_bytes(),
|
||||
&serde_json::to_vec(&pdu).expect("CanonicalJsonObject is valid"),
|
||||
)
|
||||
}
|
||||
fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()> {
|
||||
self.eventid_outlierpdu.insert(
|
||||
event_id.as_bytes(),
|
||||
&serde_json::to_vec(&pdu).expect("CanonicalJsonObject is valid"),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,85 +3,78 @@ use std::{mem, sync::Arc};
|
||||
use ruma::{EventId, RoomId, UserId};
|
||||
|
||||
use crate::{
|
||||
database::KeyValueDatabase,
|
||||
service::{self, rooms::timeline::PduCount},
|
||||
services, utils, Error, PduEvent, Result,
|
||||
database::KeyValueDatabase,
|
||||
service::{self, rooms::timeline::PduCount},
|
||||
services, utils, Error, PduEvent, Result,
|
||||
};
|
||||
|
||||
impl service::rooms::pdu_metadata::Data for KeyValueDatabase {
|
||||
fn add_relation(&self, from: u64, to: u64) -> Result<()> {
|
||||
let mut key = to.to_be_bytes().to_vec();
|
||||
key.extend_from_slice(&from.to_be_bytes());
|
||||
self.tofrom_relation.insert(&key, &[])?;
|
||||
Ok(())
|
||||
}
|
||||
fn add_relation(&self, from: u64, to: u64) -> Result<()> {
|
||||
let mut key = to.to_be_bytes().to_vec();
|
||||
key.extend_from_slice(&from.to_be_bytes());
|
||||
self.tofrom_relation.insert(&key, &[])?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn relations_until<'a>(
|
||||
&'a self,
|
||||
user_id: &'a UserId,
|
||||
shortroomid: u64,
|
||||
target: u64,
|
||||
until: PduCount,
|
||||
) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>> {
|
||||
let prefix = target.to_be_bytes().to_vec();
|
||||
let mut current = prefix.clone();
|
||||
fn relations_until<'a>(
|
||||
&'a self, user_id: &'a UserId, shortroomid: u64, target: u64, until: PduCount,
|
||||
) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>> {
|
||||
let prefix = target.to_be_bytes().to_vec();
|
||||
let mut current = prefix.clone();
|
||||
|
||||
let count_raw = match until {
|
||||
PduCount::Normal(x) => x - 1,
|
||||
PduCount::Backfilled(x) => {
|
||||
current.extend_from_slice(&0_u64.to_be_bytes());
|
||||
u64::MAX - x - 1
|
||||
}
|
||||
};
|
||||
current.extend_from_slice(&count_raw.to_be_bytes());
|
||||
let count_raw = match until {
|
||||
PduCount::Normal(x) => x - 1,
|
||||
PduCount::Backfilled(x) => {
|
||||
current.extend_from_slice(&0_u64.to_be_bytes());
|
||||
u64::MAX - x - 1
|
||||
},
|
||||
};
|
||||
current.extend_from_slice(&count_raw.to_be_bytes());
|
||||
|
||||
Ok(Box::new(
|
||||
self.tofrom_relation
|
||||
.iter_from(¤t, true)
|
||||
.take_while(move |(k, _)| k.starts_with(&prefix))
|
||||
.map(move |(tofrom, _data)| {
|
||||
let from = utils::u64_from_bytes(&tofrom[(mem::size_of::<u64>())..])
|
||||
.map_err(|_| Error::bad_database("Invalid count in tofrom_relation."))?;
|
||||
Ok(Box::new(
|
||||
self.tofrom_relation.iter_from(¤t, true).take_while(move |(k, _)| k.starts_with(&prefix)).map(
|
||||
move |(tofrom, _data)| {
|
||||
let from = utils::u64_from_bytes(&tofrom[(mem::size_of::<u64>())..])
|
||||
.map_err(|_| Error::bad_database("Invalid count in tofrom_relation."))?;
|
||||
|
||||
let mut pduid = shortroomid.to_be_bytes().to_vec();
|
||||
pduid.extend_from_slice(&from.to_be_bytes());
|
||||
let mut pduid = shortroomid.to_be_bytes().to_vec();
|
||||
pduid.extend_from_slice(&from.to_be_bytes());
|
||||
|
||||
let mut pdu = services()
|
||||
.rooms
|
||||
.timeline
|
||||
.get_pdu_from_id(&pduid)?
|
||||
.ok_or_else(|| Error::bad_database("Pdu in tofrom_relation is invalid."))?;
|
||||
if pdu.sender != user_id {
|
||||
pdu.remove_transaction_id()?;
|
||||
}
|
||||
Ok((PduCount::Normal(from), pdu))
|
||||
}),
|
||||
))
|
||||
}
|
||||
let mut pdu = services()
|
||||
.rooms
|
||||
.timeline
|
||||
.get_pdu_from_id(&pduid)?
|
||||
.ok_or_else(|| Error::bad_database("Pdu in tofrom_relation is invalid."))?;
|
||||
if pdu.sender != user_id {
|
||||
pdu.remove_transaction_id()?;
|
||||
}
|
||||
Ok((PduCount::Normal(from), pdu))
|
||||
},
|
||||
),
|
||||
))
|
||||
}
|
||||
|
||||
fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc<EventId>]) -> Result<()> {
|
||||
for prev in event_ids {
|
||||
let mut key = room_id.as_bytes().to_vec();
|
||||
key.extend_from_slice(prev.as_bytes());
|
||||
self.referencedevents.insert(&key, &[])?;
|
||||
}
|
||||
fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc<EventId>]) -> Result<()> {
|
||||
for prev in event_ids {
|
||||
let mut key = room_id.as_bytes().to_vec();
|
||||
key.extend_from_slice(prev.as_bytes());
|
||||
self.referencedevents.insert(&key, &[])?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result<bool> {
|
||||
let mut key = room_id.as_bytes().to_vec();
|
||||
key.extend_from_slice(event_id.as_bytes());
|
||||
Ok(self.referencedevents.get(&key)?.is_some())
|
||||
}
|
||||
fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result<bool> {
|
||||
let mut key = room_id.as_bytes().to_vec();
|
||||
key.extend_from_slice(event_id.as_bytes());
|
||||
Ok(self.referencedevents.get(&key)?.is_some())
|
||||
}
|
||||
|
||||
fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()> {
|
||||
self.softfailedeventids.insert(event_id.as_bytes(), &[])
|
||||
}
|
||||
fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()> {
|
||||
self.softfailedeventids.insert(event_id.as_bytes(), &[])
|
||||
}
|
||||
|
||||
fn is_event_soft_failed(&self, event_id: &EventId) -> Result<bool> {
|
||||
self.softfailedeventids
|
||||
.get(event_id.as_bytes())
|
||||
.map(|o| o.is_some())
|
||||
}
|
||||
fn is_event_soft_failed(&self, event_id: &EventId) -> Result<bool> {
|
||||
self.softfailedeventids.get(event_id.as_bytes()).map(|o| o.is_some())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,61 +5,55 @@ use crate::{database::KeyValueDatabase, service, services, utils, Result};
|
||||
type SearchPdusResult<'a> = Result<Option<(Box<dyn Iterator<Item = Vec<u8>> + 'a>, Vec<String>)>>;
|
||||
|
||||
impl service::rooms::search::Data for KeyValueDatabase {
|
||||
fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> {
|
||||
let mut batch = message_body
|
||||
.split_terminator(|c: char| !c.is_alphanumeric())
|
||||
.filter(|s| !s.is_empty())
|
||||
.filter(|word| word.len() <= 50)
|
||||
.map(str::to_lowercase)
|
||||
.map(|word| {
|
||||
let mut key = shortroomid.to_be_bytes().to_vec();
|
||||
key.extend_from_slice(word.as_bytes());
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(pdu_id); // TODO: currently we save the room id a second time here
|
||||
(key, Vec::new())
|
||||
});
|
||||
fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> {
|
||||
let mut batch = message_body
|
||||
.split_terminator(|c: char| !c.is_alphanumeric())
|
||||
.filter(|s| !s.is_empty())
|
||||
.filter(|word| word.len() <= 50)
|
||||
.map(str::to_lowercase)
|
||||
.map(|word| {
|
||||
let mut key = shortroomid.to_be_bytes().to_vec();
|
||||
key.extend_from_slice(word.as_bytes());
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(pdu_id); // TODO: currently we save the room id a second time here
|
||||
(key, Vec::new())
|
||||
});
|
||||
|
||||
self.tokenids.insert_batch(&mut batch)
|
||||
}
|
||||
self.tokenids.insert_batch(&mut batch)
|
||||
}
|
||||
|
||||
fn search_pdus<'a>(&'a self, room_id: &RoomId, search_string: &str) -> SearchPdusResult<'a> {
|
||||
let prefix = services()
|
||||
.rooms
|
||||
.short
|
||||
.get_shortroomid(room_id)?
|
||||
.expect("room exists")
|
||||
.to_be_bytes()
|
||||
.to_vec();
|
||||
fn search_pdus<'a>(&'a self, room_id: &RoomId, search_string: &str) -> SearchPdusResult<'a> {
|
||||
let prefix = services().rooms.short.get_shortroomid(room_id)?.expect("room exists").to_be_bytes().to_vec();
|
||||
|
||||
let words: Vec<_> = search_string
|
||||
.split_terminator(|c: char| !c.is_alphanumeric())
|
||||
.filter(|s| !s.is_empty())
|
||||
.map(str::to_lowercase)
|
||||
.collect();
|
||||
let words: Vec<_> = search_string
|
||||
.split_terminator(|c: char| !c.is_alphanumeric())
|
||||
.filter(|s| !s.is_empty())
|
||||
.map(str::to_lowercase)
|
||||
.collect();
|
||||
|
||||
let iterators = words.clone().into_iter().map(move |word| {
|
||||
let mut prefix2 = prefix.clone();
|
||||
prefix2.extend_from_slice(word.as_bytes());
|
||||
prefix2.push(0xff);
|
||||
let prefix3 = prefix2.clone();
|
||||
let iterators = words.clone().into_iter().map(move |word| {
|
||||
let mut prefix2 = prefix.clone();
|
||||
prefix2.extend_from_slice(word.as_bytes());
|
||||
prefix2.push(0xFF);
|
||||
let prefix3 = prefix2.clone();
|
||||
|
||||
let mut last_possible_id = prefix2.clone();
|
||||
last_possible_id.extend_from_slice(&u64::MAX.to_be_bytes());
|
||||
let mut last_possible_id = prefix2.clone();
|
||||
last_possible_id.extend_from_slice(&u64::MAX.to_be_bytes());
|
||||
|
||||
self.tokenids
|
||||
.iter_from(&last_possible_id, true) // Newest pdus first
|
||||
.take_while(move |(k, _)| k.starts_with(&prefix2))
|
||||
.map(move |(key, _)| key[prefix3.len()..].to_vec())
|
||||
});
|
||||
self.tokenids
|
||||
.iter_from(&last_possible_id, true) // Newest pdus first
|
||||
.take_while(move |(k, _)| k.starts_with(&prefix2))
|
||||
.map(move |(key, _)| key[prefix3.len()..].to_vec())
|
||||
});
|
||||
|
||||
let common_elements = match utils::common_elements(iterators, |a, b| {
|
||||
// We compare b with a because we reversed the iterator earlier
|
||||
b.cmp(a)
|
||||
}) {
|
||||
Some(it) => it,
|
||||
None => return Ok(None),
|
||||
};
|
||||
let common_elements = match utils::common_elements(iterators, |a, b| {
|
||||
// We compare b with a because we reversed the iterator earlier
|
||||
b.cmp(a)
|
||||
}) {
|
||||
Some(it) => it,
|
||||
None => return Ok(None),
|
||||
};
|
||||
|
||||
Ok(Some((Box::new(common_elements), words)))
|
||||
}
|
||||
Ok(Some((Box::new(common_elements), words)))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,214 +6,165 @@ use tracing::warn;
|
||||
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
|
||||
|
||||
impl service::rooms::short::Data for KeyValueDatabase {
|
||||
fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result<u64> {
|
||||
if let Some(short) = self.eventidshort_cache.lock().unwrap().get_mut(event_id) {
|
||||
return Ok(*short);
|
||||
}
|
||||
fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result<u64> {
|
||||
if let Some(short) = self.eventidshort_cache.lock().unwrap().get_mut(event_id) {
|
||||
return Ok(*short);
|
||||
}
|
||||
|
||||
let short = match self.eventid_shorteventid.get(event_id.as_bytes())? {
|
||||
Some(shorteventid) => utils::u64_from_bytes(&shorteventid)
|
||||
.map_err(|_| Error::bad_database("Invalid shorteventid in db."))?,
|
||||
None => {
|
||||
let shorteventid = services().globals.next_count()?;
|
||||
self.eventid_shorteventid
|
||||
.insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?;
|
||||
self.shorteventid_eventid
|
||||
.insert(&shorteventid.to_be_bytes(), event_id.as_bytes())?;
|
||||
shorteventid
|
||||
}
|
||||
};
|
||||
let short = match self.eventid_shorteventid.get(event_id.as_bytes())? {
|
||||
Some(shorteventid) => {
|
||||
utils::u64_from_bytes(&shorteventid).map_err(|_| Error::bad_database("Invalid shorteventid in db."))?
|
||||
},
|
||||
None => {
|
||||
let shorteventid = services().globals.next_count()?;
|
||||
self.eventid_shorteventid.insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?;
|
||||
self.shorteventid_eventid.insert(&shorteventid.to_be_bytes(), event_id.as_bytes())?;
|
||||
shorteventid
|
||||
},
|
||||
};
|
||||
|
||||
self.eventidshort_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.insert(event_id.to_owned(), short);
|
||||
self.eventidshort_cache.lock().unwrap().insert(event_id.to_owned(), short);
|
||||
|
||||
Ok(short)
|
||||
}
|
||||
Ok(short)
|
||||
}
|
||||
|
||||
fn get_shortstatekey(
|
||||
&self,
|
||||
event_type: &StateEventType,
|
||||
state_key: &str,
|
||||
) -> Result<Option<u64>> {
|
||||
if let Some(short) = self
|
||||
.statekeyshort_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.get_mut(&(event_type.clone(), state_key.to_owned()))
|
||||
{
|
||||
return Ok(Some(*short));
|
||||
}
|
||||
fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result<Option<u64>> {
|
||||
if let Some(short) =
|
||||
self.statekeyshort_cache.lock().unwrap().get_mut(&(event_type.clone(), state_key.to_owned()))
|
||||
{
|
||||
return Ok(Some(*short));
|
||||
}
|
||||
|
||||
let mut statekey_vec = event_type.to_string().as_bytes().to_vec();
|
||||
statekey_vec.push(0xff);
|
||||
statekey_vec.extend_from_slice(state_key.as_bytes());
|
||||
let mut statekey_vec = event_type.to_string().as_bytes().to_vec();
|
||||
statekey_vec.push(0xFF);
|
||||
statekey_vec.extend_from_slice(state_key.as_bytes());
|
||||
|
||||
let short = self
|
||||
.statekey_shortstatekey
|
||||
.get(&statekey_vec)?
|
||||
.map(|shortstatekey| {
|
||||
utils::u64_from_bytes(&shortstatekey)
|
||||
.map_err(|_| Error::bad_database("Invalid shortstatekey in db."))
|
||||
})
|
||||
.transpose()?;
|
||||
let short = self
|
||||
.statekey_shortstatekey
|
||||
.get(&statekey_vec)?
|
||||
.map(|shortstatekey| {
|
||||
utils::u64_from_bytes(&shortstatekey).map_err(|_| Error::bad_database("Invalid shortstatekey in db."))
|
||||
})
|
||||
.transpose()?;
|
||||
|
||||
if let Some(s) = short {
|
||||
self.statekeyshort_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.insert((event_type.clone(), state_key.to_owned()), s);
|
||||
}
|
||||
if let Some(s) = short {
|
||||
self.statekeyshort_cache.lock().unwrap().insert((event_type.clone(), state_key.to_owned()), s);
|
||||
}
|
||||
|
||||
Ok(short)
|
||||
}
|
||||
Ok(short)
|
||||
}
|
||||
|
||||
fn get_or_create_shortstatekey(
|
||||
&self,
|
||||
event_type: &StateEventType,
|
||||
state_key: &str,
|
||||
) -> Result<u64> {
|
||||
if let Some(short) = self
|
||||
.statekeyshort_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.get_mut(&(event_type.clone(), state_key.to_owned()))
|
||||
{
|
||||
return Ok(*short);
|
||||
}
|
||||
fn get_or_create_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result<u64> {
|
||||
if let Some(short) =
|
||||
self.statekeyshort_cache.lock().unwrap().get_mut(&(event_type.clone(), state_key.to_owned()))
|
||||
{
|
||||
return Ok(*short);
|
||||
}
|
||||
|
||||
let mut statekey_vec = event_type.to_string().as_bytes().to_vec();
|
||||
statekey_vec.push(0xff);
|
||||
statekey_vec.extend_from_slice(state_key.as_bytes());
|
||||
let mut statekey_vec = event_type.to_string().as_bytes().to_vec();
|
||||
statekey_vec.push(0xFF);
|
||||
statekey_vec.extend_from_slice(state_key.as_bytes());
|
||||
|
||||
let short = match self.statekey_shortstatekey.get(&statekey_vec)? {
|
||||
Some(shortstatekey) => utils::u64_from_bytes(&shortstatekey)
|
||||
.map_err(|_| Error::bad_database("Invalid shortstatekey in db."))?,
|
||||
None => {
|
||||
let shortstatekey = services().globals.next_count()?;
|
||||
self.statekey_shortstatekey
|
||||
.insert(&statekey_vec, &shortstatekey.to_be_bytes())?;
|
||||
self.shortstatekey_statekey
|
||||
.insert(&shortstatekey.to_be_bytes(), &statekey_vec)?;
|
||||
shortstatekey
|
||||
}
|
||||
};
|
||||
let short = match self.statekey_shortstatekey.get(&statekey_vec)? {
|
||||
Some(shortstatekey) => utils::u64_from_bytes(&shortstatekey)
|
||||
.map_err(|_| Error::bad_database("Invalid shortstatekey in db."))?,
|
||||
None => {
|
||||
let shortstatekey = services().globals.next_count()?;
|
||||
self.statekey_shortstatekey.insert(&statekey_vec, &shortstatekey.to_be_bytes())?;
|
||||
self.shortstatekey_statekey.insert(&shortstatekey.to_be_bytes(), &statekey_vec)?;
|
||||
shortstatekey
|
||||
},
|
||||
};
|
||||
|
||||
self.statekeyshort_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.insert((event_type.clone(), state_key.to_owned()), short);
|
||||
self.statekeyshort_cache.lock().unwrap().insert((event_type.clone(), state_key.to_owned()), short);
|
||||
|
||||
Ok(short)
|
||||
}
|
||||
Ok(short)
|
||||
}
|
||||
|
||||
fn get_eventid_from_short(&self, shorteventid: u64) -> Result<Arc<EventId>> {
|
||||
if let Some(id) = self
|
||||
.shorteventid_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.get_mut(&shorteventid)
|
||||
{
|
||||
return Ok(Arc::clone(id));
|
||||
}
|
||||
fn get_eventid_from_short(&self, shorteventid: u64) -> Result<Arc<EventId>> {
|
||||
if let Some(id) = self.shorteventid_cache.lock().unwrap().get_mut(&shorteventid) {
|
||||
return Ok(Arc::clone(id));
|
||||
}
|
||||
|
||||
let bytes = self
|
||||
.shorteventid_eventid
|
||||
.get(&shorteventid.to_be_bytes())?
|
||||
.ok_or_else(|| Error::bad_database("Shorteventid does not exist"))?;
|
||||
let bytes = self
|
||||
.shorteventid_eventid
|
||||
.get(&shorteventid.to_be_bytes())?
|
||||
.ok_or_else(|| Error::bad_database("Shorteventid does not exist"))?;
|
||||
|
||||
let event_id = EventId::parse_arc(utils::string_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("EventID in shorteventid_eventid is invalid unicode.")
|
||||
})?)
|
||||
.map_err(|_| Error::bad_database("EventId in shorteventid_eventid is invalid."))?;
|
||||
let event_id = EventId::parse_arc(
|
||||
utils::string_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("EventID in shorteventid_eventid is invalid unicode."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("EventId in shorteventid_eventid is invalid."))?;
|
||||
|
||||
self.shorteventid_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.insert(shorteventid, Arc::clone(&event_id));
|
||||
self.shorteventid_cache.lock().unwrap().insert(shorteventid, Arc::clone(&event_id));
|
||||
|
||||
Ok(event_id)
|
||||
}
|
||||
Ok(event_id)
|
||||
}
|
||||
|
||||
fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> {
|
||||
if let Some(id) = self
|
||||
.shortstatekey_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.get_mut(&shortstatekey)
|
||||
{
|
||||
return Ok(id.clone());
|
||||
}
|
||||
fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> {
|
||||
if let Some(id) = self.shortstatekey_cache.lock().unwrap().get_mut(&shortstatekey) {
|
||||
return Ok(id.clone());
|
||||
}
|
||||
|
||||
let bytes = self
|
||||
.shortstatekey_statekey
|
||||
.get(&shortstatekey.to_be_bytes())?
|
||||
.ok_or_else(|| Error::bad_database("Shortstatekey does not exist"))?;
|
||||
let bytes = self
|
||||
.shortstatekey_statekey
|
||||
.get(&shortstatekey.to_be_bytes())?
|
||||
.ok_or_else(|| Error::bad_database("Shortstatekey does not exist"))?;
|
||||
|
||||
let mut parts = bytes.splitn(2, |&b| b == 0xff);
|
||||
let eventtype_bytes = parts.next().expect("split always returns one entry");
|
||||
let statekey_bytes = parts
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("Invalid statekey in shortstatekey_statekey."))?;
|
||||
let mut parts = bytes.splitn(2, |&b| b == 0xFF);
|
||||
let eventtype_bytes = parts.next().expect("split always returns one entry");
|
||||
let statekey_bytes =
|
||||
parts.next().ok_or_else(|| Error::bad_database("Invalid statekey in shortstatekey_statekey."))?;
|
||||
|
||||
let event_type =
|
||||
StateEventType::from(utils::string_from_bytes(eventtype_bytes).map_err(|e| {
|
||||
warn!("Event type in shortstatekey_statekey is invalid: {}", e);
|
||||
Error::bad_database("Event type in shortstatekey_statekey is invalid.")
|
||||
})?);
|
||||
let event_type = StateEventType::from(utils::string_from_bytes(eventtype_bytes).map_err(|e| {
|
||||
warn!("Event type in shortstatekey_statekey is invalid: {}", e);
|
||||
Error::bad_database("Event type in shortstatekey_statekey is invalid.")
|
||||
})?);
|
||||
|
||||
let state_key = utils::string_from_bytes(statekey_bytes).map_err(|_| {
|
||||
Error::bad_database("Statekey in shortstatekey_statekey is invalid unicode.")
|
||||
})?;
|
||||
let state_key = utils::string_from_bytes(statekey_bytes)
|
||||
.map_err(|_| Error::bad_database("Statekey in shortstatekey_statekey is invalid unicode."))?;
|
||||
|
||||
let result = (event_type, state_key);
|
||||
let result = (event_type, state_key);
|
||||
|
||||
self.shortstatekey_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.insert(shortstatekey, result.clone());
|
||||
self.shortstatekey_cache.lock().unwrap().insert(shortstatekey, result.clone());
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Returns (shortstatehash, already_existed)
|
||||
fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)> {
|
||||
Ok(match self.statehash_shortstatehash.get(state_hash)? {
|
||||
Some(shortstatehash) => (
|
||||
utils::u64_from_bytes(&shortstatehash)
|
||||
.map_err(|_| Error::bad_database("Invalid shortstatehash in db."))?,
|
||||
true,
|
||||
),
|
||||
None => {
|
||||
let shortstatehash = services().globals.next_count()?;
|
||||
self.statehash_shortstatehash
|
||||
.insert(state_hash, &shortstatehash.to_be_bytes())?;
|
||||
(shortstatehash, false)
|
||||
}
|
||||
})
|
||||
}
|
||||
/// Returns (shortstatehash, already_existed)
|
||||
fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)> {
|
||||
Ok(match self.statehash_shortstatehash.get(state_hash)? {
|
||||
Some(shortstatehash) => (
|
||||
utils::u64_from_bytes(&shortstatehash)
|
||||
.map_err(|_| Error::bad_database("Invalid shortstatehash in db."))?,
|
||||
true,
|
||||
),
|
||||
None => {
|
||||
let shortstatehash = services().globals.next_count()?;
|
||||
self.statehash_shortstatehash.insert(state_hash, &shortstatehash.to_be_bytes())?;
|
||||
(shortstatehash, false)
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
fn get_shortroomid(&self, room_id: &RoomId) -> Result<Option<u64>> {
|
||||
self.roomid_shortroomid
|
||||
.get(room_id.as_bytes())?
|
||||
.map(|bytes| {
|
||||
utils::u64_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Invalid shortroomid in db."))
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
fn get_shortroomid(&self, room_id: &RoomId) -> Result<Option<u64>> {
|
||||
self.roomid_shortroomid
|
||||
.get(room_id.as_bytes())?
|
||||
.map(|bytes| utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid shortroomid in db.")))
|
||||
.transpose()
|
||||
}
|
||||
|
||||
fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result<u64> {
|
||||
Ok(match self.roomid_shortroomid.get(room_id.as_bytes())? {
|
||||
Some(short) => utils::u64_from_bytes(&short)
|
||||
.map_err(|_| Error::bad_database("Invalid shortroomid in db."))?,
|
||||
None => {
|
||||
let short = services().globals.next_count()?;
|
||||
self.roomid_shortroomid
|
||||
.insert(room_id.as_bytes(), &short.to_be_bytes())?;
|
||||
short
|
||||
}
|
||||
})
|
||||
}
|
||||
fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result<u64> {
|
||||
Ok(match self.roomid_shortroomid.get(room_id.as_bytes())? {
|
||||
Some(short) => {
|
||||
utils::u64_from_bytes(&short).map_err(|_| Error::bad_database("Invalid shortroomid in db."))?
|
||||
},
|
||||
None => {
|
||||
let short = services().globals.next_count()?;
|
||||
self.roomid_shortroomid.insert(room_id.as_bytes(), &short.to_be_bytes())?;
|
||||
short
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,73 +1,69 @@
|
||||
use ruma::{EventId, OwnedEventId, RoomId};
|
||||
use std::collections::HashSet;
|
||||
use std::{collections::HashSet, sync::Arc};
|
||||
|
||||
use std::sync::Arc;
|
||||
use ruma::{EventId, OwnedEventId, RoomId};
|
||||
use tokio::sync::MutexGuard;
|
||||
|
||||
use crate::{database::KeyValueDatabase, service, utils, Error, Result};
|
||||
|
||||
impl service::rooms::state::Data for KeyValueDatabase {
|
||||
fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result<Option<u64>> {
|
||||
self.roomid_shortstatehash
|
||||
.get(room_id.as_bytes())?
|
||||
.map_or(Ok(None), |bytes| {
|
||||
Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("Invalid shortstatehash in roomid_shortstatehash")
|
||||
})?))
|
||||
})
|
||||
}
|
||||
fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result<Option<u64>> {
|
||||
self.roomid_shortstatehash.get(room_id.as_bytes())?.map_or(Ok(None), |bytes| {
|
||||
Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("Invalid shortstatehash in roomid_shortstatehash")
|
||||
})?))
|
||||
})
|
||||
}
|
||||
|
||||
fn set_room_state(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
new_shortstatehash: u64,
|
||||
_mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex
|
||||
) -> Result<()> {
|
||||
self.roomid_shortstatehash
|
||||
.insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes())?;
|
||||
Ok(())
|
||||
}
|
||||
fn set_room_state(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
new_shortstatehash: u64,
|
||||
_mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex
|
||||
) -> Result<()> {
|
||||
self.roomid_shortstatehash.insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes())?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn set_event_state(&self, shorteventid: u64, shortstatehash: u64) -> Result<()> {
|
||||
self.shorteventid_shortstatehash
|
||||
.insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?;
|
||||
Ok(())
|
||||
}
|
||||
fn set_event_state(&self, shorteventid: u64, shortstatehash: u64) -> Result<()> {
|
||||
self.shorteventid_shortstatehash.insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_forward_extremities(&self, room_id: &RoomId) -> Result<HashSet<Arc<EventId>>> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
fn get_forward_extremities(&self, room_id: &RoomId) -> Result<HashSet<Arc<EventId>>> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xFF);
|
||||
|
||||
self.roomid_pduleaves
|
||||
.scan_prefix(prefix)
|
||||
.map(|(_, bytes)| {
|
||||
EventId::parse_arc(utils::string_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("EventID in roomid_pduleaves is invalid unicode.")
|
||||
})?)
|
||||
.map_err(|_| Error::bad_database("EventId in roomid_pduleaves is invalid."))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
self.roomid_pduleaves
|
||||
.scan_prefix(prefix)
|
||||
.map(|(_, bytes)| {
|
||||
EventId::parse_arc(
|
||||
utils::string_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("EventID in roomid_pduleaves is invalid unicode."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("EventId in roomid_pduleaves is invalid."))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn set_forward_extremities(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
event_ids: Vec<OwnedEventId>,
|
||||
_mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex
|
||||
) -> Result<()> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
fn set_forward_extremities(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
event_ids: Vec<OwnedEventId>,
|
||||
_mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex
|
||||
) -> Result<()> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xFF);
|
||||
|
||||
for (key, _) in self.roomid_pduleaves.scan_prefix(prefix.clone()) {
|
||||
self.roomid_pduleaves.remove(&key)?;
|
||||
}
|
||||
for (key, _) in self.roomid_pduleaves.scan_prefix(prefix.clone()) {
|
||||
self.roomid_pduleaves.remove(&key)?;
|
||||
}
|
||||
|
||||
for event_id in event_ids {
|
||||
let mut key = prefix.clone();
|
||||
key.extend_from_slice(event_id.as_bytes());
|
||||
self.roomid_pduleaves.insert(&key, event_id.as_bytes())?;
|
||||
}
|
||||
for event_id in event_ids {
|
||||
let mut key = prefix.clone();
|
||||
key.extend_from_slice(event_id.as_bytes());
|
||||
self.roomid_pduleaves.insert(&key, event_id.as_bytes())?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,186 +1,144 @@
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use crate::{database::KeyValueDatabase, service, services, utils, Error, PduEvent, Result};
|
||||
use async_trait::async_trait;
|
||||
use ruma::{events::StateEventType, EventId, RoomId};
|
||||
|
||||
use crate::{database::KeyValueDatabase, service, services, utils, Error, PduEvent, Result};
|
||||
|
||||
#[async_trait]
|
||||
impl service::rooms::state_accessor::Data for KeyValueDatabase {
|
||||
async fn state_full_ids(&self, shortstatehash: u64) -> Result<HashMap<u64, Arc<EventId>>> {
|
||||
let full_state = services()
|
||||
.rooms
|
||||
.state_compressor
|
||||
.load_shortstatehash_info(shortstatehash)?
|
||||
.pop()
|
||||
.expect("there is always one layer")
|
||||
.1;
|
||||
let mut result = HashMap::new();
|
||||
let mut i = 0;
|
||||
for compressed in full_state.iter() {
|
||||
let parsed = services()
|
||||
.rooms
|
||||
.state_compressor
|
||||
.parse_compressed_state_event(compressed)?;
|
||||
result.insert(parsed.0, parsed.1);
|
||||
async fn state_full_ids(&self, shortstatehash: u64) -> Result<HashMap<u64, Arc<EventId>>> {
|
||||
let full_state = services()
|
||||
.rooms
|
||||
.state_compressor
|
||||
.load_shortstatehash_info(shortstatehash)?
|
||||
.pop()
|
||||
.expect("there is always one layer")
|
||||
.1;
|
||||
let mut result = HashMap::new();
|
||||
let mut i = 0;
|
||||
for compressed in full_state.iter() {
|
||||
let parsed = services().rooms.state_compressor.parse_compressed_state_event(compressed)?;
|
||||
result.insert(parsed.0, parsed.1);
|
||||
|
||||
i += 1;
|
||||
if i % 100 == 0 {
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
i += 1;
|
||||
if i % 100 == 0 {
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
async fn state_full(
|
||||
&self,
|
||||
shortstatehash: u64,
|
||||
) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> {
|
||||
let full_state = services()
|
||||
.rooms
|
||||
.state_compressor
|
||||
.load_shortstatehash_info(shortstatehash)?
|
||||
.pop()
|
||||
.expect("there is always one layer")
|
||||
.1;
|
||||
async fn state_full(&self, shortstatehash: u64) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> {
|
||||
let full_state = services()
|
||||
.rooms
|
||||
.state_compressor
|
||||
.load_shortstatehash_info(shortstatehash)?
|
||||
.pop()
|
||||
.expect("there is always one layer")
|
||||
.1;
|
||||
|
||||
let mut result = HashMap::new();
|
||||
let mut i = 0;
|
||||
for compressed in full_state.iter() {
|
||||
let (_, eventid) = services()
|
||||
.rooms
|
||||
.state_compressor
|
||||
.parse_compressed_state_event(compressed)?;
|
||||
if let Some(pdu) = services().rooms.timeline.get_pdu(&eventid)? {
|
||||
result.insert(
|
||||
(
|
||||
pdu.kind.to_string().into(),
|
||||
pdu.state_key
|
||||
.as_ref()
|
||||
.ok_or_else(|| Error::bad_database("State event has no state key."))?
|
||||
.clone(),
|
||||
),
|
||||
pdu,
|
||||
);
|
||||
}
|
||||
let mut result = HashMap::new();
|
||||
let mut i = 0;
|
||||
for compressed in full_state.iter() {
|
||||
let (_, eventid) = services().rooms.state_compressor.parse_compressed_state_event(compressed)?;
|
||||
if let Some(pdu) = services().rooms.timeline.get_pdu(&eventid)? {
|
||||
result.insert(
|
||||
(
|
||||
pdu.kind.to_string().into(),
|
||||
pdu.state_key
|
||||
.as_ref()
|
||||
.ok_or_else(|| Error::bad_database("State event has no state key."))?
|
||||
.clone(),
|
||||
),
|
||||
pdu,
|
||||
);
|
||||
}
|
||||
|
||||
i += 1;
|
||||
if i % 100 == 0 {
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
}
|
||||
i += 1;
|
||||
if i % 100 == 0 {
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
|
||||
fn state_get_id(
|
||||
&self,
|
||||
shortstatehash: u64,
|
||||
event_type: &StateEventType,
|
||||
state_key: &str,
|
||||
) -> Result<Option<Arc<EventId>>> {
|
||||
let shortstatekey = match services()
|
||||
.rooms
|
||||
.short
|
||||
.get_shortstatekey(event_type, state_key)?
|
||||
{
|
||||
Some(s) => s,
|
||||
None => return Ok(None),
|
||||
};
|
||||
let full_state = services()
|
||||
.rooms
|
||||
.state_compressor
|
||||
.load_shortstatehash_info(shortstatehash)?
|
||||
.pop()
|
||||
.expect("there is always one layer")
|
||||
.1;
|
||||
Ok(full_state
|
||||
.iter()
|
||||
.find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes()))
|
||||
.and_then(|compressed| {
|
||||
services()
|
||||
.rooms
|
||||
.state_compressor
|
||||
.parse_compressed_state_event(compressed)
|
||||
.ok()
|
||||
.map(|(_, id)| id)
|
||||
}))
|
||||
}
|
||||
/// Returns a single PDU from `room_id` with key (`event_type`,
|
||||
/// `state_key`).
|
||||
fn state_get_id(
|
||||
&self, shortstatehash: u64, event_type: &StateEventType, state_key: &str,
|
||||
) -> Result<Option<Arc<EventId>>> {
|
||||
let shortstatekey = match services().rooms.short.get_shortstatekey(event_type, state_key)? {
|
||||
Some(s) => s,
|
||||
None => return Ok(None),
|
||||
};
|
||||
let full_state = services()
|
||||
.rooms
|
||||
.state_compressor
|
||||
.load_shortstatehash_info(shortstatehash)?
|
||||
.pop()
|
||||
.expect("there is always one layer")
|
||||
.1;
|
||||
Ok(
|
||||
full_state.iter().find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes())).and_then(|compressed| {
|
||||
services().rooms.state_compressor.parse_compressed_state_event(compressed).ok().map(|(_, id)| id)
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
|
||||
fn state_get(
|
||||
&self,
|
||||
shortstatehash: u64,
|
||||
event_type: &StateEventType,
|
||||
state_key: &str,
|
||||
) -> Result<Option<Arc<PduEvent>>> {
|
||||
self.state_get_id(shortstatehash, event_type, state_key)?
|
||||
.map_or(Ok(None), |event_id| {
|
||||
services().rooms.timeline.get_pdu(&event_id)
|
||||
})
|
||||
}
|
||||
/// Returns a single PDU from `room_id` with key (`event_type`,
|
||||
/// `state_key`).
|
||||
fn state_get(
|
||||
&self, shortstatehash: u64, event_type: &StateEventType, state_key: &str,
|
||||
) -> Result<Option<Arc<PduEvent>>> {
|
||||
self.state_get_id(shortstatehash, event_type, state_key)?
|
||||
.map_or(Ok(None), |event_id| services().rooms.timeline.get_pdu(&event_id))
|
||||
}
|
||||
|
||||
/// Returns the state hash for this pdu.
|
||||
fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<Option<u64>> {
|
||||
self.eventid_shorteventid
|
||||
.get(event_id.as_bytes())?
|
||||
.map_or(Ok(None), |shorteventid| {
|
||||
self.shorteventid_shortstatehash
|
||||
.get(&shorteventid)?
|
||||
.map(|bytes| {
|
||||
utils::u64_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database(
|
||||
"Invalid shortstatehash bytes in shorteventid_shortstatehash",
|
||||
)
|
||||
})
|
||||
})
|
||||
.transpose()
|
||||
})
|
||||
}
|
||||
/// Returns the state hash for this pdu.
|
||||
fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<Option<u64>> {
|
||||
self.eventid_shorteventid.get(event_id.as_bytes())?.map_or(Ok(None), |shorteventid| {
|
||||
self.shorteventid_shortstatehash
|
||||
.get(&shorteventid)?
|
||||
.map(|bytes| {
|
||||
utils::u64_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Invalid shortstatehash bytes in shorteventid_shortstatehash"))
|
||||
})
|
||||
.transpose()
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns the full room state.
|
||||
async fn room_state_full(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> {
|
||||
if let Some(current_shortstatehash) =
|
||||
services().rooms.state.get_room_shortstatehash(room_id)?
|
||||
{
|
||||
self.state_full(current_shortstatehash).await
|
||||
} else {
|
||||
Ok(HashMap::new())
|
||||
}
|
||||
}
|
||||
/// Returns the full room state.
|
||||
async fn room_state_full(&self, room_id: &RoomId) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> {
|
||||
if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? {
|
||||
self.state_full(current_shortstatehash).await
|
||||
} else {
|
||||
Ok(HashMap::new())
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
|
||||
fn room_state_get_id(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
event_type: &StateEventType,
|
||||
state_key: &str,
|
||||
) -> Result<Option<Arc<EventId>>> {
|
||||
if let Some(current_shortstatehash) =
|
||||
services().rooms.state.get_room_shortstatehash(room_id)?
|
||||
{
|
||||
self.state_get_id(current_shortstatehash, event_type, state_key)
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
/// Returns a single PDU from `room_id` with key (`event_type`,
|
||||
/// `state_key`).
|
||||
fn room_state_get_id(
|
||||
&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
|
||||
) -> Result<Option<Arc<EventId>>> {
|
||||
if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? {
|
||||
self.state_get_id(current_shortstatehash, event_type, state_key)
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
|
||||
fn room_state_get(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
event_type: &StateEventType,
|
||||
state_key: &str,
|
||||
) -> Result<Option<Arc<PduEvent>>> {
|
||||
if let Some(current_shortstatehash) =
|
||||
services().rooms.state.get_room_shortstatehash(room_id)?
|
||||
{
|
||||
self.state_get(current_shortstatehash, event_type, state_key)
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
/// Returns a single PDU from `room_id` with key (`event_type`,
|
||||
/// `state_key`).
|
||||
fn room_state_get(
|
||||
&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
|
||||
) -> Result<Option<Arc<PduEvent>>> {
|
||||
if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? {
|
||||
self.state_get(current_shortstatehash, event_type, state_key)
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,61 +1,63 @@
|
||||
use std::{collections::HashSet, mem::size_of, sync::Arc};
|
||||
|
||||
use crate::{
|
||||
database::KeyValueDatabase,
|
||||
service::{self, rooms::state_compressor::data::StateDiff},
|
||||
utils, Error, Result,
|
||||
database::KeyValueDatabase,
|
||||
service::{self, rooms::state_compressor::data::StateDiff},
|
||||
utils, Error, Result,
|
||||
};
|
||||
|
||||
impl service::rooms::state_compressor::Data for KeyValueDatabase {
|
||||
fn get_statediff(&self, shortstatehash: u64) -> Result<StateDiff> {
|
||||
let value = self
|
||||
.shortstatehash_statediff
|
||||
.get(&shortstatehash.to_be_bytes())?
|
||||
.ok_or_else(|| Error::bad_database("State hash does not exist"))?;
|
||||
let parent =
|
||||
utils::u64_from_bytes(&value[0..size_of::<u64>()]).expect("bytes have right length");
|
||||
let parent = if parent != 0 { Some(parent) } else { None };
|
||||
fn get_statediff(&self, shortstatehash: u64) -> Result<StateDiff> {
|
||||
let value = self
|
||||
.shortstatehash_statediff
|
||||
.get(&shortstatehash.to_be_bytes())?
|
||||
.ok_or_else(|| Error::bad_database("State hash does not exist"))?;
|
||||
let parent = utils::u64_from_bytes(&value[0..size_of::<u64>()]).expect("bytes have right length");
|
||||
let parent = if parent != 0 {
|
||||
Some(parent)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let mut add_mode = true;
|
||||
let mut added = HashSet::new();
|
||||
let mut removed = HashSet::new();
|
||||
let mut add_mode = true;
|
||||
let mut added = HashSet::new();
|
||||
let mut removed = HashSet::new();
|
||||
|
||||
let mut i = size_of::<u64>();
|
||||
while let Some(v) = value.get(i..i + 2 * size_of::<u64>()) {
|
||||
if add_mode && v.starts_with(&0_u64.to_be_bytes()) {
|
||||
add_mode = false;
|
||||
i += size_of::<u64>();
|
||||
continue;
|
||||
}
|
||||
if add_mode {
|
||||
added.insert(v.try_into().expect("we checked the size above"));
|
||||
} else {
|
||||
removed.insert(v.try_into().expect("we checked the size above"));
|
||||
}
|
||||
i += 2 * size_of::<u64>();
|
||||
}
|
||||
let mut i = size_of::<u64>();
|
||||
while let Some(v) = value.get(i..i + 2 * size_of::<u64>()) {
|
||||
if add_mode && v.starts_with(&0_u64.to_be_bytes()) {
|
||||
add_mode = false;
|
||||
i += size_of::<u64>();
|
||||
continue;
|
||||
}
|
||||
if add_mode {
|
||||
added.insert(v.try_into().expect("we checked the size above"));
|
||||
} else {
|
||||
removed.insert(v.try_into().expect("we checked the size above"));
|
||||
}
|
||||
i += 2 * size_of::<u64>();
|
||||
}
|
||||
|
||||
Ok(StateDiff {
|
||||
parent,
|
||||
added: Arc::new(added),
|
||||
removed: Arc::new(removed),
|
||||
})
|
||||
}
|
||||
Ok(StateDiff {
|
||||
parent,
|
||||
added: Arc::new(added),
|
||||
removed: Arc::new(removed),
|
||||
})
|
||||
}
|
||||
|
||||
fn save_statediff(&self, shortstatehash: u64, diff: StateDiff) -> Result<()> {
|
||||
let mut value = diff.parent.unwrap_or(0).to_be_bytes().to_vec();
|
||||
for new in diff.added.iter() {
|
||||
value.extend_from_slice(&new[..]);
|
||||
}
|
||||
fn save_statediff(&self, shortstatehash: u64, diff: StateDiff) -> Result<()> {
|
||||
let mut value = diff.parent.unwrap_or(0).to_be_bytes().to_vec();
|
||||
for new in diff.added.iter() {
|
||||
value.extend_from_slice(&new[..]);
|
||||
}
|
||||
|
||||
if !diff.removed.is_empty() {
|
||||
value.extend_from_slice(&0_u64.to_be_bytes());
|
||||
for removed in diff.removed.iter() {
|
||||
value.extend_from_slice(&removed[..]);
|
||||
}
|
||||
}
|
||||
if !diff.removed.is_empty() {
|
||||
value.extend_from_slice(&0_u64.to_be_bytes());
|
||||
for removed in diff.removed.iter() {
|
||||
value.extend_from_slice(&removed[..]);
|
||||
}
|
||||
}
|
||||
|
||||
self.shortstatehash_statediff
|
||||
.insert(&shortstatehash.to_be_bytes(), &value)
|
||||
}
|
||||
self.shortstatehash_statediff.insert(&shortstatehash.to_be_bytes(), &value)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,74 +7,58 @@ use crate::{database::KeyValueDatabase, service, services, utils, Error, PduEven
|
||||
type PduEventIterResult<'a> = Result<Box<dyn Iterator<Item = Result<(u64, PduEvent)>> + 'a>>;
|
||||
|
||||
impl service::rooms::threads::Data for KeyValueDatabase {
|
||||
fn threads_until<'a>(
|
||||
&'a self,
|
||||
user_id: &'a UserId,
|
||||
room_id: &'a RoomId,
|
||||
until: u64,
|
||||
_include: &'a IncludeThreads,
|
||||
) -> PduEventIterResult<'a> {
|
||||
let prefix = services()
|
||||
.rooms
|
||||
.short
|
||||
.get_shortroomid(room_id)?
|
||||
.expect("room exists")
|
||||
.to_be_bytes()
|
||||
.to_vec();
|
||||
fn threads_until<'a>(
|
||||
&'a self, user_id: &'a UserId, room_id: &'a RoomId, until: u64, _include: &'a IncludeThreads,
|
||||
) -> PduEventIterResult<'a> {
|
||||
let prefix = services().rooms.short.get_shortroomid(room_id)?.expect("room exists").to_be_bytes().to_vec();
|
||||
|
||||
let mut current = prefix.clone();
|
||||
current.extend_from_slice(&(until - 1).to_be_bytes());
|
||||
let mut current = prefix.clone();
|
||||
current.extend_from_slice(&(until - 1).to_be_bytes());
|
||||
|
||||
Ok(Box::new(
|
||||
self.threadid_userids
|
||||
.iter_from(¤t, true)
|
||||
.take_while(move |(k, _)| k.starts_with(&prefix))
|
||||
.map(move |(pduid, _users)| {
|
||||
let count = utils::u64_from_bytes(&pduid[(mem::size_of::<u64>())..])
|
||||
.map_err(|_| Error::bad_database("Invalid pduid in threadid_userids."))?;
|
||||
let mut pdu = services()
|
||||
.rooms
|
||||
.timeline
|
||||
.get_pdu_from_id(&pduid)?
|
||||
.ok_or_else(|| {
|
||||
Error::bad_database("Invalid pduid reference in threadid_userids")
|
||||
})?;
|
||||
if pdu.sender != user_id {
|
||||
pdu.remove_transaction_id()?;
|
||||
}
|
||||
Ok((count, pdu))
|
||||
}),
|
||||
))
|
||||
}
|
||||
Ok(Box::new(
|
||||
self.threadid_userids.iter_from(¤t, true).take_while(move |(k, _)| k.starts_with(&prefix)).map(
|
||||
move |(pduid, _users)| {
|
||||
let count = utils::u64_from_bytes(&pduid[(mem::size_of::<u64>())..])
|
||||
.map_err(|_| Error::bad_database("Invalid pduid in threadid_userids."))?;
|
||||
let mut pdu = services()
|
||||
.rooms
|
||||
.timeline
|
||||
.get_pdu_from_id(&pduid)?
|
||||
.ok_or_else(|| Error::bad_database("Invalid pduid reference in threadid_userids"))?;
|
||||
if pdu.sender != user_id {
|
||||
pdu.remove_transaction_id()?;
|
||||
}
|
||||
Ok((count, pdu))
|
||||
},
|
||||
),
|
||||
))
|
||||
}
|
||||
|
||||
fn update_participants(&self, root_id: &[u8], participants: &[OwnedUserId]) -> Result<()> {
|
||||
let users = participants
|
||||
.iter()
|
||||
.map(|user| user.as_bytes())
|
||||
.collect::<Vec<_>>()
|
||||
.join(&[0xff][..]);
|
||||
fn update_participants(&self, root_id: &[u8], participants: &[OwnedUserId]) -> Result<()> {
|
||||
let users = participants.iter().map(|user| user.as_bytes()).collect::<Vec<_>>().join(&[0xFF][..]);
|
||||
|
||||
self.threadid_userids.insert(root_id, &users)?;
|
||||
self.threadid_userids.insert(root_id, &users)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_participants(&self, root_id: &[u8]) -> Result<Option<Vec<OwnedUserId>>> {
|
||||
if let Some(users) = self.threadid_userids.get(root_id)? {
|
||||
Ok(Some(
|
||||
users
|
||||
.split(|b| *b == 0xff)
|
||||
.map(|bytes| {
|
||||
UserId::parse(utils::string_from_bytes(bytes).map_err(|_| {
|
||||
Error::bad_database("Invalid UserId bytes in threadid_userids.")
|
||||
})?)
|
||||
.map_err(|_| Error::bad_database("Invalid UserId in threadid_userids."))
|
||||
})
|
||||
.filter_map(std::result::Result::ok)
|
||||
.collect(),
|
||||
))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
fn get_participants(&self, root_id: &[u8]) -> Result<Option<Vec<OwnedUserId>>> {
|
||||
if let Some(users) = self.threadid_userids.get(root_id)? {
|
||||
Ok(Some(
|
||||
users
|
||||
.split(|b| *b == 0xFF)
|
||||
.map(|bytes| {
|
||||
UserId::parse(
|
||||
utils::string_from_bytes(bytes)
|
||||
.map_err(|_| Error::bad_database("Invalid UserId bytes in threadid_userids."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Invalid UserId in threadid_userids."))
|
||||
})
|
||||
.filter_map(std::result::Result::ok)
|
||||
.collect(),
|
||||
))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,364 +1,286 @@
|
||||
use std::{collections::hash_map, mem::size_of, sync::Arc};
|
||||
|
||||
use ruma::{
|
||||
api::client::error::ErrorKind, CanonicalJsonObject, EventId, OwnedUserId, RoomId, UserId,
|
||||
};
|
||||
use ruma::{api::client::error::ErrorKind, CanonicalJsonObject, EventId, OwnedUserId, RoomId, UserId};
|
||||
use service::rooms::timeline::PduCount;
|
||||
use tracing::error;
|
||||
|
||||
use crate::{database::KeyValueDatabase, service, services, utils, Error, PduEvent, Result};
|
||||
|
||||
use service::rooms::timeline::PduCount;
|
||||
|
||||
impl service::rooms::timeline::Data for KeyValueDatabase {
|
||||
fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result<PduCount> {
|
||||
match self
|
||||
.lasttimelinecount_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.entry(room_id.to_owned())
|
||||
{
|
||||
hash_map::Entry::Vacant(v) => {
|
||||
if let Some(last_count) = self
|
||||
.pdus_until(sender_user, room_id, PduCount::max())?
|
||||
.find_map(|r| {
|
||||
// Filter out buggy events
|
||||
if r.is_err() {
|
||||
error!("Bad pdu in pdus_since: {:?}", r);
|
||||
}
|
||||
r.ok()
|
||||
})
|
||||
{
|
||||
Ok(*v.insert(last_count.0))
|
||||
} else {
|
||||
Ok(PduCount::Normal(0))
|
||||
}
|
||||
}
|
||||
hash_map::Entry::Occupied(o) => Ok(*o.get()),
|
||||
}
|
||||
}
|
||||
fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result<PduCount> {
|
||||
match self.lasttimelinecount_cache.lock().unwrap().entry(room_id.to_owned()) {
|
||||
hash_map::Entry::Vacant(v) => {
|
||||
if let Some(last_count) = self.pdus_until(sender_user, room_id, PduCount::max())?.find_map(|r| {
|
||||
// Filter out buggy events
|
||||
if r.is_err() {
|
||||
error!("Bad pdu in pdus_since: {:?}", r);
|
||||
}
|
||||
r.ok()
|
||||
}) {
|
||||
Ok(*v.insert(last_count.0))
|
||||
} else {
|
||||
Ok(PduCount::Normal(0))
|
||||
}
|
||||
},
|
||||
hash_map::Entry::Occupied(o) => Ok(*o.get()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the `count` of this pdu's id.
|
||||
fn get_pdu_count(&self, event_id: &EventId) -> Result<Option<PduCount>> {
|
||||
self.eventid_pduid
|
||||
.get(event_id.as_bytes())?
|
||||
.map(|pdu_id| pdu_count(&pdu_id))
|
||||
.transpose()
|
||||
}
|
||||
/// Returns the `count` of this pdu's id.
|
||||
fn get_pdu_count(&self, event_id: &EventId) -> Result<Option<PduCount>> {
|
||||
self.eventid_pduid.get(event_id.as_bytes())?.map(|pdu_id| pdu_count(&pdu_id)).transpose()
|
||||
}
|
||||
|
||||
/// Returns the json of a pdu.
|
||||
fn get_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
|
||||
self.get_non_outlier_pdu_json(event_id)?.map_or_else(
|
||||
|| {
|
||||
self.eventid_outlierpdu
|
||||
.get(event_id.as_bytes())?
|
||||
.map(|pdu| {
|
||||
serde_json::from_slice(&pdu)
|
||||
.map_err(|_| Error::bad_database("Invalid PDU in db."))
|
||||
})
|
||||
.transpose()
|
||||
},
|
||||
|x| Ok(Some(x)),
|
||||
)
|
||||
}
|
||||
/// Returns the json of a pdu.
|
||||
fn get_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
|
||||
self.get_non_outlier_pdu_json(event_id)?.map_or_else(
|
||||
|| {
|
||||
self.eventid_outlierpdu
|
||||
.get(event_id.as_bytes())?
|
||||
.map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")))
|
||||
.transpose()
|
||||
},
|
||||
|x| Ok(Some(x)),
|
||||
)
|
||||
}
|
||||
|
||||
/// Returns the json of a pdu.
|
||||
fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
|
||||
self.eventid_pduid
|
||||
.get(event_id.as_bytes())?
|
||||
.map(|pduid| {
|
||||
self.pduid_pdu
|
||||
.get(&pduid)?
|
||||
.ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid."))
|
||||
})
|
||||
.transpose()?
|
||||
.map(|pdu| {
|
||||
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
/// Returns the json of a pdu.
|
||||
fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
|
||||
self.eventid_pduid
|
||||
.get(event_id.as_bytes())?
|
||||
.map(|pduid| {
|
||||
self.pduid_pdu.get(&pduid)?.ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid."))
|
||||
})
|
||||
.transpose()?
|
||||
.map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")))
|
||||
.transpose()
|
||||
}
|
||||
|
||||
/// Returns the pdu's id.
|
||||
fn get_pdu_id(&self, event_id: &EventId) -> Result<Option<Vec<u8>>> {
|
||||
self.eventid_pduid.get(event_id.as_bytes())
|
||||
}
|
||||
/// Returns the pdu's id.
|
||||
fn get_pdu_id(&self, event_id: &EventId) -> Result<Option<Vec<u8>>> { self.eventid_pduid.get(event_id.as_bytes()) }
|
||||
|
||||
/// Returns the pdu.
|
||||
fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>> {
|
||||
self.eventid_pduid
|
||||
.get(event_id.as_bytes())?
|
||||
.map(|pduid| {
|
||||
self.pduid_pdu
|
||||
.get(&pduid)?
|
||||
.ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid."))
|
||||
})
|
||||
.transpose()?
|
||||
.map(|pdu| {
|
||||
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
/// Returns the pdu.
|
||||
fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>> {
|
||||
self.eventid_pduid
|
||||
.get(event_id.as_bytes())?
|
||||
.map(|pduid| {
|
||||
self.pduid_pdu.get(&pduid)?.ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid."))
|
||||
})
|
||||
.transpose()?
|
||||
.map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")))
|
||||
.transpose()
|
||||
}
|
||||
|
||||
/// Returns the pdu.
|
||||
///
|
||||
/// Checks the `eventid_outlierpdu` Tree if not found in the timeline.
|
||||
fn get_pdu(&self, event_id: &EventId) -> Result<Option<Arc<PduEvent>>> {
|
||||
if let Some(p) = self.pdu_cache.lock().unwrap().get_mut(event_id) {
|
||||
return Ok(Some(Arc::clone(p)));
|
||||
}
|
||||
/// Returns the pdu.
|
||||
///
|
||||
/// Checks the `eventid_outlierpdu` Tree if not found in the timeline.
|
||||
fn get_pdu(&self, event_id: &EventId) -> Result<Option<Arc<PduEvent>>> {
|
||||
if let Some(p) = self.pdu_cache.lock().unwrap().get_mut(event_id) {
|
||||
return Ok(Some(Arc::clone(p)));
|
||||
}
|
||||
|
||||
if let Some(pdu) = self
|
||||
.get_non_outlier_pdu(event_id)?
|
||||
.map_or_else(
|
||||
|| {
|
||||
self.eventid_outlierpdu
|
||||
.get(event_id.as_bytes())?
|
||||
.map(|pdu| {
|
||||
serde_json::from_slice(&pdu)
|
||||
.map_err(|_| Error::bad_database("Invalid PDU in db."))
|
||||
})
|
||||
.transpose()
|
||||
},
|
||||
|x| Ok(Some(x)),
|
||||
)?
|
||||
.map(Arc::new)
|
||||
{
|
||||
self.pdu_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.insert(event_id.to_owned(), Arc::clone(&pdu));
|
||||
Ok(Some(pdu))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
if let Some(pdu) = self
|
||||
.get_non_outlier_pdu(event_id)?
|
||||
.map_or_else(
|
||||
|| {
|
||||
self.eventid_outlierpdu
|
||||
.get(event_id.as_bytes())?
|
||||
.map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")))
|
||||
.transpose()
|
||||
},
|
||||
|x| Ok(Some(x)),
|
||||
)?
|
||||
.map(Arc::new)
|
||||
{
|
||||
self.pdu_cache.lock().unwrap().insert(event_id.to_owned(), Arc::clone(&pdu));
|
||||
Ok(Some(pdu))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the pdu.
|
||||
///
|
||||
/// This does __NOT__ check the outliers `Tree`.
|
||||
fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result<Option<PduEvent>> {
|
||||
self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| {
|
||||
Ok(Some(
|
||||
serde_json::from_slice(&pdu)
|
||||
.map_err(|_| Error::bad_database("Invalid PDU in db."))?,
|
||||
))
|
||||
})
|
||||
}
|
||||
/// Returns the pdu.
|
||||
///
|
||||
/// This does __NOT__ check the outliers `Tree`.
|
||||
fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result<Option<PduEvent>> {
|
||||
self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| {
|
||||
Ok(Some(
|
||||
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))?,
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns the pdu as a `BTreeMap<String, CanonicalJsonValue>`.
|
||||
fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result<Option<CanonicalJsonObject>> {
|
||||
self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| {
|
||||
Ok(Some(
|
||||
serde_json::from_slice(&pdu)
|
||||
.map_err(|_| Error::bad_database("Invalid PDU in db."))?,
|
||||
))
|
||||
})
|
||||
}
|
||||
/// Returns the pdu as a `BTreeMap<String, CanonicalJsonValue>`.
|
||||
fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result<Option<CanonicalJsonObject>> {
|
||||
self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| {
|
||||
Ok(Some(
|
||||
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))?,
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
fn append_pdu(
|
||||
&self,
|
||||
pdu_id: &[u8],
|
||||
pdu: &PduEvent,
|
||||
json: &CanonicalJsonObject,
|
||||
count: u64,
|
||||
) -> Result<()> {
|
||||
self.pduid_pdu.insert(
|
||||
pdu_id,
|
||||
&serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"),
|
||||
)?;
|
||||
fn append_pdu(&self, pdu_id: &[u8], pdu: &PduEvent, json: &CanonicalJsonObject, count: u64) -> Result<()> {
|
||||
self.pduid_pdu.insert(
|
||||
pdu_id,
|
||||
&serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"),
|
||||
)?;
|
||||
|
||||
self.lasttimelinecount_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.insert(pdu.room_id.clone(), PduCount::Normal(count));
|
||||
self.lasttimelinecount_cache.lock().unwrap().insert(pdu.room_id.clone(), PduCount::Normal(count));
|
||||
|
||||
self.eventid_pduid.insert(pdu.event_id.as_bytes(), pdu_id)?;
|
||||
self.eventid_outlierpdu.remove(pdu.event_id.as_bytes())?;
|
||||
self.eventid_pduid.insert(pdu.event_id.as_bytes(), pdu_id)?;
|
||||
self.eventid_outlierpdu.remove(pdu.event_id.as_bytes())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn prepend_backfill_pdu(
|
||||
&self,
|
||||
pdu_id: &[u8],
|
||||
event_id: &EventId,
|
||||
json: &CanonicalJsonObject,
|
||||
) -> Result<()> {
|
||||
self.pduid_pdu.insert(
|
||||
pdu_id,
|
||||
&serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"),
|
||||
)?;
|
||||
fn prepend_backfill_pdu(&self, pdu_id: &[u8], event_id: &EventId, json: &CanonicalJsonObject) -> Result<()> {
|
||||
self.pduid_pdu.insert(
|
||||
pdu_id,
|
||||
&serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"),
|
||||
)?;
|
||||
|
||||
self.eventid_pduid.insert(event_id.as_bytes(), pdu_id)?;
|
||||
self.eventid_outlierpdu.remove(event_id.as_bytes())?;
|
||||
self.eventid_pduid.insert(event_id.as_bytes(), pdu_id)?;
|
||||
self.eventid_outlierpdu.remove(event_id.as_bytes())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Removes a pdu and creates a new one with the same id.
|
||||
fn replace_pdu(
|
||||
&self,
|
||||
pdu_id: &[u8],
|
||||
pdu_json: &CanonicalJsonObject,
|
||||
pdu: &PduEvent,
|
||||
) -> Result<()> {
|
||||
if self.pduid_pdu.get(pdu_id)?.is_some() {
|
||||
self.pduid_pdu.insert(
|
||||
pdu_id,
|
||||
&serde_json::to_vec(pdu_json).expect("CanonicalJsonObject is always a valid"),
|
||||
)?;
|
||||
} else {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"PDU does not exist.",
|
||||
));
|
||||
}
|
||||
/// Removes a pdu and creates a new one with the same id.
|
||||
fn replace_pdu(&self, pdu_id: &[u8], pdu_json: &CanonicalJsonObject, pdu: &PduEvent) -> Result<()> {
|
||||
if self.pduid_pdu.get(pdu_id)?.is_some() {
|
||||
self.pduid_pdu.insert(
|
||||
pdu_id,
|
||||
&serde_json::to_vec(pdu_json).expect("CanonicalJsonObject is always a valid"),
|
||||
)?;
|
||||
} else {
|
||||
return Err(Error::BadRequest(ErrorKind::NotFound, "PDU does not exist."));
|
||||
}
|
||||
|
||||
self.pdu_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.remove(&(*pdu.event_id).to_owned());
|
||||
self.pdu_cache.lock().unwrap().remove(&(*pdu.event_id).to_owned());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns an iterator over all events and their tokens in a room that happened before the
|
||||
/// event with id `until` in reverse-chronological order.
|
||||
fn pdus_until<'a>(
|
||||
&'a self,
|
||||
user_id: &UserId,
|
||||
room_id: &RoomId,
|
||||
until: PduCount,
|
||||
) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>> {
|
||||
let (prefix, current) = count_to_id(room_id, until, 1, true)?;
|
||||
/// Returns an iterator over all events and their tokens in a room that
|
||||
/// happened before the event with id `until` in reverse-chronological
|
||||
/// order.
|
||||
fn pdus_until<'a>(
|
||||
&'a self, user_id: &UserId, room_id: &RoomId, until: PduCount,
|
||||
) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>> {
|
||||
let (prefix, current) = count_to_id(room_id, until, 1, true)?;
|
||||
|
||||
let user_id = user_id.to_owned();
|
||||
let user_id = user_id.to_owned();
|
||||
|
||||
Ok(Box::new(
|
||||
self.pduid_pdu
|
||||
.iter_from(¤t, true)
|
||||
.take_while(move |(k, _)| k.starts_with(&prefix))
|
||||
.map(move |(pdu_id, v)| {
|
||||
let mut pdu = serde_json::from_slice::<PduEvent>(&v)
|
||||
.map_err(|_| Error::bad_database("PDU in db is invalid."))?;
|
||||
if pdu.sender != user_id {
|
||||
pdu.remove_transaction_id()?;
|
||||
}
|
||||
pdu.add_age()?;
|
||||
let count = pdu_count(&pdu_id)?;
|
||||
Ok((count, pdu))
|
||||
}),
|
||||
))
|
||||
}
|
||||
Ok(Box::new(
|
||||
self.pduid_pdu.iter_from(¤t, true).take_while(move |(k, _)| k.starts_with(&prefix)).map(
|
||||
move |(pdu_id, v)| {
|
||||
let mut pdu = serde_json::from_slice::<PduEvent>(&v)
|
||||
.map_err(|_| Error::bad_database("PDU in db is invalid."))?;
|
||||
if pdu.sender != user_id {
|
||||
pdu.remove_transaction_id()?;
|
||||
}
|
||||
pdu.add_age()?;
|
||||
let count = pdu_count(&pdu_id)?;
|
||||
Ok((count, pdu))
|
||||
},
|
||||
),
|
||||
))
|
||||
}
|
||||
|
||||
fn pdus_after<'a>(
|
||||
&'a self,
|
||||
user_id: &UserId,
|
||||
room_id: &RoomId,
|
||||
from: PduCount,
|
||||
) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>> {
|
||||
let (prefix, current) = count_to_id(room_id, from, 1, false)?;
|
||||
fn pdus_after<'a>(
|
||||
&'a self, user_id: &UserId, room_id: &RoomId, from: PduCount,
|
||||
) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>> {
|
||||
let (prefix, current) = count_to_id(room_id, from, 1, false)?;
|
||||
|
||||
let user_id = user_id.to_owned();
|
||||
let user_id = user_id.to_owned();
|
||||
|
||||
Ok(Box::new(
|
||||
self.pduid_pdu
|
||||
.iter_from(¤t, false)
|
||||
.take_while(move |(k, _)| k.starts_with(&prefix))
|
||||
.map(move |(pdu_id, v)| {
|
||||
let mut pdu = serde_json::from_slice::<PduEvent>(&v)
|
||||
.map_err(|_| Error::bad_database("PDU in db is invalid."))?;
|
||||
if pdu.sender != user_id {
|
||||
pdu.remove_transaction_id()?;
|
||||
}
|
||||
pdu.add_age()?;
|
||||
let count = pdu_count(&pdu_id)?;
|
||||
Ok((count, pdu))
|
||||
}),
|
||||
))
|
||||
}
|
||||
Ok(Box::new(
|
||||
self.pduid_pdu.iter_from(¤t, false).take_while(move |(k, _)| k.starts_with(&prefix)).map(
|
||||
move |(pdu_id, v)| {
|
||||
let mut pdu = serde_json::from_slice::<PduEvent>(&v)
|
||||
.map_err(|_| Error::bad_database("PDU in db is invalid."))?;
|
||||
if pdu.sender != user_id {
|
||||
pdu.remove_transaction_id()?;
|
||||
}
|
||||
pdu.add_age()?;
|
||||
let count = pdu_count(&pdu_id)?;
|
||||
Ok((count, pdu))
|
||||
},
|
||||
),
|
||||
))
|
||||
}
|
||||
|
||||
fn increment_notification_counts(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
notifies: Vec<OwnedUserId>,
|
||||
highlights: Vec<OwnedUserId>,
|
||||
) -> Result<()> {
|
||||
let mut notifies_batch = Vec::new();
|
||||
let mut highlights_batch = Vec::new();
|
||||
for user in notifies {
|
||||
let mut userroom_id = user.as_bytes().to_vec();
|
||||
userroom_id.push(0xff);
|
||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||
notifies_batch.push(userroom_id);
|
||||
}
|
||||
for user in highlights {
|
||||
let mut userroom_id = user.as_bytes().to_vec();
|
||||
userroom_id.push(0xff);
|
||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||
highlights_batch.push(userroom_id);
|
||||
}
|
||||
fn increment_notification_counts(
|
||||
&self, room_id: &RoomId, notifies: Vec<OwnedUserId>, highlights: Vec<OwnedUserId>,
|
||||
) -> Result<()> {
|
||||
let mut notifies_batch = Vec::new();
|
||||
let mut highlights_batch = Vec::new();
|
||||
for user in notifies {
|
||||
let mut userroom_id = user.as_bytes().to_vec();
|
||||
userroom_id.push(0xFF);
|
||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||
notifies_batch.push(userroom_id);
|
||||
}
|
||||
for user in highlights {
|
||||
let mut userroom_id = user.as_bytes().to_vec();
|
||||
userroom_id.push(0xFF);
|
||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||
highlights_batch.push(userroom_id);
|
||||
}
|
||||
|
||||
self.userroomid_notificationcount
|
||||
.increment_batch(&mut notifies_batch.into_iter())?;
|
||||
self.userroomid_highlightcount
|
||||
.increment_batch(&mut highlights_batch.into_iter())?;
|
||||
Ok(())
|
||||
}
|
||||
self.userroomid_notificationcount.increment_batch(&mut notifies_batch.into_iter())?;
|
||||
self.userroomid_highlightcount.increment_batch(&mut highlights_batch.into_iter())?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the `count` of this pdu's id.
|
||||
fn pdu_count(pdu_id: &[u8]) -> Result<PduCount> {
|
||||
let last_u64 = utils::u64_from_bytes(&pdu_id[pdu_id.len() - size_of::<u64>()..])
|
||||
.map_err(|_| Error::bad_database("PDU has invalid count bytes."))?;
|
||||
let second_last_u64 = utils::u64_from_bytes(
|
||||
&pdu_id[pdu_id.len() - 2 * size_of::<u64>()..pdu_id.len() - size_of::<u64>()],
|
||||
);
|
||||
let last_u64 = utils::u64_from_bytes(&pdu_id[pdu_id.len() - size_of::<u64>()..])
|
||||
.map_err(|_| Error::bad_database("PDU has invalid count bytes."))?;
|
||||
let second_last_u64 =
|
||||
utils::u64_from_bytes(&pdu_id[pdu_id.len() - 2 * size_of::<u64>()..pdu_id.len() - size_of::<u64>()]);
|
||||
|
||||
if matches!(second_last_u64, Ok(0)) {
|
||||
Ok(PduCount::Backfilled(u64::MAX - last_u64))
|
||||
} else {
|
||||
Ok(PduCount::Normal(last_u64))
|
||||
}
|
||||
if matches!(second_last_u64, Ok(0)) {
|
||||
Ok(PduCount::Backfilled(u64::MAX - last_u64))
|
||||
} else {
|
||||
Ok(PduCount::Normal(last_u64))
|
||||
}
|
||||
}
|
||||
|
||||
fn count_to_id(
|
||||
room_id: &RoomId,
|
||||
count: PduCount,
|
||||
offset: u64,
|
||||
subtract: bool,
|
||||
) -> Result<(Vec<u8>, Vec<u8>)> {
|
||||
let prefix = services()
|
||||
.rooms
|
||||
.short
|
||||
.get_shortroomid(room_id)?
|
||||
.ok_or_else(|| Error::bad_database("Looked for bad shortroomid in timeline"))?
|
||||
.to_be_bytes()
|
||||
.to_vec();
|
||||
let mut pdu_id = prefix.clone();
|
||||
// +1 so we don't send the base event
|
||||
let count_raw = match count {
|
||||
PduCount::Normal(x) => {
|
||||
if subtract {
|
||||
x - offset
|
||||
} else {
|
||||
x + offset
|
||||
}
|
||||
}
|
||||
PduCount::Backfilled(x) => {
|
||||
pdu_id.extend_from_slice(&0_u64.to_be_bytes());
|
||||
let num = u64::MAX - x;
|
||||
if subtract {
|
||||
if num > 0 {
|
||||
num - offset
|
||||
} else {
|
||||
num
|
||||
}
|
||||
} else {
|
||||
num + offset
|
||||
}
|
||||
}
|
||||
};
|
||||
pdu_id.extend_from_slice(&count_raw.to_be_bytes());
|
||||
fn count_to_id(room_id: &RoomId, count: PduCount, offset: u64, subtract: bool) -> Result<(Vec<u8>, Vec<u8>)> {
|
||||
let prefix = services()
|
||||
.rooms
|
||||
.short
|
||||
.get_shortroomid(room_id)?
|
||||
.ok_or_else(|| Error::bad_database("Looked for bad shortroomid in timeline"))?
|
||||
.to_be_bytes()
|
||||
.to_vec();
|
||||
let mut pdu_id = prefix.clone();
|
||||
// +1 so we don't send the base event
|
||||
let count_raw = match count {
|
||||
PduCount::Normal(x) => {
|
||||
if subtract {
|
||||
x - offset
|
||||
} else {
|
||||
x + offset
|
||||
}
|
||||
},
|
||||
PduCount::Backfilled(x) => {
|
||||
pdu_id.extend_from_slice(&0_u64.to_be_bytes());
|
||||
let num = u64::MAX - x;
|
||||
if subtract {
|
||||
if num > 0 {
|
||||
num - offset
|
||||
} else {
|
||||
num
|
||||
}
|
||||
} else {
|
||||
num + offset
|
||||
}
|
||||
},
|
||||
};
|
||||
pdu_id.extend_from_slice(&count_raw.to_be_bytes());
|
||||
|
||||
Ok((prefix, pdu_id))
|
||||
Ok((prefix, pdu_id))
|
||||
}
|
||||
|
||||
@@ -3,147 +3,122 @@ use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId};
|
||||
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
|
||||
|
||||
impl service::rooms::user::Data for KeyValueDatabase {
|
||||
fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
|
||||
let mut userroom_id = user_id.as_bytes().to_vec();
|
||||
userroom_id.push(0xff);
|
||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||
let mut roomuser_id = room_id.as_bytes().to_vec();
|
||||
roomuser_id.push(0xff);
|
||||
roomuser_id.extend_from_slice(user_id.as_bytes());
|
||||
fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
|
||||
let mut userroom_id = user_id.as_bytes().to_vec();
|
||||
userroom_id.push(0xFF);
|
||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||
let mut roomuser_id = room_id.as_bytes().to_vec();
|
||||
roomuser_id.push(0xFF);
|
||||
roomuser_id.extend_from_slice(user_id.as_bytes());
|
||||
|
||||
self.userroomid_notificationcount
|
||||
.insert(&userroom_id, &0_u64.to_be_bytes())?;
|
||||
self.userroomid_highlightcount
|
||||
.insert(&userroom_id, &0_u64.to_be_bytes())?;
|
||||
self.userroomid_notificationcount.insert(&userroom_id, &0_u64.to_be_bytes())?;
|
||||
self.userroomid_highlightcount.insert(&userroom_id, &0_u64.to_be_bytes())?;
|
||||
|
||||
self.roomuserid_lastnotificationread.insert(
|
||||
&roomuser_id,
|
||||
&services().globals.next_count()?.to_be_bytes(),
|
||||
)?;
|
||||
self.roomuserid_lastnotificationread.insert(&roomuser_id, &services().globals.next_count()?.to_be_bytes())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
|
||||
let mut userroom_id = user_id.as_bytes().to_vec();
|
||||
userroom_id.push(0xff);
|
||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||
fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
|
||||
let mut userroom_id = user_id.as_bytes().to_vec();
|
||||
userroom_id.push(0xFF);
|
||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||
|
||||
self.userroomid_notificationcount
|
||||
.get(&userroom_id)?
|
||||
.map(|bytes| {
|
||||
utils::u64_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Invalid notification count in db."))
|
||||
})
|
||||
.unwrap_or(Ok(0))
|
||||
}
|
||||
self.userroomid_notificationcount
|
||||
.get(&userroom_id)?
|
||||
.map(|bytes| {
|
||||
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid notification count in db."))
|
||||
})
|
||||
.unwrap_or(Ok(0))
|
||||
}
|
||||
|
||||
fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
|
||||
let mut userroom_id = user_id.as_bytes().to_vec();
|
||||
userroom_id.push(0xff);
|
||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||
fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
|
||||
let mut userroom_id = user_id.as_bytes().to_vec();
|
||||
userroom_id.push(0xFF);
|
||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||
|
||||
self.userroomid_highlightcount
|
||||
.get(&userroom_id)?
|
||||
.map(|bytes| {
|
||||
utils::u64_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Invalid highlight count in db."))
|
||||
})
|
||||
.unwrap_or(Ok(0))
|
||||
}
|
||||
self.userroomid_highlightcount
|
||||
.get(&userroom_id)?
|
||||
.map(|bytes| {
|
||||
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid highlight count in db."))
|
||||
})
|
||||
.unwrap_or(Ok(0))
|
||||
}
|
||||
|
||||
fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
|
||||
let mut key = room_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(user_id.as_bytes());
|
||||
fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
|
||||
let mut key = room_id.as_bytes().to_vec();
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(user_id.as_bytes());
|
||||
|
||||
Ok(self
|
||||
.roomuserid_lastnotificationread
|
||||
.get(&key)?
|
||||
.map(|bytes| {
|
||||
utils::u64_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid.")
|
||||
})
|
||||
})
|
||||
.transpose()?
|
||||
.unwrap_or(0))
|
||||
}
|
||||
Ok(self
|
||||
.roomuserid_lastnotificationread
|
||||
.get(&key)?
|
||||
.map(|bytes| {
|
||||
utils::u64_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid."))
|
||||
})
|
||||
.transpose()?
|
||||
.unwrap_or(0))
|
||||
}
|
||||
|
||||
fn associate_token_shortstatehash(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
token: u64,
|
||||
shortstatehash: u64,
|
||||
) -> Result<()> {
|
||||
let shortroomid = services()
|
||||
.rooms
|
||||
.short
|
||||
.get_shortroomid(room_id)?
|
||||
.expect("room exists");
|
||||
fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64, shortstatehash: u64) -> Result<()> {
|
||||
let shortroomid = services().rooms.short.get_shortroomid(room_id)?.expect("room exists");
|
||||
|
||||
let mut key = shortroomid.to_be_bytes().to_vec();
|
||||
key.extend_from_slice(&token.to_be_bytes());
|
||||
let mut key = shortroomid.to_be_bytes().to_vec();
|
||||
key.extend_from_slice(&token.to_be_bytes());
|
||||
|
||||
self.roomsynctoken_shortstatehash
|
||||
.insert(&key, &shortstatehash.to_be_bytes())
|
||||
}
|
||||
self.roomsynctoken_shortstatehash.insert(&key, &shortstatehash.to_be_bytes())
|
||||
}
|
||||
|
||||
fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<Option<u64>> {
|
||||
let shortroomid = services()
|
||||
.rooms
|
||||
.short
|
||||
.get_shortroomid(room_id)?
|
||||
.expect("room exists");
|
||||
fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<Option<u64>> {
|
||||
let shortroomid = services().rooms.short.get_shortroomid(room_id)?.expect("room exists");
|
||||
|
||||
let mut key = shortroomid.to_be_bytes().to_vec();
|
||||
key.extend_from_slice(&token.to_be_bytes());
|
||||
let mut key = shortroomid.to_be_bytes().to_vec();
|
||||
key.extend_from_slice(&token.to_be_bytes());
|
||||
|
||||
self.roomsynctoken_shortstatehash
|
||||
.get(&key)?
|
||||
.map(|bytes| {
|
||||
utils::u64_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("Invalid shortstatehash in roomsynctoken_shortstatehash")
|
||||
})
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
self.roomsynctoken_shortstatehash
|
||||
.get(&key)?
|
||||
.map(|bytes| {
|
||||
utils::u64_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Invalid shortstatehash in roomsynctoken_shortstatehash"))
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
|
||||
fn get_shared_rooms<'a>(
|
||||
&'a self,
|
||||
users: Vec<OwnedUserId>,
|
||||
) -> Result<Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a>> {
|
||||
let iterators = users.into_iter().map(move |user_id| {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
fn get_shared_rooms<'a>(
|
||||
&'a self, users: Vec<OwnedUserId>,
|
||||
) -> Result<Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a>> {
|
||||
let iterators = users.into_iter().map(move |user_id| {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xFF);
|
||||
|
||||
self.userroomid_joined
|
||||
.scan_prefix(prefix)
|
||||
.map(|(key, _)| {
|
||||
let roomid_index = key
|
||||
.iter()
|
||||
.enumerate()
|
||||
.find(|(_, &b)| b == 0xff)
|
||||
.ok_or_else(|| Error::bad_database("Invalid userroomid_joined in db."))?
|
||||
.0
|
||||
+ 1; // +1 because the room id starts AFTER the separator
|
||||
self.userroomid_joined
|
||||
.scan_prefix(prefix)
|
||||
.map(|(key, _)| {
|
||||
let roomid_index = key
|
||||
.iter()
|
||||
.enumerate()
|
||||
.find(|(_, &b)| b == 0xFF)
|
||||
.ok_or_else(|| Error::bad_database("Invalid userroomid_joined in db."))?
|
||||
.0 + 1; // +1 because the room id starts AFTER the separator
|
||||
|
||||
let room_id = key[roomid_index..].to_vec();
|
||||
let room_id = key[roomid_index..].to_vec();
|
||||
|
||||
Ok::<_, Error>(room_id)
|
||||
})
|
||||
.filter_map(std::result::Result::ok)
|
||||
});
|
||||
Ok::<_, Error>(room_id)
|
||||
})
|
||||
.filter_map(std::result::Result::ok)
|
||||
});
|
||||
|
||||
// We use the default compare function because keys are sorted correctly (not reversed)
|
||||
Ok(Box::new(
|
||||
utils::common_elements(iterators, Ord::cmp)
|
||||
.expect("users is not empty")
|
||||
.map(|bytes| {
|
||||
RoomId::parse(utils::string_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("Invalid RoomId bytes in userroomid_joined")
|
||||
})?)
|
||||
.map_err(|_| Error::bad_database("Invalid RoomId in userroomid_joined."))
|
||||
}),
|
||||
))
|
||||
}
|
||||
// We use the default compare function because keys are sorted correctly (not
|
||||
// reversed)
|
||||
Ok(Box::new(
|
||||
utils::common_elements(iterators, Ord::cmp).expect("users is not empty").map(|bytes| {
|
||||
RoomId::parse(
|
||||
utils::string_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Invalid RoomId bytes in userroomid_joined"))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Invalid RoomId in userroomid_joined."))
|
||||
}),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
+148
-172
@@ -1,205 +1,181 @@
|
||||
use ruma::{ServerName, UserId};
|
||||
|
||||
use crate::{
|
||||
database::KeyValueDatabase,
|
||||
service::{
|
||||
self,
|
||||
sending::{OutgoingKind, SendingEventType},
|
||||
},
|
||||
services, utils, Error, Result,
|
||||
database::KeyValueDatabase,
|
||||
service::{
|
||||
self,
|
||||
sending::{OutgoingKind, SendingEventType},
|
||||
},
|
||||
services, utils, Error, Result,
|
||||
};
|
||||
|
||||
impl service::sending::Data for KeyValueDatabase {
|
||||
fn active_requests<'a>(
|
||||
&'a self,
|
||||
) -> Box<dyn Iterator<Item = Result<(Vec<u8>, OutgoingKind, SendingEventType)>> + 'a> {
|
||||
Box::new(
|
||||
self.servercurrentevent_data
|
||||
.iter()
|
||||
.map(|(key, v)| parse_servercurrentevent(&key, v).map(|(k, e)| (key, k, e))),
|
||||
)
|
||||
}
|
||||
fn active_requests<'a>(
|
||||
&'a self,
|
||||
) -> Box<dyn Iterator<Item = Result<(Vec<u8>, OutgoingKind, SendingEventType)>> + 'a> {
|
||||
Box::new(
|
||||
self.servercurrentevent_data
|
||||
.iter()
|
||||
.map(|(key, v)| parse_servercurrentevent(&key, v).map(|(k, e)| (key, k, e))),
|
||||
)
|
||||
}
|
||||
|
||||
fn active_requests_for<'a>(
|
||||
&'a self,
|
||||
outgoing_kind: &OutgoingKind,
|
||||
) -> Box<dyn Iterator<Item = Result<(Vec<u8>, SendingEventType)>> + 'a> {
|
||||
let prefix = outgoing_kind.get_prefix();
|
||||
Box::new(
|
||||
self.servercurrentevent_data
|
||||
.scan_prefix(prefix)
|
||||
.map(|(key, v)| parse_servercurrentevent(&key, v).map(|(_, e)| (key, e))),
|
||||
)
|
||||
}
|
||||
fn active_requests_for<'a>(
|
||||
&'a self, outgoing_kind: &OutgoingKind,
|
||||
) -> Box<dyn Iterator<Item = Result<(Vec<u8>, SendingEventType)>> + 'a> {
|
||||
let prefix = outgoing_kind.get_prefix();
|
||||
Box::new(
|
||||
self.servercurrentevent_data
|
||||
.scan_prefix(prefix)
|
||||
.map(|(key, v)| parse_servercurrentevent(&key, v).map(|(_, e)| (key, e))),
|
||||
)
|
||||
}
|
||||
|
||||
fn delete_active_request(&self, key: Vec<u8>) -> Result<()> {
|
||||
self.servercurrentevent_data.remove(&key)
|
||||
}
|
||||
fn delete_active_request(&self, key: Vec<u8>) -> Result<()> { self.servercurrentevent_data.remove(&key) }
|
||||
|
||||
fn delete_all_active_requests_for(&self, outgoing_kind: &OutgoingKind) -> Result<()> {
|
||||
let prefix = outgoing_kind.get_prefix();
|
||||
for (key, _) in self.servercurrentevent_data.scan_prefix(prefix) {
|
||||
self.servercurrentevent_data.remove(&key)?;
|
||||
}
|
||||
fn delete_all_active_requests_for(&self, outgoing_kind: &OutgoingKind) -> Result<()> {
|
||||
let prefix = outgoing_kind.get_prefix();
|
||||
for (key, _) in self.servercurrentevent_data.scan_prefix(prefix) {
|
||||
self.servercurrentevent_data.remove(&key)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn delete_all_requests_for(&self, outgoing_kind: &OutgoingKind) -> Result<()> {
|
||||
let prefix = outgoing_kind.get_prefix();
|
||||
for (key, _) in self.servercurrentevent_data.scan_prefix(prefix.clone()) {
|
||||
self.servercurrentevent_data.remove(&key).unwrap();
|
||||
}
|
||||
fn delete_all_requests_for(&self, outgoing_kind: &OutgoingKind) -> Result<()> {
|
||||
let prefix = outgoing_kind.get_prefix();
|
||||
for (key, _) in self.servercurrentevent_data.scan_prefix(prefix.clone()) {
|
||||
self.servercurrentevent_data.remove(&key).unwrap();
|
||||
}
|
||||
|
||||
for (key, _) in self.servernameevent_data.scan_prefix(prefix) {
|
||||
self.servernameevent_data.remove(&key).unwrap();
|
||||
}
|
||||
for (key, _) in self.servernameevent_data.scan_prefix(prefix) {
|
||||
self.servernameevent_data.remove(&key).unwrap();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn queue_requests(
|
||||
&self,
|
||||
requests: &[(&OutgoingKind, SendingEventType)],
|
||||
) -> Result<Vec<Vec<u8>>> {
|
||||
let mut batch = Vec::new();
|
||||
let mut keys = Vec::new();
|
||||
for (outgoing_kind, event) in requests {
|
||||
let mut key = outgoing_kind.get_prefix();
|
||||
if let SendingEventType::Pdu(value) = &event {
|
||||
key.extend_from_slice(value);
|
||||
} else {
|
||||
key.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
|
||||
}
|
||||
let value = if let SendingEventType::Edu(value) = &event {
|
||||
&**value
|
||||
} else {
|
||||
&[]
|
||||
};
|
||||
batch.push((key.clone(), value.to_owned()));
|
||||
keys.push(key);
|
||||
}
|
||||
self.servernameevent_data
|
||||
.insert_batch(&mut batch.into_iter())?;
|
||||
Ok(keys)
|
||||
}
|
||||
fn queue_requests(&self, requests: &[(&OutgoingKind, SendingEventType)]) -> Result<Vec<Vec<u8>>> {
|
||||
let mut batch = Vec::new();
|
||||
let mut keys = Vec::new();
|
||||
for (outgoing_kind, event) in requests {
|
||||
let mut key = outgoing_kind.get_prefix();
|
||||
if let SendingEventType::Pdu(value) = &event {
|
||||
key.extend_from_slice(value);
|
||||
} else {
|
||||
key.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
|
||||
}
|
||||
let value = if let SendingEventType::Edu(value) = &event {
|
||||
&**value
|
||||
} else {
|
||||
&[]
|
||||
};
|
||||
batch.push((key.clone(), value.to_owned()));
|
||||
keys.push(key);
|
||||
}
|
||||
self.servernameevent_data.insert_batch(&mut batch.into_iter())?;
|
||||
Ok(keys)
|
||||
}
|
||||
|
||||
fn queued_requests<'a>(
|
||||
&'a self,
|
||||
outgoing_kind: &OutgoingKind,
|
||||
) -> Box<dyn Iterator<Item = Result<(SendingEventType, Vec<u8>)>> + 'a> {
|
||||
let prefix = outgoing_kind.get_prefix();
|
||||
return Box::new(
|
||||
self.servernameevent_data
|
||||
.scan_prefix(prefix)
|
||||
.map(|(k, v)| parse_servercurrentevent(&k, v).map(|(_, ev)| (ev, k))),
|
||||
);
|
||||
}
|
||||
fn queued_requests<'a>(
|
||||
&'a self, outgoing_kind: &OutgoingKind,
|
||||
) -> Box<dyn Iterator<Item = Result<(SendingEventType, Vec<u8>)>> + 'a> {
|
||||
let prefix = outgoing_kind.get_prefix();
|
||||
return Box::new(
|
||||
self.servernameevent_data
|
||||
.scan_prefix(prefix)
|
||||
.map(|(k, v)| parse_servercurrentevent(&k, v).map(|(_, ev)| (ev, k))),
|
||||
);
|
||||
}
|
||||
|
||||
fn mark_as_active(&self, events: &[(SendingEventType, Vec<u8>)]) -> Result<()> {
|
||||
for (e, key) in events {
|
||||
let value = if let SendingEventType::Edu(value) = &e {
|
||||
&**value
|
||||
} else {
|
||||
&[]
|
||||
};
|
||||
self.servercurrentevent_data.insert(key, value)?;
|
||||
self.servernameevent_data.remove(key)?;
|
||||
}
|
||||
fn mark_as_active(&self, events: &[(SendingEventType, Vec<u8>)]) -> Result<()> {
|
||||
for (e, key) in events {
|
||||
let value = if let SendingEventType::Edu(value) = &e {
|
||||
&**value
|
||||
} else {
|
||||
&[]
|
||||
};
|
||||
self.servercurrentevent_data.insert(key, value)?;
|
||||
self.servernameevent_data.remove(key)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn set_latest_educount(&self, server_name: &ServerName, last_count: u64) -> Result<()> {
|
||||
self.servername_educount
|
||||
.insert(server_name.as_bytes(), &last_count.to_be_bytes())
|
||||
}
|
||||
fn set_latest_educount(&self, server_name: &ServerName, last_count: u64) -> Result<()> {
|
||||
self.servername_educount.insert(server_name.as_bytes(), &last_count.to_be_bytes())
|
||||
}
|
||||
|
||||
fn get_latest_educount(&self, server_name: &ServerName) -> Result<u64> {
|
||||
self.servername_educount
|
||||
.get(server_name.as_bytes())?
|
||||
.map_or(Ok(0), |bytes| {
|
||||
utils::u64_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Invalid u64 in servername_educount."))
|
||||
})
|
||||
}
|
||||
fn get_latest_educount(&self, server_name: &ServerName) -> Result<u64> {
|
||||
self.servername_educount.get(server_name.as_bytes())?.map_or(Ok(0), |bytes| {
|
||||
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid u64 in servername_educount."))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(key))]
|
||||
fn parse_servercurrentevent(
|
||||
key: &[u8],
|
||||
value: Vec<u8>,
|
||||
) -> Result<(OutgoingKind, SendingEventType)> {
|
||||
// Appservices start with a plus
|
||||
Ok::<_, Error>(if key.starts_with(b"+") {
|
||||
let mut parts = key[1..].splitn(2, |&b| b == 0xff);
|
||||
fn parse_servercurrentevent(key: &[u8], value: Vec<u8>) -> Result<(OutgoingKind, SendingEventType)> {
|
||||
// Appservices start with a plus
|
||||
Ok::<_, Error>(if key.starts_with(b"+") {
|
||||
let mut parts = key[1..].splitn(2, |&b| b == 0xFF);
|
||||
|
||||
let server = parts.next().expect("splitn always returns one element");
|
||||
let event = parts
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
|
||||
let server = parts.next().expect("splitn always returns one element");
|
||||
let event = parts.next().ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
|
||||
|
||||
let server = utils::string_from_bytes(server).map_err(|_| {
|
||||
Error::bad_database("Invalid server bytes in server_currenttransaction")
|
||||
})?;
|
||||
let server = utils::string_from_bytes(server)
|
||||
.map_err(|_| Error::bad_database("Invalid server bytes in server_currenttransaction"))?;
|
||||
|
||||
(
|
||||
OutgoingKind::Appservice(server),
|
||||
if value.is_empty() {
|
||||
SendingEventType::Pdu(event.to_vec())
|
||||
} else {
|
||||
SendingEventType::Edu(value)
|
||||
},
|
||||
)
|
||||
} else if key.starts_with(b"$") {
|
||||
let mut parts = key[1..].splitn(3, |&b| b == 0xff);
|
||||
(
|
||||
OutgoingKind::Appservice(server),
|
||||
if value.is_empty() {
|
||||
SendingEventType::Pdu(event.to_vec())
|
||||
} else {
|
||||
SendingEventType::Edu(value)
|
||||
},
|
||||
)
|
||||
} else if key.starts_with(b"$") {
|
||||
let mut parts = key[1..].splitn(3, |&b| b == 0xFF);
|
||||
|
||||
let user = parts.next().expect("splitn always returns one element");
|
||||
let user_string = utils::string_from_bytes(user)
|
||||
.map_err(|_| Error::bad_database("Invalid user string in servercurrentevent"))?;
|
||||
let user_id = UserId::parse(user_string)
|
||||
.map_err(|_| Error::bad_database("Invalid user id in servercurrentevent"))?;
|
||||
let user = parts.next().expect("splitn always returns one element");
|
||||
let user_string = utils::string_from_bytes(user)
|
||||
.map_err(|_| Error::bad_database("Invalid user string in servercurrentevent"))?;
|
||||
let user_id =
|
||||
UserId::parse(user_string).map_err(|_| Error::bad_database("Invalid user id in servercurrentevent"))?;
|
||||
|
||||
let pushkey = parts
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
|
||||
let pushkey_string = utils::string_from_bytes(pushkey)
|
||||
.map_err(|_| Error::bad_database("Invalid pushkey in servercurrentevent"))?;
|
||||
let pushkey = parts.next().ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
|
||||
let pushkey_string = utils::string_from_bytes(pushkey)
|
||||
.map_err(|_| Error::bad_database("Invalid pushkey in servercurrentevent"))?;
|
||||
|
||||
let event = parts
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
|
||||
let event = parts.next().ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
|
||||
|
||||
(
|
||||
OutgoingKind::Push(user_id, pushkey_string),
|
||||
if value.is_empty() {
|
||||
SendingEventType::Pdu(event.to_vec())
|
||||
} else {
|
||||
// I'm pretty sure this should never be called
|
||||
SendingEventType::Edu(value)
|
||||
},
|
||||
)
|
||||
} else {
|
||||
let mut parts = key.splitn(2, |&b| b == 0xff);
|
||||
(
|
||||
OutgoingKind::Push(user_id, pushkey_string),
|
||||
if value.is_empty() {
|
||||
SendingEventType::Pdu(event.to_vec())
|
||||
} else {
|
||||
// I'm pretty sure this should never be called
|
||||
SendingEventType::Edu(value)
|
||||
},
|
||||
)
|
||||
} else {
|
||||
let mut parts = key.splitn(2, |&b| b == 0xFF);
|
||||
|
||||
let server = parts.next().expect("splitn always returns one element");
|
||||
let event = parts
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
|
||||
let server = parts.next().expect("splitn always returns one element");
|
||||
let event = parts.next().ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
|
||||
|
||||
let server = utils::string_from_bytes(server).map_err(|_| {
|
||||
Error::bad_database("Invalid server bytes in server_currenttransaction")
|
||||
})?;
|
||||
let server = utils::string_from_bytes(server)
|
||||
.map_err(|_| Error::bad_database("Invalid server bytes in server_currenttransaction"))?;
|
||||
|
||||
(
|
||||
OutgoingKind::Normal(ServerName::parse(server).map_err(|_| {
|
||||
Error::bad_database("Invalid server string in server_currenttransaction")
|
||||
})?),
|
||||
if value.is_empty() {
|
||||
SendingEventType::Pdu(event.to_vec())
|
||||
} else {
|
||||
SendingEventType::Edu(value)
|
||||
},
|
||||
)
|
||||
})
|
||||
(
|
||||
OutgoingKind::Normal(
|
||||
ServerName::parse(server)
|
||||
.map_err(|_| Error::bad_database("Invalid server string in server_currenttransaction"))?,
|
||||
),
|
||||
if value.is_empty() {
|
||||
SendingEventType::Pdu(event.to_vec())
|
||||
} else {
|
||||
SendingEventType::Edu(value)
|
||||
},
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -3,37 +3,30 @@ use ruma::{DeviceId, TransactionId, UserId};
|
||||
use crate::{database::KeyValueDatabase, service, Result};
|
||||
|
||||
impl service::transaction_ids::Data for KeyValueDatabase {
|
||||
fn add_txnid(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: Option<&DeviceId>,
|
||||
txn_id: &TransactionId,
|
||||
data: &[u8],
|
||||
) -> Result<()> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default());
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(txn_id.as_bytes());
|
||||
fn add_txnid(
|
||||
&self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, data: &[u8],
|
||||
) -> Result<()> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default());
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(txn_id.as_bytes());
|
||||
|
||||
self.userdevicetxnid_response.insert(&key, data)?;
|
||||
self.userdevicetxnid_response.insert(&key, data)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn existing_txnid(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: Option<&DeviceId>,
|
||||
txn_id: &TransactionId,
|
||||
) -> Result<Option<Vec<u8>>> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default());
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(txn_id.as_bytes());
|
||||
fn existing_txnid(
|
||||
&self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId,
|
||||
) -> Result<Option<Vec<u8>>> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default());
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(txn_id.as_bytes());
|
||||
|
||||
// If there's no entry, this is a new transaction
|
||||
self.userdevicetxnid_response.get(&key)
|
||||
}
|
||||
// If there's no entry, this is a new transaction
|
||||
self.userdevicetxnid_response.get(&key)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,89 +1,64 @@
|
||||
use ruma::{
|
||||
api::client::{error::ErrorKind, uiaa::UiaaInfo},
|
||||
CanonicalJsonValue, DeviceId, UserId,
|
||||
api::client::{error::ErrorKind, uiaa::UiaaInfo},
|
||||
CanonicalJsonValue, DeviceId, UserId,
|
||||
};
|
||||
|
||||
use crate::{database::KeyValueDatabase, service, Error, Result};
|
||||
|
||||
impl service::uiaa::Data for KeyValueDatabase {
|
||||
fn set_uiaa_request(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
session: &str,
|
||||
request: &CanonicalJsonValue,
|
||||
) -> Result<()> {
|
||||
self.userdevicesessionid_uiaarequest
|
||||
.write()
|
||||
.unwrap()
|
||||
.insert(
|
||||
(user_id.to_owned(), device_id.to_owned(), session.to_owned()),
|
||||
request.to_owned(),
|
||||
);
|
||||
fn set_uiaa_request(
|
||||
&self, user_id: &UserId, device_id: &DeviceId, session: &str, request: &CanonicalJsonValue,
|
||||
) -> Result<()> {
|
||||
self.userdevicesessionid_uiaarequest.write().unwrap().insert(
|
||||
(user_id.to_owned(), device_id.to_owned(), session.to_owned()),
|
||||
request.to_owned(),
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_uiaa_request(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
session: &str,
|
||||
) -> Option<CanonicalJsonValue> {
|
||||
self.userdevicesessionid_uiaarequest
|
||||
.read()
|
||||
.unwrap()
|
||||
.get(&(user_id.to_owned(), device_id.to_owned(), session.to_owned()))
|
||||
.map(std::borrow::ToOwned::to_owned)
|
||||
}
|
||||
fn get_uiaa_request(&self, user_id: &UserId, device_id: &DeviceId, session: &str) -> Option<CanonicalJsonValue> {
|
||||
self.userdevicesessionid_uiaarequest
|
||||
.read()
|
||||
.unwrap()
|
||||
.get(&(user_id.to_owned(), device_id.to_owned(), session.to_owned()))
|
||||
.map(std::borrow::ToOwned::to_owned)
|
||||
}
|
||||
|
||||
fn update_uiaa_session(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
session: &str,
|
||||
uiaainfo: Option<&UiaaInfo>,
|
||||
) -> Result<()> {
|
||||
let mut userdevicesessionid = user_id.as_bytes().to_vec();
|
||||
userdevicesessionid.push(0xff);
|
||||
userdevicesessionid.extend_from_slice(device_id.as_bytes());
|
||||
userdevicesessionid.push(0xff);
|
||||
userdevicesessionid.extend_from_slice(session.as_bytes());
|
||||
fn update_uiaa_session(
|
||||
&self, user_id: &UserId, device_id: &DeviceId, session: &str, uiaainfo: Option<&UiaaInfo>,
|
||||
) -> Result<()> {
|
||||
let mut userdevicesessionid = user_id.as_bytes().to_vec();
|
||||
userdevicesessionid.push(0xFF);
|
||||
userdevicesessionid.extend_from_slice(device_id.as_bytes());
|
||||
userdevicesessionid.push(0xFF);
|
||||
userdevicesessionid.extend_from_slice(session.as_bytes());
|
||||
|
||||
if let Some(uiaainfo) = uiaainfo {
|
||||
self.userdevicesessionid_uiaainfo.insert(
|
||||
&userdevicesessionid,
|
||||
&serde_json::to_vec(&uiaainfo).expect("UiaaInfo::to_vec always works"),
|
||||
)?;
|
||||
} else {
|
||||
self.userdevicesessionid_uiaainfo
|
||||
.remove(&userdevicesessionid)?;
|
||||
}
|
||||
if let Some(uiaainfo) = uiaainfo {
|
||||
self.userdevicesessionid_uiaainfo.insert(
|
||||
&userdevicesessionid,
|
||||
&serde_json::to_vec(&uiaainfo).expect("UiaaInfo::to_vec always works"),
|
||||
)?;
|
||||
} else {
|
||||
self.userdevicesessionid_uiaainfo.remove(&userdevicesessionid)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_uiaa_session(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
session: &str,
|
||||
) -> Result<UiaaInfo> {
|
||||
let mut userdevicesessionid = user_id.as_bytes().to_vec();
|
||||
userdevicesessionid.push(0xff);
|
||||
userdevicesessionid.extend_from_slice(device_id.as_bytes());
|
||||
userdevicesessionid.push(0xff);
|
||||
userdevicesessionid.extend_from_slice(session.as_bytes());
|
||||
fn get_uiaa_session(&self, user_id: &UserId, device_id: &DeviceId, session: &str) -> Result<UiaaInfo> {
|
||||
let mut userdevicesessionid = user_id.as_bytes().to_vec();
|
||||
userdevicesessionid.push(0xFF);
|
||||
userdevicesessionid.extend_from_slice(device_id.as_bytes());
|
||||
userdevicesessionid.push(0xFF);
|
||||
userdevicesessionid.extend_from_slice(session.as_bytes());
|
||||
|
||||
serde_json::from_slice(
|
||||
&self
|
||||
.userdevicesessionid_uiaainfo
|
||||
.get(&userdevicesessionid)?
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"UIAA session does not exist.",
|
||||
))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("UiaaInfo in userdeviceid_uiaainfo is invalid."))
|
||||
}
|
||||
serde_json::from_slice(
|
||||
&self
|
||||
.userdevicesessionid_uiaainfo
|
||||
.get(&userdevicesessionid)?
|
||||
.ok_or(Error::BadRequest(ErrorKind::Forbidden, "UIAA session does not exist."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("UiaaInfo in userdeviceid_uiaainfo is invalid."))
|
||||
}
|
||||
}
|
||||
|
||||
+817
-979
File diff suppressed because it is too large
Load Diff
+1094
-1161
File diff suppressed because it is too large
Load Diff
+1
-4
@@ -15,8 +15,5 @@ pub use utils::error::{Error, Result};
|
||||
pub static SERVICES: RwLock<Option<&'static Services<'static>>> = RwLock::new(None);
|
||||
|
||||
pub fn services() -> &'static Services<'static> {
|
||||
SERVICES
|
||||
.read()
|
||||
.unwrap()
|
||||
.expect("SERVICES should be initialized when this is called")
|
||||
SERVICES.read().unwrap().expect("SERVICES should be initialized when this is called")
|
||||
}
|
||||
|
||||
+651
-665
File diff suppressed because it is too large
Load Diff
@@ -1,35 +1,28 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::Result;
|
||||
use ruma::{
|
||||
events::{AnyEphemeralRoomEvent, RoomAccountDataEventType},
|
||||
serde::Raw,
|
||||
RoomId, UserId,
|
||||
events::{AnyEphemeralRoomEvent, RoomAccountDataEventType},
|
||||
serde::Raw,
|
||||
RoomId, UserId,
|
||||
};
|
||||
|
||||
use crate::Result;
|
||||
|
||||
pub trait Data: Send + Sync {
|
||||
/// Places one event in the account data of the user and removes the previous entry.
|
||||
fn update(
|
||||
&self,
|
||||
room_id: Option<&RoomId>,
|
||||
user_id: &UserId,
|
||||
event_type: RoomAccountDataEventType,
|
||||
data: &serde_json::Value,
|
||||
) -> Result<()>;
|
||||
/// Places one event in the account data of the user and removes the
|
||||
/// previous entry.
|
||||
fn update(
|
||||
&self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType,
|
||||
data: &serde_json::Value,
|
||||
) -> Result<()>;
|
||||
|
||||
/// Searches the account data for a specific kind.
|
||||
fn get(
|
||||
&self,
|
||||
room_id: Option<&RoomId>,
|
||||
user_id: &UserId,
|
||||
kind: RoomAccountDataEventType,
|
||||
) -> Result<Option<Box<serde_json::value::RawValue>>>;
|
||||
/// Searches the account data for a specific kind.
|
||||
fn get(
|
||||
&self, room_id: Option<&RoomId>, user_id: &UserId, kind: RoomAccountDataEventType,
|
||||
) -> Result<Option<Box<serde_json::value::RawValue>>>;
|
||||
|
||||
/// Returns all changes to the account data that happened after `since`.
|
||||
fn changes_since(
|
||||
&self,
|
||||
room_id: Option<&RoomId>,
|
||||
user_id: &UserId,
|
||||
since: u64,
|
||||
) -> Result<HashMap<RoomAccountDataEventType, Raw<AnyEphemeralRoomEvent>>>;
|
||||
/// Returns all changes to the account data that happened after `since`.
|
||||
fn changes_since(
|
||||
&self, room_id: Option<&RoomId>, user_id: &UserId, since: u64,
|
||||
) -> Result<HashMap<RoomAccountDataEventType, Raw<AnyEphemeralRoomEvent>>>;
|
||||
}
|
||||
|
||||
@@ -1,53 +1,44 @@
|
||||
mod data;
|
||||
|
||||
pub(crate) use data::Data;
|
||||
|
||||
use ruma::{
|
||||
events::{AnyEphemeralRoomEvent, RoomAccountDataEventType},
|
||||
serde::Raw,
|
||||
RoomId, UserId,
|
||||
};
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
pub(crate) use data::Data;
|
||||
use ruma::{
|
||||
events::{AnyEphemeralRoomEvent, RoomAccountDataEventType},
|
||||
serde::Raw,
|
||||
RoomId, UserId,
|
||||
};
|
||||
|
||||
use crate::Result;
|
||||
|
||||
pub struct Service {
|
||||
pub db: &'static dyn Data,
|
||||
pub db: &'static dyn Data,
|
||||
}
|
||||
|
||||
impl Service {
|
||||
/// Places one event in the account data of the user and removes the previous entry.
|
||||
#[tracing::instrument(skip(self, room_id, user_id, event_type, data))]
|
||||
pub fn update(
|
||||
&self,
|
||||
room_id: Option<&RoomId>,
|
||||
user_id: &UserId,
|
||||
event_type: RoomAccountDataEventType,
|
||||
data: &serde_json::Value,
|
||||
) -> Result<()> {
|
||||
self.db.update(room_id, user_id, event_type, data)
|
||||
}
|
||||
/// Places one event in the account data of the user and removes the
|
||||
/// previous entry.
|
||||
#[tracing::instrument(skip(self, room_id, user_id, event_type, data))]
|
||||
pub fn update(
|
||||
&self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType,
|
||||
data: &serde_json::Value,
|
||||
) -> Result<()> {
|
||||
self.db.update(room_id, user_id, event_type, data)
|
||||
}
|
||||
|
||||
/// Searches the account data for a specific kind.
|
||||
#[tracing::instrument(skip(self, room_id, user_id, event_type))]
|
||||
pub fn get(
|
||||
&self,
|
||||
room_id: Option<&RoomId>,
|
||||
user_id: &UserId,
|
||||
event_type: RoomAccountDataEventType,
|
||||
) -> Result<Option<Box<serde_json::value::RawValue>>> {
|
||||
self.db.get(room_id, user_id, event_type)
|
||||
}
|
||||
/// Searches the account data for a specific kind.
|
||||
#[tracing::instrument(skip(self, room_id, user_id, event_type))]
|
||||
pub fn get(
|
||||
&self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType,
|
||||
) -> Result<Option<Box<serde_json::value::RawValue>>> {
|
||||
self.db.get(room_id, user_id, event_type)
|
||||
}
|
||||
|
||||
/// Returns all changes to the account data that happened after `since`.
|
||||
#[tracing::instrument(skip(self, room_id, user_id, since))]
|
||||
pub fn changes_since(
|
||||
&self,
|
||||
room_id: Option<&RoomId>,
|
||||
user_id: &UserId,
|
||||
since: u64,
|
||||
) -> Result<HashMap<RoomAccountDataEventType, Raw<AnyEphemeralRoomEvent>>> {
|
||||
self.db.changes_since(room_id, user_id, since)
|
||||
}
|
||||
/// Returns all changes to the account data that happened after `since`.
|
||||
#[tracing::instrument(skip(self, room_id, user_id, since))]
|
||||
pub fn changes_since(
|
||||
&self, room_id: Option<&RoomId>, user_id: &UserId, since: u64,
|
||||
) -> Result<HashMap<RoomAccountDataEventType, Raw<AnyEphemeralRoomEvent>>> {
|
||||
self.db.changes_since(room_id, user_id, since)
|
||||
}
|
||||
}
|
||||
|
||||
+2236
-2284
File diff suppressed because it is too large
Load Diff
@@ -3,19 +3,19 @@ use ruma::api::appservice::Registration;
|
||||
use crate::Result;
|
||||
|
||||
pub trait Data: Send + Sync {
|
||||
/// Registers an appservice and returns the ID to the caller
|
||||
fn register_appservice(&self, yaml: Registration) -> Result<String>;
|
||||
/// Registers an appservice and returns the ID to the caller
|
||||
fn register_appservice(&self, yaml: Registration) -> Result<String>;
|
||||
|
||||
/// Remove an appservice registration
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `service_name` - the name you send to register the service previously
|
||||
fn unregister_appservice(&self, service_name: &str) -> Result<()>;
|
||||
/// Remove an appservice registration
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `service_name` - the name you send to register the service previously
|
||||
fn unregister_appservice(&self, service_name: &str) -> Result<()>;
|
||||
|
||||
fn get_registration(&self, id: &str) -> Result<Option<Registration>>;
|
||||
fn get_registration(&self, id: &str) -> Result<Option<Registration>>;
|
||||
|
||||
fn iter_ids<'a>(&'a self) -> Result<Box<dyn Iterator<Item = Result<String>> + 'a>>;
|
||||
fn iter_ids<'a>(&'a self) -> Result<Box<dyn Iterator<Item = Result<String>> + 'a>>;
|
||||
|
||||
fn all(&self) -> Result<Vec<(String, Registration)>>;
|
||||
fn all(&self) -> Result<Vec<(String, Registration)>>;
|
||||
}
|
||||
|
||||
@@ -6,33 +6,25 @@ use ruma::api::appservice::Registration;
|
||||
use crate::Result;
|
||||
|
||||
pub struct Service {
|
||||
pub db: &'static dyn Data,
|
||||
pub db: &'static dyn Data,
|
||||
}
|
||||
|
||||
impl Service {
|
||||
/// Registers an appservice and returns the ID to the caller
|
||||
pub fn register_appservice(&self, yaml: Registration) -> Result<String> {
|
||||
self.db.register_appservice(yaml)
|
||||
}
|
||||
/// Registers an appservice and returns the ID to the caller
|
||||
pub fn register_appservice(&self, yaml: Registration) -> Result<String> { self.db.register_appservice(yaml) }
|
||||
|
||||
/// Remove an appservice registration
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `service_name` - the name you send to register the service previously
|
||||
pub fn unregister_appservice(&self, service_name: &str) -> Result<()> {
|
||||
self.db.unregister_appservice(service_name)
|
||||
}
|
||||
/// Remove an appservice registration
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `service_name` - the name you send to register the service previously
|
||||
pub fn unregister_appservice(&self, service_name: &str) -> Result<()> {
|
||||
self.db.unregister_appservice(service_name)
|
||||
}
|
||||
|
||||
pub fn get_registration(&self, id: &str) -> Result<Option<Registration>> {
|
||||
self.db.get_registration(id)
|
||||
}
|
||||
pub fn get_registration(&self, id: &str) -> Result<Option<Registration>> { self.db.get_registration(id) }
|
||||
|
||||
pub fn iter_ids(&self) -> Result<impl Iterator<Item = Result<String>> + '_> {
|
||||
self.db.iter_ids()
|
||||
}
|
||||
pub fn iter_ids(&self) -> Result<impl Iterator<Item = Result<String>> + '_> { self.db.iter_ids() }
|
||||
|
||||
pub fn all(&self) -> Result<Vec<(String, Registration)>> {
|
||||
self.db.all()
|
||||
}
|
||||
pub fn all(&self) -> Result<Vec<(String, Registration)>> { self.db.all() }
|
||||
}
|
||||
|
||||
+21
-25
@@ -2,36 +2,32 @@ use std::collections::BTreeMap;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use ruma::{
|
||||
api::federation::discovery::{ServerSigningKeys, VerifyKey},
|
||||
signatures::Ed25519KeyPair,
|
||||
DeviceId, OwnedServerSigningKeyId, ServerName, UserId,
|
||||
api::federation::discovery::{ServerSigningKeys, VerifyKey},
|
||||
signatures::Ed25519KeyPair,
|
||||
DeviceId, OwnedServerSigningKeyId, ServerName, UserId,
|
||||
};
|
||||
|
||||
use crate::Result;
|
||||
|
||||
#[async_trait]
|
||||
pub trait Data: Send + Sync {
|
||||
fn next_count(&self) -> Result<u64>;
|
||||
fn current_count(&self) -> Result<u64>;
|
||||
fn last_check_for_updates_id(&self) -> Result<u64>;
|
||||
fn update_check_for_updates_id(&self, id: u64) -> Result<()>;
|
||||
async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()>;
|
||||
fn cleanup(&self) -> Result<()>;
|
||||
fn memory_usage(&self) -> String;
|
||||
fn clear_caches(&self, amount: u32);
|
||||
fn load_keypair(&self) -> Result<Ed25519KeyPair>;
|
||||
fn remove_keypair(&self) -> Result<()>;
|
||||
fn add_signing_key(
|
||||
&self,
|
||||
origin: &ServerName,
|
||||
new_keys: ServerSigningKeys,
|
||||
) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>>;
|
||||
fn next_count(&self) -> Result<u64>;
|
||||
fn current_count(&self) -> Result<u64>;
|
||||
fn last_check_for_updates_id(&self) -> Result<u64>;
|
||||
fn update_check_for_updates_id(&self, id: u64) -> Result<()>;
|
||||
async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()>;
|
||||
fn cleanup(&self) -> Result<()>;
|
||||
fn memory_usage(&self) -> String;
|
||||
fn clear_caches(&self, amount: u32);
|
||||
fn load_keypair(&self) -> Result<Ed25519KeyPair>;
|
||||
fn remove_keypair(&self) -> Result<()>;
|
||||
fn add_signing_key(
|
||||
&self, origin: &ServerName, new_keys: ServerSigningKeys,
|
||||
) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>>;
|
||||
|
||||
/// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found for the server.
|
||||
fn signing_keys_for(
|
||||
&self,
|
||||
origin: &ServerName,
|
||||
) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>>;
|
||||
fn database_version(&self) -> Result<u64>;
|
||||
fn bump_database_version(&self, new_version: u64) -> Result<()>;
|
||||
/// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found
|
||||
/// for the server.
|
||||
fn signing_keys_for(&self, origin: &ServerName) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>>;
|
||||
fn database_version(&self) -> Result<u64>;
|
||||
fn bump_database_version(&self, new_version: u64) -> Result<()>;
|
||||
}
|
||||
|
||||
+462
-586
File diff suppressed because it is too large
Load Diff
@@ -1,78 +1,47 @@
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
use crate::Result;
|
||||
use ruma::{
|
||||
api::client::backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup},
|
||||
serde::Raw,
|
||||
OwnedRoomId, RoomId, UserId,
|
||||
api::client::backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup},
|
||||
serde::Raw,
|
||||
OwnedRoomId, RoomId, UserId,
|
||||
};
|
||||
|
||||
use crate::Result;
|
||||
|
||||
pub trait Data: Send + Sync {
|
||||
fn create_backup(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
backup_metadata: &Raw<BackupAlgorithm>,
|
||||
) -> Result<String>;
|
||||
fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw<BackupAlgorithm>) -> Result<String>;
|
||||
|
||||
fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()>;
|
||||
fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()>;
|
||||
|
||||
fn update_backup(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
backup_metadata: &Raw<BackupAlgorithm>,
|
||||
) -> Result<String>;
|
||||
fn update_backup(&self, user_id: &UserId, version: &str, backup_metadata: &Raw<BackupAlgorithm>) -> Result<String>;
|
||||
|
||||
fn get_latest_backup_version(&self, user_id: &UserId) -> Result<Option<String>>;
|
||||
fn get_latest_backup_version(&self, user_id: &UserId) -> Result<Option<String>>;
|
||||
|
||||
fn get_latest_backup(&self, user_id: &UserId)
|
||||
-> Result<Option<(String, Raw<BackupAlgorithm>)>>;
|
||||
fn get_latest_backup(&self, user_id: &UserId) -> Result<Option<(String, Raw<BackupAlgorithm>)>>;
|
||||
|
||||
fn get_backup(&self, user_id: &UserId, version: &str) -> Result<Option<Raw<BackupAlgorithm>>>;
|
||||
fn get_backup(&self, user_id: &UserId, version: &str) -> Result<Option<Raw<BackupAlgorithm>>>;
|
||||
|
||||
fn add_key(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
room_id: &RoomId,
|
||||
session_id: &str,
|
||||
key_data: &Raw<KeyBackupData>,
|
||||
) -> Result<()>;
|
||||
fn add_key(
|
||||
&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, key_data: &Raw<KeyBackupData>,
|
||||
) -> Result<()>;
|
||||
|
||||
fn count_keys(&self, user_id: &UserId, version: &str) -> Result<usize>;
|
||||
fn count_keys(&self, user_id: &UserId, version: &str) -> Result<usize>;
|
||||
|
||||
fn get_etag(&self, user_id: &UserId, version: &str) -> Result<String>;
|
||||
fn get_etag(&self, user_id: &UserId, version: &str) -> Result<String>;
|
||||
|
||||
fn get_all(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
) -> Result<BTreeMap<OwnedRoomId, RoomKeyBackup>>;
|
||||
fn get_all(&self, user_id: &UserId, version: &str) -> Result<BTreeMap<OwnedRoomId, RoomKeyBackup>>;
|
||||
|
||||
fn get_room(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
room_id: &RoomId,
|
||||
) -> Result<BTreeMap<String, Raw<KeyBackupData>>>;
|
||||
fn get_room(
|
||||
&self, user_id: &UserId, version: &str, room_id: &RoomId,
|
||||
) -> Result<BTreeMap<String, Raw<KeyBackupData>>>;
|
||||
|
||||
fn get_session(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
room_id: &RoomId,
|
||||
session_id: &str,
|
||||
) -> Result<Option<Raw<KeyBackupData>>>;
|
||||
fn get_session(
|
||||
&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str,
|
||||
) -> Result<Option<Raw<KeyBackupData>>>;
|
||||
|
||||
fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()>;
|
||||
fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()>;
|
||||
|
||||
fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()>;
|
||||
fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()>;
|
||||
|
||||
fn delete_room_key(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
room_id: &RoomId,
|
||||
session_id: &str,
|
||||
) -> Result<()>;
|
||||
fn delete_room_key(&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str) -> Result<()>;
|
||||
}
|
||||
|
||||
+59
-105
@@ -1,127 +1,81 @@
|
||||
mod data;
|
||||
pub(crate) use data::Data;
|
||||
|
||||
use crate::Result;
|
||||
use ruma::{
|
||||
api::client::backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup},
|
||||
serde::Raw,
|
||||
OwnedRoomId, RoomId, UserId,
|
||||
};
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
pub(crate) use data::Data;
|
||||
use ruma::{
|
||||
api::client::backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup},
|
||||
serde::Raw,
|
||||
OwnedRoomId, RoomId, UserId,
|
||||
};
|
||||
|
||||
use crate::Result;
|
||||
|
||||
pub struct Service {
|
||||
pub db: &'static dyn Data,
|
||||
pub db: &'static dyn Data,
|
||||
}
|
||||
|
||||
impl Service {
|
||||
pub fn create_backup(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
backup_metadata: &Raw<BackupAlgorithm>,
|
||||
) -> Result<String> {
|
||||
self.db.create_backup(user_id, backup_metadata)
|
||||
}
|
||||
pub fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw<BackupAlgorithm>) -> Result<String> {
|
||||
self.db.create_backup(user_id, backup_metadata)
|
||||
}
|
||||
|
||||
pub fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> {
|
||||
self.db.delete_backup(user_id, version)
|
||||
}
|
||||
pub fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> {
|
||||
self.db.delete_backup(user_id, version)
|
||||
}
|
||||
|
||||
pub fn update_backup(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
backup_metadata: &Raw<BackupAlgorithm>,
|
||||
) -> Result<String> {
|
||||
self.db.update_backup(user_id, version, backup_metadata)
|
||||
}
|
||||
pub fn update_backup(
|
||||
&self, user_id: &UserId, version: &str, backup_metadata: &Raw<BackupAlgorithm>,
|
||||
) -> Result<String> {
|
||||
self.db.update_backup(user_id, version, backup_metadata)
|
||||
}
|
||||
|
||||
pub fn get_latest_backup_version(&self, user_id: &UserId) -> Result<Option<String>> {
|
||||
self.db.get_latest_backup_version(user_id)
|
||||
}
|
||||
pub fn get_latest_backup_version(&self, user_id: &UserId) -> Result<Option<String>> {
|
||||
self.db.get_latest_backup_version(user_id)
|
||||
}
|
||||
|
||||
pub fn get_latest_backup(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
) -> Result<Option<(String, Raw<BackupAlgorithm>)>> {
|
||||
self.db.get_latest_backup(user_id)
|
||||
}
|
||||
pub fn get_latest_backup(&self, user_id: &UserId) -> Result<Option<(String, Raw<BackupAlgorithm>)>> {
|
||||
self.db.get_latest_backup(user_id)
|
||||
}
|
||||
|
||||
pub fn get_backup(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
) -> Result<Option<Raw<BackupAlgorithm>>> {
|
||||
self.db.get_backup(user_id, version)
|
||||
}
|
||||
pub fn get_backup(&self, user_id: &UserId, version: &str) -> Result<Option<Raw<BackupAlgorithm>>> {
|
||||
self.db.get_backup(user_id, version)
|
||||
}
|
||||
|
||||
pub fn add_key(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
room_id: &RoomId,
|
||||
session_id: &str,
|
||||
key_data: &Raw<KeyBackupData>,
|
||||
) -> Result<()> {
|
||||
self.db
|
||||
.add_key(user_id, version, room_id, session_id, key_data)
|
||||
}
|
||||
pub fn add_key(
|
||||
&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, key_data: &Raw<KeyBackupData>,
|
||||
) -> Result<()> {
|
||||
self.db.add_key(user_id, version, room_id, session_id, key_data)
|
||||
}
|
||||
|
||||
pub fn count_keys(&self, user_id: &UserId, version: &str) -> Result<usize> {
|
||||
self.db.count_keys(user_id, version)
|
||||
}
|
||||
pub fn count_keys(&self, user_id: &UserId, version: &str) -> Result<usize> { self.db.count_keys(user_id, version) }
|
||||
|
||||
pub fn get_etag(&self, user_id: &UserId, version: &str) -> Result<String> {
|
||||
self.db.get_etag(user_id, version)
|
||||
}
|
||||
pub fn get_etag(&self, user_id: &UserId, version: &str) -> Result<String> { self.db.get_etag(user_id, version) }
|
||||
|
||||
pub fn get_all(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
) -> Result<BTreeMap<OwnedRoomId, RoomKeyBackup>> {
|
||||
self.db.get_all(user_id, version)
|
||||
}
|
||||
pub fn get_all(&self, user_id: &UserId, version: &str) -> Result<BTreeMap<OwnedRoomId, RoomKeyBackup>> {
|
||||
self.db.get_all(user_id, version)
|
||||
}
|
||||
|
||||
pub fn get_room(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
room_id: &RoomId,
|
||||
) -> Result<BTreeMap<String, Raw<KeyBackupData>>> {
|
||||
self.db.get_room(user_id, version, room_id)
|
||||
}
|
||||
pub fn get_room(
|
||||
&self, user_id: &UserId, version: &str, room_id: &RoomId,
|
||||
) -> Result<BTreeMap<String, Raw<KeyBackupData>>> {
|
||||
self.db.get_room(user_id, version, room_id)
|
||||
}
|
||||
|
||||
pub fn get_session(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
room_id: &RoomId,
|
||||
session_id: &str,
|
||||
) -> Result<Option<Raw<KeyBackupData>>> {
|
||||
self.db.get_session(user_id, version, room_id, session_id)
|
||||
}
|
||||
pub fn get_session(
|
||||
&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str,
|
||||
) -> Result<Option<Raw<KeyBackupData>>> {
|
||||
self.db.get_session(user_id, version, room_id, session_id)
|
||||
}
|
||||
|
||||
pub fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()> {
|
||||
self.db.delete_all_keys(user_id, version)
|
||||
}
|
||||
pub fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()> {
|
||||
self.db.delete_all_keys(user_id, version)
|
||||
}
|
||||
|
||||
pub fn delete_room_keys(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
room_id: &RoomId,
|
||||
) -> Result<()> {
|
||||
self.db.delete_room_keys(user_id, version, room_id)
|
||||
}
|
||||
pub fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()> {
|
||||
self.db.delete_room_keys(user_id, version, room_id)
|
||||
}
|
||||
|
||||
pub fn delete_room_key(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
room_id: &RoomId,
|
||||
session_id: &str,
|
||||
) -> Result<()> {
|
||||
self.db
|
||||
.delete_room_key(user_id, version, room_id, session_id)
|
||||
}
|
||||
pub fn delete_room_key(&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str) -> Result<()> {
|
||||
self.db.delete_room_key(user_id, version, room_id, session_id)
|
||||
}
|
||||
}
|
||||
|
||||
+13
-26
@@ -1,37 +1,24 @@
|
||||
use crate::Result;
|
||||
|
||||
pub trait Data: Send + Sync {
|
||||
fn create_file_metadata(
|
||||
&self,
|
||||
mxc: String,
|
||||
width: u32,
|
||||
height: u32,
|
||||
content_disposition: Option<&str>,
|
||||
content_type: Option<&str>,
|
||||
) -> Result<Vec<u8>>;
|
||||
fn create_file_metadata(
|
||||
&self, mxc: String, width: u32, height: u32, content_disposition: Option<&str>, content_type: Option<&str>,
|
||||
) -> Result<Vec<u8>>;
|
||||
|
||||
fn delete_file_mxc(&self, mxc: String) -> Result<()>;
|
||||
fn delete_file_mxc(&self, mxc: String) -> Result<()>;
|
||||
|
||||
/// Returns content_disposition, content_type and the metadata key.
|
||||
fn search_file_metadata(
|
||||
&self,
|
||||
mxc: String,
|
||||
width: u32,
|
||||
height: u32,
|
||||
) -> Result<(Option<String>, Option<String>, Vec<u8>)>;
|
||||
/// Returns content_disposition, content_type and the metadata key.
|
||||
fn search_file_metadata(
|
||||
&self, mxc: String, width: u32, height: u32,
|
||||
) -> Result<(Option<String>, Option<String>, Vec<u8>)>;
|
||||
|
||||
fn search_mxc_metadata_prefix(&self, mxc: String) -> Result<Vec<Vec<u8>>>;
|
||||
fn search_mxc_metadata_prefix(&self, mxc: String) -> Result<Vec<Vec<u8>>>;
|
||||
|
||||
fn get_all_media_keys(&self) -> Result<Vec<Vec<u8>>>;
|
||||
fn get_all_media_keys(&self) -> Result<Vec<Vec<u8>>>;
|
||||
|
||||
fn remove_url_preview(&self, url: &str) -> Result<()>;
|
||||
fn remove_url_preview(&self, url: &str) -> Result<()>;
|
||||
|
||||
fn set_url_preview(
|
||||
&self,
|
||||
url: &str,
|
||||
data: &super::UrlPreviewData,
|
||||
timestamp: std::time::Duration,
|
||||
) -> Result<()>;
|
||||
fn set_url_preview(&self, url: &str, data: &super::UrlPreviewData, timestamp: std::time::Duration) -> Result<()>;
|
||||
|
||||
fn get_url_preview(&self, url: &str) -> Option<super::UrlPreviewData>;
|
||||
fn get_url_preview(&self, url: &str) -> Option<super::UrlPreviewData>;
|
||||
}
|
||||
|
||||
+389
-488
@@ -1,579 +1,480 @@
|
||||
mod data;
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
io::Cursor,
|
||||
sync::{Arc, RwLock},
|
||||
time::SystemTime,
|
||||
collections::HashMap,
|
||||
io::Cursor,
|
||||
sync::{Arc, RwLock},
|
||||
time::SystemTime,
|
||||
};
|
||||
|
||||
pub(crate) use data::Data;
|
||||
use image::imageops::FilterType;
|
||||
use ruma::OwnedMxcUri;
|
||||
use serde::Serialize;
|
||||
use tokio::{
|
||||
fs::{self, File},
|
||||
io::{AsyncReadExt, AsyncWriteExt, BufReader},
|
||||
sync::Mutex,
|
||||
};
|
||||
use tracing::{debug, error};
|
||||
|
||||
use crate::{services, utils, Error, Result};
|
||||
use image::imageops::FilterType;
|
||||
|
||||
use tokio::{
|
||||
fs::{self, File},
|
||||
io::{AsyncReadExt, AsyncWriteExt, BufReader},
|
||||
sync::Mutex,
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct FileMeta {
|
||||
pub content_disposition: Option<String>,
|
||||
pub content_type: Option<String>,
|
||||
pub file: Vec<u8>,
|
||||
pub content_disposition: Option<String>,
|
||||
pub content_type: Option<String>,
|
||||
pub file: Vec<u8>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Default)]
|
||||
pub struct UrlPreviewData {
|
||||
#[serde(
|
||||
skip_serializing_if = "Option::is_none",
|
||||
rename(serialize = "og:title")
|
||||
)]
|
||||
pub title: Option<String>,
|
||||
#[serde(
|
||||
skip_serializing_if = "Option::is_none",
|
||||
rename(serialize = "og:description")
|
||||
)]
|
||||
pub description: Option<String>,
|
||||
#[serde(
|
||||
skip_serializing_if = "Option::is_none",
|
||||
rename(serialize = "og:image")
|
||||
)]
|
||||
pub image: Option<String>,
|
||||
#[serde(
|
||||
skip_serializing_if = "Option::is_none",
|
||||
rename(serialize = "matrix:image:size")
|
||||
)]
|
||||
pub image_size: Option<usize>,
|
||||
#[serde(
|
||||
skip_serializing_if = "Option::is_none",
|
||||
rename(serialize = "og:image:width")
|
||||
)]
|
||||
pub image_width: Option<u32>,
|
||||
#[serde(
|
||||
skip_serializing_if = "Option::is_none",
|
||||
rename(serialize = "og:image:height")
|
||||
)]
|
||||
pub image_height: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none", rename(serialize = "og:title"))]
|
||||
pub title: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none", rename(serialize = "og:description"))]
|
||||
pub description: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none", rename(serialize = "og:image"))]
|
||||
pub image: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none", rename(serialize = "matrix:image:size"))]
|
||||
pub image_size: Option<usize>,
|
||||
#[serde(skip_serializing_if = "Option::is_none", rename(serialize = "og:image:width"))]
|
||||
pub image_width: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none", rename(serialize = "og:image:height"))]
|
||||
pub image_height: Option<u32>,
|
||||
}
|
||||
|
||||
pub struct Service {
|
||||
pub db: &'static dyn Data,
|
||||
pub url_preview_mutex: RwLock<HashMap<String, Arc<Mutex<()>>>>,
|
||||
pub db: &'static dyn Data,
|
||||
pub url_preview_mutex: RwLock<HashMap<String, Arc<Mutex<()>>>>,
|
||||
}
|
||||
|
||||
impl Service {
|
||||
/// Uploads a file.
|
||||
pub async fn create(
|
||||
&self,
|
||||
mxc: String,
|
||||
content_disposition: Option<&str>,
|
||||
content_type: Option<&str>,
|
||||
file: &[u8],
|
||||
) -> Result<()> {
|
||||
// Width, Height = 0 if it's not a thumbnail
|
||||
let key = self
|
||||
.db
|
||||
.create_file_metadata(mxc, 0, 0, content_disposition, content_type)?;
|
||||
/// Uploads a file.
|
||||
pub async fn create(
|
||||
&self, mxc: String, content_disposition: Option<&str>, content_type: Option<&str>, file: &[u8],
|
||||
) -> Result<()> {
|
||||
// Width, Height = 0 if it's not a thumbnail
|
||||
let key = self.db.create_file_metadata(mxc, 0, 0, content_disposition, content_type)?;
|
||||
|
||||
let path = if cfg!(feature = "sha256_media") {
|
||||
services().globals.get_media_file_new(&key)
|
||||
} else {
|
||||
#[allow(deprecated)]
|
||||
services().globals.get_media_file(&key)
|
||||
};
|
||||
let path = if cfg!(feature = "sha256_media") {
|
||||
services().globals.get_media_file_new(&key)
|
||||
} else {
|
||||
#[allow(deprecated)]
|
||||
services().globals.get_media_file(&key)
|
||||
};
|
||||
|
||||
let mut f = File::create(path).await?;
|
||||
f.write_all(file).await?;
|
||||
Ok(())
|
||||
}
|
||||
let mut f = File::create(path).await?;
|
||||
f.write_all(file).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Deletes a file in the database and from the media directory via an MXC
|
||||
pub async fn delete(&self, mxc: String) -> Result<()> {
|
||||
if let Ok(keys) = self.db.search_mxc_metadata_prefix(mxc.clone()) {
|
||||
for key in keys {
|
||||
let file_path = if cfg!(feature = "sha256_media") {
|
||||
services().globals.get_media_file_new(&key)
|
||||
} else {
|
||||
#[allow(deprecated)]
|
||||
services().globals.get_media_file(&key)
|
||||
};
|
||||
debug!("Got local file path: {:?}", file_path);
|
||||
/// Deletes a file in the database and from the media directory via an MXC
|
||||
pub async fn delete(&self, mxc: String) -> Result<()> {
|
||||
if let Ok(keys) = self.db.search_mxc_metadata_prefix(mxc.clone()) {
|
||||
for key in keys {
|
||||
let file_path = if cfg!(feature = "sha256_media") {
|
||||
services().globals.get_media_file_new(&key)
|
||||
} else {
|
||||
#[allow(deprecated)]
|
||||
services().globals.get_media_file(&key)
|
||||
};
|
||||
debug!("Got local file path: {:?}", file_path);
|
||||
|
||||
debug!(
|
||||
"Deleting local file {:?} from filesystem, original MXC: {}",
|
||||
file_path, mxc
|
||||
);
|
||||
tokio::fs::remove_file(file_path).await?;
|
||||
debug!("Deleting local file {:?} from filesystem, original MXC: {}", file_path, mxc);
|
||||
tokio::fs::remove_file(file_path).await?;
|
||||
|
||||
debug!("Deleting MXC {mxc} from database");
|
||||
self.db.delete_file_mxc(mxc.clone())?;
|
||||
}
|
||||
debug!("Deleting MXC {mxc} from database");
|
||||
self.db.delete_file_mxc(mxc.clone())?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
} else {
|
||||
error!("Failed to find any media keys for MXC \"{mxc}\" in our database (MXC does not exist)");
|
||||
Err(Error::bad_database("Failed to find any media keys for the provided MXC in our database (MXC does not exist)"))
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
} else {
|
||||
error!("Failed to find any media keys for MXC \"{mxc}\" in our database (MXC does not exist)");
|
||||
Err(Error::bad_database(
|
||||
"Failed to find any media keys for the provided MXC in our database (MXC does not exist)",
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
/// Uploads or replaces a file thumbnail.
|
||||
pub async fn upload_thumbnail(
|
||||
&self,
|
||||
mxc: String,
|
||||
content_disposition: Option<&str>,
|
||||
content_type: Option<&str>,
|
||||
width: u32,
|
||||
height: u32,
|
||||
file: &[u8],
|
||||
) -> Result<()> {
|
||||
let key =
|
||||
self.db
|
||||
.create_file_metadata(mxc, width, height, content_disposition, content_type)?;
|
||||
/// Uploads or replaces a file thumbnail.
|
||||
pub async fn upload_thumbnail(
|
||||
&self, mxc: String, content_disposition: Option<&str>, content_type: Option<&str>, width: u32, height: u32,
|
||||
file: &[u8],
|
||||
) -> Result<()> {
|
||||
let key = self.db.create_file_metadata(mxc, width, height, content_disposition, content_type)?;
|
||||
|
||||
let path = if cfg!(feature = "sha256_media") {
|
||||
services().globals.get_media_file_new(&key)
|
||||
} else {
|
||||
#[allow(deprecated)]
|
||||
services().globals.get_media_file(&key)
|
||||
};
|
||||
let path = if cfg!(feature = "sha256_media") {
|
||||
services().globals.get_media_file_new(&key)
|
||||
} else {
|
||||
#[allow(deprecated)]
|
||||
services().globals.get_media_file(&key)
|
||||
};
|
||||
|
||||
let mut f = File::create(path).await?;
|
||||
f.write_all(file).await?;
|
||||
let mut f = File::create(path).await?;
|
||||
f.write_all(file).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Downloads a file.
|
||||
pub async fn get(&self, mxc: String) -> Result<Option<FileMeta>> {
|
||||
if let Ok((content_disposition, content_type, key)) =
|
||||
self.db.search_file_metadata(mxc, 0, 0)
|
||||
{
|
||||
let path = if cfg!(feature = "sha256_media") {
|
||||
services().globals.get_media_file_new(&key)
|
||||
} else {
|
||||
#[allow(deprecated)]
|
||||
services().globals.get_media_file(&key)
|
||||
};
|
||||
/// Downloads a file.
|
||||
pub async fn get(&self, mxc: String) -> Result<Option<FileMeta>> {
|
||||
if let Ok((content_disposition, content_type, key)) = self.db.search_file_metadata(mxc, 0, 0) {
|
||||
let path = if cfg!(feature = "sha256_media") {
|
||||
services().globals.get_media_file_new(&key)
|
||||
} else {
|
||||
#[allow(deprecated)]
|
||||
services().globals.get_media_file(&key)
|
||||
};
|
||||
|
||||
let mut file = Vec::new();
|
||||
BufReader::new(File::open(path).await?)
|
||||
.read_to_end(&mut file)
|
||||
.await?;
|
||||
let mut file = Vec::new();
|
||||
BufReader::new(File::open(path).await?).read_to_end(&mut file).await?;
|
||||
|
||||
Ok(Some(FileMeta {
|
||||
content_disposition,
|
||||
content_type,
|
||||
file,
|
||||
}))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
Ok(Some(FileMeta {
|
||||
content_disposition,
|
||||
content_type,
|
||||
file,
|
||||
}))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
/// Deletes all remote only media files in the given at or after time/duration. Returns a u32
|
||||
/// with the amount of media files deleted.
|
||||
pub async fn delete_all_remote_media_at_after_time(&self, time: String) -> Result<u32> {
|
||||
if let Ok(all_keys) = self.db.get_all_media_keys() {
|
||||
let user_duration: SystemTime = match cyborgtime::parse_duration(&time) {
|
||||
Ok(duration) => {
|
||||
debug!("Parsed duration: {:?}", duration);
|
||||
debug!("System time now: {:?}", SystemTime::now());
|
||||
SystemTime::now() - duration
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to parse user-specified time duration: {}", e);
|
||||
return Err(Error::bad_database(
|
||||
"Failed to parse user-specified time duration.",
|
||||
));
|
||||
}
|
||||
};
|
||||
/// Deletes all remote only media files in the given at or after
|
||||
/// time/duration. Returns a u32 with the amount of media files deleted.
|
||||
pub async fn delete_all_remote_media_at_after_time(&self, time: String) -> Result<u32> {
|
||||
if let Ok(all_keys) = self.db.get_all_media_keys() {
|
||||
let user_duration: SystemTime = match cyborgtime::parse_duration(&time) {
|
||||
Ok(duration) => {
|
||||
debug!("Parsed duration: {:?}", duration);
|
||||
debug!("System time now: {:?}", SystemTime::now());
|
||||
SystemTime::now() - duration
|
||||
},
|
||||
Err(e) => {
|
||||
error!("Failed to parse user-specified time duration: {}", e);
|
||||
return Err(Error::bad_database("Failed to parse user-specified time duration."));
|
||||
},
|
||||
};
|
||||
|
||||
let mut remote_mxcs: Vec<String> = vec![];
|
||||
let mut remote_mxcs: Vec<String> = vec![];
|
||||
|
||||
for key in all_keys {
|
||||
debug!("Full MXC key from database: {:?}", key);
|
||||
for key in all_keys {
|
||||
debug!("Full MXC key from database: {:?}", key);
|
||||
|
||||
// we need to get the MXC URL from the first part of the key (the first 0xff / 255 push)
|
||||
// this code does look kinda crazy but blame conduit for using magic keys
|
||||
let mut parts = key.split(|&b| b == 0xff);
|
||||
let mxc = parts
|
||||
.next()
|
||||
.map(|bytes| {
|
||||
utils::string_from_bytes(bytes).map_err(|e| {
|
||||
error!("Failed to parse MXC unicode bytes from our database: {}", e);
|
||||
Error::bad_database(
|
||||
"Failed to parse MXC unicode bytes from our database",
|
||||
)
|
||||
})
|
||||
})
|
||||
.transpose()?;
|
||||
// we need to get the MXC URL from the first part of the key (the first 0xff /
|
||||
// 255 push) this code does look kinda crazy but blame conduit for using magic
|
||||
// keys
|
||||
let mut parts = key.split(|&b| b == 0xFF);
|
||||
let mxc = parts
|
||||
.next()
|
||||
.map(|bytes| {
|
||||
utils::string_from_bytes(bytes).map_err(|e| {
|
||||
error!("Failed to parse MXC unicode bytes from our database: {}", e);
|
||||
Error::bad_database("Failed to parse MXC unicode bytes from our database")
|
||||
})
|
||||
})
|
||||
.transpose()?;
|
||||
|
||||
let mxc_s = match mxc {
|
||||
Some(mxc) => mxc,
|
||||
None => {
|
||||
return Err(Error::bad_database(
|
||||
"Parsed MXC URL unicode bytes from database but still is None",
|
||||
));
|
||||
}
|
||||
};
|
||||
let mxc_s = match mxc {
|
||||
Some(mxc) => mxc,
|
||||
None => {
|
||||
return Err(Error::bad_database(
|
||||
"Parsed MXC URL unicode bytes from database but still is None",
|
||||
));
|
||||
},
|
||||
};
|
||||
|
||||
debug!("Parsed MXC key to URL: {}", mxc_s);
|
||||
debug!("Parsed MXC key to URL: {}", mxc_s);
|
||||
|
||||
let mxc = OwnedMxcUri::from(mxc_s);
|
||||
if mxc.server_name() == Ok(services().globals.server_name()) {
|
||||
debug!("Ignoring local media MXC: {}", mxc);
|
||||
// ignore our own MXC URLs as this would be local media.
|
||||
continue;
|
||||
}
|
||||
let mxc = OwnedMxcUri::from(mxc_s);
|
||||
if mxc.server_name() == Ok(services().globals.server_name()) {
|
||||
debug!("Ignoring local media MXC: {}", mxc);
|
||||
// ignore our own MXC URLs as this would be local media.
|
||||
continue;
|
||||
}
|
||||
|
||||
let path = if cfg!(feature = "sha256_media") {
|
||||
services().globals.get_media_file_new(&key)
|
||||
} else {
|
||||
#[allow(deprecated)]
|
||||
services().globals.get_media_file(&key)
|
||||
};
|
||||
let path = if cfg!(feature = "sha256_media") {
|
||||
services().globals.get_media_file_new(&key)
|
||||
} else {
|
||||
#[allow(deprecated)]
|
||||
services().globals.get_media_file(&key)
|
||||
};
|
||||
|
||||
debug!("MXC path: {:?}", path);
|
||||
debug!("MXC path: {:?}", path);
|
||||
|
||||
let file_metadata = fs::metadata(path.clone()).await?;
|
||||
debug!("File metadata: {:?}", file_metadata);
|
||||
let file_metadata = fs::metadata(path.clone()).await?;
|
||||
debug!("File metadata: {:?}", file_metadata);
|
||||
|
||||
let file_created_at = file_metadata.created()?;
|
||||
debug!("File created at: {:?}", file_created_at);
|
||||
let file_created_at = file_metadata.created()?;
|
||||
debug!("File created at: {:?}", file_created_at);
|
||||
|
||||
if file_created_at >= user_duration {
|
||||
debug!("File is within user duration, pushing to list of file paths and keys to delete.");
|
||||
remote_mxcs.push(mxc.to_string());
|
||||
}
|
||||
}
|
||||
if file_created_at >= user_duration {
|
||||
debug!("File is within user duration, pushing to list of file paths and keys to delete.");
|
||||
remote_mxcs.push(mxc.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
debug!("Finished going through all our media in database for eligible keys to delete, checking if these are empty");
|
||||
debug!(
|
||||
"Finished going through all our media in database for eligible keys to delete, checking if these are \
|
||||
empty"
|
||||
);
|
||||
|
||||
if remote_mxcs.is_empty() {
|
||||
return Err(Error::bad_database(
|
||||
"Did not found any eligible MXCs to delete.",
|
||||
));
|
||||
}
|
||||
if remote_mxcs.is_empty() {
|
||||
return Err(Error::bad_database("Did not found any eligible MXCs to delete."));
|
||||
}
|
||||
|
||||
debug!("Deleting media now in the past \"{:?}\".", user_duration);
|
||||
debug!("Deleting media now in the past \"{:?}\".", user_duration);
|
||||
|
||||
let mut deletion_count = 0;
|
||||
let mut deletion_count = 0;
|
||||
|
||||
for mxc in remote_mxcs {
|
||||
debug!("Deleting MXC {mxc} from database and filesystem");
|
||||
self.delete(mxc).await?;
|
||||
deletion_count += 1;
|
||||
}
|
||||
for mxc in remote_mxcs {
|
||||
debug!("Deleting MXC {mxc} from database and filesystem");
|
||||
self.delete(mxc).await?;
|
||||
deletion_count += 1;
|
||||
}
|
||||
|
||||
Ok(deletion_count)
|
||||
} else {
|
||||
Err(Error::bad_database(
|
||||
"Failed to get all our media keys (filesystem or database issue?).",
|
||||
))
|
||||
}
|
||||
}
|
||||
Ok(deletion_count)
|
||||
} else {
|
||||
Err(Error::bad_database(
|
||||
"Failed to get all our media keys (filesystem or database issue?).",
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns width, height of the thumbnail and whether it should be cropped. Returns None when
|
||||
/// the server should send the original file.
|
||||
pub fn thumbnail_properties(&self, width: u32, height: u32) -> Option<(u32, u32, bool)> {
|
||||
match (width, height) {
|
||||
(0..=32, 0..=32) => Some((32, 32, true)),
|
||||
(0..=96, 0..=96) => Some((96, 96, true)),
|
||||
(0..=320, 0..=240) => Some((320, 240, false)),
|
||||
(0..=640, 0..=480) => Some((640, 480, false)),
|
||||
(0..=800, 0..=600) => Some((800, 600, false)),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
/// Returns width, height of the thumbnail and whether it should be cropped.
|
||||
/// Returns None when the server should send the original file.
|
||||
pub fn thumbnail_properties(&self, width: u32, height: u32) -> Option<(u32, u32, bool)> {
|
||||
match (width, height) {
|
||||
(0..=32, 0..=32) => Some((32, 32, true)),
|
||||
(0..=96, 0..=96) => Some((96, 96, true)),
|
||||
(0..=320, 0..=240) => Some((320, 240, false)),
|
||||
(0..=640, 0..=480) => Some((640, 480, false)),
|
||||
(0..=800, 0..=600) => Some((800, 600, false)),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Downloads a file's thumbnail.
|
||||
///
|
||||
/// Here's an example on how it works:
|
||||
///
|
||||
/// - Client requests an image with width=567, height=567
|
||||
/// - Server rounds that up to (800, 600), so it doesn't have to save too many thumbnails
|
||||
/// - Server rounds that up again to (958, 600) to fix the aspect ratio (only for width,height>96)
|
||||
/// - Server creates the thumbnail and sends it to the user
|
||||
///
|
||||
/// For width,height <= 96 the server uses another thumbnailing algorithm which crops the image afterwards.
|
||||
pub async fn get_thumbnail(
|
||||
&self,
|
||||
mxc: String,
|
||||
width: u32,
|
||||
height: u32,
|
||||
) -> Result<Option<FileMeta>> {
|
||||
let (width, height, crop) = self
|
||||
.thumbnail_properties(width, height)
|
||||
.unwrap_or((0, 0, false)); // 0, 0 because that's the original file
|
||||
/// Downloads a file's thumbnail.
|
||||
///
|
||||
/// Here's an example on how it works:
|
||||
///
|
||||
/// - Client requests an image with width=567, height=567
|
||||
/// - Server rounds that up to (800, 600), so it doesn't have to save too
|
||||
/// many thumbnails
|
||||
/// - Server rounds that up again to (958, 600) to fix the aspect ratio
|
||||
/// (only for width,height>96)
|
||||
/// - Server creates the thumbnail and sends it to the user
|
||||
///
|
||||
/// For width,height <= 96 the server uses another thumbnailing algorithm
|
||||
/// which crops the image afterwards.
|
||||
pub async fn get_thumbnail(&self, mxc: String, width: u32, height: u32) -> Result<Option<FileMeta>> {
|
||||
let (width, height, crop) = self.thumbnail_properties(width, height).unwrap_or((0, 0, false)); // 0, 0 because that's the original file
|
||||
|
||||
if let Ok((content_disposition, content_type, key)) =
|
||||
self.db.search_file_metadata(mxc.clone(), width, height)
|
||||
{
|
||||
// Using saved thumbnail
|
||||
let path = if cfg!(feature = "sha256_media") {
|
||||
services().globals.get_media_file_new(&key)
|
||||
} else {
|
||||
#[allow(deprecated)]
|
||||
services().globals.get_media_file(&key)
|
||||
};
|
||||
if let Ok((content_disposition, content_type, key)) = self.db.search_file_metadata(mxc.clone(), width, height) {
|
||||
// Using saved thumbnail
|
||||
let path = if cfg!(feature = "sha256_media") {
|
||||
services().globals.get_media_file_new(&key)
|
||||
} else {
|
||||
#[allow(deprecated)]
|
||||
services().globals.get_media_file(&key)
|
||||
};
|
||||
|
||||
let mut file = Vec::new();
|
||||
File::open(path).await?.read_to_end(&mut file).await?;
|
||||
let mut file = Vec::new();
|
||||
File::open(path).await?.read_to_end(&mut file).await?;
|
||||
|
||||
Ok(Some(FileMeta {
|
||||
content_disposition,
|
||||
content_type,
|
||||
file: file.clone(),
|
||||
}))
|
||||
} else if let Ok((content_disposition, content_type, key)) =
|
||||
self.db.search_file_metadata(mxc.clone(), 0, 0)
|
||||
{
|
||||
// Generate a thumbnail
|
||||
let path = if cfg!(feature = "sha256_media") {
|
||||
services().globals.get_media_file_new(&key)
|
||||
} else {
|
||||
#[allow(deprecated)]
|
||||
services().globals.get_media_file(&key)
|
||||
};
|
||||
Ok(Some(FileMeta {
|
||||
content_disposition,
|
||||
content_type,
|
||||
file: file.clone(),
|
||||
}))
|
||||
} else if let Ok((content_disposition, content_type, key)) = self.db.search_file_metadata(mxc.clone(), 0, 0) {
|
||||
// Generate a thumbnail
|
||||
let path = if cfg!(feature = "sha256_media") {
|
||||
services().globals.get_media_file_new(&key)
|
||||
} else {
|
||||
#[allow(deprecated)]
|
||||
services().globals.get_media_file(&key)
|
||||
};
|
||||
|
||||
let mut file = Vec::new();
|
||||
File::open(path).await?.read_to_end(&mut file).await?;
|
||||
let mut file = Vec::new();
|
||||
File::open(path).await?.read_to_end(&mut file).await?;
|
||||
|
||||
if let Ok(image) = image::load_from_memory(&file) {
|
||||
let original_width = image.width();
|
||||
let original_height = image.height();
|
||||
if width > original_width || height > original_height {
|
||||
return Ok(Some(FileMeta {
|
||||
content_disposition,
|
||||
content_type,
|
||||
file: file.clone(),
|
||||
}));
|
||||
}
|
||||
if let Ok(image) = image::load_from_memory(&file) {
|
||||
let original_width = image.width();
|
||||
let original_height = image.height();
|
||||
if width > original_width || height > original_height {
|
||||
return Ok(Some(FileMeta {
|
||||
content_disposition,
|
||||
content_type,
|
||||
file: file.clone(),
|
||||
}));
|
||||
}
|
||||
|
||||
let thumbnail = if crop {
|
||||
image.resize_to_fill(width, height, FilterType::CatmullRom)
|
||||
} else {
|
||||
let (exact_width, exact_height) = {
|
||||
// Copied from image::dynimage::resize_dimensions
|
||||
let ratio = u64::from(original_width) * u64::from(height);
|
||||
let nratio = u64::from(width) * u64::from(original_height);
|
||||
let thumbnail = if crop {
|
||||
image.resize_to_fill(width, height, FilterType::CatmullRom)
|
||||
} else {
|
||||
let (exact_width, exact_height) = {
|
||||
// Copied from image::dynimage::resize_dimensions
|
||||
let ratio = u64::from(original_width) * u64::from(height);
|
||||
let nratio = u64::from(width) * u64::from(original_height);
|
||||
|
||||
let use_width = nratio <= ratio;
|
||||
let intermediate = if use_width {
|
||||
u64::from(original_height) * u64::from(width)
|
||||
/ u64::from(original_width)
|
||||
} else {
|
||||
u64::from(original_width) * u64::from(height)
|
||||
/ u64::from(original_height)
|
||||
};
|
||||
if use_width {
|
||||
if intermediate <= u64::from(::std::u32::MAX) {
|
||||
(width, intermediate as u32)
|
||||
} else {
|
||||
(
|
||||
(u64::from(width) * u64::from(::std::u32::MAX) / intermediate)
|
||||
as u32,
|
||||
::std::u32::MAX,
|
||||
)
|
||||
}
|
||||
} else if intermediate <= u64::from(::std::u32::MAX) {
|
||||
(intermediate as u32, height)
|
||||
} else {
|
||||
(
|
||||
::std::u32::MAX,
|
||||
(u64::from(height) * u64::from(::std::u32::MAX) / intermediate)
|
||||
as u32,
|
||||
)
|
||||
}
|
||||
};
|
||||
let use_width = nratio <= ratio;
|
||||
let intermediate = if use_width {
|
||||
u64::from(original_height) * u64::from(width) / u64::from(original_width)
|
||||
} else {
|
||||
u64::from(original_width) * u64::from(height) / u64::from(original_height)
|
||||
};
|
||||
if use_width {
|
||||
if intermediate <= u64::from(::std::u32::MAX) {
|
||||
(width, intermediate as u32)
|
||||
} else {
|
||||
(
|
||||
(u64::from(width) * u64::from(::std::u32::MAX) / intermediate) as u32,
|
||||
::std::u32::MAX,
|
||||
)
|
||||
}
|
||||
} else if intermediate <= u64::from(::std::u32::MAX) {
|
||||
(intermediate as u32, height)
|
||||
} else {
|
||||
(
|
||||
::std::u32::MAX,
|
||||
(u64::from(height) * u64::from(::std::u32::MAX) / intermediate) as u32,
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
image.thumbnail_exact(exact_width, exact_height)
|
||||
};
|
||||
image.thumbnail_exact(exact_width, exact_height)
|
||||
};
|
||||
|
||||
let mut thumbnail_bytes = Vec::new();
|
||||
thumbnail.write_to(
|
||||
&mut Cursor::new(&mut thumbnail_bytes),
|
||||
image::ImageOutputFormat::Png,
|
||||
)?;
|
||||
let mut thumbnail_bytes = Vec::new();
|
||||
thumbnail.write_to(&mut Cursor::new(&mut thumbnail_bytes), image::ImageOutputFormat::Png)?;
|
||||
|
||||
// Save thumbnail in database so we don't have to generate it again next time
|
||||
let thumbnail_key = self.db.create_file_metadata(
|
||||
mxc,
|
||||
width,
|
||||
height,
|
||||
content_disposition.as_deref(),
|
||||
content_type.as_deref(),
|
||||
)?;
|
||||
// Save thumbnail in database so we don't have to generate it again next time
|
||||
let thumbnail_key = self.db.create_file_metadata(
|
||||
mxc,
|
||||
width,
|
||||
height,
|
||||
content_disposition.as_deref(),
|
||||
content_type.as_deref(),
|
||||
)?;
|
||||
|
||||
let path = if cfg!(feature = "sha256_media") {
|
||||
services().globals.get_media_file_new(&thumbnail_key)
|
||||
} else {
|
||||
#[allow(deprecated)]
|
||||
services().globals.get_media_file(&thumbnail_key)
|
||||
};
|
||||
let path = if cfg!(feature = "sha256_media") {
|
||||
services().globals.get_media_file_new(&thumbnail_key)
|
||||
} else {
|
||||
#[allow(deprecated)]
|
||||
services().globals.get_media_file(&thumbnail_key)
|
||||
};
|
||||
|
||||
let mut f = File::create(path).await?;
|
||||
f.write_all(&thumbnail_bytes).await?;
|
||||
let mut f = File::create(path).await?;
|
||||
f.write_all(&thumbnail_bytes).await?;
|
||||
|
||||
Ok(Some(FileMeta {
|
||||
content_disposition,
|
||||
content_type,
|
||||
file: thumbnail_bytes.clone(),
|
||||
}))
|
||||
} else {
|
||||
// Couldn't parse file to generate thumbnail, send original
|
||||
Ok(Some(FileMeta {
|
||||
content_disposition,
|
||||
content_type,
|
||||
file: file.clone(),
|
||||
}))
|
||||
}
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
Ok(Some(FileMeta {
|
||||
content_disposition,
|
||||
content_type,
|
||||
file: thumbnail_bytes.clone(),
|
||||
}))
|
||||
} else {
|
||||
// Couldn't parse file to generate thumbnail, send original
|
||||
Ok(Some(FileMeta {
|
||||
content_disposition,
|
||||
content_type,
|
||||
file: file.clone(),
|
||||
}))
|
||||
}
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_url_preview(&self, url: &str) -> Option<UrlPreviewData> {
|
||||
self.db.get_url_preview(url)
|
||||
}
|
||||
pub async fn get_url_preview(&self, url: &str) -> Option<UrlPreviewData> { self.db.get_url_preview(url) }
|
||||
|
||||
pub async fn remove_url_preview(&self, url: &str) -> Result<()> {
|
||||
// TODO: also remove the downloaded image
|
||||
self.db.remove_url_preview(url)
|
||||
}
|
||||
pub async fn remove_url_preview(&self, url: &str) -> Result<()> {
|
||||
// TODO: also remove the downloaded image
|
||||
self.db.remove_url_preview(url)
|
||||
}
|
||||
|
||||
pub async fn set_url_preview(&self, url: &str, data: &UrlPreviewData) -> Result<()> {
|
||||
let now = SystemTime::now()
|
||||
.duration_since(SystemTime::UNIX_EPOCH)
|
||||
.expect("valid system time");
|
||||
self.db.set_url_preview(url, data, now)
|
||||
}
|
||||
pub async fn set_url_preview(&self, url: &str, data: &UrlPreviewData) -> Result<()> {
|
||||
let now = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).expect("valid system time");
|
||||
self.db.set_url_preview(url, data, now)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::path::PathBuf;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use sha2::Digest;
|
||||
use base64::{engine::general_purpose, Engine as _};
|
||||
use sha2::Digest;
|
||||
|
||||
use base64::{engine::general_purpose, Engine as _};
|
||||
use super::*;
|
||||
|
||||
use super::*;
|
||||
struct MockedKVDatabase;
|
||||
|
||||
struct MockedKVDatabase;
|
||||
impl Data for MockedKVDatabase {
|
||||
fn create_file_metadata(
|
||||
&self, mxc: String, width: u32, height: u32, content_disposition: Option<&str>, content_type: Option<&str>,
|
||||
) -> Result<Vec<u8>> {
|
||||
// copied from src/database/key_value/media.rs
|
||||
let mut key = mxc.as_bytes().to_vec();
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(&width.to_be_bytes());
|
||||
key.extend_from_slice(&height.to_be_bytes());
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(content_disposition.as_ref().map(|f| f.as_bytes()).unwrap_or_default());
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(content_type.as_ref().map(|c| c.as_bytes()).unwrap_or_default());
|
||||
|
||||
impl Data for MockedKVDatabase {
|
||||
fn create_file_metadata(
|
||||
&self,
|
||||
mxc: String,
|
||||
width: u32,
|
||||
height: u32,
|
||||
content_disposition: Option<&str>,
|
||||
content_type: Option<&str>,
|
||||
) -> Result<Vec<u8>> {
|
||||
// copied from src/database/key_value/media.rs
|
||||
let mut key = mxc.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(&width.to_be_bytes());
|
||||
key.extend_from_slice(&height.to_be_bytes());
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(
|
||||
content_disposition
|
||||
.as_ref()
|
||||
.map(|f| f.as_bytes())
|
||||
.unwrap_or_default(),
|
||||
);
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(
|
||||
content_type
|
||||
.as_ref()
|
||||
.map(|c| c.as_bytes())
|
||||
.unwrap_or_default(),
|
||||
);
|
||||
Ok(key)
|
||||
}
|
||||
|
||||
Ok(key)
|
||||
}
|
||||
fn delete_file_mxc(&self, _mxc: String) -> Result<()> { todo!() }
|
||||
|
||||
fn delete_file_mxc(&self, _mxc: String) -> Result<()> {
|
||||
todo!()
|
||||
}
|
||||
fn search_mxc_metadata_prefix(&self, _mxc: String) -> Result<Vec<Vec<u8>>> { todo!() }
|
||||
|
||||
fn search_mxc_metadata_prefix(&self, _mxc: String) -> Result<Vec<Vec<u8>>> {
|
||||
todo!()
|
||||
}
|
||||
fn get_all_media_keys(&self) -> Result<Vec<Vec<u8>>> { todo!() }
|
||||
|
||||
fn get_all_media_keys(&self) -> Result<Vec<Vec<u8>>> {
|
||||
todo!()
|
||||
}
|
||||
fn search_file_metadata(
|
||||
&self, _mxc: String, _width: u32, _height: u32,
|
||||
) -> Result<(Option<String>, Option<String>, Vec<u8>)> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn search_file_metadata(
|
||||
&self,
|
||||
_mxc: String,
|
||||
_width: u32,
|
||||
_height: u32,
|
||||
) -> Result<(Option<String>, Option<String>, Vec<u8>)> {
|
||||
todo!()
|
||||
}
|
||||
fn remove_url_preview(&self, _url: &str) -> Result<()> { todo!() }
|
||||
|
||||
fn remove_url_preview(&self, _url: &str) -> Result<()> {
|
||||
todo!()
|
||||
}
|
||||
fn set_url_preview(&self, _url: &str, _data: &UrlPreviewData, _timestamp: std::time::Duration) -> Result<()> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn set_url_preview(
|
||||
&self,
|
||||
_url: &str,
|
||||
_data: &UrlPreviewData,
|
||||
_timestamp: std::time::Duration,
|
||||
) -> Result<()> {
|
||||
todo!()
|
||||
}
|
||||
fn get_url_preview(&self, _url: &str) -> Option<UrlPreviewData> { todo!() }
|
||||
}
|
||||
|
||||
fn get_url_preview(&self, _url: &str) -> Option<UrlPreviewData> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
#[tokio::test]
|
||||
async fn long_file_names_works() {
|
||||
static DB: MockedKVDatabase = MockedKVDatabase;
|
||||
let media = Service {
|
||||
db: &DB,
|
||||
url_preview_mutex: RwLock::new(HashMap::new()),
|
||||
};
|
||||
|
||||
#[tokio::test]
|
||||
async fn long_file_names_works() {
|
||||
static DB: MockedKVDatabase = MockedKVDatabase;
|
||||
let media = Service {
|
||||
db: &DB,
|
||||
url_preview_mutex: RwLock::new(HashMap::new()),
|
||||
};
|
||||
|
||||
let mxc = "mxc://example.com/ascERGshawAWawugaAcauga".to_owned();
|
||||
let width = 100;
|
||||
let height = 100;
|
||||
let content_disposition = "attachment; filename=\"this is a very long file name with spaces and special characters like äöüß and even emoji like 🦀.png\"";
|
||||
let content_type = "image/png";
|
||||
let key = media
|
||||
.db
|
||||
.create_file_metadata(
|
||||
mxc,
|
||||
width,
|
||||
height,
|
||||
Some(content_disposition),
|
||||
Some(content_type),
|
||||
)
|
||||
.unwrap();
|
||||
let mut r = PathBuf::new();
|
||||
r.push("/tmp");
|
||||
r.push("media");
|
||||
// r.push(base64::encode_config(key, base64::URL_SAFE_NO_PAD));
|
||||
// use the sha256 hash of the key as the file name instead of the key itself
|
||||
// this is because the base64 encoded key can be longer than 255 characters.
|
||||
r.push(general_purpose::URL_SAFE_NO_PAD.encode(sha2::Sha256::digest(key)));
|
||||
// Check that the file path is not longer than 255 characters
|
||||
// (255 is the maximum length of a file path on most file systems)
|
||||
assert!(
|
||||
r.to_str().unwrap().len() <= 255,
|
||||
"File path is too long: {}",
|
||||
r.to_str().unwrap().len()
|
||||
);
|
||||
}
|
||||
let mxc = "mxc://example.com/ascERGshawAWawugaAcauga".to_owned();
|
||||
let width = 100;
|
||||
let height = 100;
|
||||
let content_disposition = "attachment; filename=\"this is a very long file name with spaces and special \
|
||||
characters like äöüß and even emoji like 🦀.png\"";
|
||||
let content_type = "image/png";
|
||||
let key =
|
||||
media.db.create_file_metadata(mxc, width, height, Some(content_disposition), Some(content_type)).unwrap();
|
||||
let mut r = PathBuf::new();
|
||||
r.push("/tmp");
|
||||
r.push("media");
|
||||
// r.push(base64::encode_config(key, base64::URL_SAFE_NO_PAD));
|
||||
// use the sha256 hash of the key as the file name instead of the key itself
|
||||
// this is because the base64 encoded key can be longer than 255 characters.
|
||||
r.push(general_purpose::URL_SAFE_NO_PAD.encode(sha2::Sha256::digest(key)));
|
||||
// Check that the file path is not longer than 255 characters
|
||||
// (255 is the maximum length of a file path on most file systems)
|
||||
assert!(
|
||||
r.to_str().unwrap().len() <= 255,
|
||||
"File path is too long: {}",
|
||||
r.to_str().unwrap().len()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
+174
-198
@@ -1,6 +1,6 @@
|
||||
use std::{
|
||||
collections::{BTreeMap, HashMap},
|
||||
sync::{Arc, Mutex, RwLock},
|
||||
collections::{BTreeMap, HashMap},
|
||||
sync::{Arc, Mutex, RwLock},
|
||||
};
|
||||
|
||||
use lru_cache::LruCache;
|
||||
@@ -22,210 +22,186 @@ pub(crate) mod uiaa;
|
||||
pub(crate) mod users;
|
||||
|
||||
pub struct Services<'a> {
|
||||
pub appservice: appservice::Service,
|
||||
pub pusher: pusher::Service,
|
||||
pub rooms: rooms::Service,
|
||||
pub transaction_ids: transaction_ids::Service,
|
||||
pub uiaa: uiaa::Service,
|
||||
pub users: users::Service,
|
||||
pub account_data: account_data::Service,
|
||||
pub admin: Arc<admin::Service>,
|
||||
pub globals: globals::Service<'a>,
|
||||
pub key_backups: key_backups::Service,
|
||||
pub media: media::Service,
|
||||
pub sending: Arc<sending::Service>,
|
||||
pub appservice: appservice::Service,
|
||||
pub pusher: pusher::Service,
|
||||
pub rooms: rooms::Service,
|
||||
pub transaction_ids: transaction_ids::Service,
|
||||
pub uiaa: uiaa::Service,
|
||||
pub users: users::Service,
|
||||
pub account_data: account_data::Service,
|
||||
pub admin: Arc<admin::Service>,
|
||||
pub globals: globals::Service<'a>,
|
||||
pub key_backups: key_backups::Service,
|
||||
pub media: media::Service,
|
||||
pub sending: Arc<sending::Service>,
|
||||
}
|
||||
|
||||
impl Services<'_> {
|
||||
pub fn build<
|
||||
D: appservice::Data
|
||||
+ pusher::Data
|
||||
+ rooms::Data
|
||||
+ transaction_ids::Data
|
||||
+ uiaa::Data
|
||||
+ users::Data
|
||||
+ account_data::Data
|
||||
+ globals::Data
|
||||
+ key_backups::Data
|
||||
+ media::Data
|
||||
+ sending::Data
|
||||
+ 'static,
|
||||
>(
|
||||
db: &'static D,
|
||||
config: Config,
|
||||
) -> Result<Self> {
|
||||
Ok(Self {
|
||||
appservice: appservice::Service { db },
|
||||
pusher: pusher::Service { db },
|
||||
rooms: rooms::Service {
|
||||
alias: rooms::alias::Service { db },
|
||||
auth_chain: rooms::auth_chain::Service { db },
|
||||
directory: rooms::directory::Service { db },
|
||||
edus: rooms::edus::Service {
|
||||
presence: rooms::edus::presence::Service { db },
|
||||
read_receipt: rooms::edus::read_receipt::Service { db },
|
||||
typing: rooms::edus::typing::Service { db },
|
||||
},
|
||||
event_handler: rooms::event_handler::Service,
|
||||
lazy_loading: rooms::lazy_loading::Service {
|
||||
db,
|
||||
lazy_load_waiting: Mutex::new(HashMap::new()),
|
||||
},
|
||||
metadata: rooms::metadata::Service { db },
|
||||
outlier: rooms::outlier::Service { db },
|
||||
pdu_metadata: rooms::pdu_metadata::Service { db },
|
||||
search: rooms::search::Service { db },
|
||||
short: rooms::short::Service { db },
|
||||
state: rooms::state::Service { db },
|
||||
state_accessor: rooms::state_accessor::Service {
|
||||
db,
|
||||
server_visibility_cache: Mutex::new(LruCache::new(
|
||||
(100.0 * config.conduit_cache_capacity_modifier) as usize,
|
||||
)),
|
||||
user_visibility_cache: Mutex::new(LruCache::new(
|
||||
(100.0 * config.conduit_cache_capacity_modifier) as usize,
|
||||
)),
|
||||
},
|
||||
state_cache: rooms::state_cache::Service { db },
|
||||
state_compressor: rooms::state_compressor::Service {
|
||||
db,
|
||||
stateinfo_cache: Mutex::new(LruCache::new(
|
||||
(100.0 * config.conduit_cache_capacity_modifier) as usize,
|
||||
)),
|
||||
},
|
||||
timeline: rooms::timeline::Service {
|
||||
db,
|
||||
lasttimelinecount_cache: Mutex::new(HashMap::new()),
|
||||
},
|
||||
threads: rooms::threads::Service { db },
|
||||
spaces: rooms::spaces::Service {
|
||||
roomid_spacechunk_cache: Mutex::new(LruCache::new(
|
||||
(100.0 * config.conduit_cache_capacity_modifier) as usize,
|
||||
)),
|
||||
},
|
||||
user: rooms::user::Service { db },
|
||||
},
|
||||
transaction_ids: transaction_ids::Service { db },
|
||||
uiaa: uiaa::Service { db },
|
||||
users: users::Service {
|
||||
db,
|
||||
connections: Mutex::new(BTreeMap::new()),
|
||||
},
|
||||
account_data: account_data::Service { db },
|
||||
admin: admin::Service::build(),
|
||||
key_backups: key_backups::Service { db },
|
||||
media: media::Service {
|
||||
db,
|
||||
url_preview_mutex: RwLock::new(HashMap::new()),
|
||||
},
|
||||
sending: sending::Service::build(db, &config),
|
||||
pub fn build<
|
||||
D: appservice::Data
|
||||
+ pusher::Data
|
||||
+ rooms::Data
|
||||
+ transaction_ids::Data
|
||||
+ uiaa::Data
|
||||
+ users::Data
|
||||
+ account_data::Data
|
||||
+ globals::Data
|
||||
+ key_backups::Data
|
||||
+ media::Data
|
||||
+ sending::Data
|
||||
+ 'static,
|
||||
>(
|
||||
db: &'static D, config: Config,
|
||||
) -> Result<Self> {
|
||||
Ok(Self {
|
||||
appservice: appservice::Service {
|
||||
db,
|
||||
},
|
||||
pusher: pusher::Service {
|
||||
db,
|
||||
},
|
||||
rooms: rooms::Service {
|
||||
alias: rooms::alias::Service {
|
||||
db,
|
||||
},
|
||||
auth_chain: rooms::auth_chain::Service {
|
||||
db,
|
||||
},
|
||||
directory: rooms::directory::Service {
|
||||
db,
|
||||
},
|
||||
edus: rooms::edus::Service {
|
||||
presence: rooms::edus::presence::Service {
|
||||
db,
|
||||
},
|
||||
read_receipt: rooms::edus::read_receipt::Service {
|
||||
db,
|
||||
},
|
||||
typing: rooms::edus::typing::Service {
|
||||
db,
|
||||
},
|
||||
},
|
||||
event_handler: rooms::event_handler::Service,
|
||||
lazy_loading: rooms::lazy_loading::Service {
|
||||
db,
|
||||
lazy_load_waiting: Mutex::new(HashMap::new()),
|
||||
},
|
||||
metadata: rooms::metadata::Service {
|
||||
db,
|
||||
},
|
||||
outlier: rooms::outlier::Service {
|
||||
db,
|
||||
},
|
||||
pdu_metadata: rooms::pdu_metadata::Service {
|
||||
db,
|
||||
},
|
||||
search: rooms::search::Service {
|
||||
db,
|
||||
},
|
||||
short: rooms::short::Service {
|
||||
db,
|
||||
},
|
||||
state: rooms::state::Service {
|
||||
db,
|
||||
},
|
||||
state_accessor: rooms::state_accessor::Service {
|
||||
db,
|
||||
server_visibility_cache: Mutex::new(LruCache::new(
|
||||
(100.0 * config.conduit_cache_capacity_modifier) as usize,
|
||||
)),
|
||||
user_visibility_cache: Mutex::new(LruCache::new(
|
||||
(100.0 * config.conduit_cache_capacity_modifier) as usize,
|
||||
)),
|
||||
},
|
||||
state_cache: rooms::state_cache::Service {
|
||||
db,
|
||||
},
|
||||
state_compressor: rooms::state_compressor::Service {
|
||||
db,
|
||||
stateinfo_cache: Mutex::new(LruCache::new(
|
||||
(100.0 * config.conduit_cache_capacity_modifier) as usize,
|
||||
)),
|
||||
},
|
||||
timeline: rooms::timeline::Service {
|
||||
db,
|
||||
lasttimelinecount_cache: Mutex::new(HashMap::new()),
|
||||
},
|
||||
threads: rooms::threads::Service {
|
||||
db,
|
||||
},
|
||||
spaces: rooms::spaces::Service {
|
||||
roomid_spacechunk_cache: Mutex::new(LruCache::new(
|
||||
(100.0 * config.conduit_cache_capacity_modifier) as usize,
|
||||
)),
|
||||
},
|
||||
user: rooms::user::Service {
|
||||
db,
|
||||
},
|
||||
},
|
||||
transaction_ids: transaction_ids::Service {
|
||||
db,
|
||||
},
|
||||
uiaa: uiaa::Service {
|
||||
db,
|
||||
},
|
||||
users: users::Service {
|
||||
db,
|
||||
connections: Mutex::new(BTreeMap::new()),
|
||||
},
|
||||
account_data: account_data::Service {
|
||||
db,
|
||||
},
|
||||
admin: admin::Service::build(),
|
||||
key_backups: key_backups::Service {
|
||||
db,
|
||||
},
|
||||
media: media::Service {
|
||||
db,
|
||||
url_preview_mutex: RwLock::new(HashMap::new()),
|
||||
},
|
||||
sending: sending::Service::build(db, &config),
|
||||
|
||||
globals: globals::Service::load(db, config)?,
|
||||
})
|
||||
}
|
||||
fn memory_usage(&self) -> String {
|
||||
let lazy_load_waiting = self
|
||||
.rooms
|
||||
.lazy_loading
|
||||
.lazy_load_waiting
|
||||
.lock()
|
||||
.unwrap()
|
||||
.len();
|
||||
let server_visibility_cache = self
|
||||
.rooms
|
||||
.state_accessor
|
||||
.server_visibility_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.len();
|
||||
let user_visibility_cache = self
|
||||
.rooms
|
||||
.state_accessor
|
||||
.user_visibility_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.len();
|
||||
let stateinfo_cache = self
|
||||
.rooms
|
||||
.state_compressor
|
||||
.stateinfo_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.len();
|
||||
let lasttimelinecount_cache = self
|
||||
.rooms
|
||||
.timeline
|
||||
.lasttimelinecount_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.len();
|
||||
let roomid_spacechunk_cache = self
|
||||
.rooms
|
||||
.spaces
|
||||
.roomid_spacechunk_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.len();
|
||||
globals: globals::Service::load(db, config)?,
|
||||
})
|
||||
}
|
||||
|
||||
format!(
|
||||
"\
|
||||
fn memory_usage(&self) -> String {
|
||||
let lazy_load_waiting = self.rooms.lazy_loading.lazy_load_waiting.lock().unwrap().len();
|
||||
let server_visibility_cache = self.rooms.state_accessor.server_visibility_cache.lock().unwrap().len();
|
||||
let user_visibility_cache = self.rooms.state_accessor.user_visibility_cache.lock().unwrap().len();
|
||||
let stateinfo_cache = self.rooms.state_compressor.stateinfo_cache.lock().unwrap().len();
|
||||
let lasttimelinecount_cache = self.rooms.timeline.lasttimelinecount_cache.lock().unwrap().len();
|
||||
let roomid_spacechunk_cache = self.rooms.spaces.roomid_spacechunk_cache.lock().unwrap().len();
|
||||
|
||||
format!(
|
||||
"\
|
||||
lazy_load_waiting: {lazy_load_waiting}
|
||||
server_visibility_cache: {server_visibility_cache}
|
||||
user_visibility_cache: {user_visibility_cache}
|
||||
stateinfo_cache: {stateinfo_cache}
|
||||
lasttimelinecount_cache: {lasttimelinecount_cache}
|
||||
roomid_spacechunk_cache: {roomid_spacechunk_cache}\
|
||||
"
|
||||
)
|
||||
}
|
||||
fn clear_caches(&self, amount: u32) {
|
||||
if amount > 0 {
|
||||
self.rooms
|
||||
.lazy_loading
|
||||
.lazy_load_waiting
|
||||
.lock()
|
||||
.unwrap()
|
||||
.clear();
|
||||
}
|
||||
if amount > 1 {
|
||||
self.rooms
|
||||
.state_accessor
|
||||
.server_visibility_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.clear();
|
||||
}
|
||||
if amount > 2 {
|
||||
self.rooms
|
||||
.state_accessor
|
||||
.user_visibility_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.clear();
|
||||
}
|
||||
if amount > 3 {
|
||||
self.rooms
|
||||
.state_compressor
|
||||
.stateinfo_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.clear();
|
||||
}
|
||||
if amount > 4 {
|
||||
self.rooms
|
||||
.timeline
|
||||
.lasttimelinecount_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.clear();
|
||||
}
|
||||
if amount > 5 {
|
||||
self.rooms
|
||||
.spaces
|
||||
.roomid_spacechunk_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.clear();
|
||||
}
|
||||
}
|
||||
roomid_spacechunk_cache: {roomid_spacechunk_cache}"
|
||||
)
|
||||
}
|
||||
|
||||
fn clear_caches(&self, amount: u32) {
|
||||
if amount > 0 {
|
||||
self.rooms.lazy_loading.lazy_load_waiting.lock().unwrap().clear();
|
||||
}
|
||||
if amount > 1 {
|
||||
self.rooms.state_accessor.server_visibility_cache.lock().unwrap().clear();
|
||||
}
|
||||
if amount > 2 {
|
||||
self.rooms.state_accessor.user_visibility_cache.lock().unwrap().clear();
|
||||
}
|
||||
if amount > 3 {
|
||||
self.rooms.state_compressor.stateinfo_cache.lock().unwrap().clear();
|
||||
}
|
||||
if amount > 4 {
|
||||
self.rooms.timeline.lasttimelinecount_cache.lock().unwrap().clear();
|
||||
}
|
||||
if amount > 5 {
|
||||
self.rooms.spaces.roomid_spacechunk_cache.lock().unwrap().clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+278
-316
@@ -1,410 +1,372 @@
|
||||
use crate::Error;
|
||||
use std::{cmp::Ordering, collections::BTreeMap, sync::Arc};
|
||||
|
||||
use ruma::{
|
||||
canonical_json::redact_content_in_place,
|
||||
events::{
|
||||
room::member::RoomMemberEventContent, space::child::HierarchySpaceChildEvent,
|
||||
AnyEphemeralRoomEvent, AnyMessageLikeEvent, AnyStateEvent, AnyStrippedStateEvent,
|
||||
AnySyncStateEvent, AnySyncTimelineEvent, AnyTimelineEvent, StateEvent, TimelineEventType,
|
||||
},
|
||||
serde::Raw,
|
||||
state_res, CanonicalJsonObject, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch,
|
||||
OwnedEventId, OwnedRoomId, OwnedUserId, RoomId, RoomVersionId, UInt, UserId,
|
||||
canonical_json::redact_content_in_place,
|
||||
events::{
|
||||
room::member::RoomMemberEventContent, space::child::HierarchySpaceChildEvent, AnyEphemeralRoomEvent,
|
||||
AnyMessageLikeEvent, AnyStateEvent, AnyStrippedStateEvent, AnySyncStateEvent, AnySyncTimelineEvent,
|
||||
AnyTimelineEvent, StateEvent, TimelineEventType,
|
||||
},
|
||||
serde::Raw,
|
||||
state_res, CanonicalJsonObject, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedRoomId,
|
||||
OwnedUserId, RoomId, RoomVersionId, UInt, UserId,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{
|
||||
json,
|
||||
value::{to_raw_value, RawValue as RawJsonValue},
|
||||
json,
|
||||
value::{to_raw_value, RawValue as RawJsonValue},
|
||||
};
|
||||
use std::{cmp::Ordering, collections::BTreeMap, sync::Arc};
|
||||
use tracing::warn;
|
||||
|
||||
use crate::Error;
|
||||
|
||||
/// Content hashes of a PDU.
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct EventHash {
|
||||
/// The SHA-256 hash.
|
||||
pub sha256: String,
|
||||
/// The SHA-256 hash.
|
||||
pub sha256: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, Debug)]
|
||||
pub struct PduEvent {
|
||||
pub event_id: Arc<EventId>,
|
||||
pub room_id: OwnedRoomId,
|
||||
pub sender: OwnedUserId,
|
||||
pub origin_server_ts: UInt,
|
||||
#[serde(rename = "type")]
|
||||
pub kind: TimelineEventType,
|
||||
pub content: Box<RawJsonValue>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub state_key: Option<String>,
|
||||
pub prev_events: Vec<Arc<EventId>>,
|
||||
pub depth: UInt,
|
||||
pub auth_events: Vec<Arc<EventId>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub redacts: Option<Arc<EventId>>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub unsigned: Option<Box<RawJsonValue>>,
|
||||
pub hashes: EventHash,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub signatures: Option<Box<RawJsonValue>>, // BTreeMap<Box<ServerName>, BTreeMap<ServerSigningKeyId, String>>
|
||||
pub event_id: Arc<EventId>,
|
||||
pub room_id: OwnedRoomId,
|
||||
pub sender: OwnedUserId,
|
||||
pub origin_server_ts: UInt,
|
||||
#[serde(rename = "type")]
|
||||
pub kind: TimelineEventType,
|
||||
pub content: Box<RawJsonValue>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub state_key: Option<String>,
|
||||
pub prev_events: Vec<Arc<EventId>>,
|
||||
pub depth: UInt,
|
||||
pub auth_events: Vec<Arc<EventId>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub redacts: Option<Arc<EventId>>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub unsigned: Option<Box<RawJsonValue>>,
|
||||
pub hashes: EventHash,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub signatures: Option<Box<RawJsonValue>>, // BTreeMap<Box<ServerName>, BTreeMap<ServerSigningKeyId, String>>
|
||||
}
|
||||
|
||||
impl PduEvent {
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn redact(
|
||||
&mut self,
|
||||
room_version_id: RoomVersionId,
|
||||
reason: &PduEvent,
|
||||
) -> crate::Result<()> {
|
||||
self.unsigned = None;
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn redact(&mut self, room_version_id: RoomVersionId, reason: &PduEvent) -> crate::Result<()> {
|
||||
self.unsigned = None;
|
||||
|
||||
let mut content = serde_json::from_str(self.content.get())
|
||||
.map_err(|_| Error::bad_database("PDU in db has invalid content."))?;
|
||||
redact_content_in_place(&mut content, &room_version_id, self.kind.to_string())
|
||||
.map_err(|e| Error::RedactionError(self.sender.server_name().to_owned(), e))?;
|
||||
let mut content = serde_json::from_str(self.content.get())
|
||||
.map_err(|_| Error::bad_database("PDU in db has invalid content."))?;
|
||||
redact_content_in_place(&mut content, &room_version_id, self.kind.to_string())
|
||||
.map_err(|e| Error::RedactionError(self.sender.server_name().to_owned(), e))?;
|
||||
|
||||
self.unsigned = Some(to_raw_value(&json!({
|
||||
"redacted_because": serde_json::to_value(reason).expect("to_value(PduEvent) always works")
|
||||
})).expect("to string always works"));
|
||||
self.unsigned = Some(
|
||||
to_raw_value(&json!({
|
||||
"redacted_because": serde_json::to_value(reason).expect("to_value(PduEvent) always works")
|
||||
}))
|
||||
.expect("to string always works"),
|
||||
);
|
||||
|
||||
self.content = to_raw_value(&content).expect("to string always works");
|
||||
self.content = to_raw_value(&content).expect("to string always works");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn remove_transaction_id(&mut self) -> crate::Result<()> {
|
||||
if let Some(unsigned) = &self.unsigned {
|
||||
let mut unsigned: BTreeMap<String, Box<RawJsonValue>> =
|
||||
serde_json::from_str(unsigned.get())
|
||||
.map_err(|_| Error::bad_database("Invalid unsigned in pdu event"))?;
|
||||
unsigned.remove("transaction_id");
|
||||
self.unsigned = Some(to_raw_value(&unsigned).expect("unsigned is valid"));
|
||||
}
|
||||
pub fn remove_transaction_id(&mut self) -> crate::Result<()> {
|
||||
if let Some(unsigned) = &self.unsigned {
|
||||
let mut unsigned: BTreeMap<String, Box<RawJsonValue>> = serde_json::from_str(unsigned.get())
|
||||
.map_err(|_| Error::bad_database("Invalid unsigned in pdu event"))?;
|
||||
unsigned.remove("transaction_id");
|
||||
self.unsigned = Some(to_raw_value(&unsigned).expect("unsigned is valid"));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn add_age(&mut self) -> crate::Result<()> {
|
||||
let mut unsigned: BTreeMap<String, Box<RawJsonValue>> = self
|
||||
.unsigned
|
||||
.as_ref()
|
||||
.map_or_else(|| Ok(BTreeMap::new()), |u| serde_json::from_str(u.get()))
|
||||
.map_err(|_| Error::bad_database("Invalid unsigned in pdu event"))?;
|
||||
pub fn add_age(&mut self) -> crate::Result<()> {
|
||||
let mut unsigned: BTreeMap<String, Box<RawJsonValue>> = self
|
||||
.unsigned
|
||||
.as_ref()
|
||||
.map_or_else(|| Ok(BTreeMap::new()), |u| serde_json::from_str(u.get()))
|
||||
.map_err(|_| Error::bad_database("Invalid unsigned in pdu event"))?;
|
||||
|
||||
unsigned.insert("age".to_owned(), to_raw_value(&1).unwrap());
|
||||
self.unsigned = Some(to_raw_value(&unsigned).expect("unsigned is valid"));
|
||||
unsigned.insert("age".to_owned(), to_raw_value(&1).unwrap());
|
||||
self.unsigned = Some(to_raw_value(&unsigned).expect("unsigned is valid"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn to_sync_room_event(&self) -> Raw<AnySyncTimelineEvent> {
|
||||
let mut json = json!({
|
||||
"content": self.content,
|
||||
"type": self.kind,
|
||||
"event_id": self.event_id,
|
||||
"sender": self.sender,
|
||||
"origin_server_ts": self.origin_server_ts,
|
||||
});
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn to_sync_room_event(&self) -> Raw<AnySyncTimelineEvent> {
|
||||
let mut json = json!({
|
||||
"content": self.content,
|
||||
"type": self.kind,
|
||||
"event_id": self.event_id,
|
||||
"sender": self.sender,
|
||||
"origin_server_ts": self.origin_server_ts,
|
||||
});
|
||||
|
||||
if let Some(unsigned) = &self.unsigned {
|
||||
json["unsigned"] = json!(unsigned);
|
||||
}
|
||||
if let Some(state_key) = &self.state_key {
|
||||
json["state_key"] = json!(state_key);
|
||||
}
|
||||
if let Some(redacts) = &self.redacts {
|
||||
json["redacts"] = json!(redacts);
|
||||
}
|
||||
if let Some(unsigned) = &self.unsigned {
|
||||
json["unsigned"] = json!(unsigned);
|
||||
}
|
||||
if let Some(state_key) = &self.state_key {
|
||||
json["state_key"] = json!(state_key);
|
||||
}
|
||||
if let Some(redacts) = &self.redacts {
|
||||
json["redacts"] = json!(redacts);
|
||||
}
|
||||
|
||||
serde_json::from_value(json).expect("Raw::from_value always works")
|
||||
}
|
||||
serde_json::from_value(json).expect("Raw::from_value always works")
|
||||
}
|
||||
|
||||
/// This only works for events that are also AnyRoomEvents.
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn to_any_event(&self) -> Raw<AnyEphemeralRoomEvent> {
|
||||
let mut json = json!({
|
||||
"content": self.content,
|
||||
"type": self.kind,
|
||||
"event_id": self.event_id,
|
||||
"sender": self.sender,
|
||||
"origin_server_ts": self.origin_server_ts,
|
||||
"room_id": self.room_id,
|
||||
});
|
||||
/// This only works for events that are also AnyRoomEvents.
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn to_any_event(&self) -> Raw<AnyEphemeralRoomEvent> {
|
||||
let mut json = json!({
|
||||
"content": self.content,
|
||||
"type": self.kind,
|
||||
"event_id": self.event_id,
|
||||
"sender": self.sender,
|
||||
"origin_server_ts": self.origin_server_ts,
|
||||
"room_id": self.room_id,
|
||||
});
|
||||
|
||||
if let Some(unsigned) = &self.unsigned {
|
||||
json["unsigned"] = json!(unsigned);
|
||||
}
|
||||
if let Some(state_key) = &self.state_key {
|
||||
json["state_key"] = json!(state_key);
|
||||
}
|
||||
if let Some(redacts) = &self.redacts {
|
||||
json["redacts"] = json!(redacts);
|
||||
}
|
||||
if let Some(unsigned) = &self.unsigned {
|
||||
json["unsigned"] = json!(unsigned);
|
||||
}
|
||||
if let Some(state_key) = &self.state_key {
|
||||
json["state_key"] = json!(state_key);
|
||||
}
|
||||
if let Some(redacts) = &self.redacts {
|
||||
json["redacts"] = json!(redacts);
|
||||
}
|
||||
|
||||
serde_json::from_value(json).expect("Raw::from_value always works")
|
||||
}
|
||||
serde_json::from_value(json).expect("Raw::from_value always works")
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn to_room_event(&self) -> Raw<AnyTimelineEvent> {
|
||||
let mut json = json!({
|
||||
"content": self.content,
|
||||
"type": self.kind,
|
||||
"event_id": self.event_id,
|
||||
"sender": self.sender,
|
||||
"origin_server_ts": self.origin_server_ts,
|
||||
"room_id": self.room_id,
|
||||
});
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn to_room_event(&self) -> Raw<AnyTimelineEvent> {
|
||||
let mut json = json!({
|
||||
"content": self.content,
|
||||
"type": self.kind,
|
||||
"event_id": self.event_id,
|
||||
"sender": self.sender,
|
||||
"origin_server_ts": self.origin_server_ts,
|
||||
"room_id": self.room_id,
|
||||
});
|
||||
|
||||
if let Some(unsigned) = &self.unsigned {
|
||||
json["unsigned"] = json!(unsigned);
|
||||
}
|
||||
if let Some(state_key) = &self.state_key {
|
||||
json["state_key"] = json!(state_key);
|
||||
}
|
||||
if let Some(redacts) = &self.redacts {
|
||||
json["redacts"] = json!(redacts);
|
||||
}
|
||||
if let Some(unsigned) = &self.unsigned {
|
||||
json["unsigned"] = json!(unsigned);
|
||||
}
|
||||
if let Some(state_key) = &self.state_key {
|
||||
json["state_key"] = json!(state_key);
|
||||
}
|
||||
if let Some(redacts) = &self.redacts {
|
||||
json["redacts"] = json!(redacts);
|
||||
}
|
||||
|
||||
serde_json::from_value(json).expect("Raw::from_value always works")
|
||||
}
|
||||
serde_json::from_value(json).expect("Raw::from_value always works")
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn to_message_like_event(&self) -> Raw<AnyMessageLikeEvent> {
|
||||
let mut json = json!({
|
||||
"content": self.content,
|
||||
"type": self.kind,
|
||||
"event_id": self.event_id,
|
||||
"sender": self.sender,
|
||||
"origin_server_ts": self.origin_server_ts,
|
||||
"room_id": self.room_id,
|
||||
});
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn to_message_like_event(&self) -> Raw<AnyMessageLikeEvent> {
|
||||
let mut json = json!({
|
||||
"content": self.content,
|
||||
"type": self.kind,
|
||||
"event_id": self.event_id,
|
||||
"sender": self.sender,
|
||||
"origin_server_ts": self.origin_server_ts,
|
||||
"room_id": self.room_id,
|
||||
});
|
||||
|
||||
if let Some(unsigned) = &self.unsigned {
|
||||
json["unsigned"] = json!(unsigned);
|
||||
}
|
||||
if let Some(state_key) = &self.state_key {
|
||||
json["state_key"] = json!(state_key);
|
||||
}
|
||||
if let Some(redacts) = &self.redacts {
|
||||
json["redacts"] = json!(redacts);
|
||||
}
|
||||
if let Some(unsigned) = &self.unsigned {
|
||||
json["unsigned"] = json!(unsigned);
|
||||
}
|
||||
if let Some(state_key) = &self.state_key {
|
||||
json["state_key"] = json!(state_key);
|
||||
}
|
||||
if let Some(redacts) = &self.redacts {
|
||||
json["redacts"] = json!(redacts);
|
||||
}
|
||||
|
||||
serde_json::from_value(json).expect("Raw::from_value always works")
|
||||
}
|
||||
serde_json::from_value(json).expect("Raw::from_value always works")
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn to_state_event(&self) -> Raw<AnyStateEvent> {
|
||||
let mut json = json!({
|
||||
"content": self.content,
|
||||
"type": self.kind,
|
||||
"event_id": self.event_id,
|
||||
"sender": self.sender,
|
||||
"origin_server_ts": self.origin_server_ts,
|
||||
"room_id": self.room_id,
|
||||
"state_key": self.state_key,
|
||||
});
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn to_state_event(&self) -> Raw<AnyStateEvent> {
|
||||
let mut json = json!({
|
||||
"content": self.content,
|
||||
"type": self.kind,
|
||||
"event_id": self.event_id,
|
||||
"sender": self.sender,
|
||||
"origin_server_ts": self.origin_server_ts,
|
||||
"room_id": self.room_id,
|
||||
"state_key": self.state_key,
|
||||
});
|
||||
|
||||
if let Some(unsigned) = &self.unsigned {
|
||||
json["unsigned"] = json!(unsigned);
|
||||
}
|
||||
if let Some(unsigned) = &self.unsigned {
|
||||
json["unsigned"] = json!(unsigned);
|
||||
}
|
||||
|
||||
serde_json::from_value(json).expect("Raw::from_value always works")
|
||||
}
|
||||
serde_json::from_value(json).expect("Raw::from_value always works")
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn to_sync_state_event(&self) -> Raw<AnySyncStateEvent> {
|
||||
let mut json = json!({
|
||||
"content": self.content,
|
||||
"type": self.kind,
|
||||
"event_id": self.event_id,
|
||||
"sender": self.sender,
|
||||
"origin_server_ts": self.origin_server_ts,
|
||||
"state_key": self.state_key,
|
||||
});
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn to_sync_state_event(&self) -> Raw<AnySyncStateEvent> {
|
||||
let mut json = json!({
|
||||
"content": self.content,
|
||||
"type": self.kind,
|
||||
"event_id": self.event_id,
|
||||
"sender": self.sender,
|
||||
"origin_server_ts": self.origin_server_ts,
|
||||
"state_key": self.state_key,
|
||||
});
|
||||
|
||||
if let Some(unsigned) = &self.unsigned {
|
||||
json["unsigned"] = json!(unsigned);
|
||||
}
|
||||
if let Some(unsigned) = &self.unsigned {
|
||||
json["unsigned"] = json!(unsigned);
|
||||
}
|
||||
|
||||
serde_json::from_value(json).expect("Raw::from_value always works")
|
||||
}
|
||||
serde_json::from_value(json).expect("Raw::from_value always works")
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn to_stripped_state_event(&self) -> Raw<AnyStrippedStateEvent> {
|
||||
let json = json!({
|
||||
"content": self.content,
|
||||
"type": self.kind,
|
||||
"sender": self.sender,
|
||||
"state_key": self.state_key,
|
||||
});
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn to_stripped_state_event(&self) -> Raw<AnyStrippedStateEvent> {
|
||||
let json = json!({
|
||||
"content": self.content,
|
||||
"type": self.kind,
|
||||
"sender": self.sender,
|
||||
"state_key": self.state_key,
|
||||
});
|
||||
|
||||
serde_json::from_value(json).expect("Raw::from_value always works")
|
||||
}
|
||||
serde_json::from_value(json).expect("Raw::from_value always works")
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn to_stripped_spacechild_state_event(&self) -> Raw<HierarchySpaceChildEvent> {
|
||||
let json = json!({
|
||||
"content": self.content,
|
||||
"type": self.kind,
|
||||
"sender": self.sender,
|
||||
"state_key": self.state_key,
|
||||
"origin_server_ts": self.origin_server_ts,
|
||||
});
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn to_stripped_spacechild_state_event(&self) -> Raw<HierarchySpaceChildEvent> {
|
||||
let json = json!({
|
||||
"content": self.content,
|
||||
"type": self.kind,
|
||||
"sender": self.sender,
|
||||
"state_key": self.state_key,
|
||||
"origin_server_ts": self.origin_server_ts,
|
||||
});
|
||||
|
||||
serde_json::from_value(json).expect("Raw::from_value always works")
|
||||
}
|
||||
serde_json::from_value(json).expect("Raw::from_value always works")
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn to_member_event(&self) -> Raw<StateEvent<RoomMemberEventContent>> {
|
||||
let mut json = json!({
|
||||
"content": self.content,
|
||||
"type": self.kind,
|
||||
"event_id": self.event_id,
|
||||
"sender": self.sender,
|
||||
"origin_server_ts": self.origin_server_ts,
|
||||
"redacts": self.redacts,
|
||||
"room_id": self.room_id,
|
||||
"state_key": self.state_key,
|
||||
});
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn to_member_event(&self) -> Raw<StateEvent<RoomMemberEventContent>> {
|
||||
let mut json = json!({
|
||||
"content": self.content,
|
||||
"type": self.kind,
|
||||
"event_id": self.event_id,
|
||||
"sender": self.sender,
|
||||
"origin_server_ts": self.origin_server_ts,
|
||||
"redacts": self.redacts,
|
||||
"room_id": self.room_id,
|
||||
"state_key": self.state_key,
|
||||
});
|
||||
|
||||
if let Some(unsigned) = &self.unsigned {
|
||||
json["unsigned"] = json!(unsigned);
|
||||
}
|
||||
if let Some(unsigned) = &self.unsigned {
|
||||
json["unsigned"] = json!(unsigned);
|
||||
}
|
||||
|
||||
serde_json::from_value(json).expect("Raw::from_value always works")
|
||||
}
|
||||
serde_json::from_value(json).expect("Raw::from_value always works")
|
||||
}
|
||||
|
||||
/// This does not return a full `Pdu` it is only to satisfy ruma's types.
|
||||
#[tracing::instrument]
|
||||
pub fn convert_to_outgoing_federation_event(
|
||||
mut pdu_json: CanonicalJsonObject,
|
||||
) -> Box<RawJsonValue> {
|
||||
if let Some(unsigned) = pdu_json
|
||||
.get_mut("unsigned")
|
||||
.and_then(|val| val.as_object_mut())
|
||||
{
|
||||
unsigned.remove("transaction_id");
|
||||
}
|
||||
/// This does not return a full `Pdu` it is only to satisfy ruma's types.
|
||||
#[tracing::instrument]
|
||||
pub fn convert_to_outgoing_federation_event(mut pdu_json: CanonicalJsonObject) -> Box<RawJsonValue> {
|
||||
if let Some(unsigned) = pdu_json.get_mut("unsigned").and_then(|val| val.as_object_mut()) {
|
||||
unsigned.remove("transaction_id");
|
||||
}
|
||||
|
||||
pdu_json.remove("event_id");
|
||||
pdu_json.remove("event_id");
|
||||
|
||||
// TODO: another option would be to convert it to a canonical string to validate size
|
||||
// and return a Result<Raw<...>>
|
||||
// serde_json::from_str::<Raw<_>>(
|
||||
// ruma::serde::to_canonical_json_string(pdu_json).expect("CanonicalJson is valid serde_json::Value"),
|
||||
// )
|
||||
// .expect("Raw::from_value always works")
|
||||
// TODO: another option would be to convert it to a canonical string to validate
|
||||
// size and return a Result<Raw<...>>
|
||||
// serde_json::from_str::<Raw<_>>(
|
||||
// ruma::serde::to_canonical_json_string(pdu_json).expect("CanonicalJson is
|
||||
// valid serde_json::Value"), )
|
||||
// .expect("Raw::from_value always works")
|
||||
|
||||
to_raw_value(&pdu_json).expect("CanonicalJson is valid serde_json::Value")
|
||||
}
|
||||
to_raw_value(&pdu_json).expect("CanonicalJson is valid serde_json::Value")
|
||||
}
|
||||
|
||||
pub fn from_id_val(
|
||||
event_id: &EventId,
|
||||
mut json: CanonicalJsonObject,
|
||||
) -> Result<Self, serde_json::Error> {
|
||||
json.insert(
|
||||
"event_id".to_owned(),
|
||||
CanonicalJsonValue::String(event_id.as_str().to_owned()),
|
||||
);
|
||||
pub fn from_id_val(event_id: &EventId, mut json: CanonicalJsonObject) -> Result<Self, serde_json::Error> {
|
||||
json.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.as_str().to_owned()));
|
||||
|
||||
serde_json::from_value(serde_json::to_value(json).expect("valid JSON"))
|
||||
}
|
||||
serde_json::from_value(serde_json::to_value(json).expect("valid JSON"))
|
||||
}
|
||||
}
|
||||
|
||||
impl state_res::Event for PduEvent {
|
||||
type Id = Arc<EventId>;
|
||||
type Id = Arc<EventId>;
|
||||
|
||||
fn event_id(&self) -> &Self::Id {
|
||||
&self.event_id
|
||||
}
|
||||
fn event_id(&self) -> &Self::Id { &self.event_id }
|
||||
|
||||
fn room_id(&self) -> &RoomId {
|
||||
&self.room_id
|
||||
}
|
||||
fn room_id(&self) -> &RoomId { &self.room_id }
|
||||
|
||||
fn sender(&self) -> &UserId {
|
||||
&self.sender
|
||||
}
|
||||
fn sender(&self) -> &UserId { &self.sender }
|
||||
|
||||
fn event_type(&self) -> &TimelineEventType {
|
||||
&self.kind
|
||||
}
|
||||
fn event_type(&self) -> &TimelineEventType { &self.kind }
|
||||
|
||||
fn content(&self) -> &RawJsonValue {
|
||||
&self.content
|
||||
}
|
||||
fn content(&self) -> &RawJsonValue { &self.content }
|
||||
|
||||
fn origin_server_ts(&self) -> MilliSecondsSinceUnixEpoch {
|
||||
MilliSecondsSinceUnixEpoch(self.origin_server_ts)
|
||||
}
|
||||
fn origin_server_ts(&self) -> MilliSecondsSinceUnixEpoch { MilliSecondsSinceUnixEpoch(self.origin_server_ts) }
|
||||
|
||||
fn state_key(&self) -> Option<&str> {
|
||||
self.state_key.as_deref()
|
||||
}
|
||||
fn state_key(&self) -> Option<&str> { self.state_key.as_deref() }
|
||||
|
||||
fn prev_events(&self) -> Box<dyn DoubleEndedIterator<Item = &Self::Id> + '_> {
|
||||
Box::new(self.prev_events.iter())
|
||||
}
|
||||
fn prev_events(&self) -> Box<dyn DoubleEndedIterator<Item = &Self::Id> + '_> { Box::new(self.prev_events.iter()) }
|
||||
|
||||
fn auth_events(&self) -> Box<dyn DoubleEndedIterator<Item = &Self::Id> + '_> {
|
||||
Box::new(self.auth_events.iter())
|
||||
}
|
||||
fn auth_events(&self) -> Box<dyn DoubleEndedIterator<Item = &Self::Id> + '_> { Box::new(self.auth_events.iter()) }
|
||||
|
||||
fn redacts(&self) -> Option<&Self::Id> {
|
||||
self.redacts.as_ref()
|
||||
}
|
||||
fn redacts(&self) -> Option<&Self::Id> { self.redacts.as_ref() }
|
||||
}
|
||||
|
||||
// These impl's allow us to dedup state snapshots when resolving state
|
||||
// for incoming events (federation/send/{txn}).
|
||||
impl Eq for PduEvent {}
|
||||
impl PartialEq for PduEvent {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.event_id == other.event_id
|
||||
}
|
||||
fn eq(&self, other: &Self) -> bool { self.event_id == other.event_id }
|
||||
}
|
||||
impl PartialOrd for PduEvent {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> { Some(self.cmp(other)) }
|
||||
}
|
||||
impl Ord for PduEvent {
|
||||
fn cmp(&self, other: &Self) -> Ordering {
|
||||
self.event_id.cmp(&other.event_id)
|
||||
}
|
||||
fn cmp(&self, other: &Self) -> Ordering { self.event_id.cmp(&other.event_id) }
|
||||
}
|
||||
|
||||
/// Generates a correct eventId for the incoming pdu.
|
||||
///
|
||||
/// Returns a tuple of the new `EventId` and the PDU as a `BTreeMap<String, CanonicalJsonValue>`.
|
||||
/// Returns a tuple of the new `EventId` and the PDU as a `BTreeMap<String,
|
||||
/// CanonicalJsonValue>`.
|
||||
pub(crate) fn gen_event_id_canonical_json(
|
||||
pdu: &RawJsonValue,
|
||||
room_version_id: &RoomVersionId,
|
||||
pdu: &RawJsonValue, room_version_id: &RoomVersionId,
|
||||
) -> crate::Result<(OwnedEventId, CanonicalJsonObject)> {
|
||||
let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| {
|
||||
warn!("Error parsing incoming event {:?}: {:?}", pdu, e);
|
||||
Error::BadServerResponse("Invalid PDU in server response")
|
||||
})?;
|
||||
let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| {
|
||||
warn!("Error parsing incoming event {:?}: {:?}", pdu, e);
|
||||
Error::BadServerResponse("Invalid PDU in server response")
|
||||
})?;
|
||||
|
||||
let event_id = format!(
|
||||
"${}",
|
||||
// Anything higher than version3 behaves the same
|
||||
ruma::signatures::reference_hash(&value, room_version_id)
|
||||
.expect("ruma can calculate reference hashes")
|
||||
)
|
||||
.try_into()
|
||||
.expect("ruma's reference hashes are valid event ids");
|
||||
let event_id = format!(
|
||||
"${}",
|
||||
// Anything higher than version3 behaves the same
|
||||
ruma::signatures::reference_hash(&value, room_version_id).expect("ruma can calculate reference hashes")
|
||||
)
|
||||
.try_into()
|
||||
.expect("ruma's reference hashes are valid event ids");
|
||||
|
||||
Ok((event_id, value))
|
||||
Ok((event_id, value))
|
||||
}
|
||||
|
||||
/// Build the start of a PDU in order to add it to the Database.
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct PduBuilder {
|
||||
#[serde(rename = "type")]
|
||||
pub event_type: TimelineEventType,
|
||||
pub content: Box<RawJsonValue>,
|
||||
pub unsigned: Option<BTreeMap<String, serde_json::Value>>,
|
||||
pub state_key: Option<String>,
|
||||
pub redacts: Option<Arc<EventId>>,
|
||||
#[serde(rename = "type")]
|
||||
pub event_type: TimelineEventType,
|
||||
pub content: Box<RawJsonValue>,
|
||||
pub unsigned: Option<BTreeMap<String, serde_json::Value>>,
|
||||
pub state_key: Option<String>,
|
||||
pub redacts: Option<Arc<EventId>>,
|
||||
}
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
use crate::Result;
|
||||
use ruma::{
|
||||
api::client::push::{set_pusher, Pusher},
|
||||
UserId,
|
||||
api::client::push::{set_pusher, Pusher},
|
||||
UserId,
|
||||
};
|
||||
|
||||
use crate::Result;
|
||||
|
||||
pub trait Data: Send + Sync {
|
||||
fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::PusherAction) -> Result<()>;
|
||||
fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::PusherAction) -> Result<()>;
|
||||
|
||||
fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result<Option<Pusher>>;
|
||||
fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result<Option<Pusher>>;
|
||||
|
||||
fn get_pushers(&self, sender: &UserId) -> Result<Vec<Pusher>>;
|
||||
fn get_pushers(&self, sender: &UserId) -> Result<Vec<Pusher>>;
|
||||
|
||||
fn get_pushkeys<'a>(&'a self, sender: &UserId)
|
||||
-> Box<dyn Iterator<Item = Result<String>> + 'a>;
|
||||
fn get_pushkeys<'a>(&'a self, sender: &UserId) -> Box<dyn Iterator<Item = Result<String>> + 'a>;
|
||||
}
|
||||
|
||||
+193
-249
@@ -1,292 +1,236 @@
|
||||
mod data;
|
||||
pub use data::Data;
|
||||
use ruma::{events::AnySyncTimelineEvent, push::PushConditionPowerLevelsCtx};
|
||||
|
||||
use crate::{services, Error, PduEvent, Result};
|
||||
use bytes::BytesMut;
|
||||
use ruma::{
|
||||
api::{
|
||||
client::push::{set_pusher, Pusher, PusherKind},
|
||||
push_gateway::send_event_notification::{
|
||||
self,
|
||||
v1::{Device, Notification, NotificationCounts, NotificationPriority},
|
||||
},
|
||||
IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken,
|
||||
},
|
||||
events::{room::power_levels::RoomPowerLevelsEventContent, StateEventType, TimelineEventType},
|
||||
push::{Action, PushConditionRoomCtx, PushFormat, Ruleset, Tweak},
|
||||
serde::Raw,
|
||||
uint, RoomId, UInt, UserId,
|
||||
};
|
||||
|
||||
use std::{fmt::Debug, mem};
|
||||
|
||||
use bytes::BytesMut;
|
||||
pub use data::Data;
|
||||
use ruma::{
|
||||
api::{
|
||||
client::push::{set_pusher, Pusher, PusherKind},
|
||||
push_gateway::send_event_notification::{
|
||||
self,
|
||||
v1::{Device, Notification, NotificationCounts, NotificationPriority},
|
||||
},
|
||||
IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken,
|
||||
},
|
||||
events::{
|
||||
room::power_levels::RoomPowerLevelsEventContent, AnySyncTimelineEvent, StateEventType, TimelineEventType,
|
||||
},
|
||||
push::{Action, PushConditionPowerLevelsCtx, PushConditionRoomCtx, PushFormat, Ruleset, Tweak},
|
||||
serde::Raw,
|
||||
uint, RoomId, UInt, UserId,
|
||||
};
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::{services, Error, PduEvent, Result};
|
||||
|
||||
pub struct Service {
|
||||
pub db: &'static dyn Data,
|
||||
pub db: &'static dyn Data,
|
||||
}
|
||||
|
||||
impl Service {
|
||||
pub fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::PusherAction) -> Result<()> {
|
||||
self.db.set_pusher(sender, pusher)
|
||||
}
|
||||
pub fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::PusherAction) -> Result<()> {
|
||||
self.db.set_pusher(sender, pusher)
|
||||
}
|
||||
|
||||
pub fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result<Option<Pusher>> {
|
||||
self.db.get_pusher(sender, pushkey)
|
||||
}
|
||||
pub fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result<Option<Pusher>> {
|
||||
self.db.get_pusher(sender, pushkey)
|
||||
}
|
||||
|
||||
pub fn get_pushers(&self, sender: &UserId) -> Result<Vec<Pusher>> {
|
||||
self.db.get_pushers(sender)
|
||||
}
|
||||
pub fn get_pushers(&self, sender: &UserId) -> Result<Vec<Pusher>> { self.db.get_pushers(sender) }
|
||||
|
||||
pub fn get_pushkeys(&self, sender: &UserId) -> Box<dyn Iterator<Item = Result<String>>> {
|
||||
self.db.get_pushkeys(sender)
|
||||
}
|
||||
pub fn get_pushkeys(&self, sender: &UserId) -> Box<dyn Iterator<Item = Result<String>>> {
|
||||
self.db.get_pushkeys(sender)
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, destination, request))]
|
||||
pub async fn send_request<T>(
|
||||
&self,
|
||||
destination: &str,
|
||||
request: T,
|
||||
) -> Result<T::IncomingResponse>
|
||||
where
|
||||
T: OutgoingRequest + Debug,
|
||||
{
|
||||
let destination = destination.replace(services().globals.notification_push_path(), "");
|
||||
#[tracing::instrument(skip(self, destination, request))]
|
||||
pub async fn send_request<T>(&self, destination: &str, request: T) -> Result<T::IncomingResponse>
|
||||
where
|
||||
T: OutgoingRequest + Debug,
|
||||
{
|
||||
let destination = destination.replace(services().globals.notification_push_path(), "");
|
||||
|
||||
let http_request = request
|
||||
.try_into_http_request::<BytesMut>(
|
||||
&destination,
|
||||
SendAccessToken::IfRequired(""),
|
||||
&[MatrixVersion::V1_0],
|
||||
)
|
||||
.map_err(|e| {
|
||||
warn!("Failed to find destination {}: {}", destination, e);
|
||||
Error::BadServerResponse("Invalid destination")
|
||||
})?
|
||||
.map(bytes::BytesMut::freeze);
|
||||
let http_request = request
|
||||
.try_into_http_request::<BytesMut>(&destination, SendAccessToken::IfRequired(""), &[MatrixVersion::V1_0])
|
||||
.map_err(|e| {
|
||||
warn!("Failed to find destination {}: {}", destination, e);
|
||||
Error::BadServerResponse("Invalid destination")
|
||||
})?
|
||||
.map(bytes::BytesMut::freeze);
|
||||
|
||||
let reqwest_request = reqwest::Request::try_from(http_request)?;
|
||||
let reqwest_request = reqwest::Request::try_from(http_request)?;
|
||||
|
||||
// TODO: we could keep this very short and let expo backoff do it's thing...
|
||||
//*reqwest_request.timeout_mut() = Some(Duration::from_secs(5));
|
||||
// TODO: we could keep this very short and let expo backoff do it's thing...
|
||||
//*reqwest_request.timeout_mut() = Some(Duration::from_secs(5));
|
||||
|
||||
let url = reqwest_request.url().clone();
|
||||
let response = services()
|
||||
.globals
|
||||
.default_client()
|
||||
.execute(reqwest_request)
|
||||
.await;
|
||||
let url = reqwest_request.url().clone();
|
||||
let response = services().globals.default_client().execute(reqwest_request).await;
|
||||
|
||||
match response {
|
||||
Ok(mut response) => {
|
||||
// reqwest::Response -> http::Response conversion
|
||||
let status = response.status();
|
||||
let mut http_response_builder = http::Response::builder()
|
||||
.status(status)
|
||||
.version(response.version());
|
||||
mem::swap(
|
||||
response.headers_mut(),
|
||||
http_response_builder
|
||||
.headers_mut()
|
||||
.expect("http::response::Builder is usable"),
|
||||
);
|
||||
match response {
|
||||
Ok(mut response) => {
|
||||
// reqwest::Response -> http::Response conversion
|
||||
let status = response.status();
|
||||
let mut http_response_builder = http::Response::builder().status(status).version(response.version());
|
||||
mem::swap(
|
||||
response.headers_mut(),
|
||||
http_response_builder.headers_mut().expect("http::response::Builder is usable"),
|
||||
);
|
||||
|
||||
let body = response.bytes().await.unwrap_or_else(|e| {
|
||||
warn!("server error {}", e);
|
||||
Vec::new().into()
|
||||
}); // TODO: handle timeout
|
||||
let body = response.bytes().await.unwrap_or_else(|e| {
|
||||
warn!("server error {}", e);
|
||||
Vec::new().into()
|
||||
}); // TODO: handle timeout
|
||||
|
||||
if !status.is_success() {
|
||||
info!(
|
||||
"Push gateway returned bad response {} {}\n{}\n{:?}",
|
||||
destination,
|
||||
status,
|
||||
url,
|
||||
crate::utils::string_from_bytes(&body)
|
||||
);
|
||||
}
|
||||
if !status.is_success() {
|
||||
info!(
|
||||
"Push gateway returned bad response {} {}\n{}\n{:?}",
|
||||
destination,
|
||||
status,
|
||||
url,
|
||||
crate::utils::string_from_bytes(&body)
|
||||
);
|
||||
}
|
||||
|
||||
let response = T::IncomingResponse::try_from_http_response(
|
||||
http_response_builder
|
||||
.body(body)
|
||||
.expect("reqwest body is valid http body"),
|
||||
);
|
||||
response.map_err(|_| {
|
||||
info!(
|
||||
"Push gateway returned invalid response bytes {}\n{}",
|
||||
destination, url
|
||||
);
|
||||
Error::BadServerResponse("Push gateway returned bad response.")
|
||||
})
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Could not send request to pusher {}: {}", destination, e);
|
||||
Err(e.into())
|
||||
}
|
||||
}
|
||||
}
|
||||
let response = T::IncomingResponse::try_from_http_response(
|
||||
http_response_builder.body(body).expect("reqwest body is valid http body"),
|
||||
);
|
||||
response.map_err(|_| {
|
||||
info!("Push gateway returned invalid response bytes {}\n{}", destination, url);
|
||||
Error::BadServerResponse("Push gateway returned bad response.")
|
||||
})
|
||||
},
|
||||
Err(e) => {
|
||||
warn!("Could not send request to pusher {}: {}", destination, e);
|
||||
Err(e.into())
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, user, unread, pusher, ruleset, pdu))]
|
||||
pub async fn send_push_notice(
|
||||
&self,
|
||||
user: &UserId,
|
||||
unread: UInt,
|
||||
pusher: &Pusher,
|
||||
ruleset: Ruleset,
|
||||
pdu: &PduEvent,
|
||||
) -> Result<()> {
|
||||
let mut notify = None;
|
||||
let mut tweaks = Vec::new();
|
||||
#[tracing::instrument(skip(self, user, unread, pusher, ruleset, pdu))]
|
||||
pub async fn send_push_notice(
|
||||
&self, user: &UserId, unread: UInt, pusher: &Pusher, ruleset: Ruleset, pdu: &PduEvent,
|
||||
) -> Result<()> {
|
||||
let mut notify = None;
|
||||
let mut tweaks = Vec::new();
|
||||
|
||||
let power_levels: RoomPowerLevelsEventContent = services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")?
|
||||
.map(|ev| {
|
||||
serde_json::from_str(ev.content.get())
|
||||
.map_err(|_| Error::bad_database("invalid m.room.power_levels event"))
|
||||
})
|
||||
.transpose()?
|
||||
.unwrap_or_default();
|
||||
let power_levels: RoomPowerLevelsEventContent = services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")?
|
||||
.map(|ev| {
|
||||
serde_json::from_str(ev.content.get())
|
||||
.map_err(|_| Error::bad_database("invalid m.room.power_levels event"))
|
||||
})
|
||||
.transpose()?
|
||||
.unwrap_or_default();
|
||||
|
||||
for action in self.get_actions(
|
||||
user,
|
||||
&ruleset,
|
||||
&power_levels,
|
||||
&pdu.to_sync_room_event(),
|
||||
&pdu.room_id,
|
||||
)? {
|
||||
let n = match action {
|
||||
Action::Notify => true,
|
||||
Action::SetTweak(tweak) => {
|
||||
tweaks.push(tweak.clone());
|
||||
continue;
|
||||
}
|
||||
_ => false,
|
||||
};
|
||||
for action in self.get_actions(user, &ruleset, &power_levels, &pdu.to_sync_room_event(), &pdu.room_id)? {
|
||||
let n = match action {
|
||||
Action::Notify => true,
|
||||
Action::SetTweak(tweak) => {
|
||||
tweaks.push(tweak.clone());
|
||||
continue;
|
||||
},
|
||||
_ => false,
|
||||
};
|
||||
|
||||
if notify.is_some() {
|
||||
return Err(Error::bad_database(
|
||||
r#"Malformed pushrule contains more than one of these actions: ["dont_notify", "notify", "coalesce"]"#,
|
||||
));
|
||||
}
|
||||
if notify.is_some() {
|
||||
return Err(Error::bad_database(
|
||||
r#"Malformed pushrule contains more than one of these actions: ["dont_notify", "notify", "coalesce"]"#,
|
||||
));
|
||||
}
|
||||
|
||||
notify = Some(n);
|
||||
}
|
||||
notify = Some(n);
|
||||
}
|
||||
|
||||
if notify == Some(true) {
|
||||
self.send_notice(unread, pusher, tweaks, pdu).await?;
|
||||
}
|
||||
// Else the event triggered no actions
|
||||
if notify == Some(true) {
|
||||
self.send_notice(unread, pusher, tweaks, pdu).await?;
|
||||
}
|
||||
// Else the event triggered no actions
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, user, ruleset, pdu))]
|
||||
pub fn get_actions<'a>(
|
||||
&self,
|
||||
user: &UserId,
|
||||
ruleset: &'a Ruleset,
|
||||
power_levels: &RoomPowerLevelsEventContent,
|
||||
pdu: &Raw<AnySyncTimelineEvent>,
|
||||
room_id: &RoomId,
|
||||
) -> Result<&'a [Action]> {
|
||||
let power_levels = PushConditionPowerLevelsCtx {
|
||||
users: power_levels.users.clone(),
|
||||
users_default: power_levels.users_default,
|
||||
notifications: power_levels.notifications.clone(),
|
||||
};
|
||||
#[tracing::instrument(skip(self, user, ruleset, pdu))]
|
||||
pub fn get_actions<'a>(
|
||||
&self, user: &UserId, ruleset: &'a Ruleset, power_levels: &RoomPowerLevelsEventContent,
|
||||
pdu: &Raw<AnySyncTimelineEvent>, room_id: &RoomId,
|
||||
) -> Result<&'a [Action]> {
|
||||
let power_levels = PushConditionPowerLevelsCtx {
|
||||
users: power_levels.users.clone(),
|
||||
users_default: power_levels.users_default,
|
||||
notifications: power_levels.notifications.clone(),
|
||||
};
|
||||
|
||||
let ctx = PushConditionRoomCtx {
|
||||
room_id: room_id.to_owned(),
|
||||
member_count: UInt::from(
|
||||
services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.room_joined_count(room_id)?
|
||||
.unwrap_or(1) as u32,
|
||||
),
|
||||
user_id: user.to_owned(),
|
||||
user_display_name: services()
|
||||
.users
|
||||
.displayname(user)?
|
||||
.unwrap_or_else(|| user.localpart().to_owned()),
|
||||
power_levels: Some(power_levels),
|
||||
};
|
||||
let ctx = PushConditionRoomCtx {
|
||||
room_id: room_id.to_owned(),
|
||||
member_count: UInt::from(services().rooms.state_cache.room_joined_count(room_id)?.unwrap_or(1) as u32),
|
||||
user_id: user.to_owned(),
|
||||
user_display_name: services().users.displayname(user)?.unwrap_or_else(|| user.localpart().to_owned()),
|
||||
power_levels: Some(power_levels),
|
||||
};
|
||||
|
||||
Ok(ruleset.get_actions(pdu, &ctx))
|
||||
}
|
||||
Ok(ruleset.get_actions(pdu, &ctx))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, unread, pusher, tweaks, event))]
|
||||
async fn send_notice(
|
||||
&self,
|
||||
unread: UInt,
|
||||
pusher: &Pusher,
|
||||
tweaks: Vec<Tweak>,
|
||||
event: &PduEvent,
|
||||
) -> Result<()> {
|
||||
// TODO: email
|
||||
match &pusher.kind {
|
||||
PusherKind::Http(http) => {
|
||||
// TODO:
|
||||
// Two problems with this
|
||||
// 1. if "event_id_only" is the only format kind it seems we should never add more info
|
||||
// 2. can pusher/devices have conflicting formats
|
||||
let event_id_only = http.format == Some(PushFormat::EventIdOnly);
|
||||
#[tracing::instrument(skip(self, unread, pusher, tweaks, event))]
|
||||
async fn send_notice(&self, unread: UInt, pusher: &Pusher, tweaks: Vec<Tweak>, event: &PduEvent) -> Result<()> {
|
||||
// TODO: email
|
||||
match &pusher.kind {
|
||||
PusherKind::Http(http) => {
|
||||
// TODO:
|
||||
// Two problems with this
|
||||
// 1. if "event_id_only" is the only format kind it seems we should never add
|
||||
// more info
|
||||
// 2. can pusher/devices have conflicting formats
|
||||
let event_id_only = http.format == Some(PushFormat::EventIdOnly);
|
||||
|
||||
let mut device = Device::new(pusher.ids.app_id.clone(), pusher.ids.pushkey.clone());
|
||||
device.data.default_payload = http.default_payload.clone();
|
||||
device.data.format = http.format.clone();
|
||||
let mut device = Device::new(pusher.ids.app_id.clone(), pusher.ids.pushkey.clone());
|
||||
device.data.default_payload = http.default_payload.clone();
|
||||
device.data.format = http.format.clone();
|
||||
|
||||
// Tweaks are only added if the format is NOT event_id_only
|
||||
if !event_id_only {
|
||||
device.tweaks = tweaks.clone();
|
||||
}
|
||||
// Tweaks are only added if the format is NOT event_id_only
|
||||
if !event_id_only {
|
||||
device.tweaks = tweaks.clone();
|
||||
}
|
||||
|
||||
let d = vec![device];
|
||||
let mut notifi = Notification::new(d);
|
||||
let d = vec![device];
|
||||
let mut notifi = Notification::new(d);
|
||||
|
||||
notifi.prio = NotificationPriority::Low;
|
||||
notifi.event_id = Some((*event.event_id).to_owned());
|
||||
notifi.room_id = Some((*event.room_id).to_owned());
|
||||
// TODO: missed calls
|
||||
notifi.counts = NotificationCounts::new(unread, uint!(0));
|
||||
notifi.prio = NotificationPriority::Low;
|
||||
notifi.event_id = Some((*event.event_id).to_owned());
|
||||
notifi.room_id = Some((*event.room_id).to_owned());
|
||||
// TODO: missed calls
|
||||
notifi.counts = NotificationCounts::new(unread, uint!(0));
|
||||
|
||||
if event.kind == TimelineEventType::RoomEncrypted
|
||||
|| tweaks
|
||||
.iter()
|
||||
.any(|t| matches!(t, Tweak::Highlight(true) | Tweak::Sound(_)))
|
||||
{
|
||||
notifi.prio = NotificationPriority::High;
|
||||
}
|
||||
if event.kind == TimelineEventType::RoomEncrypted
|
||||
|| tweaks.iter().any(|t| matches!(t, Tweak::Highlight(true) | Tweak::Sound(_)))
|
||||
{
|
||||
notifi.prio = NotificationPriority::High;
|
||||
}
|
||||
|
||||
if event_id_only {
|
||||
self.send_request(&http.url, send_event_notification::v1::Request::new(notifi))
|
||||
.await?;
|
||||
} else {
|
||||
notifi.sender = Some(event.sender.clone());
|
||||
notifi.event_type = Some(event.kind.clone());
|
||||
notifi.content = serde_json::value::to_raw_value(&event.content).ok();
|
||||
if event_id_only {
|
||||
self.send_request(&http.url, send_event_notification::v1::Request::new(notifi)).await?;
|
||||
} else {
|
||||
notifi.sender = Some(event.sender.clone());
|
||||
notifi.event_type = Some(event.kind.clone());
|
||||
notifi.content = serde_json::value::to_raw_value(&event.content).ok();
|
||||
|
||||
if event.kind == TimelineEventType::RoomMember {
|
||||
notifi.user_is_target =
|
||||
event.state_key.as_deref() == Some(event.sender.as_str());
|
||||
}
|
||||
if event.kind == TimelineEventType::RoomMember {
|
||||
notifi.user_is_target = event.state_key.as_deref() == Some(event.sender.as_str());
|
||||
}
|
||||
|
||||
notifi.sender_display_name = services().users.displayname(&event.sender)?;
|
||||
notifi.sender_display_name = services().users.displayname(&event.sender)?;
|
||||
|
||||
notifi.room_name = services().rooms.state_accessor.get_name(&event.room_id)?;
|
||||
notifi.room_name = services().rooms.state_accessor.get_name(&event.room_id)?;
|
||||
|
||||
self.send_request(&http.url, send_event_notification::v1::Request::new(notifi))
|
||||
.await?;
|
||||
}
|
||||
self.send_request(&http.url, send_event_notification::v1::Request::new(notifi)).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
// TODO: Handle email
|
||||
//PusherKind::Email(_) => Ok(()),
|
||||
_ => Ok(()),
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
},
|
||||
// TODO: Handle email
|
||||
//PusherKind::Email(_) => Ok(()),
|
||||
_ => Ok(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,24 +1,22 @@
|
||||
use crate::Result;
|
||||
use ruma::{OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId};
|
||||
|
||||
use crate::Result;
|
||||
|
||||
pub trait Data: Send + Sync {
|
||||
/// Creates or updates the alias to the given room id.
|
||||
fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId) -> Result<()>;
|
||||
/// Creates or updates the alias to the given room id.
|
||||
fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId) -> Result<()>;
|
||||
|
||||
/// Forgets about an alias. Returns an error if the alias did not exist.
|
||||
fn remove_alias(&self, alias: &RoomAliasId) -> Result<()>;
|
||||
/// Forgets about an alias. Returns an error if the alias did not exist.
|
||||
fn remove_alias(&self, alias: &RoomAliasId) -> Result<()>;
|
||||
|
||||
/// Looks up the roomid for the given alias.
|
||||
fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result<Option<OwnedRoomId>>;
|
||||
/// Looks up the roomid for the given alias.
|
||||
fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result<Option<OwnedRoomId>>;
|
||||
|
||||
/// Returns all local aliases that point to the given room
|
||||
fn local_aliases_for_room<'a>(
|
||||
&'a self,
|
||||
room_id: &RoomId,
|
||||
) -> Box<dyn Iterator<Item = Result<OwnedRoomAliasId>> + 'a>;
|
||||
/// Returns all local aliases that point to the given room
|
||||
fn local_aliases_for_room<'a>(
|
||||
&'a self, room_id: &RoomId,
|
||||
) -> Box<dyn Iterator<Item = Result<OwnedRoomAliasId>> + 'a>;
|
||||
|
||||
/// Returns all local aliases on the server
|
||||
fn all_local_aliases<'a>(
|
||||
&'a self,
|
||||
) -> Box<dyn Iterator<Item = Result<(OwnedRoomId, String)>> + 'a>;
|
||||
/// Returns all local aliases on the server
|
||||
fn all_local_aliases<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(OwnedRoomId, String)>> + 'a>;
|
||||
}
|
||||
|
||||
@@ -1,42 +1,35 @@
|
||||
mod data;
|
||||
|
||||
pub use data::Data;
|
||||
|
||||
use crate::Result;
|
||||
use ruma::{OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId};
|
||||
|
||||
use crate::Result;
|
||||
|
||||
pub struct Service {
|
||||
pub db: &'static dyn Data,
|
||||
pub db: &'static dyn Data,
|
||||
}
|
||||
|
||||
impl Service {
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId) -> Result<()> {
|
||||
self.db.set_alias(alias, room_id)
|
||||
}
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId) -> Result<()> { self.db.set_alias(alias, room_id) }
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn remove_alias(&self, alias: &RoomAliasId) -> Result<()> {
|
||||
self.db.remove_alias(alias)
|
||||
}
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn remove_alias(&self, alias: &RoomAliasId) -> Result<()> { self.db.remove_alias(alias) }
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result<Option<OwnedRoomId>> {
|
||||
self.db.resolve_local_alias(alias)
|
||||
}
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result<Option<OwnedRoomId>> {
|
||||
self.db.resolve_local_alias(alias)
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn local_aliases_for_room<'a>(
|
||||
&'a self,
|
||||
room_id: &RoomId,
|
||||
) -> Box<dyn Iterator<Item = Result<OwnedRoomAliasId>> + 'a> {
|
||||
self.db.local_aliases_for_room(room_id)
|
||||
}
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn local_aliases_for_room<'a>(
|
||||
&'a self, room_id: &RoomId,
|
||||
) -> Box<dyn Iterator<Item = Result<OwnedRoomAliasId>> + 'a> {
|
||||
self.db.local_aliases_for_room(room_id)
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn all_local_aliases<'a>(
|
||||
&'a self,
|
||||
) -> Box<dyn Iterator<Item = Result<(OwnedRoomId, String)>> + 'a> {
|
||||
self.db.all_local_aliases()
|
||||
}
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn all_local_aliases<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(OwnedRoomId, String)>> + 'a> {
|
||||
self.db.all_local_aliases()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,8 @@
|
||||
use crate::Result;
|
||||
use std::{collections::HashSet, sync::Arc};
|
||||
|
||||
use crate::Result;
|
||||
|
||||
pub trait Data: Send + Sync {
|
||||
fn get_cached_eventid_authchain(
|
||||
&self,
|
||||
shorteventid: &[u64],
|
||||
) -> Result<Option<Arc<HashSet<u64>>>>;
|
||||
fn cache_auth_chain(&self, shorteventid: Vec<u64>, auth_chain: Arc<HashSet<u64>>)
|
||||
-> Result<()>;
|
||||
fn get_cached_eventid_authchain(&self, shorteventid: &[u64]) -> Result<Option<Arc<HashSet<u64>>>>;
|
||||
fn cache_auth_chain(&self, shorteventid: Vec<u64>, auth_chain: Arc<HashSet<u64>>) -> Result<()>;
|
||||
}
|
||||
|
||||
+110
-131
@@ -1,7 +1,7 @@
|
||||
mod data;
|
||||
use std::{
|
||||
collections::{BTreeSet, HashSet},
|
||||
sync::Arc,
|
||||
collections::{BTreeSet, HashSet},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
pub use data::Data;
|
||||
@@ -11,151 +11,130 @@ use tracing::{debug, error, warn};
|
||||
use crate::{services, Error, Result};
|
||||
|
||||
pub struct Service {
|
||||
pub db: &'static dyn Data,
|
||||
pub db: &'static dyn Data,
|
||||
}
|
||||
|
||||
impl Service {
|
||||
pub fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Option<Arc<HashSet<u64>>>> {
|
||||
self.db.get_cached_eventid_authchain(key)
|
||||
}
|
||||
pub fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Option<Arc<HashSet<u64>>>> {
|
||||
self.db.get_cached_eventid_authchain(key)
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn cache_auth_chain(&self, key: Vec<u64>, auth_chain: Arc<HashSet<u64>>) -> Result<()> {
|
||||
self.db.cache_auth_chain(key, auth_chain)
|
||||
}
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn cache_auth_chain(&self, key: Vec<u64>, auth_chain: Arc<HashSet<u64>>) -> Result<()> {
|
||||
self.db.cache_auth_chain(key, auth_chain)
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, starting_events))]
|
||||
pub async fn get_auth_chain<'a>(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
starting_events: Vec<Arc<EventId>>,
|
||||
) -> Result<impl Iterator<Item = Arc<EventId>> + 'a> {
|
||||
const NUM_BUCKETS: usize = 50;
|
||||
#[tracing::instrument(skip(self, starting_events))]
|
||||
pub async fn get_auth_chain<'a>(
|
||||
&self, room_id: &RoomId, starting_events: Vec<Arc<EventId>>,
|
||||
) -> Result<impl Iterator<Item = Arc<EventId>> + 'a> {
|
||||
const NUM_BUCKETS: usize = 50;
|
||||
|
||||
let mut buckets = vec![BTreeSet::new(); NUM_BUCKETS];
|
||||
let mut buckets = vec![BTreeSet::new(); NUM_BUCKETS];
|
||||
|
||||
let mut i = 0;
|
||||
for id in starting_events {
|
||||
let short = services().rooms.short.get_or_create_shorteventid(&id)?;
|
||||
let bucket_id = (short % NUM_BUCKETS as u64) as usize;
|
||||
buckets[bucket_id].insert((short, id.clone()));
|
||||
i += 1;
|
||||
if i % 100 == 0 {
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
}
|
||||
let mut i = 0;
|
||||
for id in starting_events {
|
||||
let short = services().rooms.short.get_or_create_shorteventid(&id)?;
|
||||
let bucket_id = (short % NUM_BUCKETS as u64) as usize;
|
||||
buckets[bucket_id].insert((short, id.clone()));
|
||||
i += 1;
|
||||
if i % 100 == 0 {
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
}
|
||||
|
||||
let mut full_auth_chain = HashSet::new();
|
||||
let mut full_auth_chain = HashSet::new();
|
||||
|
||||
let mut hits = 0;
|
||||
let mut misses = 0;
|
||||
for chunk in buckets {
|
||||
if chunk.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let mut hits = 0;
|
||||
let mut misses = 0;
|
||||
for chunk in buckets {
|
||||
if chunk.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let chunk_key: Vec<u64> = chunk.iter().map(|(short, _)| short).copied().collect();
|
||||
if let Some(cached) = services()
|
||||
.rooms
|
||||
.auth_chain
|
||||
.get_cached_eventid_authchain(&chunk_key)?
|
||||
{
|
||||
hits += 1;
|
||||
full_auth_chain.extend(cached.iter().copied());
|
||||
continue;
|
||||
}
|
||||
misses += 1;
|
||||
let chunk_key: Vec<u64> = chunk.iter().map(|(short, _)| short).copied().collect();
|
||||
if let Some(cached) = services().rooms.auth_chain.get_cached_eventid_authchain(&chunk_key)? {
|
||||
hits += 1;
|
||||
full_auth_chain.extend(cached.iter().copied());
|
||||
continue;
|
||||
}
|
||||
misses += 1;
|
||||
|
||||
let mut chunk_cache = HashSet::new();
|
||||
let mut hits2 = 0;
|
||||
let mut misses2 = 0;
|
||||
let mut i = 0;
|
||||
for (sevent_id, event_id) in chunk {
|
||||
if let Some(cached) = services()
|
||||
.rooms
|
||||
.auth_chain
|
||||
.get_cached_eventid_authchain(&[sevent_id])?
|
||||
{
|
||||
hits2 += 1;
|
||||
chunk_cache.extend(cached.iter().copied());
|
||||
} else {
|
||||
misses2 += 1;
|
||||
let auth_chain = Arc::new(self.get_auth_chain_inner(room_id, &event_id)?);
|
||||
services()
|
||||
.rooms
|
||||
.auth_chain
|
||||
.cache_auth_chain(vec![sevent_id], Arc::clone(&auth_chain))?;
|
||||
debug!(
|
||||
event_id = ?event_id,
|
||||
chain_length = ?auth_chain.len(),
|
||||
"Cache missed event"
|
||||
);
|
||||
chunk_cache.extend(auth_chain.iter());
|
||||
let mut chunk_cache = HashSet::new();
|
||||
let mut hits2 = 0;
|
||||
let mut misses2 = 0;
|
||||
let mut i = 0;
|
||||
for (sevent_id, event_id) in chunk {
|
||||
if let Some(cached) = services().rooms.auth_chain.get_cached_eventid_authchain(&[sevent_id])? {
|
||||
hits2 += 1;
|
||||
chunk_cache.extend(cached.iter().copied());
|
||||
} else {
|
||||
misses2 += 1;
|
||||
let auth_chain = Arc::new(self.get_auth_chain_inner(room_id, &event_id)?);
|
||||
services().rooms.auth_chain.cache_auth_chain(vec![sevent_id], Arc::clone(&auth_chain))?;
|
||||
debug!(
|
||||
event_id = ?event_id,
|
||||
chain_length = ?auth_chain.len(),
|
||||
"Cache missed event"
|
||||
);
|
||||
chunk_cache.extend(auth_chain.iter());
|
||||
|
||||
i += 1;
|
||||
if i % 100 == 0 {
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
};
|
||||
}
|
||||
debug!(
|
||||
chunk_cache_length = ?chunk_cache.len(),
|
||||
hits = ?hits2,
|
||||
misses = ?misses2,
|
||||
"Chunk missed",
|
||||
);
|
||||
let chunk_cache = Arc::new(chunk_cache);
|
||||
services()
|
||||
.rooms
|
||||
.auth_chain
|
||||
.cache_auth_chain(chunk_key, Arc::clone(&chunk_cache))?;
|
||||
full_auth_chain.extend(chunk_cache.iter());
|
||||
}
|
||||
i += 1;
|
||||
if i % 100 == 0 {
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
};
|
||||
}
|
||||
debug!(
|
||||
chunk_cache_length = ?chunk_cache.len(),
|
||||
hits = ?hits2,
|
||||
misses = ?misses2,
|
||||
"Chunk missed",
|
||||
);
|
||||
let chunk_cache = Arc::new(chunk_cache);
|
||||
services().rooms.auth_chain.cache_auth_chain(chunk_key, Arc::clone(&chunk_cache))?;
|
||||
full_auth_chain.extend(chunk_cache.iter());
|
||||
}
|
||||
|
||||
debug!(
|
||||
chain_length = ?full_auth_chain.len(),
|
||||
hits = ?hits,
|
||||
misses = ?misses,
|
||||
"Auth chain stats",
|
||||
);
|
||||
debug!(
|
||||
chain_length = ?full_auth_chain.len(),
|
||||
hits = ?hits,
|
||||
misses = ?misses,
|
||||
"Auth chain stats",
|
||||
);
|
||||
|
||||
Ok(full_auth_chain
|
||||
.into_iter()
|
||||
.filter_map(move |sid| services().rooms.short.get_eventid_from_short(sid).ok()))
|
||||
}
|
||||
Ok(full_auth_chain.into_iter().filter_map(move |sid| services().rooms.short.get_eventid_from_short(sid).ok()))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, event_id))]
|
||||
fn get_auth_chain_inner(&self, room_id: &RoomId, event_id: &EventId) -> Result<HashSet<u64>> {
|
||||
let mut todo = vec![Arc::from(event_id)];
|
||||
let mut found = HashSet::new();
|
||||
#[tracing::instrument(skip(self, event_id))]
|
||||
fn get_auth_chain_inner(&self, room_id: &RoomId, event_id: &EventId) -> Result<HashSet<u64>> {
|
||||
let mut todo = vec![Arc::from(event_id)];
|
||||
let mut found = HashSet::new();
|
||||
|
||||
while let Some(event_id) = todo.pop() {
|
||||
match services().rooms.timeline.get_pdu(&event_id) {
|
||||
Ok(Some(pdu)) => {
|
||||
if pdu.room_id != room_id {
|
||||
return Err(Error::BadRequest(ErrorKind::Forbidden, "Evil event in db"));
|
||||
}
|
||||
for auth_event in &pdu.auth_events {
|
||||
let sauthevent = services()
|
||||
.rooms
|
||||
.short
|
||||
.get_or_create_shorteventid(auth_event)?;
|
||||
while let Some(event_id) = todo.pop() {
|
||||
match services().rooms.timeline.get_pdu(&event_id) {
|
||||
Ok(Some(pdu)) => {
|
||||
if pdu.room_id != room_id {
|
||||
return Err(Error::BadRequest(ErrorKind::Forbidden, "Evil event in db"));
|
||||
}
|
||||
for auth_event in &pdu.auth_events {
|
||||
let sauthevent = services().rooms.short.get_or_create_shorteventid(auth_event)?;
|
||||
|
||||
if !found.contains(&sauthevent) {
|
||||
found.insert(sauthevent);
|
||||
todo.push(auth_event.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(None) => {
|
||||
warn!(?event_id, "Could not find pdu mentioned in auth events");
|
||||
}
|
||||
Err(error) => {
|
||||
error!(?event_id, ?error, "Could not load event in auth chain");
|
||||
}
|
||||
}
|
||||
}
|
||||
if !found.contains(&sauthevent) {
|
||||
found.insert(sauthevent);
|
||||
todo.push(auth_event.clone());
|
||||
}
|
||||
}
|
||||
},
|
||||
Ok(None) => {
|
||||
warn!(?event_id, "Could not find pdu mentioned in auth events");
|
||||
},
|
||||
Err(error) => {
|
||||
error!(?event_id, ?error, "Could not load event in auth chain");
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
Ok(found)
|
||||
}
|
||||
Ok(found)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
use crate::Result;
|
||||
use ruma::{OwnedRoomId, RoomId};
|
||||
|
||||
use crate::Result;
|
||||
|
||||
pub trait Data: Send + Sync {
|
||||
/// Adds the room to the public room directory
|
||||
fn set_public(&self, room_id: &RoomId) -> Result<()>;
|
||||
/// Adds the room to the public room directory
|
||||
fn set_public(&self, room_id: &RoomId) -> Result<()>;
|
||||
|
||||
/// Removes the room from the public room directory.
|
||||
fn set_not_public(&self, room_id: &RoomId) -> Result<()>;
|
||||
/// Removes the room from the public room directory.
|
||||
fn set_not_public(&self, room_id: &RoomId) -> Result<()>;
|
||||
|
||||
/// Returns true if the room is in the public room directory.
|
||||
fn is_public_room(&self, room_id: &RoomId) -> Result<bool>;
|
||||
/// Returns true if the room is in the public room directory.
|
||||
fn is_public_room(&self, room_id: &RoomId) -> Result<bool>;
|
||||
|
||||
/// Returns the unsorted public room directory
|
||||
fn public_rooms<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a>;
|
||||
/// Returns the unsorted public room directory
|
||||
fn public_rooms<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a>;
|
||||
}
|
||||
|
||||
@@ -6,27 +6,19 @@ use ruma::{OwnedRoomId, RoomId};
|
||||
use crate::Result;
|
||||
|
||||
pub struct Service {
|
||||
pub db: &'static dyn Data,
|
||||
pub db: &'static dyn Data,
|
||||
}
|
||||
|
||||
impl Service {
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn set_public(&self, room_id: &RoomId) -> Result<()> {
|
||||
self.db.set_public(room_id)
|
||||
}
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn set_public(&self, room_id: &RoomId) -> Result<()> { self.db.set_public(room_id) }
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn set_not_public(&self, room_id: &RoomId) -> Result<()> {
|
||||
self.db.set_not_public(room_id)
|
||||
}
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn set_not_public(&self, room_id: &RoomId) -> Result<()> { self.db.set_not_public(room_id) }
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn is_public_room(&self, room_id: &RoomId) -> Result<bool> {
|
||||
self.db.is_public_room(room_id)
|
||||
}
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn is_public_room(&self, room_id: &RoomId) -> Result<bool> { self.db.is_public_room(room_id) }
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn public_rooms(&self) -> impl Iterator<Item = Result<OwnedRoomId>> + '_ {
|
||||
self.db.public_rooms()
|
||||
}
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn public_rooms(&self) -> impl Iterator<Item = Result<OwnedRoomId>> + '_ { self.db.public_rooms() }
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ pub mod typing;
|
||||
pub trait Data: presence::Data + read_receipt::Data + typing::Data + 'static {}
|
||||
|
||||
pub struct Service {
|
||||
pub presence: presence::Service,
|
||||
pub read_receipt: read_receipt::Service,
|
||||
pub typing: typing::Service,
|
||||
pub presence: presence::Service,
|
||||
pub read_receipt: read_receipt::Service,
|
||||
pub typing: typing::Service,
|
||||
}
|
||||
|
||||
@@ -1,33 +1,27 @@
|
||||
use ruma::{events::presence::PresenceEvent, presence::PresenceState, OwnedUserId, RoomId, UInt, UserId};
|
||||
|
||||
use crate::Result;
|
||||
use ruma::{
|
||||
events::presence::PresenceEvent, presence::PresenceState, OwnedUserId, RoomId, UInt, UserId,
|
||||
};
|
||||
|
||||
pub trait Data: Send + Sync {
|
||||
/// Returns the latest presence event for the given user in the given room.
|
||||
fn get_presence(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<PresenceEvent>>;
|
||||
/// Returns the latest presence event for the given user in the given room.
|
||||
fn get_presence(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<PresenceEvent>>;
|
||||
|
||||
/// Pings the presence of the given user in the given room, setting the specified state.
|
||||
fn ping_presence(&self, user_id: &UserId, new_state: PresenceState) -> Result<()>;
|
||||
/// Pings the presence of the given user in the given room, setting the
|
||||
/// specified state.
|
||||
fn ping_presence(&self, user_id: &UserId, new_state: PresenceState) -> Result<()>;
|
||||
|
||||
/// Adds a presence event which will be saved until a new event replaces it.
|
||||
fn set_presence(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
user_id: &UserId,
|
||||
presence_state: PresenceState,
|
||||
currently_active: Option<bool>,
|
||||
last_active_ago: Option<UInt>,
|
||||
status_msg: Option<String>,
|
||||
) -> Result<()>;
|
||||
/// Adds a presence event which will be saved until a new event replaces it.
|
||||
fn set_presence(
|
||||
&self, room_id: &RoomId, user_id: &UserId, presence_state: PresenceState, currently_active: Option<bool>,
|
||||
last_active_ago: Option<UInt>, status_msg: Option<String>,
|
||||
) -> Result<()>;
|
||||
|
||||
/// Removes the presence record for the given user from the database.
|
||||
fn remove_presence(&self, user_id: &UserId) -> Result<()>;
|
||||
/// Removes the presence record for the given user from the database.
|
||||
fn remove_presence(&self, user_id: &UserId) -> Result<()>;
|
||||
|
||||
/// Returns the most recent presence updates that happened after the event with id `since`.
|
||||
fn presence_since<'a>(
|
||||
&'a self,
|
||||
room_id: &RoomId,
|
||||
since: u64,
|
||||
) -> Box<dyn Iterator<Item = (OwnedUserId, u64, PresenceEvent)> + 'a>;
|
||||
/// Returns the most recent presence updates that happened after the event
|
||||
/// with id `since`.
|
||||
fn presence_since<'a>(
|
||||
&'a self, room_id: &RoomId, since: u64,
|
||||
) -> Box<dyn Iterator<Item = (OwnedUserId, u64, PresenceEvent)> + 'a>;
|
||||
}
|
||||
|
||||
@@ -5,9 +5,9 @@ use std::time::Duration;
|
||||
pub use data::Data;
|
||||
use futures_util::{stream::FuturesUnordered, StreamExt};
|
||||
use ruma::{
|
||||
events::presence::{PresenceEvent, PresenceEventContent},
|
||||
presence::PresenceState,
|
||||
OwnedUserId, RoomId, UInt, UserId,
|
||||
events::presence::{PresenceEvent, PresenceEventContent},
|
||||
presence::PresenceState,
|
||||
OwnedUserId, RoomId, UInt, UserId,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::{sync::mpsc, time::sleep};
|
||||
@@ -15,197 +15,164 @@ use tracing::debug;
|
||||
|
||||
use crate::{services, utils, Error, Result};
|
||||
|
||||
/// Represents data required to be kept in order to implement the presence specification.
|
||||
/// Represents data required to be kept in order to implement the presence
|
||||
/// specification.
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct Presence {
|
||||
pub state: PresenceState,
|
||||
pub currently_active: bool,
|
||||
pub last_active_ts: u64,
|
||||
pub last_count: u64,
|
||||
pub status_msg: Option<String>,
|
||||
pub state: PresenceState,
|
||||
pub currently_active: bool,
|
||||
pub last_active_ts: u64,
|
||||
pub last_count: u64,
|
||||
pub status_msg: Option<String>,
|
||||
}
|
||||
|
||||
impl Presence {
|
||||
pub fn new(
|
||||
state: PresenceState,
|
||||
currently_active: bool,
|
||||
last_active_ts: u64,
|
||||
last_count: u64,
|
||||
status_msg: Option<String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
state,
|
||||
currently_active,
|
||||
last_active_ts,
|
||||
last_count,
|
||||
status_msg,
|
||||
}
|
||||
}
|
||||
pub fn new(
|
||||
state: PresenceState, currently_active: bool, last_active_ts: u64, last_count: u64, status_msg: Option<String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
state,
|
||||
currently_active,
|
||||
last_active_ts,
|
||||
last_count,
|
||||
status_msg,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_json_bytes(bytes: &[u8]) -> Result<Self> {
|
||||
serde_json::from_slice(bytes)
|
||||
.map_err(|_| Error::bad_database("Invalid presence data in database"))
|
||||
}
|
||||
pub fn from_json_bytes(bytes: &[u8]) -> Result<Self> {
|
||||
serde_json::from_slice(bytes).map_err(|_| Error::bad_database("Invalid presence data in database"))
|
||||
}
|
||||
|
||||
pub fn to_json_bytes(&self) -> Result<Vec<u8>> {
|
||||
serde_json::to_vec(self)
|
||||
.map_err(|_| Error::bad_database("Could not serialize Presence to JSON"))
|
||||
}
|
||||
pub fn to_json_bytes(&self) -> Result<Vec<u8>> {
|
||||
serde_json::to_vec(self).map_err(|_| Error::bad_database("Could not serialize Presence to JSON"))
|
||||
}
|
||||
|
||||
/// Creates a PresenceEvent from available data.
|
||||
pub fn to_presence_event(&self, user_id: &UserId) -> Result<PresenceEvent> {
|
||||
let now = utils::millis_since_unix_epoch();
|
||||
let last_active_ago = if self.currently_active {
|
||||
None
|
||||
} else {
|
||||
Some(UInt::new_saturating(
|
||||
now.saturating_sub(self.last_active_ts),
|
||||
))
|
||||
};
|
||||
/// Creates a PresenceEvent from available data.
|
||||
pub fn to_presence_event(&self, user_id: &UserId) -> Result<PresenceEvent> {
|
||||
let now = utils::millis_since_unix_epoch();
|
||||
let last_active_ago = if self.currently_active {
|
||||
None
|
||||
} else {
|
||||
Some(UInt::new_saturating(now.saturating_sub(self.last_active_ts)))
|
||||
};
|
||||
|
||||
Ok(PresenceEvent {
|
||||
sender: user_id.to_owned(),
|
||||
content: PresenceEventContent {
|
||||
presence: self.state.clone(),
|
||||
status_msg: self.status_msg.clone(),
|
||||
currently_active: Some(self.currently_active),
|
||||
last_active_ago,
|
||||
displayname: services().users.displayname(user_id)?,
|
||||
avatar_url: services().users.avatar_url(user_id)?,
|
||||
},
|
||||
})
|
||||
}
|
||||
Ok(PresenceEvent {
|
||||
sender: user_id.to_owned(),
|
||||
content: PresenceEventContent {
|
||||
presence: self.state.clone(),
|
||||
status_msg: self.status_msg.clone(),
|
||||
currently_active: Some(self.currently_active),
|
||||
last_active_ago,
|
||||
displayname: services().users.displayname(user_id)?,
|
||||
avatar_url: services().users.avatar_url(user_id)?,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Service {
|
||||
pub db: &'static dyn Data,
|
||||
pub db: &'static dyn Data,
|
||||
}
|
||||
|
||||
impl Service {
|
||||
/// Returns the latest presence event for the given user in the given room.
|
||||
pub fn get_presence(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
user_id: &UserId,
|
||||
) -> Result<Option<PresenceEvent>> {
|
||||
self.db.get_presence(room_id, user_id)
|
||||
}
|
||||
/// Returns the latest presence event for the given user in the given room.
|
||||
pub fn get_presence(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<PresenceEvent>> {
|
||||
self.db.get_presence(room_id, user_id)
|
||||
}
|
||||
|
||||
/// Pings the presence of the given user in the given room, setting the specified state.
|
||||
pub fn ping_presence(&self, user_id: &UserId, new_state: PresenceState) -> Result<()> {
|
||||
self.db.ping_presence(user_id, new_state)
|
||||
}
|
||||
/// Pings the presence of the given user in the given room, setting the
|
||||
/// specified state.
|
||||
pub fn ping_presence(&self, user_id: &UserId, new_state: PresenceState) -> Result<()> {
|
||||
self.db.ping_presence(user_id, new_state)
|
||||
}
|
||||
|
||||
/// Adds a presence event which will be saved until a new event replaces it.
|
||||
pub fn set_presence(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
user_id: &UserId,
|
||||
presence_state: PresenceState,
|
||||
currently_active: Option<bool>,
|
||||
last_active_ago: Option<UInt>,
|
||||
status_msg: Option<String>,
|
||||
) -> Result<()> {
|
||||
self.db.set_presence(
|
||||
room_id,
|
||||
user_id,
|
||||
presence_state,
|
||||
currently_active,
|
||||
last_active_ago,
|
||||
status_msg,
|
||||
)
|
||||
}
|
||||
/// Adds a presence event which will be saved until a new event replaces it.
|
||||
pub fn set_presence(
|
||||
&self, room_id: &RoomId, user_id: &UserId, presence_state: PresenceState, currently_active: Option<bool>,
|
||||
last_active_ago: Option<UInt>, status_msg: Option<String>,
|
||||
) -> Result<()> {
|
||||
self.db.set_presence(room_id, user_id, presence_state, currently_active, last_active_ago, status_msg)
|
||||
}
|
||||
|
||||
/// Removes the presence record for the given user from the database.
|
||||
pub fn remove_presence(&self, user_id: &UserId) -> Result<()> {
|
||||
self.db.remove_presence(user_id)
|
||||
}
|
||||
/// Removes the presence record for the given user from the database.
|
||||
pub fn remove_presence(&self, user_id: &UserId) -> Result<()> { self.db.remove_presence(user_id) }
|
||||
|
||||
/// Returns the most recent presence updates that happened after the event with id `since`.
|
||||
pub fn presence_since(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
since: u64,
|
||||
) -> Box<dyn Iterator<Item = (OwnedUserId, u64, PresenceEvent)>> {
|
||||
self.db.presence_since(room_id, since)
|
||||
}
|
||||
/// Returns the most recent presence updates that happened after the event
|
||||
/// with id `since`.
|
||||
pub fn presence_since(
|
||||
&self, room_id: &RoomId, since: u64,
|
||||
) -> Box<dyn Iterator<Item = (OwnedUserId, u64, PresenceEvent)>> {
|
||||
self.db.presence_since(room_id, since)
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn presence_handler(
|
||||
mut presence_timer_receiver: mpsc::UnboundedReceiver<(OwnedUserId, Duration)>,
|
||||
mut presence_timer_receiver: mpsc::UnboundedReceiver<(OwnedUserId, Duration)>,
|
||||
) -> Result<()> {
|
||||
let mut presence_timers = FuturesUnordered::new();
|
||||
let mut presence_timers = FuturesUnordered::new();
|
||||
|
||||
loop {
|
||||
debug!("Number of presence timers: {}", presence_timers.len());
|
||||
loop {
|
||||
debug!("Number of presence timers: {}", presence_timers.len());
|
||||
|
||||
tokio::select! {
|
||||
Some((user_id, timeout)) = presence_timer_receiver.recv() => {
|
||||
debug!("Adding timer for user '{user_id}': Timeout {timeout:?}");
|
||||
presence_timers.push(presence_timer(user_id, timeout));
|
||||
}
|
||||
tokio::select! {
|
||||
Some((user_id, timeout)) = presence_timer_receiver.recv() => {
|
||||
debug!("Adding timer for user '{user_id}': Timeout {timeout:?}");
|
||||
presence_timers.push(presence_timer(user_id, timeout));
|
||||
}
|
||||
|
||||
Some(user_id) = presence_timers.next() => {
|
||||
process_presence_timer(user_id)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(user_id) = presence_timers.next() => {
|
||||
process_presence_timer(user_id)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn presence_timer(user_id: OwnedUserId, timeout: Duration) -> OwnedUserId {
|
||||
sleep(timeout).await;
|
||||
sleep(timeout).await;
|
||||
|
||||
user_id
|
||||
user_id
|
||||
}
|
||||
|
||||
fn process_presence_timer(user_id: OwnedUserId) -> Result<()> {
|
||||
let idle_timeout = services().globals.config.presence_idle_timeout_s * 1_000;
|
||||
let offline_timeout = services().globals.config.presence_offline_timeout_s * 1_000;
|
||||
let idle_timeout = services().globals.config.presence_idle_timeout_s * 1_000;
|
||||
let offline_timeout = services().globals.config.presence_offline_timeout_s * 1_000;
|
||||
|
||||
let mut presence_state = PresenceState::Offline;
|
||||
let mut last_active_ago = None;
|
||||
let mut status_msg = None;
|
||||
let mut presence_state = PresenceState::Offline;
|
||||
let mut last_active_ago = None;
|
||||
let mut status_msg = None;
|
||||
|
||||
for room_id in services().rooms.state_cache.rooms_joined(&user_id) {
|
||||
let presence_event = services()
|
||||
.rooms
|
||||
.edus
|
||||
.presence
|
||||
.get_presence(&room_id?, &user_id)?;
|
||||
for room_id in services().rooms.state_cache.rooms_joined(&user_id) {
|
||||
let presence_event = services().rooms.edus.presence.get_presence(&room_id?, &user_id)?;
|
||||
|
||||
if let Some(presence_event) = presence_event {
|
||||
presence_state = presence_event.content.presence;
|
||||
last_active_ago = presence_event.content.last_active_ago;
|
||||
status_msg = presence_event.content.status_msg;
|
||||
if let Some(presence_event) = presence_event {
|
||||
presence_state = presence_event.content.presence;
|
||||
last_active_ago = presence_event.content.last_active_ago;
|
||||
status_msg = presence_event.content.status_msg;
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let new_state = match (&presence_state, last_active_ago.map(u64::from)) {
|
||||
(PresenceState::Online, Some(ago)) if ago >= idle_timeout => {
|
||||
Some(PresenceState::Unavailable)
|
||||
}
|
||||
(PresenceState::Unavailable, Some(ago)) if ago >= offline_timeout => {
|
||||
Some(PresenceState::Offline)
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
let new_state = match (&presence_state, last_active_ago.map(u64::from)) {
|
||||
(PresenceState::Online, Some(ago)) if ago >= idle_timeout => Some(PresenceState::Unavailable),
|
||||
(PresenceState::Unavailable, Some(ago)) if ago >= offline_timeout => Some(PresenceState::Offline),
|
||||
_ => None,
|
||||
};
|
||||
|
||||
debug!("Processed presence timer for user '{user_id}': Old state = {presence_state}, New state = {new_state:?}");
|
||||
debug!("Processed presence timer for user '{user_id}': Old state = {presence_state}, New state = {new_state:?}");
|
||||
|
||||
if let Some(new_state) = new_state {
|
||||
for room_id in services().rooms.state_cache.rooms_joined(&user_id) {
|
||||
services().rooms.edus.presence.set_presence(
|
||||
&room_id?,
|
||||
&user_id,
|
||||
new_state.clone(),
|
||||
Some(false),
|
||||
last_active_ago,
|
||||
status_msg.clone(),
|
||||
)?;
|
||||
}
|
||||
}
|
||||
if let Some(new_state) = new_state {
|
||||
for room_id in services().rooms.state_cache.rooms_joined(&user_id) {
|
||||
services().rooms.edus.presence.set_presence(
|
||||
&room_id?,
|
||||
&user_id,
|
||||
new_state.clone(),
|
||||
Some(false),
|
||||
last_active_ago,
|
||||
status_msg.clone(),
|
||||
)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user