feat: Store threepid validation sessions in memory instead of the database

This commit is contained in:
Ginger
2026-03-22 17:42:35 -04:00
committed by Ellis Git
parent 955da3a74f
commit 9d06208a7a
7 changed files with 257 additions and 282 deletions
Generated
-1
View File
@@ -1012,7 +1012,6 @@ dependencies = [
"hyper",
"ipaddress",
"itertools 0.14.0",
"lettre",
"log",
"rand 0.10.0",
"reqwest",
+3 -2
View File
@@ -2050,8 +2050,9 @@
#
# For most modern mail servers, format the URI like this:
# `smtps://username:password@hostname:port`
# Note that you will need to URL-encode the username and password. If your username _is_
# your email address, you will need to replace the `@` with `%40`.
# Note that you will need to URL-encode the username and password. If your
# username _is_ your email address, you will need to replace the `@` with
# `%40`.
#
# For a guide on the accepted URI syntax, consult Lettre's documentation:
# https://docs.rs/lettre/latest/lettre/transport/smtp/struct.AsyncSmtpTransport.html#method.from_url
-12
View File
@@ -49,10 +49,6 @@ pub(super) static MAPS: &[Descriptor] = &[
name: "bannedroomids",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "clientsecret_validationsessionid",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "disabledroomids",
..descriptor::RANDOM_SMALL
@@ -470,12 +466,4 @@ pub(super) static MAPS: &[Descriptor] = &[
name: "userroomid_invitesender",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "validationsessionid_session",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "validationsessionid_token",
..descriptor::RANDOM_SMALL
},
];
-159
View File
@@ -1,159 +0,0 @@
use std::{
sync::Arc,
time::{Duration, SystemTime},
};
use conduwuit::utils;
use database::{Database, Deserialized, Map};
use lettre::Address;
use ruma::{ClientSecret, OwnedClientSecret, OwnedSessionId};
use serde::{Deserialize, Serialize};
pub(super) struct Data {
// note: the column names of these maps use `validationsession` instead of `session`
clientsecret_sessionid: Arc<Map>,
sessionid_session: Arc<Map>,
sessionid_token: Arc<Map>,
pub(super) localpart_email: Arc<Map>,
pub(super) email_localpart: Arc<Map>,
}
#[derive(Debug, Serialize, Deserialize)]
pub(crate) struct ValidationSession {
/// The session's ID
pub session_id: OwnedSessionId,
/// The email address which is being validated
pub email: Address,
/// The client's supplied client secret
pub client_secret: OwnedClientSecret,
/// Whether the email address has been validated successfully yet
pub(super) has_been_validated: bool,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub(crate) struct ValidationToken {
pub token: String,
pub issued_at: SystemTime,
}
impl ValidationToken {
// one hour
const MAX_TOKEN_AGE: Duration = Duration::from_secs(60 * 60);
const RANDOM_TOKEN_LENGTH: usize = 16;
pub(super) fn new_random() -> Self {
Self {
token: utils::random_string(Self::RANDOM_TOKEN_LENGTH),
issued_at: SystemTime::now(),
}
}
pub(crate) fn is_valid(&self) -> bool {
let now = SystemTime::now();
now.duration_since(self.issued_at)
.is_ok_and(|duration| duration < Self::MAX_TOKEN_AGE)
}
}
impl PartialEq<str> for ValidationToken {
fn eq(&self, other: &str) -> bool { self.token == other }
}
impl Data {
pub(super) fn new(db: &Arc<Database>) -> Self {
Self {
clientsecret_sessionid: db["clientsecret_validationsessionid"].clone(),
sessionid_session: db["validationsessionid_session"].clone(),
sessionid_token: db["validationsessionid_token"].clone(),
localpart_email: db["localpart_email"].clone(),
email_localpart: db["email_localpart"].clone(),
}
}
/// Create a validation session.
pub(super) fn create_session(
&self,
email: Address,
session_id: OwnedSessionId,
client_secret: OwnedClientSecret,
token: ValidationToken,
) {
let session = ValidationSession {
session_id,
client_secret,
email,
has_been_validated: false,
};
self.clientsecret_sessionid
.insert(&session.session_id, &session.client_secret);
self.sessionid_token.raw_put(&session.session_id, token);
self.sessionid_session
.raw_put(session.session_id.clone(), session);
}
/// Get a validation session.
pub(super) async fn get_session(&self, session_id: &str) -> Option<ValidationSession> {
self.sessionid_session
.get(session_id)
.await
.deserialized()
.ok()
}
/// Get a validation session by client secret.
pub(super) async fn get_session_by_secret(
&self,
client_secret: &ClientSecret,
) -> Option<ValidationSession> {
let session_id: String = self
.clientsecret_sessionid
.get(client_secret)
.await
.deserialized()
.ok()?;
self.get_session(&session_id).await
}
/// Get the validation token for a validation session, or None if the
/// session does not exist.
pub(super) async fn get_session_validation_token(
&self,
session: &ValidationSession,
) -> Option<ValidationToken> {
self.sessionid_token
.get(&session.session_id)
.await
.deserialized()
.ok()
}
/// Update a session's validation token.
pub(super) fn update_session_validation_token(
&self,
session: &ValidationSession,
token: ValidationToken,
) {
self.sessionid_token.raw_put(&session.session_id, token);
}
/// Mark a validation session as valid.
pub(super) async fn mark_session_as_valid(&self, mut session: ValidationSession) {
self.sessionid_token.remove(&session.session_id);
session.has_been_validated = true;
self.sessionid_session
.raw_put(session.session_id.clone(), session);
}
/// Remove a validation session.
pub(super) async fn remove_session(
&self,
ValidationSession { session_id, .. }: ValidationSession,
) {
self.clientsecret_sessionid.remove(&session_id);
self.sessionid_token.remove(&session_id);
self.sessionid_session.remove(&session_id);
}
}
+102 -97
View File
@@ -1,25 +1,28 @@
use std::{
borrow::Cow,
collections::HashMap,
sync::{Arc, Mutex},
};
use std::{borrow::Cow, collections::HashMap, sync::Arc};
use conduwuit::{Err, Result, result::FlatOk};
use database::{Deserialized, Map};
use lettre::{Address, message::Mailbox};
use ruma::{ClientSecret, OwnedClientSecret, OwnedSessionId, SessionId};
mod session;
use crate::{
Args, Dep, config,
mailer::{self, messages::MessageTemplate},
threepid::session::{ValidationSessions, ValidationState, ValidationToken},
};
mod data;
use conduwuit::{Err, Result, result::FlatOk, utils};
use data::{Data, ValidationToken};
use database::Deserialized;
use lettre::{Address, message::Mailbox};
use ruma::{ClientSecret, OwnedClientSecret, OwnedSessionId};
pub struct Service {
db: Data,
services: Services,
send_attempts: Mutex<HashMap<(OwnedClientSecret, Address), usize>>,
sessions: tokio::sync::Mutex<ValidationSessions>,
send_attempts: std::sync::Mutex<HashMap<(OwnedClientSecret, Address), usize>>,
}
struct Data {
localpart_email: Arc<Map>,
email_localpart: Arc<Map>,
}
struct Services {
@@ -30,12 +33,16 @@ struct Services {
impl crate::Service for Service {
fn build(args: Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
db: Data::new(args.db),
db: Data {
email_localpart: args.db["email_localpart"].clone(),
localpart_email: args.db["localpart_email"].clone(),
},
services: Services {
config: args.depend("config"),
mailer: args.depend("mailer"),
},
send_attempts: Mutex::new(HashMap::new()),
sessions: tokio::sync::Mutex::default(),
send_attempts: std::sync::Mutex::default(),
}))
}
@@ -43,14 +50,8 @@ impl crate::Service for Service {
}
impl Service {
const RANDOM_SID_LENGTH: usize = 16;
const VALIDATION_URL_PATH: &str = "/_continuwuity/3pid/email/validate";
#[must_use]
pub fn generate_session_id() -> OwnedSessionId {
OwnedSessionId::parse(utils::random_string(Self::RANDOM_SID_LENGTH)).unwrap()
}
/// Send a validation message to an email address.
///
/// Returns the validation session ID on success.
@@ -63,62 +64,57 @@ impl Service {
send_attempt: usize,
) -> Result<OwnedSessionId> {
let mailer = self.services.mailer.expect_mailer()?;
let mut sessions = self.sessions.lock().await;
let (session_id, ValidationToken { token, .. }) =
match self.db.get_session_by_secret(client_secret).await {
// If a validation session already exists for this client secret, we can either
// reuse it with a new token or return early because it's already valid.
| Some(session) => {
// If the existing session is already valid, don't send an email.
if session.has_been_validated {
return Ok(session.session_id);
}
let session = match sessions.get_session_by_client_secret(client_secret) {
// If a validation session already exists for this client secret, we can either
// reuse it with a new token or return early because it's already valid.
| Some(session) => {
match session.validation_state {
| ValidationState::Validated => {
// If the existing session is already valid, don't send an email.
return Ok(session.session_id.clone());
},
| ValidationState::Pending(ref mut token) => {
// Check the send attempt for this session
let mut send_attempts = self.send_attempts.lock().unwrap();
let mut send_attempts = self.send_attempts.lock().unwrap();
match send_attempts
.get_mut(&(session.client_secret.clone(), session.email.clone()))
{
| Some(last_send_attempt) => {
if send_attempt <= *last_send_attempt {
// If the supplied send attempt isn't higher than the last one,
// don't send an email.
return Ok(session.session_id);
}
match send_attempts
.get_mut(&(session.client_secret.clone(), session.email.clone()))
{
| Some(last_send_attempt) => {
if send_attempt <= *last_send_attempt {
// If the supplied send attempt isn't higher than the last
// one, don't send an email.
return Ok(session.session_id.clone());
}
// Otherwise save the supplied send attempt.
*last_send_attempt = send_attempt;
},
| None => {
// Default to sending an email if no previous
// attempt could be found. This can happen if
// the server was restarted, which clears the send
// attempt tracker.
},
}
drop(send_attempts);
// Otherwise save the supplied send attempt.
*last_send_attempt = send_attempt;
},
| None => {
// Default to sending an email if no previous
// attempt could be found. This can happen if
// the server was restarted, which clears the
// send attempt tracker.
},
}
drop(send_attempts);
// Create a new token for the existing session.
let token = ValidationToken::new_random();
self.db
.update_session_validation_token(&session, token.clone());
// Create a new token for the existing session.
*token = ValidationToken::new_random();
(session.session_id, token)
},
// If no session exists, create a new one.
| None => {
let session_id = Self::generate_session_id();
let token = ValidationToken::new_random();
session
},
}
},
// If no session exists, create a new one.
| None => sessions.create_session(recipient.email.clone(), client_secret.to_owned()),
};
self.db.create_session(
recipient.email.clone(),
session_id.clone(),
client_secret.to_owned(),
token.clone(),
);
(session_id, token)
},
};
let ValidationState::Pending(token) = &session.validation_state else {
unreachable!("session should be pending")
};
let mut validation_url = self
.services
@@ -129,41 +125,46 @@ impl Service {
validation_url
.query_pairs_mut()
.append_pair("session_id", session_id.as_ref())
.append_pair("token", &token);
.append_pair("session_id", session.session_id.as_ref())
.append_pair("token", &token.token);
let message = prepare_body(validation_url.to_string());
mailer.send(recipient, message).await?;
Ok(session_id)
Ok(session.session_id.clone())
}
/// Attempt to mark a validation session as valid using a validation token.
pub async fn try_validate_session(
&self,
session_id: &str,
session_id: &SessionId,
supplied_token: &str,
) -> Result<(), Cow<'static, str>> {
let Some(session) = self.db.get_session(session_id).await else {
let mut sessions = self.sessions.lock().await;
let Some(session) = sessions.get_session(session_id) else {
return Err("Validation session does not exist".into());
};
if session.has_been_validated {
return Ok(());
}
session.validation_state = match &session.validation_state {
| ValidationState::Validated => {
// If the session is already validated, do nothing.
let token = self
.db
.get_session_validation_token(&session)
.await
.expect("valid session should have a token");
return Ok(());
},
| ValidationState::Pending(token) => {
// Otherwise check the token and mark the session as valid.
if token != *supplied_token || !token.is_valid() {
return Err("Validation token is invalid or expired, please request a new one".into());
}
if *token != *supplied_token || !token.is_valid() {
return Err("Validation token is invalid or expired, please request a new \
one"
.into());
}
self.db.mark_session_as_valid(session).await;
ValidationState::Validated
},
};
Ok(())
}
@@ -172,17 +173,21 @@ impl Service {
/// and returning the newly validated email address.
pub async fn consume_valid_session(
&self,
session_id: &str,
session_id: &SessionId,
client_secret: &ClientSecret,
) -> Result<Address, Cow<'static, str>> {
let Some(session) = self.db.get_session(session_id).await else {
let mut sessions = self.sessions.lock().await;
let Some(session) = sessions.get_session(session_id) else {
return Err("Validation session does not exist".into());
};
if session.client_secret == client_secret && session.has_been_validated {
let email = session.email.clone();
self.db.remove_session(session).await;
Ok(email)
if session.client_secret == client_secret
&& matches!(session.validation_state, ValidationState::Validated)
{
let session = sessions.remove_session(session_id);
Ok(session.email)
} else {
Err("This email address has not been validated. Did you use the link that was sent \
to you?"
@@ -198,17 +203,17 @@ impl Service {
) -> Result<()> {
match self.get_localpart_for_email(email).await {
| Some(existing_localpart) if existing_localpart != localpart => {
// Another account is already using the supplied email
// Another account is already using the supplied email.
Err!(Request(ThreepidInUse("This email address is already in use.")))
},
| Some(_) => {
// The supplied localpart is already associated with the supplied email,
// no changes are necessary
// no changes are necessary.
Ok(())
},
| None => {
// The supplied email is not already in use
// The supplied email is not already in use.
let email: &str = email.as_ref();
self.db.localpart_email.insert(localpart, email);
+128
View File
@@ -0,0 +1,128 @@
use std::{
collections::HashMap,
time::{Duration, SystemTime},
};
use conduwuit::utils;
use lettre::Address;
use ruma::{ClientSecret, OwnedClientSecret, OwnedSessionId, SessionId};
#[derive(Default)]
pub(super) struct ValidationSessions {
sessions: HashMap<OwnedSessionId, ValidationSession>,
client_secrets: HashMap<OwnedClientSecret, OwnedSessionId>,
}
/// A pending or completed email validation session.
#[derive(Debug)]
pub(crate) struct ValidationSession {
/// The session's ID
pub session_id: OwnedSessionId,
/// The client's supplied client secret
pub client_secret: OwnedClientSecret,
/// The email address which is being validated
pub email: Address,
/// The session's validation state
pub validation_state: ValidationState,
}
/// The state of an email validation session.
#[derive(Debug)]
pub(crate) enum ValidationState {
/// The session is waiting for this validation token to be provided
Pending(ValidationToken),
/// The session has been validated
Validated,
}
#[derive(Clone, Debug)]
pub(crate) struct ValidationToken {
pub token: String,
pub issued_at: SystemTime,
}
impl ValidationToken {
// one hour
const MAX_TOKEN_AGE: Duration = Duration::from_secs(60 * 60);
const RANDOM_TOKEN_LENGTH: usize = 16;
pub(super) fn new_random() -> Self {
Self {
token: utils::random_string(Self::RANDOM_TOKEN_LENGTH),
issued_at: SystemTime::now(),
}
}
pub(crate) fn is_valid(&self) -> bool {
let now = SystemTime::now();
now.duration_since(self.issued_at)
.is_ok_and(|duration| duration < Self::MAX_TOKEN_AGE)
}
}
impl PartialEq<str> for ValidationToken {
fn eq(&self, other: &str) -> bool { self.token == other }
}
impl ValidationSessions {
const RANDOM_SID_LENGTH: usize = 16;
#[must_use]
pub(super) fn generate_session_id() -> OwnedSessionId {
OwnedSessionId::parse(utils::random_string(Self::RANDOM_SID_LENGTH)).unwrap()
}
pub(super) fn create_session(
&mut self,
email: Address,
client_secret: OwnedClientSecret,
) -> &mut ValidationSession {
let session = ValidationSession {
session_id: Self::generate_session_id(),
client_secret,
email,
validation_state: ValidationState::Pending(ValidationToken::new_random()),
};
self.client_secrets
.insert(session.client_secret.clone(), session.session_id.clone());
self.sessions
.entry(session.session_id.clone())
.insert_entry(session)
.into_mut()
}
pub(super) fn get_session(
&mut self,
session_id: &SessionId,
) -> Option<&mut ValidationSession> {
self.sessions.get_mut(session_id)
}
pub(super) fn get_session_by_client_secret(
&mut self,
client_secret: &ClientSecret,
) -> Option<&mut ValidationSession> {
let session_id = self.client_secrets.get(client_secret)?;
let session = self
.sessions
.get_mut(session_id)
.expect("session should exist with session id");
Some(session)
}
pub(super) fn remove_session(&mut self, session_id: &SessionId) -> ValidationSession {
let session = self
.sessions
.remove(session_id)
.expect("session ID should exist");
self.client_secrets
.remove(&session.client_secret)
.expect("session should have an associated client secret");
session
}
}
+24 -11
View File
@@ -1,4 +1,5 @@
use std::{
borrow::Cow,
collections::{HashMap, HashSet, hash_map::Entry},
sync::Arc,
};
@@ -133,7 +134,7 @@ impl Service {
/// flows provide different values for known identity information.
///
/// Returns the info of the newly created session.
pub async fn create_session(
async fn create_session(
&self,
flows: Vec<AuthFlow>,
params: Box<RawValue>,
@@ -154,11 +155,7 @@ impl Service {
}
/// Proceed with UIAA authentication given a client's authorization data.
pub async fn continue_session(&self, auth: &AuthData) -> Result<UiaaStatus> {
let Some(session) = auth.session() else {
return Err!(Request(MissingParam("No session provided")));
};
async fn continue_session(&self, auth: &AuthData, session: &str) -> Result<UiaaStatus> {
// Hold this lock for the entire function to make sure that, if try_auth()
// is called concurrently with the same session, only one call will succeed
let mut uiaa_sessions = self.uiaa_sessions.lock().await;
@@ -238,7 +235,6 @@ impl Service {
/// Perform the full UIAA authentication sequence for a route given its
/// authentication data.
#[inline]
pub async fn authenticate(
&self,
auth: &Option<AuthData>,
@@ -252,9 +248,26 @@ impl Service {
Err(Error::Uiaa(info))
},
| Some(auth) => match self.continue_session(auth).await? {
| UiaaStatus::Retry(info) => Err(Error::Uiaa(info)),
| UiaaStatus::Success(identity) => Ok(identity),
| Some(auth) => {
let session: Cow<'_, str> = match auth.session() {
| Some(session) => session.into(),
| None => {
// Clients are allowed to send UIAA requests with an auth dict and no
// session if they want to start the UIAA exchange with existing
// authentication data. If that happens, we create a new session
// here.
self.create_session(flows, params, identity)
.await
.session
.unwrap()
.into()
},
};
match self.continue_session(auth, &session).await? {
| UiaaStatus::Retry(info) => Err(Error::Uiaa(info)),
| UiaaStatus::Success(identity) => Ok(identity),
}
},
}
}
@@ -301,7 +314,7 @@ impl Service {
match self
.services
.threepid
.consume_valid_session(sid.as_str(), client_secret)
.consume_valid_session(sid, client_secret)
.await
{
| Ok(email) => {