From 7f2db91f2332c0a393e98b7da1ef6a9b930c509d Mon Sep 17 00:00:00 2001 From: Stefan Melmuk Date: Sat, 21 Dec 2024 12:58:57 +0100 Subject: [PATCH] introduce collection_id newtype --- src/api/core/ciphers.rs | 31 ++++----- src/api/core/organizations.rs | 122 +++++++++++++++++++--------------- src/api/notifications.rs | 6 +- src/auth.rs | 12 ++-- src/db/models/cipher.rs | 23 ++++--- src/db/models/collection.rs | 98 ++++++++++++++++++++++----- src/db/models/group.rs | 12 ++-- src/db/models/mod.rs | 2 +- src/db/models/organization.rs | 12 ++-- 9 files changed, 201 insertions(+), 117 deletions(-) diff --git a/src/api/core/ciphers.rs b/src/api/core/ciphers.rs index 21dbf08c..3fbddc45 100644 --- a/src/api/core/ciphers.rs +++ b/src/api/core/ciphers.rs @@ -368,7 +368,7 @@ pub async fn update_cipher_from_data( cipher: &mut Cipher, data: CipherData, headers: &Headers, - shared_to_collections: Option>, + shared_to_collections: Option>, conn: &mut DbConn, nt: &Notify<'_>, ut: UpdateType, @@ -710,7 +710,7 @@ async fn put_cipher_partial( #[serde(rename_all = "camelCase")] struct CollectionsAdminData { #[serde(alias = "CollectionIds")] - collection_ids: Vec, + collection_ids: Vec, } #[put("/ciphers//collections_v2", data = "")] @@ -769,9 +769,9 @@ async fn post_collections_update( err!("Cipher is not write accessible") } - let posted_collections = HashSet::::from_iter(data.collection_ids); + let posted_collections = HashSet::::from_iter(data.collection_ids); let current_collections = - HashSet::::from_iter(cipher.get_collections(headers.user.uuid.to_string(), &mut conn).await); + HashSet::::from_iter(cipher.get_collections(headers.user.uuid.to_string(), &mut conn).await); for collection in posted_collections.symmetric_difference(¤t_collections) { match Collection::find_by_uuid_and_org(collection, cipher.organization_uuid.as_ref().unwrap(), &mut conn).await @@ -846,9 +846,10 @@ async fn post_collections_admin( err!("Cipher is not write accessible") } - let posted_collections = HashSet::::from_iter(data.collection_ids); - let current_collections = - HashSet::::from_iter(cipher.get_admin_collections(headers.user.uuid.to_string(), &mut conn).await); + let posted_collections = HashSet::::from_iter(data.collection_ids); + let current_collections = HashSet::::from_iter( + cipher.get_admin_collections(headers.user.uuid.to_string(), &mut conn).await, + ); for collection in posted_collections.symmetric_difference(¤t_collections) { match Collection::find_by_uuid_and_org(collection, cipher.organization_uuid.as_ref().unwrap(), &mut conn).await @@ -900,7 +901,7 @@ struct ShareCipherData { #[serde(alias = "Cipher")] cipher: CipherData, #[serde(alias = "CollectionIds")] - collection_ids: Vec, + collection_ids: Vec, } #[post("/ciphers//share", data = "")] @@ -933,7 +934,7 @@ async fn put_cipher_share( #[serde(rename_all = "camelCase")] struct ShareSelectedCipherData { ciphers: Vec, - collection_ids: Vec, + collection_ids: Vec, } #[put("/ciphers/share", data = "")] @@ -1834,10 +1835,10 @@ pub struct CipherSyncData { pub cipher_attachments: HashMap>, pub cipher_folders: HashMap, pub cipher_favorites: HashSet, - pub cipher_collections: HashMap>, + pub cipher_collections: HashMap>, pub members: HashMap, - pub user_collections: HashMap, - pub user_collections_groups: HashMap, + pub user_collections: HashMap, + pub user_collections_groups: HashMap, pub user_group_full_access_for_organizations: HashSet, } @@ -1878,7 +1879,7 @@ impl CipherSyncData { // Generate a HashMap with the Cipher UUID as key and one or more Collection UUID's let user_cipher_collections = Cipher::get_collections_with_cipher_by_user(user_uuid.to_string(), conn).await; - let mut cipher_collections: HashMap> = + let mut cipher_collections: HashMap> = HashMap::with_capacity(user_cipher_collections.len()); for (cipher, collection) in user_cipher_collections { cipher_collections.entry(cipher).or_default().push(collection); @@ -1889,14 +1890,14 @@ impl CipherSyncData { Membership::find_by_user(user_uuid, conn).await.into_iter().map(|m| (m.org_uuid.clone(), m)).collect(); // Generate a HashMap with the User_Collections UUID as key and the CollectionUser record - let user_collections: HashMap = CollectionUser::find_by_user(user_uuid, conn) + let user_collections: HashMap = CollectionUser::find_by_user(user_uuid, conn) .await .into_iter() .map(|uc| (uc.collection_uuid.clone(), uc)) .collect(); // Generate a HashMap with the collections_uuid as key and the CollectionGroup record - let user_collections_groups: HashMap = if CONFIG.org_groups_enabled() { + let user_collections_groups: HashMap = if CONFIG.org_groups_enabled() { CollectionGroup::find_by_user(user_uuid, conn) .await .into_iter() diff --git a/src/api/core/organizations.rs b/src/api/core/organizations.rs index d9a8e644..b7ac7dcb 100644 --- a/src/api/core/organizations.rs +++ b/src/api/core/organizations.rs @@ -126,7 +126,7 @@ struct NewCollectionData { name: String, groups: Vec, users: Vec, - id: Option, + id: Option, external_id: Option, } @@ -340,7 +340,7 @@ async fn get_org_collections_details( }; // get all collection memberships for the current organization - let coll_users = CollectionUser::find_by_organization(&org_id, &mut conn).await; + let col_users = CollectionUser::find_by_organization(&org_id, &mut conn).await; // check if current user has full access to the organization (either directly or via any group) let has_full_access_to_org = member.access_all @@ -355,7 +355,7 @@ async fn get_org_collections_details( && GroupUser::has_access_to_collection_by_member(&col.uuid, &member.uuid, &mut conn).await); // get the users assigned directly to the given collection - let users: Vec = coll_users + let users: Vec = col_users .iter() .filter(|collection_user| collection_user.collection_uuid == col.uuid) .map(|collection_user| UserSelection::to_collection_user_details_read_only(collection_user).to_json()) @@ -450,7 +450,7 @@ async fn post_organization_collections( #[put("/organizations//collections/", data = "")] async fn put_organization_collection_update( org_id: OrganizationId, - col_id: &str, + col_id: CollectionId, headers: ManagerHeaders, data: Json, conn: DbConn, @@ -461,7 +461,7 @@ async fn put_organization_collection_update( #[post("/organizations//collections/", data = "")] async fn post_organization_collection_update( org_id: OrganizationId, - col_id: &str, + col_id: CollectionId, headers: ManagerHeaders, data: Json, mut conn: DbConn, @@ -472,7 +472,7 @@ async fn post_organization_collection_update( err!("Can't find organization details") }; - let Some(mut collection) = Collection::find_by_uuid_and_org(col_id, &org_id, &mut conn).await else { + let Some(mut collection) = Collection::find_by_uuid_and_org(&col_id, &org_id, &mut conn).await else { err!("Collection not found") }; @@ -495,15 +495,13 @@ async fn post_organization_collection_update( ) .await; - CollectionGroup::delete_all_by_collection(col_id, &mut conn).await?; + CollectionGroup::delete_all_by_collection(&col_id, &mut conn).await?; for group in data.groups { - CollectionGroup::new(String::from(col_id), group.id, group.read_only, group.hide_passwords) - .save(&mut conn) - .await?; + CollectionGroup::new(col_id.clone(), group.id, group.read_only, group.hide_passwords).save(&mut conn).await?; } - CollectionUser::delete_all_by_collection(col_id, &mut conn).await?; + CollectionUser::delete_all_by_collection(&col_id, &mut conn).await?; for user in data.users { let Some(member) = Membership::find_by_uuid_and_org(&user.id, &org_id, &mut conn).await else { @@ -514,7 +512,7 @@ async fn post_organization_collection_update( continue; } - CollectionUser::save(&member.user_uuid, col_id, user.read_only, user.hide_passwords, &mut conn).await?; + CollectionUser::save(&member.user_uuid, &col_id, user.read_only, user.hide_passwords, &mut conn).await?; } Ok(Json(collection.to_json_details(&headers.user.uuid, None, &mut conn).await)) @@ -523,12 +521,12 @@ async fn post_organization_collection_update( #[delete("/organizations//collections//user/")] async fn delete_organization_collection_user( org_id: OrganizationId, - col_id: &str, + col_id: CollectionId, member_id: MembershipId, _headers: AdminHeaders, mut conn: DbConn, ) -> EmptyResult { - let Some(collection) = Collection::find_by_uuid_and_org(col_id, &org_id, &mut conn).await else { + let Some(collection) = Collection::find_by_uuid_and_org(&col_id, &org_id, &mut conn).await else { err!("Collection not found", "Collection does not exist or does not belong to this organization") }; @@ -546,7 +544,7 @@ async fn delete_organization_collection_user( #[post("/organizations//collections//delete-user/")] async fn post_organization_collection_delete_user( org_id: OrganizationId, - col_id: &str, + col_id: CollectionId, member_id: MembershipId, headers: AdminHeaders, conn: DbConn, @@ -556,7 +554,7 @@ async fn post_organization_collection_delete_user( async fn _delete_organization_collection( org_id: &OrganizationId, - col_id: &str, + col_id: &CollectionId, headers: &ManagerHeaders, conn: &mut DbConn, ) -> EmptyResult { @@ -579,11 +577,11 @@ async fn _delete_organization_collection( #[delete("/organizations//collections/")] async fn delete_organization_collection( org_id: OrganizationId, - col_id: &str, + col_id: CollectionId, headers: ManagerHeaders, mut conn: DbConn, ) -> EmptyResult { - _delete_organization_collection(&org_id, col_id, &headers, &mut conn).await + _delete_organization_collection(&org_id, &col_id, &headers, &mut conn).await } #[derive(Deserialize, Debug)] @@ -598,17 +596,17 @@ struct DeleteCollectionData { #[post("/organizations//collections//delete")] async fn post_organization_collection_delete( org_id: OrganizationId, - col_id: &str, + col_id: CollectionId, headers: ManagerHeaders, mut conn: DbConn, ) -> EmptyResult { - _delete_organization_collection(&org_id, col_id, &headers, &mut conn).await + _delete_organization_collection(&org_id, &col_id, &headers, &mut conn).await } #[derive(Deserialize, Debug)] #[serde(rename_all = "camelCase")] struct BulkCollectionIds { - ids: Vec, + ids: Vec, } #[delete("/organizations//collections", data = "")] @@ -630,14 +628,14 @@ async fn bulk_delete_organization_collections( Ok(()) } -#[get("/organizations//collections//details")] +#[get("/organizations//collections//details")] async fn get_org_collection_detail( org_id: OrganizationId, - coll_id: &str, + col_id: CollectionId, headers: ManagerHeaders, mut conn: DbConn, ) -> JsonResult { - match Collection::find_by_uuid_and_user(coll_id, headers.user.uuid.clone(), &mut conn).await { + match Collection::find_by_uuid_and_user(&col_id, headers.user.uuid.clone(), &mut conn).await { None => err!("Collection not found"), Some(collection) => { if collection.org_uuid != org_id { @@ -684,15 +682,15 @@ async fn get_org_collection_detail( } } -#[get("/organizations//collections//users")] +#[get("/organizations//collections//users")] async fn get_collection_users( org_id: OrganizationId, - coll_id: &str, + col_id: CollectionId, _headers: ManagerHeaders, mut conn: DbConn, ) -> JsonResult { // Get org and collection, check that collection is from org - let Some(collection) = Collection::find_by_uuid_and_org(coll_id, &org_id, &mut conn).await else { + let Some(collection) = Collection::find_by_uuid_and_org(&col_id, &org_id, &mut conn).await else { err!("Collection not found in Organization") }; @@ -709,21 +707,21 @@ async fn get_collection_users( Ok(Json(json!(user_list))) } -#[put("/organizations//collections//users", data = "")] +#[put("/organizations//collections//users", data = "")] async fn put_collection_users( org_id: OrganizationId, - coll_id: &str, - data: Json>, + col_id: CollectionId, + data: Json>, _headers: ManagerHeaders, mut conn: DbConn, ) -> EmptyResult { // Get org and collection, check that collection is from org - if Collection::find_by_uuid_and_org(coll_id, &org_id, &mut conn).await.is_none() { + if Collection::find_by_uuid_and_org(&col_id, &org_id, &mut conn).await.is_none() { err!("Collection not found in Organization") } // Delete all the user-collections - CollectionUser::delete_all_by_collection(coll_id, &mut conn).await?; + CollectionUser::delete_all_by_collection(&col_id, &mut conn).await?; // And then add all the received ones (except if the user has access_all) for d in data.iter() { @@ -735,7 +733,7 @@ async fn put_collection_users( continue; } - CollectionUser::save(&user.user_uuid, coll_id, d.read_only, d.hide_passwords, &mut conn).await?; + CollectionUser::save(&user.user_uuid, &col_id, d.read_only, d.hide_passwords, &mut conn).await?; } Ok(()) @@ -841,6 +839,14 @@ async fn post_org_keys( #[derive(Deserialize)] #[serde(rename_all = "camelCase")] struct CollectionData { + id: CollectionId, + read_only: bool, + hide_passwords: bool, +} + +#[derive(Deserialize)] +#[serde(rename_all = "camelCase")] +struct MembershipData { id: MembershipId, read_only: bool, hide_passwords: bool, @@ -1615,14 +1621,14 @@ async fn post_org_import( // TODO: See if we can optimize the whole cipher adding/importing and prevent duplicate code and checks. Cipher::validate_cipher_data(&data.ciphers)?; - let existing_collections: HashSet> = - Collection::find_by_organization(&org_id, &mut conn).await.into_iter().map(|c| (Some(c.uuid))).collect(); - let mut collections: Vec = Vec::with_capacity(data.collections.len()); - for coll in data.collections { - let collection_uuid = if existing_collections.contains(&coll.id) { - coll.id.unwrap() + let existing_collections: HashSet> = + Collection::find_by_organization(&org_id, &mut conn).await.into_iter().map(|c| Some(c.uuid)).collect(); + let mut collections: Vec = Vec::with_capacity(data.collections.len()); + for col in data.collections { + let collection_uuid = if existing_collections.contains(&col.id) { + col.id.unwrap() } else { - let new_collection = Collection::new(org_id.clone(), coll.name, coll.external_id); + let new_collection = Collection::new(org_id.clone(), col.name, col.external_id); new_collection.save(&mut conn).await?; new_collection.uuid }; @@ -1649,10 +1655,10 @@ async fn post_org_import( } // Assign the collections - for (cipher_index, coll_index) in relations { + for (cipher_index, col_index) in relations { let cipher_id = &ciphers[cipher_index]; - let coll_id = &collections[coll_index]; - CollectionCipher::save(cipher_id, coll_id, &mut conn).await?; + let col_id = &collections[col_index]; + CollectionCipher::save(cipher_id, col_id, &mut conn).await?; } let mut user = headers.user; @@ -1665,7 +1671,7 @@ async fn post_org_import( struct BulkCollectionsData { organization_id: OrganizationId, cipher_ids: Vec, - collection_ids: HashSet, + collection_ids: HashSet, remove_collections: bool, } @@ -1683,7 +1689,7 @@ async fn post_bulk_collections(data: Json, headers: Headers // Get all the collection available to the user in one query // Also filter based upon the provided collections - let user_collections: HashMap = + let user_collections: HashMap = Collection::find_by_organization_and_user_uuid(&data.organization_id, &headers.user.uuid, &mut conn) .await .into_iter() @@ -2352,7 +2358,7 @@ struct GroupRequest { #[serde(default)] access_all: bool, external_id: Option, - collections: Vec, + collections: Vec, users: Vec, } @@ -2380,10 +2386,6 @@ struct SelectionReadOnly { } impl SelectionReadOnly { - pub fn to_collection_group(&self, groups_uuid: String) -> CollectionGroup { - CollectionGroup::new(self.id.clone(), groups_uuid, self.read_only, self.hide_passwords) - } - pub fn to_collection_group_details_read_only(collection_group: &CollectionGroup) -> Self { Self { id: collection_group.groups_uuid.clone(), @@ -2397,6 +2399,20 @@ impl SelectionReadOnly { } } +#[derive(Deserialize, Serialize)] +#[serde(rename_all = "camelCase")] +struct CollectionSelection { + id: CollectionId, + read_only: bool, + hide_passwords: bool, +} + +impl CollectionSelection { + pub fn to_collection_group(&self, groups_uuid: String) -> CollectionGroup { + CollectionGroup::new(self.id.clone(), groups_uuid, self.read_only, self.hide_passwords) + } +} + #[derive(Deserialize, Serialize)] #[serde(rename_all = "camelCase")] struct UserSelection { @@ -2496,7 +2512,7 @@ async fn put_group( async fn add_update_group( mut group: Group, - collections: Vec, + collections: Vec, members: Vec, org_id: OrganizationId, headers: &AdminHeaders, @@ -2504,8 +2520,8 @@ async fn add_update_group( ) -> JsonResult { group.save(conn).await?; - for selection_read_only_request in collections { - let mut collection_group = selection_read_only_request.to_collection_group(group.uuid.clone()); + for col_selection in collections { + let mut collection_group = col_selection.to_collection_group(group.uuid.clone()); collection_group.save(conn).await?; } diff --git a/src/api/notifications.rs b/src/api/notifications.rs index 16749521..832e5668 100644 --- a/src/api/notifications.rs +++ b/src/api/notifications.rs @@ -10,7 +10,7 @@ use rocket_ws::{Message, WebSocket}; use crate::{ auth::{ClientIp, WsAccessTokenHeader}, db::{ - models::{Cipher, Folder, Send as DbSend, User, UserId}, + models::{Cipher, CollectionId, Folder, Send as DbSend, User, UserId}, DbConn, }, Error, CONFIG, @@ -415,7 +415,7 @@ impl WebSocketUsers { cipher: &Cipher, user_uuids: &[UserId], acting_device_uuid: &String, - collection_uuids: Option>, + collection_uuids: Option>, conn: &mut DbConn, ) { // Skip any processing if both WebSockets and Push are not active @@ -428,7 +428,7 @@ impl WebSocketUsers { let (user_uuid, collection_uuids, revision_date) = if let Some(collection_uuids) = collection_uuids { ( Value::Nil, - Value::Array(collection_uuids.into_iter().map(|v| v.into()).collect::>()), + Value::Array(collection_uuids.into_iter().map(|v| v.to_string().into()).collect::>()), serialize_date(Utc::now().naive_utc()), ) } else { diff --git a/src/auth.rs b/src/auth.rs index 3f92b127..2a7c33d6 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -14,7 +14,7 @@ use std::{ net::IpAddr, }; -use crate::db::models::{MembershipId, OrganizationId, UserId}; +use crate::db::models::{CollectionId, MembershipId, OrganizationId, UserId}; use crate::{error::Error, CONFIG}; const JWT_ALGORITHM: Algorithm = Algorithm::RS256; @@ -649,16 +649,16 @@ impl From for Headers { // col_id is usually the fourth path param ("/organizations//collections/"), // but there could be cases where it is a query value. // First check the path, if this is not a valid uuid, try the query values. -fn get_col_id(request: &Request<'_>) -> Option { +fn get_col_id(request: &Request<'_>) -> Option { if let Some(Ok(col_id)) = request.param::(3) { if uuid::Uuid::parse_str(&col_id).is_ok() { - return Some(col_id); + return Some(col_id.into()); } } if let Some(Ok(col_id)) = request.query_value::("collectionId") { if uuid::Uuid::parse_str(&col_id).is_ok() { - return Some(col_id); + return Some(col_id.into()); } } @@ -763,11 +763,11 @@ impl From for Headers { impl ManagerHeaders { pub async fn from_loose( h: ManagerHeadersLoose, - collections: &Vec, + collections: &Vec, conn: &mut DbConn, ) -> Result { for col_id in collections { - if uuid::Uuid::parse_str(col_id).is_err() { + if uuid::Uuid::parse_str(col_id.as_ref()).is_err() { err!("Collection Id is malformed!"); } if !Collection::can_access_collection(&h.membership, col_id, conn).await { diff --git a/src/db/models/cipher.rs b/src/db/models/cipher.rs index db75d2bf..56e272bf 100644 --- a/src/db/models/cipher.rs +++ b/src/db/models/cipher.rs @@ -4,8 +4,8 @@ use chrono::{NaiveDateTime, TimeDelta, Utc}; use serde_json::Value; use super::{ - Attachment, CollectionCipher, Favorite, FolderCipher, Group, Membership, MembershipStatus, MembershipType, - OrganizationId, User, UserId, + Attachment, CollectionCipher, CollectionId, Favorite, FolderCipher, Group, Membership, MembershipStatus, + MembershipType, OrganizationId, User, UserId, }; use crate::api::core::{CipherData, CipherSyncData, CipherSyncType}; @@ -862,7 +862,7 @@ impl Cipher { }} } - pub async fn get_collections(&self, user_id: String, conn: &mut DbConn) -> Vec { + pub async fn get_collections(&self, user_id: String, conn: &mut DbConn) -> Vec { if CONFIG.org_groups_enabled() { db_run! {conn: { ciphers_collections::table @@ -894,7 +894,7 @@ impl Cipher { .and(collections_groups::read_only.eq(false))) ) .select(ciphers_collections::collection_uuid) - .load::(conn).unwrap_or_default() + .load::(conn).unwrap_or_default() }} } else { db_run! {conn: { @@ -916,12 +916,12 @@ impl Cipher { .and(users_collections::read_only.eq(false))) ) .select(ciphers_collections::collection_uuid) - .load::(conn).unwrap_or_default() + .load::(conn).unwrap_or_default() }} } } - pub async fn get_admin_collections(&self, user_id: String, conn: &mut DbConn) -> Vec { + pub async fn get_admin_collections(&self, user_id: String, conn: &mut DbConn) -> Vec { if CONFIG.org_groups_enabled() { db_run! {conn: { ciphers_collections::table @@ -954,7 +954,7 @@ impl Cipher { .or(users_organizations::atype.le(MembershipType::Admin as i32)) // User is admin or owner ) .select(ciphers_collections::collection_uuid) - .load::(conn).unwrap_or_default() + .load::(conn).unwrap_or_default() }} } else { db_run! {conn: { @@ -977,14 +977,17 @@ impl Cipher { .or(users_organizations::atype.le(MembershipType::Admin as i32)) // User is admin or owner ) .select(ciphers_collections::collection_uuid) - .load::(conn).unwrap_or_default() + .load::(conn).unwrap_or_default() }} } } /// Return a Vec with (cipher_uuid, collection_uuid) /// This is used during a full sync so we only need one query for all collections accessible. - pub async fn get_collections_with_cipher_by_user(user_id: String, conn: &mut DbConn) -> Vec<(String, String)> { + pub async fn get_collections_with_cipher_by_user( + user_id: String, + conn: &mut DbConn, + ) -> Vec<(String, CollectionId)> { db_run! {conn: { ciphers_collections::table .inner_join(collections::table.on( @@ -1018,7 +1021,7 @@ impl Cipher { .or_filter(collections_groups::collections_uuid.is_not_null()) //Access via group .select(ciphers_collections::all_columns) .distinct() - .load::<(String, String)>(conn).unwrap_or_default() + .load::<(String, CollectionId)>(conn).unwrap_or_default() }} } } diff --git a/src/db/models/collection.rs b/src/db/models/collection.rs index 9313585d..ed3b762b 100644 --- a/src/db/models/collection.rs +++ b/src/db/models/collection.rs @@ -1,4 +1,10 @@ +use rocket::request::FromParam; use serde_json::Value; +use std::{ + borrow::Borrow, + fmt::{Display, Formatter}, + ops::Deref, +}; use super::{CollectionGroup, GroupUser, Membership, MembershipStatus, MembershipType, OrganizationId, User, UserId}; use crate::CONFIG; @@ -8,7 +14,7 @@ db_object! { #[diesel(table_name = collections)] #[diesel(primary_key(uuid))] pub struct Collection { - pub uuid: String, + pub uuid: CollectionId, pub org_uuid: OrganizationId, pub name: String, pub external_id: Option, @@ -19,7 +25,7 @@ db_object! { #[diesel(primary_key(user_uuid, collection_uuid))] pub struct CollectionUser { pub user_uuid: UserId, - pub collection_uuid: String, + pub collection_uuid: CollectionId, pub read_only: bool, pub hide_passwords: bool, } @@ -29,7 +35,7 @@ db_object! { #[diesel(primary_key(cipher_uuid, collection_uuid))] pub struct CollectionCipher { pub cipher_uuid: String, - pub collection_uuid: String, + pub collection_uuid: CollectionId, } } @@ -37,7 +43,7 @@ db_object! { impl Collection { pub fn new(org_uuid: OrganizationId, name: String, external_id: Option) -> Self { let mut new_model = Self { - uuid: crate::util::get_uuid(), + uuid: CollectionId(crate::util::get_uuid()), org_uuid, name, external_id: None, @@ -121,7 +127,7 @@ impl Collection { json_object } - pub async fn can_access_collection(member: &Membership, col_id: &str, conn: &mut DbConn) -> bool { + pub async fn can_access_collection(member: &Membership, col_id: &CollectionId, conn: &mut DbConn) -> bool { member.has_status(MembershipStatus::Confirmed) && (member.has_full_access() || CollectionUser::has_access_to_collection_by_user(col_id, &member.user_uuid, conn).await @@ -198,7 +204,7 @@ impl Collection { } } - pub async fn find_by_uuid(uuid: &str, conn: &mut DbConn) -> Option { + pub async fn find_by_uuid(uuid: &CollectionId, conn: &mut DbConn) -> Option { db_run! { conn: { collections::table .filter(collections::uuid.eq(uuid)) @@ -312,7 +318,11 @@ impl Collection { }} } - pub async fn find_by_uuid_and_org(uuid: &str, org_uuid: &OrganizationId, conn: &mut DbConn) -> Option { + pub async fn find_by_uuid_and_org( + uuid: &CollectionId, + org_uuid: &OrganizationId, + conn: &mut DbConn, + ) -> Option { db_run! { conn: { collections::table .filter(collections::uuid.eq(uuid)) @@ -324,7 +334,7 @@ impl Collection { }} } - pub async fn find_by_uuid_and_user(uuid: &str, user_uuid: UserId, conn: &mut DbConn) -> Option { + pub async fn find_by_uuid_and_user(uuid: &CollectionId, user_uuid: UserId, conn: &mut DbConn) -> Option { if CONFIG.org_groups_enabled() { db_run! { conn: { collections::table @@ -534,7 +544,7 @@ impl CollectionUser { pub async fn save( user_uuid: &UserId, - collection_uuid: &str, + collection_uuid: &CollectionId, read_only: bool, hide_passwords: bool, conn: &mut DbConn, @@ -604,7 +614,7 @@ impl CollectionUser { }} } - pub async fn find_by_collection(collection_uuid: &str, conn: &mut DbConn) -> Vec { + pub async fn find_by_collection(collection_uuid: &CollectionId, conn: &mut DbConn) -> Vec { db_run! { conn: { users_collections::table .filter(users_collections::collection_uuid.eq(collection_uuid)) @@ -616,7 +626,7 @@ impl CollectionUser { } pub async fn find_by_collection_swap_user_uuid_with_member_uuid( - collection_uuid: &str, + collection_uuid: &CollectionId, conn: &mut DbConn, ) -> Vec { db_run! { conn: { @@ -631,7 +641,7 @@ impl CollectionUser { } pub async fn find_by_collection_and_user( - collection_uuid: &str, + collection_uuid: &CollectionId, user_uuid: &UserId, conn: &mut DbConn, ) -> Option { @@ -657,7 +667,7 @@ impl CollectionUser { }} } - pub async fn delete_all_by_collection(collection_uuid: &str, conn: &mut DbConn) -> EmptyResult { + pub async fn delete_all_by_collection(collection_uuid: &CollectionId, conn: &mut DbConn) -> EmptyResult { for collection in CollectionUser::find_by_collection(collection_uuid, conn).await.iter() { User::update_uuid_revision(&collection.user_uuid, conn).await; } @@ -689,14 +699,18 @@ impl CollectionUser { }} } - pub async fn has_access_to_collection_by_user(col_id: &str, user_uuid: &UserId, conn: &mut DbConn) -> bool { + pub async fn has_access_to_collection_by_user( + col_id: &CollectionId, + user_uuid: &UserId, + conn: &mut DbConn, + ) -> bool { Self::find_by_collection_and_user(col_id, user_uuid, conn).await.is_some() } } /// Database methods impl CollectionCipher { - pub async fn save(cipher_uuid: &str, collection_uuid: &str, conn: &mut DbConn) -> EmptyResult { + pub async fn save(cipher_uuid: &str, collection_uuid: &CollectionId, conn: &mut DbConn) -> EmptyResult { Self::update_users_revision(collection_uuid, conn).await; db_run! { conn: @@ -726,7 +740,7 @@ impl CollectionCipher { } } - pub async fn delete(cipher_uuid: &str, collection_uuid: &str, conn: &mut DbConn) -> EmptyResult { + pub async fn delete(cipher_uuid: &str, collection_uuid: &CollectionId, conn: &mut DbConn) -> EmptyResult { Self::update_users_revision(collection_uuid, conn).await; db_run! { conn: { @@ -748,7 +762,7 @@ impl CollectionCipher { }} } - pub async fn delete_all_by_collection(collection_uuid: &str, conn: &mut DbConn) -> EmptyResult { + pub async fn delete_all_by_collection(collection_uuid: &CollectionId, conn: &mut DbConn) -> EmptyResult { db_run! { conn: { diesel::delete(ciphers_collections::table.filter(ciphers_collections::collection_uuid.eq(collection_uuid))) .execute(conn) @@ -756,9 +770,57 @@ impl CollectionCipher { }} } - pub async fn update_users_revision(collection_uuid: &str, conn: &mut DbConn) { + pub async fn update_users_revision(collection_uuid: &CollectionId, conn: &mut DbConn) { if let Some(collection) = Collection::find_by_uuid(collection_uuid, conn).await { collection.update_users_revision(conn).await; } } } + +#[derive(DieselNewType, FromForm, Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)] +pub struct CollectionId(String); + +impl AsRef for CollectionId { + fn as_ref(&self) -> &str { + &self.0 + } +} + +impl Deref for CollectionId { + type Target = str; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl Borrow for CollectionId { + fn borrow(&self) -> &str { + &self.0 + } +} + +impl Display for CollectionId { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +impl From for CollectionId { + fn from(raw: String) -> Self { + Self(raw) + } +} + +impl<'r> FromParam<'r> for CollectionId { + type Error = (); + + #[inline(always)] + fn from_param(param: &'r str) -> Result { + if param.chars().all(|c| matches!(c, 'a'..='z' | 'A'..='Z' |'0'..='9' | '-')) { + Ok(Self(param.to_string())) + } else { + Err(()) + } + } +} diff --git a/src/db/models/group.rs b/src/db/models/group.rs index 0a6cca10..f742b64b 100644 --- a/src/db/models/group.rs +++ b/src/db/models/group.rs @@ -1,4 +1,4 @@ -use super::{Membership, MembershipId, OrganizationId, User, UserId}; +use super::{CollectionId, Membership, MembershipId, OrganizationId, User, UserId}; use crate::api::EmptyResult; use crate::db::DbConn; use crate::error::MapResult; @@ -23,7 +23,7 @@ db_object! { #[diesel(table_name = collections_groups)] #[diesel(primary_key(collections_uuid, groups_uuid))] pub struct CollectionGroup { - pub collections_uuid: String, + pub collections_uuid: CollectionId, pub groups_uuid: String, pub read_only: bool, pub hide_passwords: bool, @@ -113,7 +113,7 @@ impl Group { } impl CollectionGroup { - pub fn new(collections_uuid: String, groups_uuid: String, read_only: bool, hide_passwords: bool) -> Self { + pub fn new(collections_uuid: CollectionId, groups_uuid: String, read_only: bool, hide_passwords: bool) -> Self { Self { collections_uuid, groups_uuid, @@ -370,7 +370,7 @@ impl CollectionGroup { }} } - pub async fn find_by_collection(collection_uuid: &str, conn: &mut DbConn) -> Vec { + pub async fn find_by_collection(collection_uuid: &CollectionId, conn: &mut DbConn) -> Vec { db_run! { conn: { collections_groups::table .filter(collections_groups::collections_uuid.eq(collection_uuid)) @@ -410,7 +410,7 @@ impl CollectionGroup { }} } - pub async fn delete_all_by_collection(collection_uuid: &str, conn: &mut DbConn) -> EmptyResult { + pub async fn delete_all_by_collection(collection_uuid: &CollectionId, conn: &mut DbConn) -> EmptyResult { let collection_assigned_to_groups = CollectionGroup::find_by_collection(collection_uuid, conn).await; for collection_assigned_to_group in collection_assigned_to_groups { let group_users = GroupUser::find_by_group(&collection_assigned_to_group.groups_uuid, conn).await; @@ -496,7 +496,7 @@ impl GroupUser { } pub async fn has_access_to_collection_by_member( - collection_uuid: &str, + collection_uuid: &CollectionId, member_uuid: &MembershipId, conn: &mut DbConn, ) -> bool { diff --git a/src/db/models/mod.rs b/src/db/models/mod.rs index 7042b294..a96c5bb9 100644 --- a/src/db/models/mod.rs +++ b/src/db/models/mod.rs @@ -19,7 +19,7 @@ mod user; pub use self::attachment::Attachment; pub use self::auth_request::AuthRequest; pub use self::cipher::{Cipher, RepromptType}; -pub use self::collection::{Collection, CollectionCipher, CollectionUser}; +pub use self::collection::{Collection, CollectionCipher, CollectionId, CollectionUser}; pub use self::device::{Device, DeviceType}; pub use self::emergency_access::{EmergencyAccess, EmergencyAccessStatus, EmergencyAccessType}; pub use self::event::{Event, EventType}; diff --git a/src/db/models/organization.rs b/src/db/models/organization.rs index db991fe1..9e39161a 100644 --- a/src/db/models/organization.rs +++ b/src/db/models/organization.rs @@ -10,8 +10,10 @@ use std::{ ops::Deref, }; -use super::{CollectionUser, Group, GroupUser, OrgPolicy, OrgPolicyType, TwoFactor, User, UserId}; -use crate::db::models::{Collection, CollectionGroup}; +use super::{ + Collection, CollectionGroup, CollectionId, CollectionUser, Group, GroupUser, OrgPolicy, OrgPolicyType, TwoFactor, + User, UserId, +}; use crate::CONFIG; db_object! { @@ -474,7 +476,7 @@ impl Membership { // If collections are to be included, only include them if the user does not have full access via a group or defined to the user it self let collections: Vec = if include_collections && !(full_access_group || self.has_full_access()) { // Get all collections for the user here already to prevent more queries - let cu: HashMap = + let cu: HashMap = CollectionUser::find_by_organization_and_user_uuid(&self.org_uuid, &self.user_uuid, conn) .await .into_iter() @@ -482,7 +484,7 @@ impl Membership { .collect(); // Get all collection groups for this user to prevent there inclusion - let cg: HashSet = CollectionGroup::find_by_user(&self.user_uuid, conn) + let cg: HashSet = CollectionGroup::find_by_user(&self.user_uuid, conn) .await .into_iter() .map(|cg| cg.collections_uuid) @@ -961,7 +963,7 @@ impl Membership { } pub async fn find_by_collection_and_org( - collection_uuid: &str, + collection_uuid: &CollectionId, org_uuid: &OrganizationId, conn: &mut DbConn, ) -> Vec {