From 54fd1d313fa5da4647f3c2e03c959edf58151da4 Mon Sep 17 00:00:00 2001 From: Ginger Date: Sat, 21 Mar 2026 20:59:00 -0400 Subject: [PATCH] feat: Implement threepid service --- src/service/mod.rs | 1 + src/service/registration_tokens/mod.rs | 6 +- src/service/services.rs | 4 +- src/service/threepid/data.rs | 158 ++++++++++++++++++ src/service/threepid/mod.rs | 223 +++++++++++++++++++++++++ 5 files changed, 388 insertions(+), 4 deletions(-) create mode 100644 src/service/threepid/data.rs create mode 100644 src/service/threepid/mod.rs diff --git a/src/service/mod.rs b/src/service/mod.rs index 0b14667e3..271614276 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -33,6 +33,7 @@ pub mod rooms; pub mod sending; pub mod server_keys; pub mod sync; +pub mod threepid; pub mod transactions; pub mod uiaa; pub mod users; diff --git a/src/service/registration_tokens/mod.rs b/src/service/registration_tokens/mod.rs index 0d72b2d52..7cccda4d4 100644 --- a/src/service/registration_tokens/mod.rs +++ b/src/service/registration_tokens/mod.rs @@ -13,8 +13,6 @@ use ruma::OwnedUserId; use crate::{Dep, config, firstrun}; -const RANDOM_TOKEN_LENGTH: usize = 16; - pub struct Service { db: Data, services: Services, @@ -103,9 +101,11 @@ impl crate::Service for Service { } impl Service { + const RANDOM_TOKEN_LENGTH: usize = 16; + /// Generate a random string suitable to be used as a registration token. #[must_use] - pub fn generate_token_string() -> String { utils::random_string(RANDOM_TOKEN_LENGTH) } + pub fn generate_token_string() -> String { utils::random_string(Self::RANDOM_TOKEN_LENGTH) } /// Issue a new registration token and save it in the database. pub fn issue_token( diff --git a/src/service/services.rs b/src/service/services.rs index a162c61a1..4b81c4612 100644 --- a/src/service/services.rs +++ b/src/service/services.rs @@ -14,7 +14,7 @@ use crate::{ media, moderation, password_reset, presence, pusher, registration_tokens, resolver, rooms, sending, server_keys, service::{self, Args, Map, Service}, - sync, transactions, uiaa, users, + sync, threepid, transactions, uiaa, users, }; pub struct Services { @@ -40,6 +40,7 @@ pub struct Services { pub server_keys: Arc, pub sync: Arc, pub transactions: Arc, + pub threepid: Arc, pub uiaa: Arc, pub users: Arc, pub moderation: Arc, @@ -114,6 +115,7 @@ impl Services { sending: build!(sending::Service), server_keys: build!(server_keys::Service), sync: build!(sync::Service), + threepid: build!(threepid::Service), transactions: build!(transactions::Service), uiaa: build!(uiaa::Service), users: build!(users::Service), diff --git a/src/service/threepid/data.rs b/src/service/threepid/data.rs new file mode 100644 index 000000000..c99b2798a --- /dev/null +++ b/src/service/threepid/data.rs @@ -0,0 +1,158 @@ +use std::{ + sync::Arc, + time::{Duration, SystemTime}, +}; + +use conduwuit::utils; +use database::{Database, Deserialized, Map}; +use lettre::Address; +use serde::{Deserialize, Serialize}; + +pub(super) struct Data { + // note: the column names of these maps use `validationsession` instead of `session` + clientsecret_sessionid: Arc, + sessionid_session: Arc, + sessionid_token: Arc, + pub(super) localpart_email: Arc, + pub(super) email_localpart: Arc, +} + +#[derive(Debug, Serialize, Deserialize)] +pub(crate) struct ValidationSession { + /// The session's ID + pub session_id: String, + /// The email address which is being validated + pub email: Address, + /// The client's supplied client secret + pub client_secret: String, + /// Whether the email address has been validated successfully yet + pub(super) has_been_validated: bool, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub(crate) struct ValidationToken { + pub token: String, + pub issued_at: SystemTime, +} + +impl ValidationToken { + // one hour + const MAX_TOKEN_AGE: Duration = Duration::from_secs(60 * 60); + const RANDOM_TOKEN_LENGTH: usize = 16; + + pub(super) fn new_random() -> Self { + Self { + token: utils::random_string(Self::RANDOM_TOKEN_LENGTH), + issued_at: SystemTime::now(), + } + } + + pub(crate) fn is_valid(&self) -> bool { + let now = SystemTime::now(); + + now.duration_since(self.issued_at) + .is_ok_and(|duration| duration < Self::MAX_TOKEN_AGE) + } +} + +impl PartialEq for ValidationToken { + fn eq(&self, other: &str) -> bool { self.token == other } +} + +impl Data { + pub(super) fn new(db: &Arc) -> Self { + Self { + clientsecret_sessionid: db["clientsecret_sessionid"].clone(), + sessionid_session: db["sessionid_session"].clone(), + sessionid_token: db["sessionid_token"].clone(), + localpart_email: db["localpart_email"].clone(), + email_localpart: db["email_localpart"].clone(), + } + } + + /// Create a validation session. + pub(super) fn create_session( + &self, + email: Address, + session_id: String, + client_secret: String, + token: ValidationToken, + ) { + let session = ValidationSession { + session_id, + client_secret, + email, + has_been_validated: false, + }; + self.clientsecret_sessionid + .insert(&session.session_id, &session.client_secret); + self.sessionid_token.raw_put(&session.session_id, token); + self.sessionid_session + .raw_put(session.session_id.clone(), session); + } + + /// Get a validation session. + pub(super) async fn get_session(&self, session_id: &str) -> Option { + self.sessionid_session + .get(session_id) + .await + .deserialized() + .ok() + } + + /// Get a validation session by client secret. + pub(super) async fn get_session_by_secret( + &self, + client_secret: &str, + ) -> Option { + let session_id: String = self + .clientsecret_sessionid + .get(client_secret) + .await + .deserialized() + .ok()?; + + self.get_session(&session_id).await + } + + /// Get the validation token for a validation session, or None if the + /// session does not exist. + pub(super) async fn get_session_validation_token( + &self, + session: &ValidationSession, + ) -> Option { + self.sessionid_token + .get(&session.session_id) + .await + .deserialized() + .ok() + } + + /// Update a session's validation token. + pub(super) fn update_session_validation_token( + &self, + session: &ValidationSession, + token: ValidationToken, + ) { + self.sessionid_token.raw_put(&session.session_id, token); + } + + /// Mark a validation session as valid. + pub(super) async fn mark_session_as_valid(&self, mut session: ValidationSession) { + self.sessionid_token.remove(&session.session_id); + + session.has_been_validated = true; + self.sessionid_session + .raw_put(session.session_id.clone(), session); + } + + /// Remove a validation session. + pub(super) async fn remove_session( + &self, + ValidationSession { session_id, .. }: ValidationSession, + ) { + self.clientsecret_sessionid.remove(&session_id); + self.sessionid_token.remove(&session_id); + self.sessionid_session.remove(&session_id); + } +} diff --git a/src/service/threepid/mod.rs b/src/service/threepid/mod.rs new file mode 100644 index 000000000..ef9976267 --- /dev/null +++ b/src/service/threepid/mod.rs @@ -0,0 +1,223 @@ +use std::{ + borrow::Cow, + collections::HashMap, + sync::{Arc, Mutex}, +}; + +use crate::{ + Args, Dep, config, + mailer::{self, messages::MessageTemplate}, +}; + +mod data; +use conduwuit::{Err, Result, utils}; +use data::{Data, ValidationToken}; +use database::Deserialized; +use lettre::{Address, message::Mailbox}; + +pub struct Service { + db: Data, + services: Services, + send_attempts: Mutex>, +} + +struct Services { + config: Dep, + mailer: Dep, +} + +impl crate::Service for Service { + fn build(args: Args<'_>) -> Result> { + Ok(Arc::new(Self { + db: Data::new(args.db), + services: Services { + config: args.depend("config"), + mailer: args.depend("mailer"), + }, + send_attempts: Mutex::new(HashMap::new()), + })) + } + + fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } +} + +impl Service { + const RANDOM_SID_LENGTH: usize = 16; + const VALIDATION_URL_PATH: &str = "/_continuwuity/3pid/email/validate"; + + /// Send a validation message to an email address. + /// + /// Returns the validation session ID on success. + #[allow(clippy::impl_trait_in_params)] + pub async fn send_validation_email( + &self, + recipient: Address, + prepare_body: impl FnOnce(String) -> Template, + client_secret: &str, + send_attempt: usize, + ) -> Result { + let Some(mailer) = self.services.mailer.mailer() else { + return Err!("SMTP is not configured"); + }; + + let (session_id, ValidationToken { token, .. }) = + match self.db.get_session_by_secret(client_secret).await { + // If a validation session already exists for this client secret, we can either + // reuse it with a new token or return early because it's already valid. + | Some(session) => { + // If the existing session is already valid, don't send an email. + if session.has_been_validated { + return Ok(session.session_id); + } + + let mut send_attempts = self.send_attempts.lock().unwrap(); + match send_attempts + .get_mut(&(session.client_secret.clone(), session.email.clone())) + { + | Some(last_send_attempt) => { + if send_attempt <= *last_send_attempt { + // If the supplied send attempt isn't higher than the last one, + // don't send an email. + return Ok(session.session_id); + } + + // Otherwise save the supplied send attempt. + *last_send_attempt = send_attempt; + }, + | None => { + // Default to sending an email if no previous + // attempt could be found. This can happen if + // the server was restarted, which clears the send + // attempt tracker. + }, + } + drop(send_attempts); + + // Create a new token for the existing session. + let token = ValidationToken::new_random(); + self.db + .update_session_validation_token(&session, token.clone()); + + (session.session_id, token) + }, + // If no session exists, create a new one. + | None => { + let session_id = utils::random_string(Self::RANDOM_SID_LENGTH); + let token = ValidationToken::new_random(); + + self.db.create_session( + recipient.clone(), + session_id.clone(), + client_secret.to_owned(), + token.clone(), + ); + + (session_id, token) + }, + }; + + let mut validation_url = self + .services + .config + .get_client_domain() + .join(Self::VALIDATION_URL_PATH) + .unwrap(); + + validation_url + .query_pairs_mut() + .append_pair("session_id", &session_id) + .append_pair("client_secret", client_secret) + .append_pair("token", &token); + + let recipient = Mailbox::new(None, recipient); + let message = prepare_body(validation_url.to_string()); + + mailer.send(recipient, message).await?; + + Ok(session_id) + } + + pub async fn try_validate_session( + &self, + session_id: &str, + client_secret: &str, + supplied_token: &str, + ) -> Result<(), Cow<'static, str>> { + let Some(session) = self.db.get_session(session_id).await else { + return Err("Validation session does not exist".into()); + }; + + if session.has_been_validated { + return Ok(()); + } + + if session.client_secret != client_secret { + return Err("Invalid client secret for session".into()); + } + + let token = self + .db + .get_session_validation_token(&session) + .await + .expect("valid session should have a token"); + if token != *supplied_token || !token.is_valid() { + return Err("Validation token is invalid or expired, please request a new one".into()); + } + + self.db.mark_session_as_valid(session).await; + + Ok(()) + } + + pub async fn consume_valid_session( + &self, + session_id: &str, + client_secret: &str, + ) -> Result> { + let Some(session) = self.db.get_session(session_id).await else { + return Err("Validation session does not exist".into()); + }; + + if session.client_secret == client_secret && session.has_been_validated { + let email = session.email.clone(); + self.db.remove_session(session).await; + Ok(email) + } else { + Err("Validation failed. Did you use the link that was sent to you?".into()) + } + } + + /// Associate a localpart with an email address. + pub fn associate_localpart_email(&self, localpart: &str, email: Address) { + self.db.localpart_email.raw_put(localpart, &email); + self.db.email_localpart.put_raw(email, localpart); + } + + /// Given a localpart, remove its corresponding email address. + /// + /// [`Self::get_localpart_for_email`] may be used if only the email is + /// known. + pub async fn disassociate_localpart_email(&self, localpart: &str) { + let email = self + .get_email_for_localpart(localpart) + .await + .expect("localpart has no email associated"); + self.db.localpart_email.remove(localpart); + self.db.email_localpart.del(&email); + } + + /// Get the email associated with a localpart, if one exists. + pub async fn get_email_for_localpart(&self, localpart: &str) -> Option
{ + self.db + .localpart_email + .get(localpart) + .await + .deserialized() + .ok() + } + + /// Get the localpart associated with an email, if one exists. + pub async fn get_localpart_for_email(&self, email: &Address) -> Option { + self.db.email_localpart.qry(email).await.deserialized().ok() + } +}