diff --git a/src/api/router/auth.rs b/src/api/router/auth.rs index f60c46382..7d4f73c64 100644 --- a/src/api/router/auth.rs +++ b/src/api/router/auth.rs @@ -1,6 +1,7 @@ use std::any::{Any, TypeId}; -use conduwuit::{Err, Result, err}; +use conduwuit::{Err, Error, Result, err}; +use http::StatusCode; use ruma::{ DeviceId, OwnedDeviceId, OwnedServerName, OwnedUserId, UserId, api::{ @@ -10,12 +11,15 @@ use ruma::{ AuthScheme, NoAccessToken, NoAuthentication, }, client, + error::{ErrorKind, UnknownTokenErrorData}, federation::authentication::ServerSignatures, }, + assign, }; use service::{ Services, server_keys::{PubKeyMap, PubKeys}, + users::AccessTokenStatus, }; use crate::{router::args::AuthQueryParams, service::appservice::RegistrationInfo}; @@ -153,7 +157,18 @@ impl CheckAuth for AccessToken { query: AuthQueryParams, route: TypeId, ) -> Result { - if let Some((sender_user, sender_device)) = services.users.find_from_token(&output).await { + if let Some((sender_user, sender_device, status)) = services.users.find_from_token(&output).await { + // If the token is expired we return a soft logout + if matches!(status, AccessTokenStatus::Expired) { + return Err(Error::Request( + ErrorKind::UnknownToken( + assign!(UnknownTokenErrorData::new(), { soft_logout: true }), + ), + "This token has expired".into(), + StatusCode::UNAUTHORIZED, + )); + } + // Locked users can only use /logout and /logout/all if services .users @@ -164,7 +179,7 @@ impl CheckAuth for AccessToken { if !(route == TypeId::of::() || route == TypeId::of::()) { - return Err!(Request(Unauthorized("Your account is locked."))); + return Err!(Request(UserLocked("Your account is locked."))); } } @@ -215,7 +230,11 @@ impl CheckAuth for AccessToken { appservice_info: Box::new(appservice_info), }) } else { - Err!(Request(Unauthorized("Invalid access token."))) + Err(Error::Request( + ErrorKind::UnknownToken(UnknownTokenErrorData::new()), + "Invalid token".into(), + StatusCode::UNAUTHORIZED, + )) } } } diff --git a/src/service/appservice/mod.rs b/src/service/appservice/mod.rs index 0d79d4505..d2be5821f 100644 --- a/src/service/appservice/mod.rs +++ b/src/service/appservice/mod.rs @@ -67,7 +67,7 @@ impl crate::Service for Service { for (id, registration) in appservices { // During startup, resolve any token collisions in favour of appservices // by logging out conflicting user devices - if let Some((user_id, device_id)) = self + if let Some((user_id, device_id, _)) = self .services .users .find_from_token(®istration.as_token) diff --git a/src/service/users/mod.rs b/src/service/users/mod.rs index 609c7e02c..0e98c8644 100644 --- a/src/service/users/mod.rs +++ b/src/service/users/mod.rs @@ -57,6 +57,12 @@ impl HashedPassword { } } +/// The status of an access token. +pub enum AccessTokenStatus { + Valid, + Expired, +} + pub struct Service { services: Services, db: Data, @@ -347,7 +353,10 @@ impl Service { pub async fn count(&self) -> usize { self.db.userid_password.count().await } /// Find out which user an access token belongs to. - pub async fn find_from_token(&self, token: &str) -> Option<(OwnedUserId, OwnedDeviceId)> { + pub async fn find_from_token( + &self, + token: &str, + ) -> Option<(OwnedUserId, OwnedDeviceId, AccessTokenStatus)> { let user = self .db .token_userdeviceid @@ -357,11 +366,11 @@ impl Service { .ok(); // Check if the token has expired - if let Some(user) = &user { + if let Some((user_id, device_id)) = user { if let Some(expires) = self .db .userdeviceid_tokenexpires - .qry(user) + .qry(&(&user_id, &device_id)) .await .deserialized::() .ok() @@ -372,12 +381,14 @@ impl Service { .expect("expiry time should not overflow SystemTime"); if SystemTime::now() > expires_at { - return None; + return Some((user_id, device_id, AccessTokenStatus::Expired)); } } - } - user + Some((user_id, device_id, AccessTokenStatus::Valid)) + } else { + None + } } /// Returns an iterator over all users on this homeserver.