From 72b51e00822cb00ff01401269635675655e8a4dd Mon Sep 17 00:00:00 2001 From: Stefan Melmuk Date: Wed, 25 Dec 2024 05:08:42 +0100 Subject: [PATCH] add auth_request_id newtype --- src/api/core/accounts.rs | 27 +++++++++++++++------------ src/api/identity.rs | 7 +++---- src/db/models/auth_request.rs | 28 ++++++++++++++++++++++++---- src/db/models/mod.rs | 2 +- 4 files changed, 43 insertions(+), 21 deletions(-) diff --git a/src/api/core/accounts.rs b/src/api/core/accounts.rs index 436bbf2d..1f114497 100644 --- a/src/api/core/accounts.rs +++ b/src/api/core/accounts.rs @@ -1189,16 +1189,17 @@ async fn post_auth_request( }))) } -#[get("/auth-requests/")] -async fn get_auth_request(uuid: &str, headers: Headers, mut conn: DbConn) -> JsonResult { - let Some(auth_request) = AuthRequest::find_by_uuid_and_user(uuid, &headers.user.uuid, &mut conn).await else { +#[get("/auth-requests/")] +async fn get_auth_request(auth_request_id: AuthRequestId, headers: Headers, mut conn: DbConn) -> JsonResult { + let Some(auth_request) = AuthRequest::find_by_uuid_and_user(&auth_request_id, &headers.user.uuid, &mut conn).await + else { err!("AuthRequest doesn't exist", "Record not found or user uuid does not match") }; let response_date_utc = auth_request.response_date.map(|response_date| format_date(&response_date)); Ok(Json(json!({ - "id": uuid, + "id": &auth_request_id, "publicKey": auth_request.public_key, "requestDeviceType": DeviceType::from_i32(auth_request.device_type).to_string(), "requestIpAddress": auth_request.request_ip, @@ -1221,9 +1222,9 @@ struct AuthResponseRequest { request_approved: bool, } -#[put("/auth-requests/", data = "")] +#[put("/auth-requests/", data = "")] async fn put_auth_request( - uuid: &str, + auth_request_id: AuthRequestId, data: Json, headers: Headers, mut conn: DbConn, @@ -1231,7 +1232,9 @@ async fn put_auth_request( nt: Notify<'_>, ) -> JsonResult { let data = data.into_inner(); - let Some(mut auth_request) = AuthRequest::find_by_uuid_and_user(uuid, &headers.user.uuid, &mut conn).await else { + let Some(mut auth_request) = + AuthRequest::find_by_uuid_and_user(&auth_request_id, &headers.user.uuid, &mut conn).await + else { err!("AuthRequest doesn't exist", "Record not found or user uuid does not match") }; @@ -1258,7 +1261,7 @@ async fn put_auth_request( } Ok(Json(json!({ - "id": uuid, + "id": &auth_request_id, "publicKey": auth_request.public_key, "requestDeviceType": DeviceType::from_i32(auth_request.device_type).to_string(), "requestIpAddress": auth_request.request_ip, @@ -1272,14 +1275,14 @@ async fn put_auth_request( }))) } -#[get("/auth-requests//response?")] +#[get("/auth-requests//response?")] async fn get_auth_request_response( - uuid: &str, + auth_request_id: AuthRequestId, code: &str, client_headers: ClientHeaders, mut conn: DbConn, ) -> JsonResult { - let Some(auth_request) = AuthRequest::find_by_uuid(uuid, &mut conn).await else { + let Some(auth_request) = AuthRequest::find_by_uuid(&auth_request_id, &mut conn).await else { err!("AuthRequest doesn't exist", "User not found") }; @@ -1293,7 +1296,7 @@ async fn get_auth_request_response( let response_date_utc = auth_request.response_date.map(|response_date| format_date(&response_date)); Ok(Json(json!({ - "id": uuid, + "id": &auth_request_id, "publicKey": auth_request.public_key, "requestDeviceType": DeviceType::from_i32(auth_request.device_type).to_string(), "requestIpAddress": auth_request.request_ip, diff --git a/src/api/identity.rs b/src/api/identity.rs index 2f02f481..38cdfce5 100644 --- a/src/api/identity.rs +++ b/src/api/identity.rs @@ -178,9 +178,8 @@ async fn _password_login( let password = data.password.as_ref().unwrap(); // If we get an auth request, we don't check the user's password, but the access code of the auth request - if let Some(ref auth_request_uuid) = data.auth_request { - let Some(auth_request) = AuthRequest::find_by_uuid_and_user(auth_request_uuid.as_str(), &user.uuid, conn).await - else { + if let Some(ref auth_request_id) = data.auth_request { + let Some(auth_request) = AuthRequest::find_by_uuid_and_user(auth_request_id, &user.uuid, conn).await else { err!( "Auth request not found. Try again.", format!("IP: {}. Username: {}.", ip.ip, username), @@ -770,7 +769,7 @@ struct ConnectData { #[field(name = uncased("twofactorremember"))] two_factor_remember: Option, #[field(name = uncased("authrequest"))] - auth_request: Option, + auth_request: Option, } fn _check_is_some(value: &Option, msg: &str) -> EmptyResult { diff --git a/src/db/models/auth_request.rs b/src/db/models/auth_request.rs index eab91d87..3417d07e 100644 --- a/src/db/models/auth_request.rs +++ b/src/db/models/auth_request.rs @@ -1,6 +1,8 @@ use super::{DeviceId, OrganizationId, UserId}; use crate::crypto::ct_eq; use chrono::{NaiveDateTime, Utc}; +use derive_more::{AsRef, Deref, Display, From}; +use rocket::request::FromParam; db_object! { #[derive(Debug, Identifiable, Queryable, Insertable, AsChangeset, Deserialize, Serialize)] @@ -8,7 +10,7 @@ db_object! { #[diesel(treat_none_as_null = true)] #[diesel(primary_key(uuid))] pub struct AuthRequest { - pub uuid: String, + pub uuid: AuthRequestId, pub user_uuid: UserId, pub organization_uuid: Option, @@ -44,7 +46,7 @@ impl AuthRequest { let now = Utc::now().naive_utc(); Self { - uuid: crate::util::get_uuid(), + uuid: AuthRequestId(crate::util::get_uuid()), user_uuid, organization_uuid: None, @@ -102,7 +104,7 @@ impl AuthRequest { } } - pub async fn find_by_uuid(uuid: &str, conn: &mut DbConn) -> Option { + pub async fn find_by_uuid(uuid: &AuthRequestId, conn: &mut DbConn) -> Option { db_run! {conn: { auth_requests::table .filter(auth_requests::uuid.eq(uuid)) @@ -112,7 +114,7 @@ impl AuthRequest { }} } - pub async fn find_by_uuid_and_user(uuid: &str, user_uuid: &UserId, conn: &mut DbConn) -> Option { + pub async fn find_by_uuid_and_user(uuid: &AuthRequestId, user_uuid: &UserId, conn: &mut DbConn) -> Option { db_run! {conn: { auth_requests::table .filter(auth_requests::uuid.eq(uuid)) @@ -158,3 +160,21 @@ impl AuthRequest { } } } + +#[derive( + Clone, Debug, AsRef, Deref, DieselNewType, Display, From, FromForm, Hash, PartialEq, Eq, Serialize, Deserialize, +)] +pub struct AuthRequestId(String); + +impl<'r> FromParam<'r> for AuthRequestId { + type Error = (); + + #[inline(always)] + fn from_param(param: &'r str) -> Result { + if param.chars().all(|c| matches!(c, 'a'..='z' | 'A'..='Z' |'0'..='9' | '-')) { + Ok(Self(param.to_string())) + } else { + Err(()) + } + } +} diff --git a/src/db/models/mod.rs b/src/db/models/mod.rs index b6691e7f..d503354c 100644 --- a/src/db/models/mod.rs +++ b/src/db/models/mod.rs @@ -17,7 +17,7 @@ mod two_factor_incomplete; mod user; pub use self::attachment::{Attachment, AttachmentId}; -pub use self::auth_request::AuthRequest; +pub use self::auth_request::{AuthRequest, AuthRequestId}; pub use self::cipher::{Cipher, CipherId, RepromptType}; pub use self::collection::{Collection, CollectionCipher, CollectionId, CollectionUser}; pub use self::device::{Device, DeviceId, DeviceType};