mirror of
https://forgejo.ellis.link/continuwuation/continuwuity.git
synced 2026-05-26 20:49:55 +00:00
fix: Return the correct error code for expired access tokens
This commit is contained in:
+23
-4
@@ -1,6 +1,7 @@
|
||||
use std::any::{Any, TypeId};
|
||||
|
||||
use conduwuit::{Err, Result, err};
|
||||
use conduwuit::{Err, Error, Result, err};
|
||||
use http::StatusCode;
|
||||
use ruma::{
|
||||
DeviceId, OwnedDeviceId, OwnedServerName, OwnedUserId, UserId,
|
||||
api::{
|
||||
@@ -10,12 +11,15 @@ use ruma::{
|
||||
AuthScheme, NoAccessToken, NoAuthentication,
|
||||
},
|
||||
client,
|
||||
error::{ErrorKind, UnknownTokenErrorData},
|
||||
federation::authentication::ServerSignatures,
|
||||
},
|
||||
assign,
|
||||
};
|
||||
use service::{
|
||||
Services,
|
||||
server_keys::{PubKeyMap, PubKeys},
|
||||
users::AccessTokenStatus,
|
||||
};
|
||||
|
||||
use crate::{router::args::AuthQueryParams, service::appservice::RegistrationInfo};
|
||||
@@ -153,7 +157,18 @@ impl CheckAuth for AccessToken {
|
||||
query: AuthQueryParams,
|
||||
route: TypeId,
|
||||
) -> Result<Self::Identity> {
|
||||
if let Some((sender_user, sender_device)) = services.users.find_from_token(&output).await {
|
||||
if let Some((sender_user, sender_device, status)) = services.users.find_from_token(&output).await {
|
||||
// If the token is expired we return a soft logout
|
||||
if matches!(status, AccessTokenStatus::Expired) {
|
||||
return Err(Error::Request(
|
||||
ErrorKind::UnknownToken(
|
||||
assign!(UnknownTokenErrorData::new(), { soft_logout: true }),
|
||||
),
|
||||
"This token has expired".into(),
|
||||
StatusCode::UNAUTHORIZED,
|
||||
));
|
||||
}
|
||||
|
||||
// Locked users can only use /logout and /logout/all
|
||||
if services
|
||||
.users
|
||||
@@ -164,7 +179,7 @@ impl CheckAuth for AccessToken {
|
||||
if !(route == TypeId::of::<client::session::logout::v3::Request>()
|
||||
|| route == TypeId::of::<client::session::logout_all::v3::Request>())
|
||||
{
|
||||
return Err!(Request(Unauthorized("Your account is locked.")));
|
||||
return Err!(Request(UserLocked("Your account is locked.")));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -215,7 +230,11 @@ impl CheckAuth for AccessToken {
|
||||
appservice_info: Box::new(appservice_info),
|
||||
})
|
||||
} else {
|
||||
Err!(Request(Unauthorized("Invalid access token.")))
|
||||
Err(Error::Request(
|
||||
ErrorKind::UnknownToken(UnknownTokenErrorData::new()),
|
||||
"Invalid token".into(),
|
||||
StatusCode::UNAUTHORIZED,
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -67,7 +67,7 @@ impl crate::Service for Service {
|
||||
for (id, registration) in appservices {
|
||||
// During startup, resolve any token collisions in favour of appservices
|
||||
// by logging out conflicting user devices
|
||||
if let Some((user_id, device_id)) = self
|
||||
if let Some((user_id, device_id, _)) = self
|
||||
.services
|
||||
.users
|
||||
.find_from_token(®istration.as_token)
|
||||
|
||||
@@ -57,6 +57,12 @@ impl HashedPassword {
|
||||
}
|
||||
}
|
||||
|
||||
/// The status of an access token.
|
||||
pub enum AccessTokenStatus {
|
||||
Valid,
|
||||
Expired,
|
||||
}
|
||||
|
||||
pub struct Service {
|
||||
services: Services,
|
||||
db: Data,
|
||||
@@ -347,7 +353,10 @@ impl Service {
|
||||
pub async fn count(&self) -> usize { self.db.userid_password.count().await }
|
||||
|
||||
/// Find out which user an access token belongs to.
|
||||
pub async fn find_from_token(&self, token: &str) -> Option<(OwnedUserId, OwnedDeviceId)> {
|
||||
pub async fn find_from_token(
|
||||
&self,
|
||||
token: &str,
|
||||
) -> Option<(OwnedUserId, OwnedDeviceId, AccessTokenStatus)> {
|
||||
let user = self
|
||||
.db
|
||||
.token_userdeviceid
|
||||
@@ -357,11 +366,11 @@ impl Service {
|
||||
.ok();
|
||||
|
||||
// Check if the token has expired
|
||||
if let Some(user) = &user {
|
||||
if let Some((user_id, device_id)) = user {
|
||||
if let Some(expires) = self
|
||||
.db
|
||||
.userdeviceid_tokenexpires
|
||||
.qry(user)
|
||||
.qry(&(&user_id, &device_id))
|
||||
.await
|
||||
.deserialized::<u64>()
|
||||
.ok()
|
||||
@@ -372,12 +381,14 @@ impl Service {
|
||||
.expect("expiry time should not overflow SystemTime");
|
||||
|
||||
if SystemTime::now() > expires_at {
|
||||
return None;
|
||||
return Some((user_id, device_id, AccessTokenStatus::Expired));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
user
|
||||
Some((user_id, device_id, AccessTokenStatus::Valid))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns an iterator over all users on this homeserver.
|
||||
|
||||
Reference in New Issue
Block a user