mirror of
https://forgejo.ellis.link/continuwuation/continuwuity.git
synced 2026-05-26 20:49:55 +00:00
feat: Store threepid validation sessions in memory instead of the database
This commit is contained in:
Generated
-1
@@ -1012,7 +1012,6 @@ dependencies = [
|
||||
"hyper",
|
||||
"ipaddress",
|
||||
"itertools 0.14.0",
|
||||
"lettre",
|
||||
"log",
|
||||
"rand 0.10.0",
|
||||
"reqwest",
|
||||
|
||||
@@ -2050,8 +2050,9 @@
|
||||
#
|
||||
# For most modern mail servers, format the URI like this:
|
||||
# `smtps://username:password@hostname:port`
|
||||
# Note that you will need to URL-encode the username and password. If your username _is_
|
||||
# your email address, you will need to replace the `@` with `%40`.
|
||||
# Note that you will need to URL-encode the username and password. If your
|
||||
# username _is_ your email address, you will need to replace the `@` with
|
||||
# `%40`.
|
||||
#
|
||||
# For a guide on the accepted URI syntax, consult Lettre's documentation:
|
||||
# https://docs.rs/lettre/latest/lettre/transport/smtp/struct.AsyncSmtpTransport.html#method.from_url
|
||||
|
||||
@@ -49,10 +49,6 @@ pub(super) static MAPS: &[Descriptor] = &[
|
||||
name: "bannedroomids",
|
||||
..descriptor::RANDOM_SMALL
|
||||
},
|
||||
Descriptor {
|
||||
name: "clientsecret_validationsessionid",
|
||||
..descriptor::RANDOM_SMALL
|
||||
},
|
||||
Descriptor {
|
||||
name: "disabledroomids",
|
||||
..descriptor::RANDOM_SMALL
|
||||
@@ -470,12 +466,4 @@ pub(super) static MAPS: &[Descriptor] = &[
|
||||
name: "userroomid_invitesender",
|
||||
..descriptor::RANDOM_SMALL
|
||||
},
|
||||
Descriptor {
|
||||
name: "validationsessionid_session",
|
||||
..descriptor::RANDOM_SMALL
|
||||
},
|
||||
Descriptor {
|
||||
name: "validationsessionid_token",
|
||||
..descriptor::RANDOM_SMALL
|
||||
},
|
||||
];
|
||||
|
||||
@@ -1,159 +0,0 @@
|
||||
use std::{
|
||||
sync::Arc,
|
||||
time::{Duration, SystemTime},
|
||||
};
|
||||
|
||||
use conduwuit::utils;
|
||||
use database::{Database, Deserialized, Map};
|
||||
use lettre::Address;
|
||||
use ruma::{ClientSecret, OwnedClientSecret, OwnedSessionId};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub(super) struct Data {
|
||||
// note: the column names of these maps use `validationsession` instead of `session`
|
||||
clientsecret_sessionid: Arc<Map>,
|
||||
sessionid_session: Arc<Map>,
|
||||
sessionid_token: Arc<Map>,
|
||||
pub(super) localpart_email: Arc<Map>,
|
||||
pub(super) email_localpart: Arc<Map>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub(crate) struct ValidationSession {
|
||||
/// The session's ID
|
||||
pub session_id: OwnedSessionId,
|
||||
/// The email address which is being validated
|
||||
pub email: Address,
|
||||
/// The client's supplied client secret
|
||||
pub client_secret: OwnedClientSecret,
|
||||
/// Whether the email address has been validated successfully yet
|
||||
pub(super) has_been_validated: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub(crate) struct ValidationToken {
|
||||
pub token: String,
|
||||
pub issued_at: SystemTime,
|
||||
}
|
||||
|
||||
impl ValidationToken {
|
||||
// one hour
|
||||
const MAX_TOKEN_AGE: Duration = Duration::from_secs(60 * 60);
|
||||
const RANDOM_TOKEN_LENGTH: usize = 16;
|
||||
|
||||
pub(super) fn new_random() -> Self {
|
||||
Self {
|
||||
token: utils::random_string(Self::RANDOM_TOKEN_LENGTH),
|
||||
issued_at: SystemTime::now(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn is_valid(&self) -> bool {
|
||||
let now = SystemTime::now();
|
||||
|
||||
now.duration_since(self.issued_at)
|
||||
.is_ok_and(|duration| duration < Self::MAX_TOKEN_AGE)
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq<str> for ValidationToken {
|
||||
fn eq(&self, other: &str) -> bool { self.token == other }
|
||||
}
|
||||
|
||||
impl Data {
|
||||
pub(super) fn new(db: &Arc<Database>) -> Self {
|
||||
Self {
|
||||
clientsecret_sessionid: db["clientsecret_validationsessionid"].clone(),
|
||||
sessionid_session: db["validationsessionid_session"].clone(),
|
||||
sessionid_token: db["validationsessionid_token"].clone(),
|
||||
localpart_email: db["localpart_email"].clone(),
|
||||
email_localpart: db["email_localpart"].clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a validation session.
|
||||
pub(super) fn create_session(
|
||||
&self,
|
||||
email: Address,
|
||||
session_id: OwnedSessionId,
|
||||
client_secret: OwnedClientSecret,
|
||||
token: ValidationToken,
|
||||
) {
|
||||
let session = ValidationSession {
|
||||
session_id,
|
||||
client_secret,
|
||||
email,
|
||||
has_been_validated: false,
|
||||
};
|
||||
self.clientsecret_sessionid
|
||||
.insert(&session.session_id, &session.client_secret);
|
||||
self.sessionid_token.raw_put(&session.session_id, token);
|
||||
self.sessionid_session
|
||||
.raw_put(session.session_id.clone(), session);
|
||||
}
|
||||
|
||||
/// Get a validation session.
|
||||
pub(super) async fn get_session(&self, session_id: &str) -> Option<ValidationSession> {
|
||||
self.sessionid_session
|
||||
.get(session_id)
|
||||
.await
|
||||
.deserialized()
|
||||
.ok()
|
||||
}
|
||||
|
||||
/// Get a validation session by client secret.
|
||||
pub(super) async fn get_session_by_secret(
|
||||
&self,
|
||||
client_secret: &ClientSecret,
|
||||
) -> Option<ValidationSession> {
|
||||
let session_id: String = self
|
||||
.clientsecret_sessionid
|
||||
.get(client_secret)
|
||||
.await
|
||||
.deserialized()
|
||||
.ok()?;
|
||||
|
||||
self.get_session(&session_id).await
|
||||
}
|
||||
|
||||
/// Get the validation token for a validation session, or None if the
|
||||
/// session does not exist.
|
||||
pub(super) async fn get_session_validation_token(
|
||||
&self,
|
||||
session: &ValidationSession,
|
||||
) -> Option<ValidationToken> {
|
||||
self.sessionid_token
|
||||
.get(&session.session_id)
|
||||
.await
|
||||
.deserialized()
|
||||
.ok()
|
||||
}
|
||||
|
||||
/// Update a session's validation token.
|
||||
pub(super) fn update_session_validation_token(
|
||||
&self,
|
||||
session: &ValidationSession,
|
||||
token: ValidationToken,
|
||||
) {
|
||||
self.sessionid_token.raw_put(&session.session_id, token);
|
||||
}
|
||||
|
||||
/// Mark a validation session as valid.
|
||||
pub(super) async fn mark_session_as_valid(&self, mut session: ValidationSession) {
|
||||
self.sessionid_token.remove(&session.session_id);
|
||||
|
||||
session.has_been_validated = true;
|
||||
self.sessionid_session
|
||||
.raw_put(session.session_id.clone(), session);
|
||||
}
|
||||
|
||||
/// Remove a validation session.
|
||||
pub(super) async fn remove_session(
|
||||
&self,
|
||||
ValidationSession { session_id, .. }: ValidationSession,
|
||||
) {
|
||||
self.clientsecret_sessionid.remove(&session_id);
|
||||
self.sessionid_token.remove(&session_id);
|
||||
self.sessionid_session.remove(&session_id);
|
||||
}
|
||||
}
|
||||
+102
-97
@@ -1,25 +1,28 @@
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
collections::HashMap,
|
||||
sync::{Arc, Mutex},
|
||||
};
|
||||
use std::{borrow::Cow, collections::HashMap, sync::Arc};
|
||||
|
||||
use conduwuit::{Err, Result, result::FlatOk};
|
||||
use database::{Deserialized, Map};
|
||||
use lettre::{Address, message::Mailbox};
|
||||
use ruma::{ClientSecret, OwnedClientSecret, OwnedSessionId, SessionId};
|
||||
|
||||
mod session;
|
||||
|
||||
use crate::{
|
||||
Args, Dep, config,
|
||||
mailer::{self, messages::MessageTemplate},
|
||||
threepid::session::{ValidationSessions, ValidationState, ValidationToken},
|
||||
};
|
||||
|
||||
mod data;
|
||||
use conduwuit::{Err, Result, result::FlatOk, utils};
|
||||
use data::{Data, ValidationToken};
|
||||
use database::Deserialized;
|
||||
use lettre::{Address, message::Mailbox};
|
||||
use ruma::{ClientSecret, OwnedClientSecret, OwnedSessionId};
|
||||
|
||||
pub struct Service {
|
||||
db: Data,
|
||||
services: Services,
|
||||
send_attempts: Mutex<HashMap<(OwnedClientSecret, Address), usize>>,
|
||||
sessions: tokio::sync::Mutex<ValidationSessions>,
|
||||
send_attempts: std::sync::Mutex<HashMap<(OwnedClientSecret, Address), usize>>,
|
||||
}
|
||||
|
||||
struct Data {
|
||||
localpart_email: Arc<Map>,
|
||||
email_localpart: Arc<Map>,
|
||||
}
|
||||
|
||||
struct Services {
|
||||
@@ -30,12 +33,16 @@ struct Services {
|
||||
impl crate::Service for Service {
|
||||
fn build(args: Args<'_>) -> Result<Arc<Self>> {
|
||||
Ok(Arc::new(Self {
|
||||
db: Data::new(args.db),
|
||||
db: Data {
|
||||
email_localpart: args.db["email_localpart"].clone(),
|
||||
localpart_email: args.db["localpart_email"].clone(),
|
||||
},
|
||||
services: Services {
|
||||
config: args.depend("config"),
|
||||
mailer: args.depend("mailer"),
|
||||
},
|
||||
send_attempts: Mutex::new(HashMap::new()),
|
||||
sessions: tokio::sync::Mutex::default(),
|
||||
send_attempts: std::sync::Mutex::default(),
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -43,14 +50,8 @@ impl crate::Service for Service {
|
||||
}
|
||||
|
||||
impl Service {
|
||||
const RANDOM_SID_LENGTH: usize = 16;
|
||||
const VALIDATION_URL_PATH: &str = "/_continuwuity/3pid/email/validate";
|
||||
|
||||
#[must_use]
|
||||
pub fn generate_session_id() -> OwnedSessionId {
|
||||
OwnedSessionId::parse(utils::random_string(Self::RANDOM_SID_LENGTH)).unwrap()
|
||||
}
|
||||
|
||||
/// Send a validation message to an email address.
|
||||
///
|
||||
/// Returns the validation session ID on success.
|
||||
@@ -63,62 +64,57 @@ impl Service {
|
||||
send_attempt: usize,
|
||||
) -> Result<OwnedSessionId> {
|
||||
let mailer = self.services.mailer.expect_mailer()?;
|
||||
let mut sessions = self.sessions.lock().await;
|
||||
|
||||
let (session_id, ValidationToken { token, .. }) =
|
||||
match self.db.get_session_by_secret(client_secret).await {
|
||||
// If a validation session already exists for this client secret, we can either
|
||||
// reuse it with a new token or return early because it's already valid.
|
||||
| Some(session) => {
|
||||
// If the existing session is already valid, don't send an email.
|
||||
if session.has_been_validated {
|
||||
return Ok(session.session_id);
|
||||
}
|
||||
let session = match sessions.get_session_by_client_secret(client_secret) {
|
||||
// If a validation session already exists for this client secret, we can either
|
||||
// reuse it with a new token or return early because it's already valid.
|
||||
| Some(session) => {
|
||||
match session.validation_state {
|
||||
| ValidationState::Validated => {
|
||||
// If the existing session is already valid, don't send an email.
|
||||
return Ok(session.session_id.clone());
|
||||
},
|
||||
| ValidationState::Pending(ref mut token) => {
|
||||
// Check the send attempt for this session
|
||||
let mut send_attempts = self.send_attempts.lock().unwrap();
|
||||
|
||||
let mut send_attempts = self.send_attempts.lock().unwrap();
|
||||
match send_attempts
|
||||
.get_mut(&(session.client_secret.clone(), session.email.clone()))
|
||||
{
|
||||
| Some(last_send_attempt) => {
|
||||
if send_attempt <= *last_send_attempt {
|
||||
// If the supplied send attempt isn't higher than the last one,
|
||||
// don't send an email.
|
||||
return Ok(session.session_id);
|
||||
}
|
||||
match send_attempts
|
||||
.get_mut(&(session.client_secret.clone(), session.email.clone()))
|
||||
{
|
||||
| Some(last_send_attempt) => {
|
||||
if send_attempt <= *last_send_attempt {
|
||||
// If the supplied send attempt isn't higher than the last
|
||||
// one, don't send an email.
|
||||
return Ok(session.session_id.clone());
|
||||
}
|
||||
|
||||
// Otherwise save the supplied send attempt.
|
||||
*last_send_attempt = send_attempt;
|
||||
},
|
||||
| None => {
|
||||
// Default to sending an email if no previous
|
||||
// attempt could be found. This can happen if
|
||||
// the server was restarted, which clears the send
|
||||
// attempt tracker.
|
||||
},
|
||||
}
|
||||
drop(send_attempts);
|
||||
// Otherwise save the supplied send attempt.
|
||||
*last_send_attempt = send_attempt;
|
||||
},
|
||||
| None => {
|
||||
// Default to sending an email if no previous
|
||||
// attempt could be found. This can happen if
|
||||
// the server was restarted, which clears the
|
||||
// send attempt tracker.
|
||||
},
|
||||
}
|
||||
drop(send_attempts);
|
||||
|
||||
// Create a new token for the existing session.
|
||||
let token = ValidationToken::new_random();
|
||||
self.db
|
||||
.update_session_validation_token(&session, token.clone());
|
||||
// Create a new token for the existing session.
|
||||
*token = ValidationToken::new_random();
|
||||
|
||||
(session.session_id, token)
|
||||
},
|
||||
// If no session exists, create a new one.
|
||||
| None => {
|
||||
let session_id = Self::generate_session_id();
|
||||
let token = ValidationToken::new_random();
|
||||
session
|
||||
},
|
||||
}
|
||||
},
|
||||
// If no session exists, create a new one.
|
||||
| None => sessions.create_session(recipient.email.clone(), client_secret.to_owned()),
|
||||
};
|
||||
|
||||
self.db.create_session(
|
||||
recipient.email.clone(),
|
||||
session_id.clone(),
|
||||
client_secret.to_owned(),
|
||||
token.clone(),
|
||||
);
|
||||
|
||||
(session_id, token)
|
||||
},
|
||||
};
|
||||
let ValidationState::Pending(token) = &session.validation_state else {
|
||||
unreachable!("session should be pending")
|
||||
};
|
||||
|
||||
let mut validation_url = self
|
||||
.services
|
||||
@@ -129,41 +125,46 @@ impl Service {
|
||||
|
||||
validation_url
|
||||
.query_pairs_mut()
|
||||
.append_pair("session_id", session_id.as_ref())
|
||||
.append_pair("token", &token);
|
||||
.append_pair("session_id", session.session_id.as_ref())
|
||||
.append_pair("token", &token.token);
|
||||
|
||||
let message = prepare_body(validation_url.to_string());
|
||||
|
||||
mailer.send(recipient, message).await?;
|
||||
|
||||
Ok(session_id)
|
||||
Ok(session.session_id.clone())
|
||||
}
|
||||
|
||||
/// Attempt to mark a validation session as valid using a validation token.
|
||||
pub async fn try_validate_session(
|
||||
&self,
|
||||
session_id: &str,
|
||||
session_id: &SessionId,
|
||||
supplied_token: &str,
|
||||
) -> Result<(), Cow<'static, str>> {
|
||||
let Some(session) = self.db.get_session(session_id).await else {
|
||||
let mut sessions = self.sessions.lock().await;
|
||||
|
||||
let Some(session) = sessions.get_session(session_id) else {
|
||||
return Err("Validation session does not exist".into());
|
||||
};
|
||||
|
||||
if session.has_been_validated {
|
||||
return Ok(());
|
||||
}
|
||||
session.validation_state = match &session.validation_state {
|
||||
| ValidationState::Validated => {
|
||||
// If the session is already validated, do nothing.
|
||||
|
||||
let token = self
|
||||
.db
|
||||
.get_session_validation_token(&session)
|
||||
.await
|
||||
.expect("valid session should have a token");
|
||||
return Ok(());
|
||||
},
|
||||
| ValidationState::Pending(token) => {
|
||||
// Otherwise check the token and mark the session as valid.
|
||||
|
||||
if token != *supplied_token || !token.is_valid() {
|
||||
return Err("Validation token is invalid or expired, please request a new one".into());
|
||||
}
|
||||
if *token != *supplied_token || !token.is_valid() {
|
||||
return Err("Validation token is invalid or expired, please request a new \
|
||||
one"
|
||||
.into());
|
||||
}
|
||||
|
||||
self.db.mark_session_as_valid(session).await;
|
||||
ValidationState::Validated
|
||||
},
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -172,17 +173,21 @@ impl Service {
|
||||
/// and returning the newly validated email address.
|
||||
pub async fn consume_valid_session(
|
||||
&self,
|
||||
session_id: &str,
|
||||
session_id: &SessionId,
|
||||
client_secret: &ClientSecret,
|
||||
) -> Result<Address, Cow<'static, str>> {
|
||||
let Some(session) = self.db.get_session(session_id).await else {
|
||||
let mut sessions = self.sessions.lock().await;
|
||||
|
||||
let Some(session) = sessions.get_session(session_id) else {
|
||||
return Err("Validation session does not exist".into());
|
||||
};
|
||||
|
||||
if session.client_secret == client_secret && session.has_been_validated {
|
||||
let email = session.email.clone();
|
||||
self.db.remove_session(session).await;
|
||||
Ok(email)
|
||||
if session.client_secret == client_secret
|
||||
&& matches!(session.validation_state, ValidationState::Validated)
|
||||
{
|
||||
let session = sessions.remove_session(session_id);
|
||||
|
||||
Ok(session.email)
|
||||
} else {
|
||||
Err("This email address has not been validated. Did you use the link that was sent \
|
||||
to you?"
|
||||
@@ -198,17 +203,17 @@ impl Service {
|
||||
) -> Result<()> {
|
||||
match self.get_localpart_for_email(email).await {
|
||||
| Some(existing_localpart) if existing_localpart != localpart => {
|
||||
// Another account is already using the supplied email
|
||||
// Another account is already using the supplied email.
|
||||
|
||||
Err!(Request(ThreepidInUse("This email address is already in use.")))
|
||||
},
|
||||
| Some(_) => {
|
||||
// The supplied localpart is already associated with the supplied email,
|
||||
// no changes are necessary
|
||||
// no changes are necessary.
|
||||
Ok(())
|
||||
},
|
||||
| None => {
|
||||
// The supplied email is not already in use
|
||||
// The supplied email is not already in use.
|
||||
|
||||
let email: &str = email.as_ref();
|
||||
self.db.localpart_email.insert(localpart, email);
|
||||
|
||||
@@ -0,0 +1,128 @@
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
time::{Duration, SystemTime},
|
||||
};
|
||||
|
||||
use conduwuit::utils;
|
||||
use lettre::Address;
|
||||
use ruma::{ClientSecret, OwnedClientSecret, OwnedSessionId, SessionId};
|
||||
|
||||
#[derive(Default)]
|
||||
pub(super) struct ValidationSessions {
|
||||
sessions: HashMap<OwnedSessionId, ValidationSession>,
|
||||
client_secrets: HashMap<OwnedClientSecret, OwnedSessionId>,
|
||||
}
|
||||
|
||||
/// A pending or completed email validation session.
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct ValidationSession {
|
||||
/// The session's ID
|
||||
pub session_id: OwnedSessionId,
|
||||
/// The client's supplied client secret
|
||||
pub client_secret: OwnedClientSecret,
|
||||
/// The email address which is being validated
|
||||
pub email: Address,
|
||||
/// The session's validation state
|
||||
pub validation_state: ValidationState,
|
||||
}
|
||||
|
||||
/// The state of an email validation session.
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum ValidationState {
|
||||
/// The session is waiting for this validation token to be provided
|
||||
Pending(ValidationToken),
|
||||
/// The session has been validated
|
||||
Validated,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct ValidationToken {
|
||||
pub token: String,
|
||||
pub issued_at: SystemTime,
|
||||
}
|
||||
|
||||
impl ValidationToken {
|
||||
// one hour
|
||||
const MAX_TOKEN_AGE: Duration = Duration::from_secs(60 * 60);
|
||||
const RANDOM_TOKEN_LENGTH: usize = 16;
|
||||
|
||||
pub(super) fn new_random() -> Self {
|
||||
Self {
|
||||
token: utils::random_string(Self::RANDOM_TOKEN_LENGTH),
|
||||
issued_at: SystemTime::now(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn is_valid(&self) -> bool {
|
||||
let now = SystemTime::now();
|
||||
|
||||
now.duration_since(self.issued_at)
|
||||
.is_ok_and(|duration| duration < Self::MAX_TOKEN_AGE)
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq<str> for ValidationToken {
|
||||
fn eq(&self, other: &str) -> bool { self.token == other }
|
||||
}
|
||||
|
||||
impl ValidationSessions {
|
||||
const RANDOM_SID_LENGTH: usize = 16;
|
||||
|
||||
#[must_use]
|
||||
pub(super) fn generate_session_id() -> OwnedSessionId {
|
||||
OwnedSessionId::parse(utils::random_string(Self::RANDOM_SID_LENGTH)).unwrap()
|
||||
}
|
||||
|
||||
pub(super) fn create_session(
|
||||
&mut self,
|
||||
email: Address,
|
||||
client_secret: OwnedClientSecret,
|
||||
) -> &mut ValidationSession {
|
||||
let session = ValidationSession {
|
||||
session_id: Self::generate_session_id(),
|
||||
client_secret,
|
||||
email,
|
||||
validation_state: ValidationState::Pending(ValidationToken::new_random()),
|
||||
};
|
||||
|
||||
self.client_secrets
|
||||
.insert(session.client_secret.clone(), session.session_id.clone());
|
||||
self.sessions
|
||||
.entry(session.session_id.clone())
|
||||
.insert_entry(session)
|
||||
.into_mut()
|
||||
}
|
||||
|
||||
pub(super) fn get_session(
|
||||
&mut self,
|
||||
session_id: &SessionId,
|
||||
) -> Option<&mut ValidationSession> {
|
||||
self.sessions.get_mut(session_id)
|
||||
}
|
||||
|
||||
pub(super) fn get_session_by_client_secret(
|
||||
&mut self,
|
||||
client_secret: &ClientSecret,
|
||||
) -> Option<&mut ValidationSession> {
|
||||
let session_id = self.client_secrets.get(client_secret)?;
|
||||
let session = self
|
||||
.sessions
|
||||
.get_mut(session_id)
|
||||
.expect("session should exist with session id");
|
||||
|
||||
Some(session)
|
||||
}
|
||||
|
||||
pub(super) fn remove_session(&mut self, session_id: &SessionId) -> ValidationSession {
|
||||
let session = self
|
||||
.sessions
|
||||
.remove(session_id)
|
||||
.expect("session ID should exist");
|
||||
|
||||
self.client_secrets
|
||||
.remove(&session.client_secret)
|
||||
.expect("session should have an associated client secret");
|
||||
|
||||
session
|
||||
}
|
||||
}
|
||||
+24
-11
@@ -1,4 +1,5 @@
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
collections::{HashMap, HashSet, hash_map::Entry},
|
||||
sync::Arc,
|
||||
};
|
||||
@@ -133,7 +134,7 @@ impl Service {
|
||||
/// flows provide different values for known identity information.
|
||||
///
|
||||
/// Returns the info of the newly created session.
|
||||
pub async fn create_session(
|
||||
async fn create_session(
|
||||
&self,
|
||||
flows: Vec<AuthFlow>,
|
||||
params: Box<RawValue>,
|
||||
@@ -154,11 +155,7 @@ impl Service {
|
||||
}
|
||||
|
||||
/// Proceed with UIAA authentication given a client's authorization data.
|
||||
pub async fn continue_session(&self, auth: &AuthData) -> Result<UiaaStatus> {
|
||||
let Some(session) = auth.session() else {
|
||||
return Err!(Request(MissingParam("No session provided")));
|
||||
};
|
||||
|
||||
async fn continue_session(&self, auth: &AuthData, session: &str) -> Result<UiaaStatus> {
|
||||
// Hold this lock for the entire function to make sure that, if try_auth()
|
||||
// is called concurrently with the same session, only one call will succeed
|
||||
let mut uiaa_sessions = self.uiaa_sessions.lock().await;
|
||||
@@ -238,7 +235,6 @@ impl Service {
|
||||
|
||||
/// Perform the full UIAA authentication sequence for a route given its
|
||||
/// authentication data.
|
||||
#[inline]
|
||||
pub async fn authenticate(
|
||||
&self,
|
||||
auth: &Option<AuthData>,
|
||||
@@ -252,9 +248,26 @@ impl Service {
|
||||
|
||||
Err(Error::Uiaa(info))
|
||||
},
|
||||
| Some(auth) => match self.continue_session(auth).await? {
|
||||
| UiaaStatus::Retry(info) => Err(Error::Uiaa(info)),
|
||||
| UiaaStatus::Success(identity) => Ok(identity),
|
||||
| Some(auth) => {
|
||||
let session: Cow<'_, str> = match auth.session() {
|
||||
| Some(session) => session.into(),
|
||||
| None => {
|
||||
// Clients are allowed to send UIAA requests with an auth dict and no
|
||||
// session if they want to start the UIAA exchange with existing
|
||||
// authentication data. If that happens, we create a new session
|
||||
// here.
|
||||
self.create_session(flows, params, identity)
|
||||
.await
|
||||
.session
|
||||
.unwrap()
|
||||
.into()
|
||||
},
|
||||
};
|
||||
|
||||
match self.continue_session(auth, &session).await? {
|
||||
| UiaaStatus::Retry(info) => Err(Error::Uiaa(info)),
|
||||
| UiaaStatus::Success(identity) => Ok(identity),
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -301,7 +314,7 @@ impl Service {
|
||||
match self
|
||||
.services
|
||||
.threepid
|
||||
.consume_valid_session(sid.as_str(), client_secret)
|
||||
.consume_valid_session(sid, client_secret)
|
||||
.await
|
||||
{
|
||||
| Ok(email) => {
|
||||
|
||||
Reference in New Issue
Block a user