From eff2ea0d3f34928f9c7bf86ced64c4b4c620f51d Mon Sep 17 00:00:00 2001 From: Stefan Melmuk Date: Mon, 23 Dec 2024 08:22:15 +0100 Subject: [PATCH] add send_id newtype --- src/api/core/accounts.rs | 4 +-- src/api/core/sends.rs | 54 ++++++++++++++++++++++------------------ src/api/notifications.rs | 2 +- src/auth.rs | 6 ++--- src/crypto.rs | 4 +-- src/db/models/mod.rs | 2 +- src/db/models/send.rs | 43 ++++++++++++++++++++++++++++---- 7 files changed, 77 insertions(+), 38 deletions(-) diff --git a/src/api/core/accounts.rs b/src/api/core/accounts.rs index ae199abe..2c97c6dc 100644 --- a/src/api/core/accounts.rs +++ b/src/api/core/accounts.rs @@ -525,8 +525,8 @@ fn validate_keydata( } // Check that we're correctly rotating all the user's sends - let existing_send_ids = existing_sends.iter().map(|s| s.uuid.as_str()).collect::>(); - let provided_send_ids = data.sends.iter().filter_map(|s| s.id.as_deref()).collect::>(); + let existing_send_ids = existing_sends.iter().map(|s| &s.uuid).collect::>(); + let provided_send_ids = data.sends.iter().filter_map(|s| s.id.as_ref()).collect::>(); if !provided_send_ids.is_superset(&existing_send_ids) { err!("All existing sends must be included in the rotation") } diff --git a/src/api/core/sends.rs b/src/api/core/sends.rs index ee3c4d9c..9a7700f0 100644 --- a/src/api/core/sends.rs +++ b/src/api/core/sends.rs @@ -67,7 +67,7 @@ pub struct SendData { file_length: Option, // Used for key rotations - pub id: Option, + pub id: Option, } /// Enforces the `Disable Send` policy. A non-owner/admin user belonging to @@ -158,8 +158,8 @@ async fn get_sends(headers: Headers, mut conn: DbConn) -> Json { } #[get("/sends/")] -async fn get_send(uuid: &str, headers: Headers, mut conn: DbConn) -> JsonResult { - match Send::find_by_uuid_and_user(uuid, &headers.user.uuid, &mut conn).await { +async fn get_send(uuid: SendId, headers: Headers, mut conn: DbConn) -> JsonResult { + match Send::find_by_uuid_and_user(&uuid, &headers.user.uuid, &mut conn).await { Some(send) => Ok(Json(send.to_json())), None => err!("Send not found", "Invalid uuid or does not belong to user"), } @@ -249,7 +249,7 @@ async fn post_send_file(data: Form>, headers: Headers, mut conn: err!("Send content is not a file"); } - let file_id = crate::crypto::generate_send_id(); + let file_id = crate::crypto::generate_send_file_id(); let folder_path = tokio::fs::canonicalize(&CONFIG.sends_folder()).await?.join(&send.uuid); let file_path = folder_path.join(&file_id); tokio::fs::create_dir_all(&folder_path).await?; @@ -324,7 +324,7 @@ async fn post_send_file_v2(data: Json, headers: Headers, mut conn: DbC let mut send = create_send(data, headers.user.uuid)?; - let file_id = crate::crypto::generate_send_id(); + let file_id = crate::crypto::generate_send_file_id(); let mut data_value: Value = serde_json::from_str(&send.data)?; if let Some(o) = data_value.as_object_mut() { @@ -352,9 +352,9 @@ pub struct SendFileData { } // https://github.com/bitwarden/server/blob/66f95d1c443490b653e5a15d32977e2f5a3f9e32/src/Api/Tools/Controllers/SendsController.cs#L250 -#[post("/sends//file/", format = "multipart/form-data", data = "")] +#[post("/sends//file/", format = "multipart/form-data", data = "")] async fn post_send_file_v2_data( - send_uuid: &str, + uuid: SendId, file_id: &str, data: Form>, headers: Headers, @@ -365,7 +365,7 @@ async fn post_send_file_v2_data( let mut data = data.into_inner(); - let Some(send) = Send::find_by_uuid_and_user(send_uuid, &headers.user.uuid, &mut conn).await else { + let Some(send) = Send::find_by_uuid_and_user(&uuid, &headers.user.uuid, &mut conn).await else { err!("Send not found. Unable to save the file.", "Invalid uuid or does not belong to user.") }; @@ -402,7 +402,7 @@ async fn post_send_file_v2_data( err!("Send file size does not match.", format!("Expected a file size of {} got {size}", send_data.size)); } - let folder_path = tokio::fs::canonicalize(&CONFIG.sends_folder()).await?.join(send_uuid); + let folder_path = tokio::fs::canonicalize(&CONFIG.sends_folder()).await?.join(uuid); let file_path = folder_path.join(file_id); // Check if the file already exists, if that is the case do not overwrite it @@ -493,16 +493,16 @@ async fn post_access( Ok(Json(send.to_json_access(&mut conn).await)) } -#[post("/sends//access/file/", data = "")] +#[post("/sends//access/file/", data = "")] async fn post_access_file( - send_id: &str, + uuid: SendId, file_id: &str, data: Json, host: Host, mut conn: DbConn, nt: Notify<'_>, ) -> JsonResult { - let Some(mut send) = Send::find_by_uuid(send_id, &mut conn).await else { + let Some(mut send) = Send::find_by_uuid(&uuid, &mut conn).await else { err_code!(SEND_INACCESSIBLE_MSG, 404) }; @@ -547,33 +547,39 @@ async fn post_access_file( ) .await; - let token_claims = crate::auth::generate_send_claims(send_id, file_id); + let token_claims = crate::auth::generate_send_claims(&uuid, file_id); let token = crate::auth::encode_jwt(&token_claims); Ok(Json(json!({ "object": "send-fileDownload", "id": file_id, - "url": format!("{}/api/sends/{}/{}?t={}", &host.host, send_id, file_id, token) + "url": format!("{}/api/sends/{}/{}?t={}", &host.host, uuid, file_id, token) }))) } -#[get("/sends//?")] -async fn download_send(send_id: SafeString, file_id: SafeString, t: &str) -> Option { +#[get("/sends//?")] +async fn download_send(uuid: SendId, file_id: SafeString, t: &str) -> Option { if let Ok(claims) = crate::auth::decode_send(t) { - if claims.sub == format!("{send_id}/{file_id}") { - return NamedFile::open(Path::new(&CONFIG.sends_folder()).join(send_id).join(file_id)).await.ok(); + if claims.sub == format!("{uuid}/{file_id}") { + return NamedFile::open(Path::new(&CONFIG.sends_folder()).join(uuid).join(file_id)).await.ok(); } } None } #[put("/sends/", data = "")] -async fn put_send(uuid: &str, data: Json, headers: Headers, mut conn: DbConn, nt: Notify<'_>) -> JsonResult { +async fn put_send( + uuid: SendId, + data: Json, + headers: Headers, + mut conn: DbConn, + nt: Notify<'_>, +) -> JsonResult { enforce_disable_send_policy(&headers, &mut conn).await?; let data: SendData = data.into_inner(); enforce_disable_hide_email_policy(&data, &headers, &mut conn).await?; - let Some(mut send) = Send::find_by_uuid_and_user(uuid, &headers.user.uuid, &mut conn).await else { + let Some(mut send) = Send::find_by_uuid_and_user(&uuid, &headers.user.uuid, &mut conn).await else { err!("Send not found", "Send uuid is invalid or does not belong to user") }; @@ -641,8 +647,8 @@ pub async fn update_send_from_data( } #[delete("/sends/")] -async fn delete_send(uuid: &str, headers: Headers, mut conn: DbConn, nt: Notify<'_>) -> EmptyResult { - let Some(send) = Send::find_by_uuid_and_user(uuid, &headers.user.uuid, &mut conn).await else { +async fn delete_send(uuid: SendId, headers: Headers, mut conn: DbConn, nt: Notify<'_>) -> EmptyResult { + let Some(send) = Send::find_by_uuid_and_user(&uuid, &headers.user.uuid, &mut conn).await else { err!("Send not found", "Invalid send uuid, or does not belong to user") }; @@ -660,10 +666,10 @@ async fn delete_send(uuid: &str, headers: Headers, mut conn: DbConn, nt: Notify< } #[put("/sends//remove-password")] -async fn put_remove_password(uuid: &str, headers: Headers, mut conn: DbConn, nt: Notify<'_>) -> JsonResult { +async fn put_remove_password(uuid: SendId, headers: Headers, mut conn: DbConn, nt: Notify<'_>) -> JsonResult { enforce_disable_send_policy(&headers, &mut conn).await?; - let Some(mut send) = Send::find_by_uuid_and_user(uuid, &headers.user.uuid, &mut conn).await else { + let Some(mut send) = Send::find_by_uuid_and_user(&uuid, &headers.user.uuid, &mut conn).await else { err!("Send not found", "Invalid send uuid, or does not belong to user") }; diff --git a/src/api/notifications.rs b/src/api/notifications.rs index 476f5341..f9b45f00 100644 --- a/src/api/notifications.rs +++ b/src/api/notifications.rs @@ -474,7 +474,7 @@ impl WebSocketUsers { let data = create_update( vec![ - ("Id".into(), send.uuid.clone().into()), + ("Id".into(), send.uuid.to_string().into()), ("UserId".into(), user_uuid), ("RevisionDate".into(), serialize_date(send.revision_date)), ], diff --git a/src/auth.rs b/src/auth.rs index 2fcd9740..74c87a54 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -15,7 +15,7 @@ use std::{ }; use crate::db::models::{ - AttachmentId, CipherId, CollectionId, DeviceId, MembershipId, OrgApiKeyId, OrganizationId, UserId, + AttachmentId, CipherId, CollectionId, DeviceId, MembershipId, OrgApiKeyId, OrganizationId, SendId, UserId, }; use crate::{error::Error, CONFIG}; @@ -358,13 +358,13 @@ pub fn generate_admin_claims() -> BasicJwtClaims { } } -pub fn generate_send_claims(send_id: &str, file_id: &str) -> BasicJwtClaims { +pub fn generate_send_claims(uuid: &SendId, file_id: &str) -> BasicJwtClaims { let time_now = Utc::now(); BasicJwtClaims { nbf: time_now.timestamp(), exp: (time_now + TimeDelta::try_minutes(2).unwrap()).timestamp(), iss: JWT_SEND_ISSUER.to_string(), - sub: format!("{send_id}/{file_id}"), + sub: format!("{uuid}/{file_id}"), } } diff --git a/src/crypto.rs b/src/crypto.rs index c9db1a4b..eff1785f 100644 --- a/src/crypto.rs +++ b/src/crypto.rs @@ -84,8 +84,8 @@ pub fn generate_id() -> String { encode_random_bytes::(HEXLOWER) } -pub fn generate_send_id() -> String { - // Send IDs are globally scoped, so make them longer to avoid collisions. +pub fn generate_send_file_id() -> String { + // Send File IDs are globally scoped, so make them longer to avoid collisions. generate_id::<32>() // 256 bits } diff --git a/src/db/models/mod.rs b/src/db/models/mod.rs index 2817fb56..b03b24d6 100644 --- a/src/db/models/mod.rs +++ b/src/db/models/mod.rs @@ -31,7 +31,7 @@ pub use self::organization::{ Membership, MembershipId, MembershipStatus, MembershipType, OrgApiKeyId, Organization, OrganizationApiKey, OrganizationId, }; -pub use self::send::{Send, SendType}; +pub use self::send::{id::SendId, Send, SendType}; pub use self::two_factor::{TwoFactor, TwoFactorType}; pub use self::two_factor_duo_context::TwoFactorDuoContext; pub use self::two_factor_incomplete::TwoFactorIncomplete; diff --git a/src/db/models/send.rs b/src/db/models/send.rs index 8cb27367..3ea8b660 100644 --- a/src/db/models/send.rs +++ b/src/db/models/send.rs @@ -4,6 +4,7 @@ use serde_json::Value; use crate::util::LowerCase; use super::{OrganizationId, User, UserId}; +use id::SendId; db_object! { #[derive(Identifiable, Queryable, Insertable, AsChangeset)] @@ -11,7 +12,7 @@ db_object! { #[diesel(treat_none_as_null = true)] #[diesel(primary_key(uuid))] pub struct Send { - pub uuid: String, + pub uuid: SendId, pub user_uuid: Option, pub organization_uuid: Option, @@ -50,7 +51,7 @@ impl Send { let now = Utc::now().naive_utc(); Self { - uuid: crate::util::get_uuid(), + uuid: SendId::from(crate::util::get_uuid()), user_uuid: None, organization_uuid: None, @@ -272,14 +273,14 @@ impl Send { }; let uuid = match Uuid::from_slice(&uuid_vec) { - Ok(u) => u.to_string(), + Ok(u) => SendId::from(u.to_string()), Err(_) => return None, }; Self::find_by_uuid(&uuid, conn).await } - pub async fn find_by_uuid(uuid: &str, conn: &mut DbConn) -> Option { + pub async fn find_by_uuid(uuid: &SendId, conn: &mut DbConn) -> Option { db_run! {conn: { sends::table .filter(sends::uuid.eq(uuid)) @@ -289,7 +290,7 @@ impl Send { }} } - pub async fn find_by_uuid_and_user(uuid: &str, user_uuid: &UserId, conn: &mut DbConn) -> Option { + pub async fn find_by_uuid_and_user(uuid: &SendId, user_uuid: &UserId, conn: &mut DbConn) -> Option { db_run! {conn: { sends::table .filter(sends::uuid.eq(uuid)) @@ -348,3 +349,35 @@ impl Send { }} } } + +// separate namespace to avoid name collision with std::marker::Send +pub mod id { + use derive_more::{AsRef, Deref, Display, From}; + use rocket::request::FromParam; + use std::marker::Send; + use std::path::Path; + #[derive( + Clone, Debug, AsRef, Deref, DieselNewType, Display, From, FromForm, Hash, PartialEq, Eq, Serialize, Deserialize, + )] + pub struct SendId(String); + + impl AsRef for SendId { + #[inline] + fn as_ref(&self) -> &Path { + Path::new(&self.0) + } + } + + impl<'r> FromParam<'r> for SendId { + type Error = (); + + #[inline(always)] + fn from_param(param: &'r str) -> Result { + if param.chars().all(|c| matches!(c, 'a'..='z' | 'A'..='Z' |'0'..='9' | '-')) { + Ok(Self(param.to_string())) + } else { + Err(()) + } + } + } +}