diff --git a/Cargo.lock b/Cargo.lock index 006fd2fa5..cb68967fe 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1012,7 +1012,6 @@ dependencies = [ "hyper", "ipaddress", "itertools 0.14.0", - "lettre", "log", "rand 0.10.0", "reqwest", diff --git a/conduwuit-example.toml b/conduwuit-example.toml index 8ee7a5a1e..f201cb215 100644 --- a/conduwuit-example.toml +++ b/conduwuit-example.toml @@ -2050,8 +2050,9 @@ # # For most modern mail servers, format the URI like this: # `smtps://username:password@hostname:port` -# Note that you will need to URL-encode the username and password. If your username _is_ -# your email address, you will need to replace the `@` with `%40`. +# Note that you will need to URL-encode the username and password. If your +# username _is_ your email address, you will need to replace the `@` with +# `%40`. # # For a guide on the accepted URI syntax, consult Lettre's documentation: # https://docs.rs/lettre/latest/lettre/transport/smtp/struct.AsyncSmtpTransport.html#method.from_url diff --git a/src/database/maps.rs b/src/database/maps.rs index e113c26a7..cbde6223e 100644 --- a/src/database/maps.rs +++ b/src/database/maps.rs @@ -49,10 +49,6 @@ pub(super) static MAPS: &[Descriptor] = &[ name: "bannedroomids", ..descriptor::RANDOM_SMALL }, - Descriptor { - name: "clientsecret_validationsessionid", - ..descriptor::RANDOM_SMALL - }, Descriptor { name: "disabledroomids", ..descriptor::RANDOM_SMALL @@ -470,12 +466,4 @@ pub(super) static MAPS: &[Descriptor] = &[ name: "userroomid_invitesender", ..descriptor::RANDOM_SMALL }, - Descriptor { - name: "validationsessionid_session", - ..descriptor::RANDOM_SMALL - }, - Descriptor { - name: "validationsessionid_token", - ..descriptor::RANDOM_SMALL - }, ]; diff --git a/src/service/threepid/data.rs b/src/service/threepid/data.rs deleted file mode 100644 index 9bc8663b8..000000000 --- a/src/service/threepid/data.rs +++ /dev/null @@ -1,159 +0,0 @@ -use std::{ - sync::Arc, - time::{Duration, SystemTime}, -}; - -use conduwuit::utils; -use database::{Database, Deserialized, Map}; -use lettre::Address; -use ruma::{ClientSecret, OwnedClientSecret, OwnedSessionId}; -use serde::{Deserialize, Serialize}; - -pub(super) struct Data { - // note: the column names of these maps use `validationsession` instead of `session` - clientsecret_sessionid: Arc, - sessionid_session: Arc, - sessionid_token: Arc, - pub(super) localpart_email: Arc, - pub(super) email_localpart: Arc, -} - -#[derive(Debug, Serialize, Deserialize)] -pub(crate) struct ValidationSession { - /// The session's ID - pub session_id: OwnedSessionId, - /// The email address which is being validated - pub email: Address, - /// The client's supplied client secret - pub client_secret: OwnedClientSecret, - /// Whether the email address has been validated successfully yet - pub(super) has_been_validated: bool, -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub(crate) struct ValidationToken { - pub token: String, - pub issued_at: SystemTime, -} - -impl ValidationToken { - // one hour - const MAX_TOKEN_AGE: Duration = Duration::from_secs(60 * 60); - const RANDOM_TOKEN_LENGTH: usize = 16; - - pub(super) fn new_random() -> Self { - Self { - token: utils::random_string(Self::RANDOM_TOKEN_LENGTH), - issued_at: SystemTime::now(), - } - } - - pub(crate) fn is_valid(&self) -> bool { - let now = SystemTime::now(); - - now.duration_since(self.issued_at) - .is_ok_and(|duration| duration < Self::MAX_TOKEN_AGE) - } -} - -impl PartialEq for ValidationToken { - fn eq(&self, other: &str) -> bool { self.token == other } -} - -impl Data { - pub(super) fn new(db: &Arc) -> Self { - Self { - clientsecret_sessionid: db["clientsecret_validationsessionid"].clone(), - sessionid_session: db["validationsessionid_session"].clone(), - sessionid_token: db["validationsessionid_token"].clone(), - localpart_email: db["localpart_email"].clone(), - email_localpart: db["email_localpart"].clone(), - } - } - - /// Create a validation session. - pub(super) fn create_session( - &self, - email: Address, - session_id: OwnedSessionId, - client_secret: OwnedClientSecret, - token: ValidationToken, - ) { - let session = ValidationSession { - session_id, - client_secret, - email, - has_been_validated: false, - }; - self.clientsecret_sessionid - .insert(&session.session_id, &session.client_secret); - self.sessionid_token.raw_put(&session.session_id, token); - self.sessionid_session - .raw_put(session.session_id.clone(), session); - } - - /// Get a validation session. - pub(super) async fn get_session(&self, session_id: &str) -> Option { - self.sessionid_session - .get(session_id) - .await - .deserialized() - .ok() - } - - /// Get a validation session by client secret. - pub(super) async fn get_session_by_secret( - &self, - client_secret: &ClientSecret, - ) -> Option { - let session_id: String = self - .clientsecret_sessionid - .get(client_secret) - .await - .deserialized() - .ok()?; - - self.get_session(&session_id).await - } - - /// Get the validation token for a validation session, or None if the - /// session does not exist. - pub(super) async fn get_session_validation_token( - &self, - session: &ValidationSession, - ) -> Option { - self.sessionid_token - .get(&session.session_id) - .await - .deserialized() - .ok() - } - - /// Update a session's validation token. - pub(super) fn update_session_validation_token( - &self, - session: &ValidationSession, - token: ValidationToken, - ) { - self.sessionid_token.raw_put(&session.session_id, token); - } - - /// Mark a validation session as valid. - pub(super) async fn mark_session_as_valid(&self, mut session: ValidationSession) { - self.sessionid_token.remove(&session.session_id); - - session.has_been_validated = true; - self.sessionid_session - .raw_put(session.session_id.clone(), session); - } - - /// Remove a validation session. - pub(super) async fn remove_session( - &self, - ValidationSession { session_id, .. }: ValidationSession, - ) { - self.clientsecret_sessionid.remove(&session_id); - self.sessionid_token.remove(&session_id); - self.sessionid_session.remove(&session_id); - } -} diff --git a/src/service/threepid/mod.rs b/src/service/threepid/mod.rs index 00654ab8f..fcd2f1480 100644 --- a/src/service/threepid/mod.rs +++ b/src/service/threepid/mod.rs @@ -1,25 +1,28 @@ -use std::{ - borrow::Cow, - collections::HashMap, - sync::{Arc, Mutex}, -}; +use std::{borrow::Cow, collections::HashMap, sync::Arc}; + +use conduwuit::{Err, Result, result::FlatOk}; +use database::{Deserialized, Map}; +use lettre::{Address, message::Mailbox}; +use ruma::{ClientSecret, OwnedClientSecret, OwnedSessionId, SessionId}; + +mod session; use crate::{ Args, Dep, config, mailer::{self, messages::MessageTemplate}, + threepid::session::{ValidationSessions, ValidationState, ValidationToken}, }; -mod data; -use conduwuit::{Err, Result, result::FlatOk, utils}; -use data::{Data, ValidationToken}; -use database::Deserialized; -use lettre::{Address, message::Mailbox}; -use ruma::{ClientSecret, OwnedClientSecret, OwnedSessionId}; - pub struct Service { db: Data, services: Services, - send_attempts: Mutex>, + sessions: tokio::sync::Mutex, + send_attempts: std::sync::Mutex>, +} + +struct Data { + localpart_email: Arc, + email_localpart: Arc, } struct Services { @@ -30,12 +33,16 @@ struct Services { impl crate::Service for Service { fn build(args: Args<'_>) -> Result> { Ok(Arc::new(Self { - db: Data::new(args.db), + db: Data { + email_localpart: args.db["email_localpart"].clone(), + localpart_email: args.db["localpart_email"].clone(), + }, services: Services { config: args.depend("config"), mailer: args.depend("mailer"), }, - send_attempts: Mutex::new(HashMap::new()), + sessions: tokio::sync::Mutex::default(), + send_attempts: std::sync::Mutex::default(), })) } @@ -43,14 +50,8 @@ impl crate::Service for Service { } impl Service { - const RANDOM_SID_LENGTH: usize = 16; const VALIDATION_URL_PATH: &str = "/_continuwuity/3pid/email/validate"; - #[must_use] - pub fn generate_session_id() -> OwnedSessionId { - OwnedSessionId::parse(utils::random_string(Self::RANDOM_SID_LENGTH)).unwrap() - } - /// Send a validation message to an email address. /// /// Returns the validation session ID on success. @@ -63,62 +64,57 @@ impl Service { send_attempt: usize, ) -> Result { let mailer = self.services.mailer.expect_mailer()?; + let mut sessions = self.sessions.lock().await; - let (session_id, ValidationToken { token, .. }) = - match self.db.get_session_by_secret(client_secret).await { - // If a validation session already exists for this client secret, we can either - // reuse it with a new token or return early because it's already valid. - | Some(session) => { - // If the existing session is already valid, don't send an email. - if session.has_been_validated { - return Ok(session.session_id); - } + let session = match sessions.get_session_by_client_secret(client_secret) { + // If a validation session already exists for this client secret, we can either + // reuse it with a new token or return early because it's already valid. + | Some(session) => { + match session.validation_state { + | ValidationState::Validated => { + // If the existing session is already valid, don't send an email. + return Ok(session.session_id.clone()); + }, + | ValidationState::Pending(ref mut token) => { + // Check the send attempt for this session + let mut send_attempts = self.send_attempts.lock().unwrap(); - let mut send_attempts = self.send_attempts.lock().unwrap(); - match send_attempts - .get_mut(&(session.client_secret.clone(), session.email.clone())) - { - | Some(last_send_attempt) => { - if send_attempt <= *last_send_attempt { - // If the supplied send attempt isn't higher than the last one, - // don't send an email. - return Ok(session.session_id); - } + match send_attempts + .get_mut(&(session.client_secret.clone(), session.email.clone())) + { + | Some(last_send_attempt) => { + if send_attempt <= *last_send_attempt { + // If the supplied send attempt isn't higher than the last + // one, don't send an email. + return Ok(session.session_id.clone()); + } - // Otherwise save the supplied send attempt. - *last_send_attempt = send_attempt; - }, - | None => { - // Default to sending an email if no previous - // attempt could be found. This can happen if - // the server was restarted, which clears the send - // attempt tracker. - }, - } - drop(send_attempts); + // Otherwise save the supplied send attempt. + *last_send_attempt = send_attempt; + }, + | None => { + // Default to sending an email if no previous + // attempt could be found. This can happen if + // the server was restarted, which clears the + // send attempt tracker. + }, + } + drop(send_attempts); - // Create a new token for the existing session. - let token = ValidationToken::new_random(); - self.db - .update_session_validation_token(&session, token.clone()); + // Create a new token for the existing session. + *token = ValidationToken::new_random(); - (session.session_id, token) - }, - // If no session exists, create a new one. - | None => { - let session_id = Self::generate_session_id(); - let token = ValidationToken::new_random(); + session + }, + } + }, + // If no session exists, create a new one. + | None => sessions.create_session(recipient.email.clone(), client_secret.to_owned()), + }; - self.db.create_session( - recipient.email.clone(), - session_id.clone(), - client_secret.to_owned(), - token.clone(), - ); - - (session_id, token) - }, - }; + let ValidationState::Pending(token) = &session.validation_state else { + unreachable!("session should be pending") + }; let mut validation_url = self .services @@ -129,41 +125,46 @@ impl Service { validation_url .query_pairs_mut() - .append_pair("session_id", session_id.as_ref()) - .append_pair("token", &token); + .append_pair("session_id", session.session_id.as_ref()) + .append_pair("token", &token.token); let message = prepare_body(validation_url.to_string()); mailer.send(recipient, message).await?; - Ok(session_id) + Ok(session.session_id.clone()) } /// Attempt to mark a validation session as valid using a validation token. pub async fn try_validate_session( &self, - session_id: &str, + session_id: &SessionId, supplied_token: &str, ) -> Result<(), Cow<'static, str>> { - let Some(session) = self.db.get_session(session_id).await else { + let mut sessions = self.sessions.lock().await; + + let Some(session) = sessions.get_session(session_id) else { return Err("Validation session does not exist".into()); }; - if session.has_been_validated { - return Ok(()); - } + session.validation_state = match &session.validation_state { + | ValidationState::Validated => { + // If the session is already validated, do nothing. - let token = self - .db - .get_session_validation_token(&session) - .await - .expect("valid session should have a token"); + return Ok(()); + }, + | ValidationState::Pending(token) => { + // Otherwise check the token and mark the session as valid. - if token != *supplied_token || !token.is_valid() { - return Err("Validation token is invalid or expired, please request a new one".into()); - } + if *token != *supplied_token || !token.is_valid() { + return Err("Validation token is invalid or expired, please request a new \ + one" + .into()); + } - self.db.mark_session_as_valid(session).await; + ValidationState::Validated + }, + }; Ok(()) } @@ -172,17 +173,21 @@ impl Service { /// and returning the newly validated email address. pub async fn consume_valid_session( &self, - session_id: &str, + session_id: &SessionId, client_secret: &ClientSecret, ) -> Result> { - let Some(session) = self.db.get_session(session_id).await else { + let mut sessions = self.sessions.lock().await; + + let Some(session) = sessions.get_session(session_id) else { return Err("Validation session does not exist".into()); }; - if session.client_secret == client_secret && session.has_been_validated { - let email = session.email.clone(); - self.db.remove_session(session).await; - Ok(email) + if session.client_secret == client_secret + && matches!(session.validation_state, ValidationState::Validated) + { + let session = sessions.remove_session(session_id); + + Ok(session.email) } else { Err("This email address has not been validated. Did you use the link that was sent \ to you?" @@ -198,17 +203,17 @@ impl Service { ) -> Result<()> { match self.get_localpart_for_email(email).await { | Some(existing_localpart) if existing_localpart != localpart => { - // Another account is already using the supplied email + // Another account is already using the supplied email. Err!(Request(ThreepidInUse("This email address is already in use."))) }, | Some(_) => { // The supplied localpart is already associated with the supplied email, - // no changes are necessary + // no changes are necessary. Ok(()) }, | None => { - // The supplied email is not already in use + // The supplied email is not already in use. let email: &str = email.as_ref(); self.db.localpart_email.insert(localpart, email); diff --git a/src/service/threepid/session.rs b/src/service/threepid/session.rs new file mode 100644 index 000000000..dd1ca4674 --- /dev/null +++ b/src/service/threepid/session.rs @@ -0,0 +1,128 @@ +use std::{ + collections::HashMap, + time::{Duration, SystemTime}, +}; + +use conduwuit::utils; +use lettre::Address; +use ruma::{ClientSecret, OwnedClientSecret, OwnedSessionId, SessionId}; + +#[derive(Default)] +pub(super) struct ValidationSessions { + sessions: HashMap, + client_secrets: HashMap, +} + +/// A pending or completed email validation session. +#[derive(Debug)] +pub(crate) struct ValidationSession { + /// The session's ID + pub session_id: OwnedSessionId, + /// The client's supplied client secret + pub client_secret: OwnedClientSecret, + /// The email address which is being validated + pub email: Address, + /// The session's validation state + pub validation_state: ValidationState, +} + +/// The state of an email validation session. +#[derive(Debug)] +pub(crate) enum ValidationState { + /// The session is waiting for this validation token to be provided + Pending(ValidationToken), + /// The session has been validated + Validated, +} + +#[derive(Clone, Debug)] +pub(crate) struct ValidationToken { + pub token: String, + pub issued_at: SystemTime, +} + +impl ValidationToken { + // one hour + const MAX_TOKEN_AGE: Duration = Duration::from_secs(60 * 60); + const RANDOM_TOKEN_LENGTH: usize = 16; + + pub(super) fn new_random() -> Self { + Self { + token: utils::random_string(Self::RANDOM_TOKEN_LENGTH), + issued_at: SystemTime::now(), + } + } + + pub(crate) fn is_valid(&self) -> bool { + let now = SystemTime::now(); + + now.duration_since(self.issued_at) + .is_ok_and(|duration| duration < Self::MAX_TOKEN_AGE) + } +} + +impl PartialEq for ValidationToken { + fn eq(&self, other: &str) -> bool { self.token == other } +} + +impl ValidationSessions { + const RANDOM_SID_LENGTH: usize = 16; + + #[must_use] + pub(super) fn generate_session_id() -> OwnedSessionId { + OwnedSessionId::parse(utils::random_string(Self::RANDOM_SID_LENGTH)).unwrap() + } + + pub(super) fn create_session( + &mut self, + email: Address, + client_secret: OwnedClientSecret, + ) -> &mut ValidationSession { + let session = ValidationSession { + session_id: Self::generate_session_id(), + client_secret, + email, + validation_state: ValidationState::Pending(ValidationToken::new_random()), + }; + + self.client_secrets + .insert(session.client_secret.clone(), session.session_id.clone()); + self.sessions + .entry(session.session_id.clone()) + .insert_entry(session) + .into_mut() + } + + pub(super) fn get_session( + &mut self, + session_id: &SessionId, + ) -> Option<&mut ValidationSession> { + self.sessions.get_mut(session_id) + } + + pub(super) fn get_session_by_client_secret( + &mut self, + client_secret: &ClientSecret, + ) -> Option<&mut ValidationSession> { + let session_id = self.client_secrets.get(client_secret)?; + let session = self + .sessions + .get_mut(session_id) + .expect("session should exist with session id"); + + Some(session) + } + + pub(super) fn remove_session(&mut self, session_id: &SessionId) -> ValidationSession { + let session = self + .sessions + .remove(session_id) + .expect("session ID should exist"); + + self.client_secrets + .remove(&session.client_secret) + .expect("session should have an associated client secret"); + + session + } +} diff --git a/src/service/uiaa/mod.rs b/src/service/uiaa/mod.rs index c81eb2fbb..b74197a44 100644 --- a/src/service/uiaa/mod.rs +++ b/src/service/uiaa/mod.rs @@ -1,4 +1,5 @@ use std::{ + borrow::Cow, collections::{HashMap, HashSet, hash_map::Entry}, sync::Arc, }; @@ -133,7 +134,7 @@ impl Service { /// flows provide different values for known identity information. /// /// Returns the info of the newly created session. - pub async fn create_session( + async fn create_session( &self, flows: Vec, params: Box, @@ -154,11 +155,7 @@ impl Service { } /// Proceed with UIAA authentication given a client's authorization data. - pub async fn continue_session(&self, auth: &AuthData) -> Result { - let Some(session) = auth.session() else { - return Err!(Request(MissingParam("No session provided"))); - }; - + async fn continue_session(&self, auth: &AuthData, session: &str) -> Result { // Hold this lock for the entire function to make sure that, if try_auth() // is called concurrently with the same session, only one call will succeed let mut uiaa_sessions = self.uiaa_sessions.lock().await; @@ -238,7 +235,6 @@ impl Service { /// Perform the full UIAA authentication sequence for a route given its /// authentication data. - #[inline] pub async fn authenticate( &self, auth: &Option, @@ -252,9 +248,26 @@ impl Service { Err(Error::Uiaa(info)) }, - | Some(auth) => match self.continue_session(auth).await? { - | UiaaStatus::Retry(info) => Err(Error::Uiaa(info)), - | UiaaStatus::Success(identity) => Ok(identity), + | Some(auth) => { + let session: Cow<'_, str> = match auth.session() { + | Some(session) => session.into(), + | None => { + // Clients are allowed to send UIAA requests with an auth dict and no + // session if they want to start the UIAA exchange with existing + // authentication data. If that happens, we create a new session + // here. + self.create_session(flows, params, identity) + .await + .session + .unwrap() + .into() + }, + }; + + match self.continue_session(auth, &session).await? { + | UiaaStatus::Retry(info) => Err(Error::Uiaa(info)), + | UiaaStatus::Success(identity) => Ok(identity), + } }, } } @@ -301,7 +314,7 @@ impl Service { match self .services .threepid - .consume_valid_session(sid.as_str(), client_secret) + .consume_valid_session(sid, client_secret) .await { | Ok(email) => {