add send_id newtype

Stefan Melmuk 1 week ago
parent e361d2445b
commit eff2ea0d3f
No known key found for this signature in database
GPG Key ID: 817020C608FE9C09

@ -525,8 +525,8 @@ fn validate_keydata(
} }
// Check that we're correctly rotating all the user's sends // Check that we're correctly rotating all the user's sends
let existing_send_ids = existing_sends.iter().map(|s| s.uuid.as_str()).collect::<HashSet<_>>(); let existing_send_ids = existing_sends.iter().map(|s| &s.uuid).collect::<HashSet<&SendId>>();
let provided_send_ids = data.sends.iter().filter_map(|s| s.id.as_deref()).collect::<HashSet<_>>(); let provided_send_ids = data.sends.iter().filter_map(|s| s.id.as_ref()).collect::<HashSet<&SendId>>();
if !provided_send_ids.is_superset(&existing_send_ids) { if !provided_send_ids.is_superset(&existing_send_ids) {
err!("All existing sends must be included in the rotation") err!("All existing sends must be included in the rotation")
} }

@ -67,7 +67,7 @@ pub struct SendData {
file_length: Option<NumberOrString>, file_length: Option<NumberOrString>,
// Used for key rotations // Used for key rotations
pub id: Option<String>, pub id: Option<SendId>,
} }
/// Enforces the `Disable Send` policy. A non-owner/admin user belonging to /// Enforces the `Disable Send` policy. A non-owner/admin user belonging to
@ -158,8 +158,8 @@ async fn get_sends(headers: Headers, mut conn: DbConn) -> Json<Value> {
} }
#[get("/sends/<uuid>")] #[get("/sends/<uuid>")]
async fn get_send(uuid: &str, headers: Headers, mut conn: DbConn) -> JsonResult { async fn get_send(uuid: SendId, headers: Headers, mut conn: DbConn) -> JsonResult {
match Send::find_by_uuid_and_user(uuid, &headers.user.uuid, &mut conn).await { match Send::find_by_uuid_and_user(&uuid, &headers.user.uuid, &mut conn).await {
Some(send) => Ok(Json(send.to_json())), Some(send) => Ok(Json(send.to_json())),
None => err!("Send not found", "Invalid uuid or does not belong to user"), None => err!("Send not found", "Invalid uuid or does not belong to user"),
} }
@ -249,7 +249,7 @@ async fn post_send_file(data: Form<UploadData<'_>>, headers: Headers, mut conn:
err!("Send content is not a file"); err!("Send content is not a file");
} }
let file_id = crate::crypto::generate_send_id(); let file_id = crate::crypto::generate_send_file_id();
let folder_path = tokio::fs::canonicalize(&CONFIG.sends_folder()).await?.join(&send.uuid); let folder_path = tokio::fs::canonicalize(&CONFIG.sends_folder()).await?.join(&send.uuid);
let file_path = folder_path.join(&file_id); let file_path = folder_path.join(&file_id);
tokio::fs::create_dir_all(&folder_path).await?; tokio::fs::create_dir_all(&folder_path).await?;
@ -324,7 +324,7 @@ async fn post_send_file_v2(data: Json<SendData>, headers: Headers, mut conn: DbC
let mut send = create_send(data, headers.user.uuid)?; let mut send = create_send(data, headers.user.uuid)?;
let file_id = crate::crypto::generate_send_id(); let file_id = crate::crypto::generate_send_file_id();
let mut data_value: Value = serde_json::from_str(&send.data)?; let mut data_value: Value = serde_json::from_str(&send.data)?;
if let Some(o) = data_value.as_object_mut() { if let Some(o) = data_value.as_object_mut() {
@ -352,9 +352,9 @@ pub struct SendFileData {
} }
// https://github.com/bitwarden/server/blob/66f95d1c443490b653e5a15d32977e2f5a3f9e32/src/Api/Tools/Controllers/SendsController.cs#L250 // https://github.com/bitwarden/server/blob/66f95d1c443490b653e5a15d32977e2f5a3f9e32/src/Api/Tools/Controllers/SendsController.cs#L250
#[post("/sends/<send_uuid>/file/<file_id>", format = "multipart/form-data", data = "<data>")] #[post("/sends/<uuid>/file/<file_id>", format = "multipart/form-data", data = "<data>")]
async fn post_send_file_v2_data( async fn post_send_file_v2_data(
send_uuid: &str, uuid: SendId,
file_id: &str, file_id: &str,
data: Form<UploadDataV2<'_>>, data: Form<UploadDataV2<'_>>,
headers: Headers, headers: Headers,
@ -365,7 +365,7 @@ async fn post_send_file_v2_data(
let mut data = data.into_inner(); let mut data = data.into_inner();
let Some(send) = Send::find_by_uuid_and_user(send_uuid, &headers.user.uuid, &mut conn).await else { let Some(send) = Send::find_by_uuid_and_user(&uuid, &headers.user.uuid, &mut conn).await else {
err!("Send not found. Unable to save the file.", "Invalid uuid or does not belong to user.") err!("Send not found. Unable to save the file.", "Invalid uuid or does not belong to user.")
}; };
@ -402,7 +402,7 @@ async fn post_send_file_v2_data(
err!("Send file size does not match.", format!("Expected a file size of {} got {size}", send_data.size)); err!("Send file size does not match.", format!("Expected a file size of {} got {size}", send_data.size));
} }
let folder_path = tokio::fs::canonicalize(&CONFIG.sends_folder()).await?.join(send_uuid); let folder_path = tokio::fs::canonicalize(&CONFIG.sends_folder()).await?.join(uuid);
let file_path = folder_path.join(file_id); let file_path = folder_path.join(file_id);
// Check if the file already exists, if that is the case do not overwrite it // Check if the file already exists, if that is the case do not overwrite it
@ -493,16 +493,16 @@ async fn post_access(
Ok(Json(send.to_json_access(&mut conn).await)) Ok(Json(send.to_json_access(&mut conn).await))
} }
#[post("/sends/<send_id>/access/file/<file_id>", data = "<data>")] #[post("/sends/<uuid>/access/file/<file_id>", data = "<data>")]
async fn post_access_file( async fn post_access_file(
send_id: &str, uuid: SendId,
file_id: &str, file_id: &str,
data: Json<SendAccessData>, data: Json<SendAccessData>,
host: Host, host: Host,
mut conn: DbConn, mut conn: DbConn,
nt: Notify<'_>, nt: Notify<'_>,
) -> JsonResult { ) -> JsonResult {
let Some(mut send) = Send::find_by_uuid(send_id, &mut conn).await else { let Some(mut send) = Send::find_by_uuid(&uuid, &mut conn).await else {
err_code!(SEND_INACCESSIBLE_MSG, 404) err_code!(SEND_INACCESSIBLE_MSG, 404)
}; };
@ -547,33 +547,39 @@ async fn post_access_file(
) )
.await; .await;
let token_claims = crate::auth::generate_send_claims(send_id, file_id); let token_claims = crate::auth::generate_send_claims(&uuid, file_id);
let token = crate::auth::encode_jwt(&token_claims); let token = crate::auth::encode_jwt(&token_claims);
Ok(Json(json!({ Ok(Json(json!({
"object": "send-fileDownload", "object": "send-fileDownload",
"id": file_id, "id": file_id,
"url": format!("{}/api/sends/{}/{}?t={}", &host.host, send_id, file_id, token) "url": format!("{}/api/sends/{}/{}?t={}", &host.host, uuid, file_id, token)
}))) })))
} }
#[get("/sends/<send_id>/<file_id>?<t>")] #[get("/sends/<uuid>/<file_id>?<t>")]
async fn download_send(send_id: SafeString, file_id: SafeString, t: &str) -> Option<NamedFile> { async fn download_send(uuid: SendId, file_id: SafeString, t: &str) -> Option<NamedFile> {
if let Ok(claims) = crate::auth::decode_send(t) { if let Ok(claims) = crate::auth::decode_send(t) {
if claims.sub == format!("{send_id}/{file_id}") { if claims.sub == format!("{uuid}/{file_id}") {
return NamedFile::open(Path::new(&CONFIG.sends_folder()).join(send_id).join(file_id)).await.ok(); return NamedFile::open(Path::new(&CONFIG.sends_folder()).join(uuid).join(file_id)).await.ok();
} }
} }
None None
} }
#[put("/sends/<uuid>", data = "<data>")] #[put("/sends/<uuid>", data = "<data>")]
async fn put_send(uuid: &str, data: Json<SendData>, headers: Headers, mut conn: DbConn, nt: Notify<'_>) -> JsonResult { async fn put_send(
uuid: SendId,
data: Json<SendData>,
headers: Headers,
mut conn: DbConn,
nt: Notify<'_>,
) -> JsonResult {
enforce_disable_send_policy(&headers, &mut conn).await?; enforce_disable_send_policy(&headers, &mut conn).await?;
let data: SendData = data.into_inner(); let data: SendData = data.into_inner();
enforce_disable_hide_email_policy(&data, &headers, &mut conn).await?; enforce_disable_hide_email_policy(&data, &headers, &mut conn).await?;
let Some(mut send) = Send::find_by_uuid_and_user(uuid, &headers.user.uuid, &mut conn).await else { let Some(mut send) = Send::find_by_uuid_and_user(&uuid, &headers.user.uuid, &mut conn).await else {
err!("Send not found", "Send uuid is invalid or does not belong to user") err!("Send not found", "Send uuid is invalid or does not belong to user")
}; };
@ -641,8 +647,8 @@ pub async fn update_send_from_data(
} }
#[delete("/sends/<uuid>")] #[delete("/sends/<uuid>")]
async fn delete_send(uuid: &str, headers: Headers, mut conn: DbConn, nt: Notify<'_>) -> EmptyResult { async fn delete_send(uuid: SendId, headers: Headers, mut conn: DbConn, nt: Notify<'_>) -> EmptyResult {
let Some(send) = Send::find_by_uuid_and_user(uuid, &headers.user.uuid, &mut conn).await else { let Some(send) = Send::find_by_uuid_and_user(&uuid, &headers.user.uuid, &mut conn).await else {
err!("Send not found", "Invalid send uuid, or does not belong to user") err!("Send not found", "Invalid send uuid, or does not belong to user")
}; };
@ -660,10 +666,10 @@ async fn delete_send(uuid: &str, headers: Headers, mut conn: DbConn, nt: Notify<
} }
#[put("/sends/<uuid>/remove-password")] #[put("/sends/<uuid>/remove-password")]
async fn put_remove_password(uuid: &str, headers: Headers, mut conn: DbConn, nt: Notify<'_>) -> JsonResult { async fn put_remove_password(uuid: SendId, headers: Headers, mut conn: DbConn, nt: Notify<'_>) -> JsonResult {
enforce_disable_send_policy(&headers, &mut conn).await?; enforce_disable_send_policy(&headers, &mut conn).await?;
let Some(mut send) = Send::find_by_uuid_and_user(uuid, &headers.user.uuid, &mut conn).await else { let Some(mut send) = Send::find_by_uuid_and_user(&uuid, &headers.user.uuid, &mut conn).await else {
err!("Send not found", "Invalid send uuid, or does not belong to user") err!("Send not found", "Invalid send uuid, or does not belong to user")
}; };

@ -474,7 +474,7 @@ impl WebSocketUsers {
let data = create_update( let data = create_update(
vec![ vec![
("Id".into(), send.uuid.clone().into()), ("Id".into(), send.uuid.to_string().into()),
("UserId".into(), user_uuid), ("UserId".into(), user_uuid),
("RevisionDate".into(), serialize_date(send.revision_date)), ("RevisionDate".into(), serialize_date(send.revision_date)),
], ],

@ -15,7 +15,7 @@ use std::{
}; };
use crate::db::models::{ use crate::db::models::{
AttachmentId, CipherId, CollectionId, DeviceId, MembershipId, OrgApiKeyId, OrganizationId, UserId, AttachmentId, CipherId, CollectionId, DeviceId, MembershipId, OrgApiKeyId, OrganizationId, SendId, UserId,
}; };
use crate::{error::Error, CONFIG}; use crate::{error::Error, CONFIG};
@ -358,13 +358,13 @@ pub fn generate_admin_claims() -> BasicJwtClaims {
} }
} }
pub fn generate_send_claims(send_id: &str, file_id: &str) -> BasicJwtClaims { pub fn generate_send_claims(uuid: &SendId, file_id: &str) -> BasicJwtClaims {
let time_now = Utc::now(); let time_now = Utc::now();
BasicJwtClaims { BasicJwtClaims {
nbf: time_now.timestamp(), nbf: time_now.timestamp(),
exp: (time_now + TimeDelta::try_minutes(2).unwrap()).timestamp(), exp: (time_now + TimeDelta::try_minutes(2).unwrap()).timestamp(),
iss: JWT_SEND_ISSUER.to_string(), iss: JWT_SEND_ISSUER.to_string(),
sub: format!("{send_id}/{file_id}"), sub: format!("{uuid}/{file_id}"),
} }
} }

@ -84,8 +84,8 @@ pub fn generate_id<const N: usize>() -> String {
encode_random_bytes::<N>(HEXLOWER) encode_random_bytes::<N>(HEXLOWER)
} }
pub fn generate_send_id() -> String { pub fn generate_send_file_id() -> String {
// Send IDs are globally scoped, so make them longer to avoid collisions. // Send File IDs are globally scoped, so make them longer to avoid collisions.
generate_id::<32>() // 256 bits generate_id::<32>() // 256 bits
} }

@ -31,7 +31,7 @@ pub use self::organization::{
Membership, MembershipId, MembershipStatus, MembershipType, OrgApiKeyId, Organization, OrganizationApiKey, Membership, MembershipId, MembershipStatus, MembershipType, OrgApiKeyId, Organization, OrganizationApiKey,
OrganizationId, OrganizationId,
}; };
pub use self::send::{Send, SendType}; pub use self::send::{id::SendId, Send, SendType};
pub use self::two_factor::{TwoFactor, TwoFactorType}; pub use self::two_factor::{TwoFactor, TwoFactorType};
pub use self::two_factor_duo_context::TwoFactorDuoContext; pub use self::two_factor_duo_context::TwoFactorDuoContext;
pub use self::two_factor_incomplete::TwoFactorIncomplete; pub use self::two_factor_incomplete::TwoFactorIncomplete;

@ -4,6 +4,7 @@ use serde_json::Value;
use crate::util::LowerCase; use crate::util::LowerCase;
use super::{OrganizationId, User, UserId}; use super::{OrganizationId, User, UserId};
use id::SendId;
db_object! { db_object! {
#[derive(Identifiable, Queryable, Insertable, AsChangeset)] #[derive(Identifiable, Queryable, Insertable, AsChangeset)]
@ -11,7 +12,7 @@ db_object! {
#[diesel(treat_none_as_null = true)] #[diesel(treat_none_as_null = true)]
#[diesel(primary_key(uuid))] #[diesel(primary_key(uuid))]
pub struct Send { pub struct Send {
pub uuid: String, pub uuid: SendId,
pub user_uuid: Option<UserId>, pub user_uuid: Option<UserId>,
pub organization_uuid: Option<OrganizationId>, pub organization_uuid: Option<OrganizationId>,
@ -50,7 +51,7 @@ impl Send {
let now = Utc::now().naive_utc(); let now = Utc::now().naive_utc();
Self { Self {
uuid: crate::util::get_uuid(), uuid: SendId::from(crate::util::get_uuid()),
user_uuid: None, user_uuid: None,
organization_uuid: None, organization_uuid: None,
@ -272,14 +273,14 @@ impl Send {
}; };
let uuid = match Uuid::from_slice(&uuid_vec) { let uuid = match Uuid::from_slice(&uuid_vec) {
Ok(u) => u.to_string(), Ok(u) => SendId::from(u.to_string()),
Err(_) => return None, Err(_) => return None,
}; };
Self::find_by_uuid(&uuid, conn).await Self::find_by_uuid(&uuid, conn).await
} }
pub async fn find_by_uuid(uuid: &str, conn: &mut DbConn) -> Option<Self> { pub async fn find_by_uuid(uuid: &SendId, conn: &mut DbConn) -> Option<Self> {
db_run! {conn: { db_run! {conn: {
sends::table sends::table
.filter(sends::uuid.eq(uuid)) .filter(sends::uuid.eq(uuid))
@ -289,7 +290,7 @@ impl Send {
}} }}
} }
pub async fn find_by_uuid_and_user(uuid: &str, user_uuid: &UserId, conn: &mut DbConn) -> Option<Self> { pub async fn find_by_uuid_and_user(uuid: &SendId, user_uuid: &UserId, conn: &mut DbConn) -> Option<Self> {
db_run! {conn: { db_run! {conn: {
sends::table sends::table
.filter(sends::uuid.eq(uuid)) .filter(sends::uuid.eq(uuid))
@ -348,3 +349,35 @@ impl Send {
}} }}
} }
} }
// separate namespace to avoid name collision with std::marker::Send
pub mod id {
use derive_more::{AsRef, Deref, Display, From};
use rocket::request::FromParam;
use std::marker::Send;
use std::path::Path;
#[derive(
Clone, Debug, AsRef, Deref, DieselNewType, Display, From, FromForm, Hash, PartialEq, Eq, Serialize, Deserialize,
)]
pub struct SendId(String);
impl AsRef<Path> for SendId {
#[inline]
fn as_ref(&self) -> &Path {
Path::new(&self.0)
}
}
impl<'r> FromParam<'r> for SendId {
type Error = ();
#[inline(always)]
fn from_param(param: &'r str) -> Result<Self, Self::Error> {
if param.chars().all(|c| matches!(c, 'a'..='z' | 'A'..='Z' |'0'..='9' | '-')) {
Ok(Self(param.to_string()))
} else {
Err(())
}
}
}
}

Loading…
Cancel
Save