diff --git a/Cargo.lock b/Cargo.lock index b008c5d3..cde3ac90 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -152,6 +152,7 @@ dependencies = [ "oath", "once_cell", "openssl", + "paste", "percent-encoding 2.1.0", "rand 0.7.3", "regex", @@ -274,9 +275,9 @@ dependencies = [ [[package]] name = "chrono" -version = "0.4.13" +version = "0.4.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c74d84029116787153e02106bf53e66828452a4b325cc8652b788b5967c0a0b6" +checksum = "942f72db697d8767c22d46a598e01f2d3b475501ea43d0db4f16d90259182d0b" dependencies = [ "num-integer", "num-traits", @@ -295,9 +296,9 @@ dependencies = [ [[package]] name = "clap" -version = "2.33.2" +version = "2.33.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "10040cdf04294b565d9e0319955430099ec3813a64c952b86a41200ad714ae48" +checksum = "37e58ac78573c40708d45522f0d80fa2f01cc4f9b4e2bf749807255454312002" dependencies = [ "ansi_term", "atty", @@ -781,14 +782,14 @@ dependencies = [ [[package]] name = "handlebars" -version = "3.3.0" +version = "3.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "86dbc8a0746b08f363d2e00da48e6c9ceb75c198ac692d2715fcbb5bee74c87d" +checksum = "5deefd4816fb852b1ff3cb48f6c41da67be2d0e1d20b26a7a3b076da11f064b1" dependencies = [ "log 0.4.11", "pest", "pest_derive", - "quick-error", + "quick-error 2.0.0", "serde", "serde_json", "walkdir", @@ -1360,7 +1361,7 @@ dependencies = [ "log 0.4.11", "mime 0.3.16", "mime_guess", - "quick-error", + "quick-error 1.2.3", "rand 0.6.5", "safemem", "tempfile", @@ -1519,9 +1520,9 @@ checksum = "1ab52be62400ca80aa00285d25253d7f7c437b7375c4de678f5405d3afe82ca5" [[package]] name = "once_cell" -version = "1.4.0" +version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b631f7e854af39a1739f401cf34a8a013dfe09eac4fa4dba91e9768bd28168d" +checksum = "260e51e7efe62b592207e9e13a68e43692a7a279171d6ba57abd208bf23645ad" [[package]] name = "opaque-debug" @@ -1628,6 +1629,12 @@ dependencies = [ "regex", ] +[[package]] +name = "paste" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6ddc8e145de01d9180ac7b78b9676f95a9c2447f6a88b2c2a04702211bc5d71" + [[package]] name = "pear" version = "0.1.4" @@ -1873,6 +1880,12 @@ version = "1.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" +[[package]] +name = "quick-error" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ac73b1112776fc109b2e61909bc46c7e1bf0d7f690ffb1676553acce16d5cda" + [[package]] name = "quote" version = "0.6.13" diff --git a/Cargo.toml b/Cargo.toml index 10876570..51d969e6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -123,6 +123,9 @@ structopt = "0.3.16" # Logging panics to logfile instead stderr only backtrace = "0.3.50" +# Macro ident concatenation +paste = "1.0" + [patch.crates-io] # Use newest ring rocket = { git = 'https://github.com/SergioBenitez/Rocket', rev = '1010f6a2a88fac899dec0cd2f642156908038a53' } diff --git a/build.rs b/build.rs index 0eeb4767..0277d21e 100644 --- a/build.rs +++ b/build.rs @@ -1,13 +1,14 @@ use std::process::Command; use std::env; -fn main() { - #[cfg(all(feature = "sqlite", feature = "mysql"))] - compile_error!("Can't enable both sqlite and mysql at the same time"); - #[cfg(all(feature = "sqlite", feature = "postgresql"))] - compile_error!("Can't enable both sqlite and postgresql at the same time"); - #[cfg(all(feature = "mysql", feature = "postgresql"))] - compile_error!("Can't enable both mysql and postgresql at the same time"); +fn main() { + // This allow using #[cfg(sqlite)] instead of #[cfg(feature = "sqlite")], which helps when trying to add them through macros + #[cfg(feature = "sqlite")] + println!("cargo:rustc-cfg=sqlite"); + #[cfg(feature = "mysql")] + println!("cargo:rustc-cfg=mysql"); + #[cfg(feature = "postgresql")] + println!("cargo:rustc-cfg=postgresql"); #[cfg(not(any(feature = "sqlite", feature = "mysql", feature = "postgresql")))] compile_error!("You need to enable one DB backend. To build with previous defaults do: cargo build --features sqlite"); diff --git a/src/api/admin.rs b/src/api/admin.rs index 925eb81d..415311e7 100644 --- a/src/api/admin.rs +++ b/src/api/admin.rs @@ -15,7 +15,7 @@ use crate::{ api::{ApiResult, EmptyResult, JsonResult}, auth::{decode_admin, encode_jwt, generate_admin_claims, ClientIp}, config::ConfigBuilder, - db::{backup_database, models::*, DbConn}, + db::{backup_database, models::*, DbConn, DbConnType}, error::{Error, MapResult}, mail, util::get_display_size, @@ -48,8 +48,12 @@ pub fn routes() -> Vec { ] } -static CAN_BACKUP: Lazy = - Lazy::new(|| cfg!(feature = "sqlite") && Command::new("sqlite3").arg("-version").status().is_ok()); +static CAN_BACKUP: Lazy = Lazy::new(|| { + DbConnType::from_url(&CONFIG.database_url()) + .map(|t| t == DbConnType::sqlite) + .unwrap_or(false) + && Command::new("sqlite3").arg("-version").status().is_ok() +}); #[get("/")] fn admin_disabled() -> &'static str { diff --git a/src/config.rs b/src/config.rs index 9f2ed8f5..9491edf2 100644 --- a/src/config.rs +++ b/src/config.rs @@ -5,6 +5,7 @@ use once_cell::sync::Lazy; use reqwest::Url; use crate::{ + db::DbConnType, error::Error, util::{get_env, get_env_bool}, }; @@ -421,20 +422,9 @@ make_config! { } fn validate_config(cfg: &ConfigItems) -> Result<(), Error> { - let db_url = cfg.database_url.to_lowercase(); - if cfg!(feature = "sqlite") - && (db_url.starts_with("mysql:") || db_url.starts_with("postgresql:") || db_url.starts_with("postgres:")) - { - err!("`DATABASE_URL` is meant for MySQL or Postgres, while this server is meant for SQLite") - } - if cfg!(feature = "mysql") && !db_url.starts_with("mysql:") { - err!("`DATABASE_URL` should start with mysql: when using the MySQL server") - } - - if cfg!(feature = "postgresql") && !(db_url.starts_with("postgresql:") || db_url.starts_with("postgres:")) { - err!("`DATABASE_URL` should start with postgresql: when using the PostgreSQL server") - } + // Validate connection URL is valid and DB feature is enabled + DbConnType::from_url(&cfg.database_url)?; let dom = cfg.domain.to_lowercase(); if !dom.starts_with("http://") && !dom.starts_with("https://") { diff --git a/src/db/mod.rs b/src/db/mod.rs index 4d3ba235..ec86efaf 100644 --- a/src/db/mod.rs +++ b/src/db/mod.rs @@ -1,51 +1,203 @@ use std::process::Command; use chrono::prelude::*; -use diesel::{r2d2, r2d2::ConnectionManager, Connection as DieselConnection, ConnectionError}; +use diesel::r2d2::{ConnectionManager, Pool, PooledConnection}; use rocket::{ http::Status, request::{FromRequest, Outcome}, Request, State, }; -use crate::{error::Error, CONFIG}; - -/// An alias to the database connection used -#[cfg(feature = "sqlite")] -type Connection = diesel::sqlite::SqliteConnection; -#[cfg(feature = "mysql")] -type Connection = diesel::mysql::MysqlConnection; -#[cfg(feature = "postgresql")] -type Connection = diesel::pg::PgConnection; - -/// An alias to the type for a pool of Diesel connections. -type Pool = r2d2::Pool>; - -/// Connection request guard type: a wrapper around an r2d2 pooled connection. -pub struct DbConn(pub r2d2::PooledConnection>); +use crate::{ + error::{Error, MapResult}, + CONFIG, +}; -pub mod models; -#[cfg(feature = "sqlite")] +#[cfg(sqlite)] #[path = "schemas/sqlite/schema.rs"] -pub mod schema; -#[cfg(feature = "mysql")] +pub mod __sqlite_schema; + +#[cfg(mysql)] #[path = "schemas/mysql/schema.rs"] -pub mod schema; -#[cfg(feature = "postgresql")] +pub mod __mysql_schema; + +#[cfg(postgresql)] #[path = "schemas/postgresql/schema.rs"] -pub mod schema; +pub mod __postgresql_schema; + + +// This is used to generate the main DbConn and DbPool enums, which contain one variant for each database supported +macro_rules! generate_connections { + ( $( $name:ident: $ty:ty ),+ ) => { + #[allow(non_camel_case_types, dead_code)] + #[derive(Eq, PartialEq)] + pub enum DbConnType { $( $name, )+ } + + #[allow(non_camel_case_types)] + pub enum DbConn { $( #[cfg($name)] $name(PooledConnection>), )+ } + + #[allow(non_camel_case_types)] + pub enum DbPool { $( #[cfg($name)] $name(Pool>), )+ } + + impl DbPool { + // For the given database URL, guess it's type, run migrations create pool and return it + pub fn from_config() -> Result { + let url = CONFIG.database_url(); + let conn_type = DbConnType::from_url(&url)?; + + match conn_type { $( + DbConnType::$name => { + #[cfg($name)] + { + paste::paste!{ [< $name _migrations >]::run_migrations(); } + let manager = ConnectionManager::new(&url); + let pool = Pool::builder().build(manager).map_res("Failed to create pool")?; + return Ok(Self::$name(pool)); + } + #[cfg(not($name))] + #[allow(unreachable_code)] + return unreachable!("Trying to use a DB backend when it's feature is disabled"); + }, + )+ } + } + // Get a connection from the pool + pub fn get(&self) -> Result { + match self { $( + #[cfg($name)] + Self::$name(p) => Ok(DbConn::$name(p.get().map_res("Error retrieving connection from pool")?)), + )+ } + } + } + }; +} -/// Initializes a database pool. -pub fn init_pool() -> Pool { - let manager = ConnectionManager::new(CONFIG.database_url()); +generate_connections! { + sqlite: diesel::sqlite::SqliteConnection, + mysql: diesel::mysql::MysqlConnection, + postgresql: diesel::pg::PgConnection +} + +impl DbConnType { + pub fn from_url(url: &str) -> Result { + // Mysql + if url.starts_with("mysql:") { + #[cfg(mysql)] + return Ok(DbConnType::mysql); + + #[cfg(not(mysql))] + err!("`DATABASE_URL` is a MySQL URL, but the 'mysql' feature is not enabled") + + // Postgres + } else if url.starts_with("postgresql:") || url.starts_with("postgres:") { + #[cfg(postgresql)] + return Ok(DbConnType::postgresql); - r2d2::Pool::builder().build(manager).expect("Failed to create pool") + #[cfg(not(postgresql))] + err!("`DATABASE_URL` is a PostgreSQL URL, but the 'postgresql' feature is not enabled") + + //Sqlite + } else { + #[cfg(sqlite)] + return Ok(DbConnType::sqlite); + + #[cfg(not(sqlite))] + err!("`DATABASE_URL` looks like a SQLite URL, but 'sqlite' feature is not enabled") + } + } } -pub fn get_connection() -> Result { - Connection::establish(&CONFIG.database_url()) + +#[macro_export] +macro_rules! db_run { + // Same for all dbs + ( $conn:ident: $body:block ) => { + db_run! { $conn: sqlite, mysql, postgresql $body } + }; + + // Different code for each db + ( $conn:ident: $( $($db:ident),+ $body:block )+ ) => { + #[allow(unused)] use diesel::prelude::*; + match $conn { + $($( + #[cfg($db)] + crate::db::DbConn::$db(ref $conn) => { + paste::paste! { + #[allow(unused)] use crate::db::[<__ $db _schema>]::{self as schema, *}; + #[allow(unused)] use [<__ $db _model>]::*; + #[allow(unused)] use crate::db::FromDb; + } + $body + }, + )+)+ + } + }; } + +pub trait FromDb { + type Output; + fn from_db(self) -> Self::Output; +} + +// For each struct eg. Cipher, we create a CipherDb inside a module named __$db_model (where $db is sqlite, mysql or postgresql), +// to implement the Diesel traits. We also provide methods to convert between them and the basic structs. Later, that module will be auto imported when using db_run! +#[macro_export] +macro_rules! db_object { + ( $( + $( #[$attr:meta] )* + pub struct $name:ident { + $( $( #[$field_attr:meta] )* $vis:vis $field:ident : $typ:ty ),+ + $(,)? + } + )+ ) => { + // Create the normal struct, without attributes + $( pub struct $name { $( /*$( #[$field_attr] )**/ $vis $field : $typ, )+ } )+ + + #[cfg(sqlite)] + pub mod __sqlite_model { $( db_object! { @db sqlite | $( #[$attr] )* | $name | $( $( #[$field_attr] )* $field : $typ ),+ } )+ } + #[cfg(mysql)] + pub mod __mysql_model { $( db_object! { @db mysql | $( #[$attr] )* | $name | $( $( #[$field_attr] )* $field : $typ ),+ } )+ } + #[cfg(postgresql)] + pub mod __postgresql_model { $( db_object! { @db postgresql | $( #[$attr] )* | $name | $( $( #[$field_attr] )* $field : $typ ),+ } )+ } + }; + + ( @db $db:ident | $( #[$attr:meta] )* | $name:ident | $( $( #[$field_attr:meta] )* $vis:vis $field:ident : $typ:ty),+) => { + paste::paste! { + #[allow(unused)] use super::*; + #[allow(unused)] use diesel::prelude::*; + #[allow(unused)] use crate::db::[<__ $db _schema>]::*; + + $( #[$attr] )* + pub struct [<$name Db>] { $( + $( #[$field_attr] )* $vis $field : $typ, + )+ } + + impl [<$name Db>] { + #[inline(always)] pub fn from_db(self) -> super::$name { super::$name { $( $field: self.$field, )+ } } + #[inline(always)] pub fn to_db(x: &super::$name) -> Self { Self { $( $field: x.$field.clone(), )+ } } + } + + impl crate::db::FromDb for [<$name Db>] { + type Output = super::$name; + #[inline(always)] fn from_db(self) -> Self::Output { super::$name { $( $field: self.$field, )+ } } + } + + impl crate::db::FromDb for Vec<[<$name Db>]> { + type Output = Vec; + #[inline(always)] fn from_db(self) -> Self::Output { self.into_iter().map(crate::db::FromDb::from_db).collect() } + } + + impl crate::db::FromDb for Option<[<$name Db>]> { + type Output = Option; + #[inline(always)] fn from_db(self) -> Self::Output { self.map(crate::db::FromDb::from_db) } + } + } + }; +} + +// Reexport the models, needs to be after the macros are defined so it can access them +pub mod models; + /// Creates a back-up of the database using sqlite3 pub fn backup_database() -> Result<(), Error> { use std::path::Path; @@ -73,18 +225,99 @@ impl<'a, 'r> FromRequest<'a, 'r> for DbConn { fn from_request(request: &'a Request<'r>) -> Outcome { // https://github.com/SergioBenitez/Rocket/commit/e3c1a4ad3ab9b840482ec6de4200d30df43e357c - let pool = try_outcome!(request.guard::>()); + let pool = try_outcome!(request.guard::>()); match pool.get() { - Ok(conn) => Outcome::Success(DbConn(conn)), + Ok(conn) => Outcome::Success(conn), Err(_) => Outcome::Failure((Status::ServiceUnavailable, ())), } } } -// For the convenience of using an &DbConn as a &Database. -impl std::ops::Deref for DbConn { - type Target = Connection; - fn deref(&self) -> &Self::Target { - &self.0 +// Embed the migrations from the migrations folder into the application +// This way, the program automatically migrates the database to the latest version +// https://docs.rs/diesel_migrations/*/diesel_migrations/macro.embed_migrations.html +#[cfg(sqlite)] +mod sqlite_migrations { + #[allow(unused_imports)] + embed_migrations!("migrations/sqlite"); + + pub fn run_migrations() { + // Make sure the directory exists + let url = crate::CONFIG.database_url(); + let path = std::path::Path::new(&url); + + if let Some(parent) = path.parent() { + if std::fs::create_dir_all(parent).is_err() { + error!("Error creating database directory"); + std::process::exit(1); + } + } + + use diesel::{Connection, RunQueryDsl}; + // Make sure the database is up to date (create if it doesn't exist, or run the migrations) + let connection = + diesel::sqlite::SqliteConnection::establish(&crate::CONFIG.database_url()).expect("Can't connect to DB"); + // Disable Foreign Key Checks during migration + + // Scoped to a connection. + diesel::sql_query("PRAGMA foreign_keys = OFF") + .execute(&connection) + .expect("Failed to disable Foreign Key Checks during migrations"); + + // Turn on WAL in SQLite + if crate::CONFIG.enable_db_wal() { + diesel::sql_query("PRAGMA journal_mode=wal") + .execute(&connection) + .expect("Failed to turn on WAL"); + } + + embedded_migrations::run_with_output(&connection, &mut std::io::stdout()).expect("Can't run migrations"); + } +} + +#[cfg(mysql)] +mod mysql_migrations { + #[allow(unused_imports)] + embed_migrations!("migrations/mysql"); + + pub fn run_migrations() { + use diesel::{Connection, RunQueryDsl}; + // Make sure the database is up to date (create if it doesn't exist, or run the migrations) + let connection = + diesel::mysql::MysqlConnection::establish(&crate::CONFIG.database_url()).expect("Can't connect to DB"); + // Disable Foreign Key Checks during migration + + // Scoped to a connection/session. + diesel::sql_query("SET FOREIGN_KEY_CHECKS = 0") + .execute(&connection) + .expect("Failed to disable Foreign Key Checks during migrations"); + + embedded_migrations::run_with_output(&connection, &mut std::io::stdout()).expect("Can't run migrations"); + } +} + +#[cfg(postgresql)] +mod postgresql_migrations { + #[allow(unused_imports)] + embed_migrations!("migrations/postgresql"); + + pub fn run_migrations() { + use diesel::{Connection, RunQueryDsl}; + // Make sure the database is up to date (create if it doesn't exist, or run the migrations) + let connection = + diesel::pg::PgConnection::establish(&crate::CONFIG.database_url()).expect("Can't connect to DB"); + // Disable Foreign Key Checks during migration + + // FIXME: Per https://www.postgresql.org/docs/12/sql-set-constraints.html, + // "SET CONSTRAINTS sets the behavior of constraint checking within the + // current transaction", so this setting probably won't take effect for + // any of the migrations since it's being run outside of a transaction. + // Migrations that need to disable foreign key checks should run this + // from within the migration script itself. + diesel::sql_query("SET CONSTRAINTS ALL DEFERRED") + .execute(&connection) + .expect("Failed to disable Foreign Key Checks during migrations"); + + embedded_migrations::run_with_output(&connection, &mut std::io::stdout()).expect("Can't run migrations"); } } diff --git a/src/db/models/attachment.rs b/src/db/models/attachment.rs index b0e5003e..5cff557b 100644 --- a/src/db/models/attachment.rs +++ b/src/db/models/attachment.rs @@ -3,17 +3,19 @@ use serde_json::Value; use super::Cipher; use crate::CONFIG; -#[derive(Debug, Identifiable, Queryable, Insertable, Associations, AsChangeset)] -#[table_name = "attachments"] -#[changeset_options(treat_none_as_null="true")] -#[belongs_to(Cipher, foreign_key = "cipher_uuid")] -#[primary_key(id)] -pub struct Attachment { - pub id: String, - pub cipher_uuid: String, - pub file_name: String, - pub file_size: i32, - pub akey: Option, +db_object! { + #[derive(Debug, Identifiable, Queryable, Insertable, Associations, AsChangeset)] + #[table_name = "attachments"] + #[changeset_options(treat_none_as_null="true")] + #[belongs_to(super::Cipher, foreign_key = "cipher_uuid")] + #[primary_key(id)] + pub struct Attachment { + pub id: String, + pub cipher_uuid: String, + pub file_name: String, + pub file_size: i32, + pub akey: Option, + } } /// Local methods @@ -50,43 +52,46 @@ impl Attachment { } } -use crate::db::schema::{attachments, ciphers}; use crate::db::DbConn; -use diesel::prelude::*; use crate::api::EmptyResult; use crate::error::MapResult; /// Database methods impl Attachment { - #[cfg(feature = "postgresql")] - pub fn save(&self, conn: &DbConn) -> EmptyResult { - diesel::insert_into(attachments::table) - .values(self) - .on_conflict(attachments::id) - .do_update() - .set(self) - .execute(&**conn) - .map_res("Error saving attachment") - } - #[cfg(not(feature = "postgresql"))] pub fn save(&self, conn: &DbConn) -> EmptyResult { - diesel::replace_into(attachments::table) - .values(self) - .execute(&**conn) - .map_res("Error saving attachment") + db_run! { conn: + sqlite, mysql { + diesel::replace_into(attachments::table) + .values(AttachmentDb::to_db(self)) + .execute(conn) + .map_res("Error saving attachment") + } + postgresql { + let value = AttachmentDb::to_db(self); + diesel::insert_into(attachments::table) + .values(&value) + .on_conflict(attachments::id) + .do_update() + .set(&value) + .execute(conn) + .map_res("Error saving attachment") + } + } } pub fn delete(self, conn: &DbConn) -> EmptyResult { - crate::util::retry( - || diesel::delete(attachments::table.filter(attachments::id.eq(&self.id))).execute(&**conn), - 10, - ) - .map_res("Error deleting attachment")?; - - crate::util::delete_file(&self.get_file_path())?; - Ok(()) + db_run! { conn: { + crate::util::retry( + || diesel::delete(attachments::table.filter(attachments::id.eq(&self.id))).execute(conn), + 10, + ) + .map_res("Error deleting attachment")?; + + crate::util::delete_file(&self.get_file_path())?; + Ok(()) + }} } pub fn delete_all_by_cipher(cipher_uuid: &str, conn: &DbConn) -> EmptyResult { @@ -97,67 +102,78 @@ impl Attachment { } pub fn find_by_id(id: &str, conn: &DbConn) -> Option { - let id = id.to_lowercase(); - - attachments::table - .filter(attachments::id.eq(id)) - .first::(&**conn) - .ok() + db_run! { conn: { + attachments::table + .filter(attachments::id.eq(id.to_lowercase())) + .first::(conn) + .ok() + .from_db() + }} } pub fn find_by_cipher(cipher_uuid: &str, conn: &DbConn) -> Vec { - attachments::table - .filter(attachments::cipher_uuid.eq(cipher_uuid)) - .load::(&**conn) - .expect("Error loading attachments") + db_run! { conn: { + attachments::table + .filter(attachments::cipher_uuid.eq(cipher_uuid)) + .load::(conn) + .expect("Error loading attachments") + .from_db() + }} } pub fn find_by_ciphers(cipher_uuids: Vec, conn: &DbConn) -> Vec { - attachments::table - .filter(attachments::cipher_uuid.eq_any(cipher_uuids)) - .load::(&**conn) - .expect("Error loading attachments") + db_run! { conn: { + attachments::table + .filter(attachments::cipher_uuid.eq_any(cipher_uuids)) + .load::(conn) + .expect("Error loading attachments") + .from_db() + }} } pub fn size_by_user(user_uuid: &str, conn: &DbConn) -> i64 { - let result: Option = attachments::table - .left_join(ciphers::table.on(ciphers::uuid.eq(attachments::cipher_uuid))) - .filter(ciphers::user_uuid.eq(user_uuid)) - .select(diesel::dsl::sum(attachments::file_size)) - .first(&**conn) - .expect("Error loading user attachment total size"); - - result.unwrap_or(0) + db_run! { conn: { + let result: Option = attachments::table + .left_join(ciphers::table.on(ciphers::uuid.eq(attachments::cipher_uuid))) + .filter(ciphers::user_uuid.eq(user_uuid)) + .select(diesel::dsl::sum(attachments::file_size)) + .first(conn) + .expect("Error loading user attachment total size"); + result.unwrap_or(0) + }} } pub fn count_by_user(user_uuid: &str, conn: &DbConn) -> i64 { - attachments::table - .left_join(ciphers::table.on(ciphers::uuid.eq(attachments::cipher_uuid))) - .filter(ciphers::user_uuid.eq(user_uuid)) - .count() - .first::(&**conn) - .ok() - .unwrap_or(0) + db_run! { conn: { + attachments::table + .left_join(ciphers::table.on(ciphers::uuid.eq(attachments::cipher_uuid))) + .filter(ciphers::user_uuid.eq(user_uuid)) + .count() + .first(conn) + .unwrap_or(0) + }} } pub fn size_by_org(org_uuid: &str, conn: &DbConn) -> i64 { - let result: Option = attachments::table - .left_join(ciphers::table.on(ciphers::uuid.eq(attachments::cipher_uuid))) - .filter(ciphers::organization_uuid.eq(org_uuid)) - .select(diesel::dsl::sum(attachments::file_size)) - .first(&**conn) - .expect("Error loading user attachment total size"); - - result.unwrap_or(0) + db_run! { conn: { + let result: Option = attachments::table + .left_join(ciphers::table.on(ciphers::uuid.eq(attachments::cipher_uuid))) + .filter(ciphers::organization_uuid.eq(org_uuid)) + .select(diesel::dsl::sum(attachments::file_size)) + .first(conn) + .expect("Error loading user attachment total size"); + result.unwrap_or(0) + }} } pub fn count_by_org(org_uuid: &str, conn: &DbConn) -> i64 { - attachments::table - .left_join(ciphers::table.on(ciphers::uuid.eq(attachments::cipher_uuid))) - .filter(ciphers::organization_uuid.eq(org_uuid)) - .count() - .first(&**conn) - .ok() - .unwrap_or(0) + db_run! { conn: { + attachments::table + .left_join(ciphers::table.on(ciphers::uuid.eq(attachments::cipher_uuid))) + .filter(ciphers::organization_uuid.eq(org_uuid)) + .count() + .first(conn) + .unwrap_or(0) + }} } } diff --git a/src/db/models/cipher.rs b/src/db/models/cipher.rs index 5328d9d6..4e223a6b 100644 --- a/src/db/models/cipher.rs +++ b/src/db/models/cipher.rs @@ -5,35 +5,37 @@ use super::{ Attachment, CollectionCipher, FolderCipher, Organization, User, UserOrgStatus, UserOrgType, UserOrganization, }; -#[derive(Debug, Identifiable, Queryable, Insertable, Associations, AsChangeset)] -#[table_name = "ciphers"] -#[changeset_options(treat_none_as_null="true")] -#[belongs_to(User, foreign_key = "user_uuid")] -#[belongs_to(Organization, foreign_key = "organization_uuid")] -#[primary_key(uuid)] -pub struct Cipher { - pub uuid: String, - pub created_at: NaiveDateTime, - pub updated_at: NaiveDateTime, - - pub user_uuid: Option, - pub organization_uuid: Option, - - /* - Login = 1, - SecureNote = 2, - Card = 3, - Identity = 4 - */ - pub atype: i32, - pub name: String, - pub notes: Option, - pub fields: Option, - - pub data: String, - - pub password_history: Option, - pub deleted_at: Option, +db_object! { + #[derive(Debug, Identifiable, Queryable, Insertable, Associations, AsChangeset)] + #[table_name = "ciphers"] + #[changeset_options(treat_none_as_null="true")] + #[belongs_to(User, foreign_key = "user_uuid")] + #[belongs_to(Organization, foreign_key = "organization_uuid")] + #[primary_key(uuid)] + pub struct Cipher { + pub uuid: String, + pub created_at: NaiveDateTime, + pub updated_at: NaiveDateTime, + + pub user_uuid: Option, + pub organization_uuid: Option, + + /* + Login = 1, + SecureNote = 2, + Card = 3, + Identity = 4 + */ + pub atype: i32, + pub name: String, + pub notes: Option, + pub fields: Option, + + pub data: String, + + pub password_history: Option, + pub deleted_at: Option, + } } /// Local methods @@ -62,9 +64,7 @@ impl Cipher { } } -use crate::db::schema::*; use crate::db::DbConn; -use diesel::prelude::*; use crate::api::EmptyResult; use crate::error::MapResult; @@ -81,7 +81,7 @@ impl Cipher { let password_history_json = self.password_history.as_ref().and_then(|s| serde_json::from_str(s).ok()).unwrap_or(Value::Null); let (read_only, hide_passwords) = - match self.get_access_restrictions(&user_uuid, &conn) { + match self.get_access_restrictions(&user_uuid, conn) { Some((ro, hp)) => (ro, hp), None => { error!("Cipher ownership assertion failure"); @@ -125,14 +125,14 @@ impl Cipher { "Type": self.atype, "RevisionDate": format_date(&self.updated_at), "DeletedDate": self.deleted_at.map_or(Value::Null, |d| Value::String(format_date(&d))), - "FolderId": self.get_folder_uuid(&user_uuid, &conn), - "Favorite": self.is_favorite(&user_uuid, &conn), + "FolderId": self.get_folder_uuid(&user_uuid, conn), + "Favorite": self.is_favorite(&user_uuid, conn), "OrganizationId": self.organization_uuid, "Attachments": attachments_json, "OrganizationUseTotp": true, // This field is specific to the cipherDetails type. - "CollectionIds": self.get_collections(user_uuid, &conn), + "CollectionIds": self.get_collections(user_uuid, conn), "Name": self.name, "Notes": self.notes, @@ -183,41 +183,42 @@ impl Cipher { user_uuids } - #[cfg(feature = "postgresql")] - pub fn save(&mut self, conn: &DbConn) -> EmptyResult { - self.update_users_revision(conn); - self.updated_at = Utc::now().naive_utc(); - - diesel::insert_into(ciphers::table) - .values(&*self) - .on_conflict(ciphers::uuid) - .do_update() - .set(&*self) - .execute(&**conn) - .map_res("Error saving cipher") - } - - #[cfg(not(feature = "postgresql"))] pub fn save(&mut self, conn: &DbConn) -> EmptyResult { self.update_users_revision(conn); self.updated_at = Utc::now().naive_utc(); - - diesel::replace_into(ciphers::table) - .values(&*self) - .execute(&**conn) - .map_res("Error saving cipher") + + db_run! { conn: + sqlite, mysql { + diesel::replace_into(ciphers::table) + .values(CipherDb::to_db(self)) + .execute(conn) + .map_res("Error saving cipher") + } + postgresql { + let value = CipherDb::to_db(self); + diesel::insert_into(ciphers::table) + .values(&value) + .on_conflict(ciphers::uuid) + .do_update() + .set(&value) + .execute(conn) + .map_res("Error saving cipher") + } + } } pub fn delete(&self, conn: &DbConn) -> EmptyResult { self.update_users_revision(conn); - FolderCipher::delete_all_by_cipher(&self.uuid, &conn)?; - CollectionCipher::delete_all_by_cipher(&self.uuid, &conn)?; - Attachment::delete_all_by_cipher(&self.uuid, &conn)?; + FolderCipher::delete_all_by_cipher(&self.uuid, conn)?; + CollectionCipher::delete_all_by_cipher(&self.uuid, conn)?; + Attachment::delete_all_by_cipher(&self.uuid, conn)?; - diesel::delete(ciphers::table.filter(ciphers::uuid.eq(&self.uuid))) - .execute(&**conn) - .map_res("Error deleting cipher") + db_run! { conn: { + diesel::delete(ciphers::table.filter(ciphers::uuid.eq(&self.uuid))) + .execute(conn) + .map_res("Error deleting cipher") + }} } pub fn delete_all_by_organization(org_uuid: &str, conn: &DbConn) -> EmptyResult { @@ -235,28 +236,28 @@ impl Cipher { } pub fn move_to_folder(&self, folder_uuid: Option, user_uuid: &str, conn: &DbConn) -> EmptyResult { - User::update_uuid_revision(user_uuid, &conn); + User::update_uuid_revision(user_uuid, conn); - match (self.get_folder_uuid(&user_uuid, &conn), folder_uuid) { + match (self.get_folder_uuid(&user_uuid, conn), folder_uuid) { // No changes (None, None) => Ok(()), (Some(ref old), Some(ref new)) if old == new => Ok(()), // Add to folder - (None, Some(new)) => FolderCipher::new(&new, &self.uuid).save(&conn), + (None, Some(new)) => FolderCipher::new(&new, &self.uuid).save(conn), // Remove from folder - (Some(old), None) => match FolderCipher::find_by_folder_and_cipher(&old, &self.uuid, &conn) { - Some(old) => old.delete(&conn), + (Some(old), None) => match FolderCipher::find_by_folder_and_cipher(&old, &self.uuid, conn) { + Some(old) => old.delete(conn), None => err!("Couldn't move from previous folder"), }, // Move to another folder (Some(old), Some(new)) => { - if let Some(old) = FolderCipher::find_by_folder_and_cipher(&old, &self.uuid, &conn) { - old.delete(&conn)?; + if let Some(old) = FolderCipher::find_by_folder_and_cipher(&old, &self.uuid, conn) { + old.delete(conn)?; } - FolderCipher::new(&new, &self.uuid).save(&conn) + FolderCipher::new(&new, &self.uuid).save(conn) } } } @@ -269,7 +270,7 @@ impl Cipher { /// Returns whether this cipher is owned by an org in which the user has full access. pub fn is_in_full_access_org(&self, user_uuid: &str, conn: &DbConn) -> bool { if let Some(ref org_uuid) = self.organization_uuid { - if let Some(user_org) = UserOrganization::find_by_user_and_org(&user_uuid, &org_uuid, &conn) { + if let Some(user_org) = UserOrganization::find_by_user_and_org(&user_uuid, &org_uuid, conn) { return user_org.has_full_access(); } } @@ -290,38 +291,40 @@ impl Cipher { return Some((false, false)); } - // Check whether this cipher is in any collections accessible to the - // user. If so, retrieve the access flags for each collection. - let query = ciphers::table - .filter(ciphers::uuid.eq(&self.uuid)) - .inner_join(ciphers_collections::table.on( - ciphers::uuid.eq(ciphers_collections::cipher_uuid))) - .inner_join(users_collections::table.on( - ciphers_collections::collection_uuid.eq(users_collections::collection_uuid) - .and(users_collections::user_uuid.eq(user_uuid)))) - .select((users_collections::read_only, users_collections::hide_passwords)); - - // There's an edge case where a cipher can be in multiple collections - // with inconsistent access flags. For example, a cipher could be in - // one collection where the user has read-only access, but also in - // another collection where the user has read/write access. To handle - // this, we do a boolean OR of all values in each of the `read_only` - // and `hide_passwords` columns. This could ideally be done as part - // of the query, but Diesel doesn't support a max() or bool_or() - // function on booleans and this behavior isn't portable anyway. - if let Some(vec) = query.load::<(bool, bool)>(&**conn).ok() { - let mut read_only = false; - let mut hide_passwords = false; - for (ro, hp) in vec.iter() { - read_only |= ro; - hide_passwords |= hp; - } + db_run! {conn: { + // Check whether this cipher is in any collections accessible to the + // user. If so, retrieve the access flags for each collection. + let query = ciphers::table + .filter(ciphers::uuid.eq(&self.uuid)) + .inner_join(ciphers_collections::table.on( + ciphers::uuid.eq(ciphers_collections::cipher_uuid))) + .inner_join(users_collections::table.on( + ciphers_collections::collection_uuid.eq(users_collections::collection_uuid) + .and(users_collections::user_uuid.eq(user_uuid)))) + .select((users_collections::read_only, users_collections::hide_passwords)); + + // There's an edge case where a cipher can be in multiple collections + // with inconsistent access flags. For example, a cipher could be in + // one collection where the user has read-only access, but also in + // another collection where the user has read/write access. To handle + // this, we do a boolean OR of all values in each of the `read_only` + // and `hide_passwords` columns. This could ideally be done as part + // of the query, but Diesel doesn't support a max() or bool_or() + // function on booleans and this behavior isn't portable anyway. + if let Some(vec) = query.load::<(bool, bool)>(conn).ok() { + let mut read_only = false; + let mut hide_passwords = false; + for (ro, hp) in vec.iter() { + read_only |= ro; + hide_passwords |= hp; + } - Some((read_only, hide_passwords)) - } else { - // This cipher isn't in any collections accessible to the user. - None - } + Some((read_only, hide_passwords)) + } else { + // This cipher isn't in any collections accessible to the user. + None + } + }} } pub fn is_write_accessible_to_user(&self, user_uuid: &str, conn: &DbConn) -> bool { @@ -337,12 +340,14 @@ impl Cipher { // Returns whether this cipher is a favorite of the specified user. pub fn is_favorite(&self, user_uuid: &str, conn: &DbConn) -> bool { - let query = favorites::table - .filter(favorites::user_uuid.eq(user_uuid)) - .filter(favorites::cipher_uuid.eq(&self.uuid)) - .count(); - - query.first::(&**conn).ok().unwrap_or(0) != 0 + db_run!{ conn: { + let query = favorites::table + .filter(favorites::user_uuid.eq(user_uuid)) + .filter(favorites::cipher_uuid.eq(&self.uuid)) + .count(); + + query.first::(conn).ok().unwrap_or(0) != 0 + }} } // Updates whether this cipher is a favorite of the specified user. @@ -356,23 +361,27 @@ impl Cipher { match (old, new) { (false, true) => { User::update_uuid_revision(user_uuid, &conn); - diesel::insert_into(favorites::table) - .values(( - favorites::user_uuid.eq(user_uuid), - favorites::cipher_uuid.eq(&self.uuid), - )) - .execute(&**conn) - .map_res("Error adding favorite") + db_run!{ conn: { + diesel::insert_into(favorites::table) + .values(( + favorites::user_uuid.eq(user_uuid), + favorites::cipher_uuid.eq(&self.uuid), + )) + .execute(conn) + .map_res("Error adding favorite") + }} } (true, false) => { User::update_uuid_revision(user_uuid, &conn); - diesel::delete( - favorites::table - .filter(favorites::user_uuid.eq(user_uuid)) - .filter(favorites::cipher_uuid.eq(&self.uuid)) - ) - .execute(&**conn) - .map_res("Error removing favorite") + db_run!{ conn: { + diesel::delete( + favorites::table + .filter(favorites::user_uuid.eq(user_uuid)) + .filter(favorites::cipher_uuid.eq(&self.uuid)) + ) + .execute(conn) + .map_res("Error removing favorite") + }} } // Otherwise, the favorite status is already what it should be. _ => Ok(()) @@ -380,112 +389,131 @@ impl Cipher { } pub fn get_folder_uuid(&self, user_uuid: &str, conn: &DbConn) -> Option { - folders_ciphers::table - .inner_join(folders::table) - .filter(folders::user_uuid.eq(&user_uuid)) - .filter(folders_ciphers::cipher_uuid.eq(&self.uuid)) - .select(folders_ciphers::folder_uuid) - .first::(&**conn) - .ok() + db_run! {conn: { + folders_ciphers::table + .inner_join(folders::table) + .filter(folders::user_uuid.eq(&user_uuid)) + .filter(folders_ciphers::cipher_uuid.eq(&self.uuid)) + .select(folders_ciphers::folder_uuid) + .first::(conn) + .ok() + }} } pub fn find_by_uuid(uuid: &str, conn: &DbConn) -> Option { - ciphers::table - .filter(ciphers::uuid.eq(uuid)) - .first::(&**conn) - .ok() + db_run! {conn: { + ciphers::table + .filter(ciphers::uuid.eq(uuid)) + .first::(conn) + .ok() + .from_db() + }} } // Find all ciphers accessible to user pub fn find_by_user(user_uuid: &str, conn: &DbConn) -> Vec { - ciphers::table - .left_join(users_organizations::table.on( - ciphers::organization_uuid.eq(users_organizations::org_uuid.nullable()).and( - users_organizations::user_uuid.eq(user_uuid).and( - users_organizations::status.eq(UserOrgStatus::Confirmed as i32) - ) - ) - )) - .left_join(ciphers_collections::table.on( - ciphers::uuid.eq(ciphers_collections::cipher_uuid) - )) - .left_join(users_collections::table.on( - ciphers_collections::collection_uuid.eq(users_collections::collection_uuid) - )) - .filter(ciphers::user_uuid.eq(user_uuid).or( // Cipher owner - users_organizations::access_all.eq(true).or( // access_all in Organization - users_organizations::atype.le(UserOrgType::Admin as i32).or( // Org admin or owner - users_collections::user_uuid.eq(user_uuid).and( // Access to Collection - users_organizations::status.eq(UserOrgStatus::Confirmed as i32) + db_run! {conn: { + ciphers::table + .left_join(users_organizations::table.on( + ciphers::organization_uuid.eq(users_organizations::org_uuid.nullable()).and( + users_organizations::user_uuid.eq(user_uuid).and( + users_organizations::status.eq(UserOrgStatus::Confirmed as i32) + ) ) - ) - ) - )) - .select(ciphers::all_columns) - .distinct() - .load::(&**conn).expect("Error loading ciphers") + )) + .left_join(ciphers_collections::table.on( + ciphers::uuid.eq(ciphers_collections::cipher_uuid) + )) + .left_join(users_collections::table.on( + ciphers_collections::collection_uuid.eq(users_collections::collection_uuid) + )) + .filter(ciphers::user_uuid.eq(user_uuid).or( // Cipher owner + users_organizations::access_all.eq(true).or( // access_all in Organization + users_organizations::atype.le(UserOrgType::Admin as i32).or( // Org admin or owner + users_collections::user_uuid.eq(user_uuid).and( // Access to Collection + users_organizations::status.eq(UserOrgStatus::Confirmed as i32) + ) + ) + ) + )) + .select(ciphers::all_columns) + .distinct() + .load::(conn).expect("Error loading ciphers").from_db() + }} } // Find all ciphers directly owned by user pub fn find_owned_by_user(user_uuid: &str, conn: &DbConn) -> Vec { - ciphers::table - .filter(ciphers::user_uuid.eq(user_uuid)) - .load::(&**conn).expect("Error loading ciphers") + db_run! {conn: { + ciphers::table + .filter(ciphers::user_uuid.eq(user_uuid)) + .load::(conn).expect("Error loading ciphers").from_db() + }} } pub fn count_owned_by_user(user_uuid: &str, conn: &DbConn) -> i64 { - ciphers::table - .filter(ciphers::user_uuid.eq(user_uuid)) - .count() - .first::(&**conn) - .ok() - .unwrap_or(0) + db_run! {conn: { + ciphers::table + .filter(ciphers::user_uuid.eq(user_uuid)) + .count() + .first::(conn) + .ok() + .unwrap_or(0) + }} } pub fn find_by_org(org_uuid: &str, conn: &DbConn) -> Vec { - ciphers::table - .filter(ciphers::organization_uuid.eq(org_uuid)) - .load::(&**conn).expect("Error loading ciphers") + db_run! {conn: { + ciphers::table + .filter(ciphers::organization_uuid.eq(org_uuid)) + .load::(conn).expect("Error loading ciphers").from_db() + }} } pub fn count_by_org(org_uuid: &str, conn: &DbConn) -> i64 { - ciphers::table - .filter(ciphers::organization_uuid.eq(org_uuid)) - .count() - .first::(&**conn) - .ok() - .unwrap_or(0) + db_run! {conn: { + ciphers::table + .filter(ciphers::organization_uuid.eq(org_uuid)) + .count() + .first::(conn) + .ok() + .unwrap_or(0) + }} } pub fn find_by_folder(folder_uuid: &str, conn: &DbConn) -> Vec { - folders_ciphers::table.inner_join(ciphers::table) - .filter(folders_ciphers::folder_uuid.eq(folder_uuid)) - .select(ciphers::all_columns) - .load::(&**conn).expect("Error loading ciphers") + db_run! {conn: { + folders_ciphers::table.inner_join(ciphers::table) + .filter(folders_ciphers::folder_uuid.eq(folder_uuid)) + .select(ciphers::all_columns) + .load::(conn).expect("Error loading ciphers").from_db() + }} } pub fn get_collections(&self, user_id: &str, conn: &DbConn) -> Vec { - ciphers_collections::table - .inner_join(collections::table.on( - collections::uuid.eq(ciphers_collections::collection_uuid) - )) - .inner_join(users_organizations::table.on( - users_organizations::org_uuid.eq(collections::org_uuid).and( - users_organizations::user_uuid.eq(user_id) - ) - )) - .left_join(users_collections::table.on( - users_collections::collection_uuid.eq(ciphers_collections::collection_uuid).and( - users_collections::user_uuid.eq(user_id) - ) - )) - .filter(ciphers_collections::cipher_uuid.eq(&self.uuid)) - .filter(users_collections::user_uuid.eq(user_id).or( // User has access to collection - users_organizations::access_all.eq(true).or( // User has access all - users_organizations::atype.le(UserOrgType::Admin as i32) // User is admin or owner - ) - )) - .select(ciphers_collections::collection_uuid) - .load::(&**conn).unwrap_or_default() + db_run! {conn: { + ciphers_collections::table + .inner_join(collections::table.on( + collections::uuid.eq(ciphers_collections::collection_uuid) + )) + .inner_join(users_organizations::table.on( + users_organizations::org_uuid.eq(collections::org_uuid).and( + users_organizations::user_uuid.eq(user_id) + ) + )) + .left_join(users_collections::table.on( + users_collections::collection_uuid.eq(ciphers_collections::collection_uuid).and( + users_collections::user_uuid.eq(user_id) + ) + )) + .filter(ciphers_collections::cipher_uuid.eq(&self.uuid)) + .filter(users_collections::user_uuid.eq(user_id).or( // User has access to collection + users_organizations::access_all.eq(true).or( // User has access all + users_organizations::atype.le(UserOrgType::Admin as i32) // User is admin or owner + ) + )) + .select(ciphers_collections::collection_uuid) + .load::(conn).unwrap_or_default() + }} } } diff --git a/src/db/models/collection.rs b/src/db/models/collection.rs index 8e05fb27..18d1ff03 100644 --- a/src/db/models/collection.rs +++ b/src/db/models/collection.rs @@ -1,15 +1,39 @@ use serde_json::Value; -use super::{Organization, UserOrgStatus, UserOrgType, UserOrganization}; - -#[derive(Debug, Identifiable, Queryable, Insertable, Associations, AsChangeset)] -#[table_name = "collections"] -#[belongs_to(Organization, foreign_key = "org_uuid")] -#[primary_key(uuid)] -pub struct Collection { - pub uuid: String, - pub org_uuid: String, - pub name: String, +use super::{Organization, UserOrgStatus, UserOrgType, UserOrganization, User, Cipher}; + +db_object! { + #[derive(Debug, Identifiable, Queryable, Insertable, Associations, AsChangeset)] + #[table_name = "collections"] + #[belongs_to(Organization, foreign_key = "org_uuid")] + #[primary_key(uuid)] + pub struct Collection { + pub uuid: String, + pub org_uuid: String, + pub name: String, + } + + #[derive(Debug, Identifiable, Queryable, Insertable, Associations)] + #[table_name = "users_collections"] + #[belongs_to(User, foreign_key = "user_uuid")] + #[belongs_to(Collection, foreign_key = "collection_uuid")] + #[primary_key(user_uuid, collection_uuid)] + pub struct CollectionUser { + pub user_uuid: String, + pub collection_uuid: String, + pub read_only: bool, + pub hide_passwords: bool, + } + + #[derive(Debug, Identifiable, Queryable, Insertable, Associations)] + #[table_name = "ciphers_collections"] + #[belongs_to(Cipher, foreign_key = "cipher_uuid")] + #[belongs_to(Collection, foreign_key = "collection_uuid")] + #[primary_key(cipher_uuid, collection_uuid)] + pub struct CollectionCipher { + pub cipher_uuid: String, + pub collection_uuid: String, + } } /// Local methods @@ -33,36 +57,34 @@ impl Collection { } } -use crate::db::schema::*; use crate::db::DbConn; -use diesel::prelude::*; use crate::api::EmptyResult; use crate::error::MapResult; /// Database methods impl Collection { - #[cfg(feature = "postgresql")] pub fn save(&self, conn: &DbConn) -> EmptyResult { self.update_users_revision(conn); - diesel::insert_into(collections::table) - .values(self) - .on_conflict(collections::uuid) - .do_update() - .set(self) - .execute(&**conn) - .map_res("Error saving collection") - } - - #[cfg(not(feature = "postgresql"))] - pub fn save(&self, conn: &DbConn) -> EmptyResult { - self.update_users_revision(conn); - - diesel::replace_into(collections::table) - .values(self) - .execute(&**conn) - .map_res("Error saving collection") + db_run! { conn: + sqlite, mysql { + diesel::replace_into(collections::table) + .values(CollectionDb::to_db(self)) + .execute(conn) + .map_res("Error saving collection") + } + postgresql { + let value = CollectionDb::to_db(self); + diesel::insert_into(collections::table) + .values(&value) + .on_conflict(collections::uuid) + .do_update() + .set(&value) + .execute(conn) + .map_res("Error saving collection") + } + } } pub fn delete(self, conn: &DbConn) -> EmptyResult { @@ -70,9 +92,11 @@ impl Collection { CollectionCipher::delete_all_by_collection(&self.uuid, &conn)?; CollectionUser::delete_all_by_collection(&self.uuid, &conn)?; - diesel::delete(collections::table.filter(collections::uuid.eq(self.uuid))) - .execute(&**conn) - .map_res("Error deleting collection") + db_run! { conn: { + diesel::delete(collections::table.filter(collections::uuid.eq(self.uuid))) + .execute(conn) + .map_res("Error deleting collection") + }} } pub fn delete_all_by_organization(org_uuid: &str, conn: &DbConn) -> EmptyResult { @@ -91,33 +115,38 @@ impl Collection { } pub fn find_by_uuid(uuid: &str, conn: &DbConn) -> Option { - collections::table - .filter(collections::uuid.eq(uuid)) - .first::(&**conn) - .ok() + db_run! { conn: { + collections::table + .filter(collections::uuid.eq(uuid)) + .first::(conn) + .ok() + .from_db() + }} } pub fn find_by_user_uuid(user_uuid: &str, conn: &DbConn) -> Vec { - collections::table - .left_join(users_collections::table.on( - users_collections::collection_uuid.eq(collections::uuid).and( - users_collections::user_uuid.eq(user_uuid) - ) - )) - .left_join(users_organizations::table.on( - collections::org_uuid.eq(users_organizations::org_uuid).and( - users_organizations::user_uuid.eq(user_uuid) - ) - )) - .filter( - users_organizations::status.eq(UserOrgStatus::Confirmed as i32) - ) - .filter( - users_collections::user_uuid.eq(user_uuid).or( // Directly accessed collection - users_organizations::access_all.eq(true) // access_all in Organization + db_run! { conn: { + collections::table + .left_join(users_collections::table.on( + users_collections::collection_uuid.eq(collections::uuid).and( + users_collections::user_uuid.eq(user_uuid) + ) + )) + .left_join(users_organizations::table.on( + collections::org_uuid.eq(users_organizations::org_uuid).and( + users_organizations::user_uuid.eq(user_uuid) + ) + )) + .filter( + users_organizations::status.eq(UserOrgStatus::Confirmed as i32) ) - ).select(collections::all_columns) - .load::(&**conn).expect("Error loading collections") + .filter( + users_collections::user_uuid.eq(user_uuid).or( // Directly accessed collection + users_organizations::access_all.eq(true) // access_all in Organization + ) + ).select(collections::all_columns) + .load::(conn).expect("Error loading collections").from_db() + }} } pub fn find_by_organization_and_user_uuid(org_uuid: &str, user_uuid: &str, conn: &DbConn) -> Vec { @@ -128,42 +157,51 @@ impl Collection { } pub fn find_by_organization(org_uuid: &str, conn: &DbConn) -> Vec { - collections::table - .filter(collections::org_uuid.eq(org_uuid)) - .load::(&**conn) - .expect("Error loading collections") + db_run! { conn: { + collections::table + .filter(collections::org_uuid.eq(org_uuid)) + .load::(conn) + .expect("Error loading collections") + .from_db() + }} } pub fn find_by_uuid_and_org(uuid: &str, org_uuid: &str, conn: &DbConn) -> Option { - collections::table - .filter(collections::uuid.eq(uuid)) - .filter(collections::org_uuid.eq(org_uuid)) - .select(collections::all_columns) - .first::(&**conn) - .ok() + db_run! { conn: { + collections::table + .filter(collections::uuid.eq(uuid)) + .filter(collections::org_uuid.eq(org_uuid)) + .select(collections::all_columns) + .first::(conn) + .ok() + .from_db() + }} } pub fn find_by_uuid_and_user(uuid: &str, user_uuid: &str, conn: &DbConn) -> Option { - collections::table - .left_join(users_collections::table.on( - users_collections::collection_uuid.eq(collections::uuid).and( - users_collections::user_uuid.eq(user_uuid) - ) - )) - .left_join(users_organizations::table.on( - collections::org_uuid.eq(users_organizations::org_uuid).and( - users_organizations::user_uuid.eq(user_uuid) - ) - )) - .filter(collections::uuid.eq(uuid)) - .filter( - users_collections::collection_uuid.eq(uuid).or( // Directly accessed collection - users_organizations::access_all.eq(true).or( // access_all in Organization - users_organizations::atype.le(UserOrgType::Admin as i32) // Org admin or owner + db_run! { conn: { + collections::table + .left_join(users_collections::table.on( + users_collections::collection_uuid.eq(collections::uuid).and( + users_collections::user_uuid.eq(user_uuid) ) - ) - ).select(collections::all_columns) - .first::(&**conn).ok() + )) + .left_join(users_organizations::table.on( + collections::org_uuid.eq(users_organizations::org_uuid).and( + users_organizations::user_uuid.eq(user_uuid) + ) + )) + .filter(collections::uuid.eq(uuid)) + .filter( + users_collections::collection_uuid.eq(uuid).or( // Directly accessed collection + users_organizations::access_all.eq(true).or( // access_all in Organization + users_organizations::atype.le(UserOrgType::Admin as i32) // Org admin or owner + ) + ) + ).select(collections::all_columns) + .first::(conn).ok() + .from_db() + }} } pub fn is_writable_by_user(&self, user_uuid: &str, conn: &DbConn) -> bool { @@ -173,110 +211,108 @@ impl Collection { if user_org.access_all { true } else { - users_collections::table - .inner_join(collections::table) - .filter(users_collections::collection_uuid.eq(&self.uuid)) - .filter(users_collections::user_uuid.eq(&user_uuid)) - .filter(users_collections::read_only.eq(false)) - .select(collections::all_columns) - .first::(&**conn) - .ok() - .is_some() // Read only or no access to collection + db_run! { conn: { + users_collections::table + .inner_join(collections::table) + .filter(users_collections::collection_uuid.eq(&self.uuid)) + .filter(users_collections::user_uuid.eq(&user_uuid)) + .filter(users_collections::read_only.eq(false)) + .select(collections::all_columns) + .first::(conn) + .ok() + .is_some() // Read only or no access to collection + }} } } } } } -use super::User; - -#[derive(Debug, Identifiable, Queryable, Insertable, Associations)] -#[table_name = "users_collections"] -#[belongs_to(User, foreign_key = "user_uuid")] -#[belongs_to(Collection, foreign_key = "collection_uuid")] -#[primary_key(user_uuid, collection_uuid)] -pub struct CollectionUser { - pub user_uuid: String, - pub collection_uuid: String, - pub read_only: bool, - pub hide_passwords: bool, -} - /// Database methods impl CollectionUser { pub fn find_by_organization_and_user_uuid(org_uuid: &str, user_uuid: &str, conn: &DbConn) -> Vec { - users_collections::table - .filter(users_collections::user_uuid.eq(user_uuid)) - .inner_join(collections::table.on(collections::uuid.eq(users_collections::collection_uuid))) - .filter(collections::org_uuid.eq(org_uuid)) - .select(users_collections::all_columns) - .load::(&**conn) - .expect("Error loading users_collections") - } - - #[cfg(feature = "postgresql")] - pub fn save(user_uuid: &str, collection_uuid: &str, read_only: bool, hide_passwords: bool, conn: &DbConn) -> EmptyResult { - User::update_uuid_revision(&user_uuid, conn); - - diesel::insert_into(users_collections::table) - .values(( - users_collections::user_uuid.eq(user_uuid), - users_collections::collection_uuid.eq(collection_uuid), - users_collections::read_only.eq(read_only), - users_collections::hide_passwords.eq(hide_passwords), - )) - .on_conflict((users_collections::user_uuid, users_collections::collection_uuid)) - .do_update() - .set(( - users_collections::read_only.eq(read_only), - users_collections::hide_passwords.eq(hide_passwords), - )) - .execute(&**conn) - .map_res("Error adding user to collection") + db_run! { conn: { + users_collections::table + .filter(users_collections::user_uuid.eq(user_uuid)) + .inner_join(collections::table.on(collections::uuid.eq(users_collections::collection_uuid))) + .filter(collections::org_uuid.eq(org_uuid)) + .select(users_collections::all_columns) + .load::(conn) + .expect("Error loading users_collections") + .from_db() + }} } - #[cfg(not(feature = "postgresql"))] pub fn save(user_uuid: &str, collection_uuid: &str, read_only: bool, hide_passwords: bool, conn: &DbConn) -> EmptyResult { User::update_uuid_revision(&user_uuid, conn); - diesel::replace_into(users_collections::table) - .values(( - users_collections::user_uuid.eq(user_uuid), - users_collections::collection_uuid.eq(collection_uuid), - users_collections::read_only.eq(read_only), - users_collections::hide_passwords.eq(hide_passwords), - )) - .execute(&**conn) - .map_res("Error adding user to collection") + db_run! { conn: + sqlite, mysql { + diesel::replace_into(users_collections::table) + .values(( + users_collections::user_uuid.eq(user_uuid), + users_collections::collection_uuid.eq(collection_uuid), + users_collections::read_only.eq(read_only), + users_collections::hide_passwords.eq(hide_passwords), + )) + .execute(conn) + .map_res("Error adding user to collection") + } + postgresql { + diesel::insert_into(users_collections::table) + .values(( + users_collections::user_uuid.eq(user_uuid), + users_collections::collection_uuid.eq(collection_uuid), + users_collections::read_only.eq(read_only), + users_collections::hide_passwords.eq(hide_passwords), + )) + .on_conflict((users_collections::user_uuid, users_collections::collection_uuid)) + .do_update() + .set(( + users_collections::read_only.eq(read_only), + users_collections::hide_passwords.eq(hide_passwords), + )) + .execute(conn) + .map_res("Error adding user to collection") + } + } } pub fn delete(self, conn: &DbConn) -> EmptyResult { User::update_uuid_revision(&self.user_uuid, conn); - diesel::delete( - users_collections::table - .filter(users_collections::user_uuid.eq(&self.user_uuid)) - .filter(users_collections::collection_uuid.eq(&self.collection_uuid)), - ) - .execute(&**conn) - .map_res("Error removing user from collection") + db_run! { conn: { + diesel::delete( + users_collections::table + .filter(users_collections::user_uuid.eq(&self.user_uuid)) + .filter(users_collections::collection_uuid.eq(&self.collection_uuid)), + ) + .execute(conn) + .map_res("Error removing user from collection") + }} } pub fn find_by_collection(collection_uuid: &str, conn: &DbConn) -> Vec { - users_collections::table - .filter(users_collections::collection_uuid.eq(collection_uuid)) - .select(users_collections::all_columns) - .load::(&**conn) - .expect("Error loading users_collections") + db_run! { conn: { + users_collections::table + .filter(users_collections::collection_uuid.eq(collection_uuid)) + .select(users_collections::all_columns) + .load::(conn) + .expect("Error loading users_collections") + .from_db() + }} } pub fn find_by_collection_and_user(collection_uuid: &str, user_uuid: &str, conn: &DbConn) -> Option { - users_collections::table - .filter(users_collections::collection_uuid.eq(collection_uuid)) - .filter(users_collections::user_uuid.eq(user_uuid)) - .select(users_collections::all_columns) - .first::(&**conn) - .ok() + db_run! { conn: { + users_collections::table + .filter(users_collections::collection_uuid.eq(collection_uuid)) + .filter(users_collections::user_uuid.eq(user_uuid)) + .select(users_collections::all_columns) + .first::(conn) + .ok() + .from_db() + }} } pub fn delete_all_by_collection(collection_uuid: &str, conn: &DbConn) -> EmptyResult { @@ -286,81 +322,81 @@ impl CollectionUser { User::update_uuid_revision(&collection.user_uuid, conn); }); - diesel::delete(users_collections::table.filter(users_collections::collection_uuid.eq(collection_uuid))) - .execute(&**conn) - .map_res("Error deleting users from collection") + db_run! { conn: { + diesel::delete(users_collections::table.filter(users_collections::collection_uuid.eq(collection_uuid))) + .execute(conn) + .map_res("Error deleting users from collection") + }} } pub fn delete_all_by_user(user_uuid: &str, conn: &DbConn) -> EmptyResult { User::update_uuid_revision(&user_uuid, conn); - diesel::delete(users_collections::table.filter(users_collections::user_uuid.eq(user_uuid))) - .execute(&**conn) - .map_res("Error removing user from collections") + db_run! { conn: { + diesel::delete(users_collections::table.filter(users_collections::user_uuid.eq(user_uuid))) + .execute(conn) + .map_res("Error removing user from collections") + }} } } -use super::Cipher; - -#[derive(Debug, Identifiable, Queryable, Insertable, Associations)] -#[table_name = "ciphers_collections"] -#[belongs_to(Cipher, foreign_key = "cipher_uuid")] -#[belongs_to(Collection, foreign_key = "collection_uuid")] -#[primary_key(cipher_uuid, collection_uuid)] -pub struct CollectionCipher { - pub cipher_uuid: String, - pub collection_uuid: String, -} - /// Database methods impl CollectionCipher { - #[cfg(feature = "postgresql")] pub fn save(cipher_uuid: &str, collection_uuid: &str, conn: &DbConn) -> EmptyResult { Self::update_users_revision(&collection_uuid, conn); - diesel::insert_into(ciphers_collections::table) - .values(( - ciphers_collections::cipher_uuid.eq(cipher_uuid), - ciphers_collections::collection_uuid.eq(collection_uuid), - )) - .on_conflict((ciphers_collections::cipher_uuid, ciphers_collections::collection_uuid)) - .do_nothing() - .execute(&**conn) - .map_res("Error adding cipher to collection") - } - #[cfg(not(feature = "postgresql"))] - pub fn save(cipher_uuid: &str, collection_uuid: &str, conn: &DbConn) -> EmptyResult { - Self::update_users_revision(&collection_uuid, conn); - diesel::replace_into(ciphers_collections::table) - .values(( - ciphers_collections::cipher_uuid.eq(cipher_uuid), - ciphers_collections::collection_uuid.eq(collection_uuid), - )) - .execute(&**conn) - .map_res("Error adding cipher to collection") + db_run! { conn: + sqlite, mysql { + diesel::replace_into(ciphers_collections::table) + .values(( + ciphers_collections::cipher_uuid.eq(cipher_uuid), + ciphers_collections::collection_uuid.eq(collection_uuid), + )) + .execute(conn) + .map_res("Error adding cipher to collection") + } + postgresql { + diesel::insert_into(ciphers_collections::table) + .values(( + ciphers_collections::cipher_uuid.eq(cipher_uuid), + ciphers_collections::collection_uuid.eq(collection_uuid), + )) + .on_conflict((ciphers_collections::cipher_uuid, ciphers_collections::collection_uuid)) + .do_nothing() + .execute(conn) + .map_res("Error adding cipher to collection") + } + } } pub fn delete(cipher_uuid: &str, collection_uuid: &str, conn: &DbConn) -> EmptyResult { Self::update_users_revision(&collection_uuid, conn); - diesel::delete( - ciphers_collections::table - .filter(ciphers_collections::cipher_uuid.eq(cipher_uuid)) - .filter(ciphers_collections::collection_uuid.eq(collection_uuid)), - ) - .execute(&**conn) - .map_res("Error deleting cipher from collection") + + db_run! { conn: { + diesel::delete( + ciphers_collections::table + .filter(ciphers_collections::cipher_uuid.eq(cipher_uuid)) + .filter(ciphers_collections::collection_uuid.eq(collection_uuid)), + ) + .execute(conn) + .map_res("Error deleting cipher from collection") + }} } pub fn delete_all_by_cipher(cipher_uuid: &str, conn: &DbConn) -> EmptyResult { - diesel::delete(ciphers_collections::table.filter(ciphers_collections::cipher_uuid.eq(cipher_uuid))) - .execute(&**conn) - .map_res("Error removing cipher from collections") + db_run! { conn: { + diesel::delete(ciphers_collections::table.filter(ciphers_collections::cipher_uuid.eq(cipher_uuid))) + .execute(conn) + .map_res("Error removing cipher from collections") + }} } pub fn delete_all_by_collection(collection_uuid: &str, conn: &DbConn) -> EmptyResult { - diesel::delete(ciphers_collections::table.filter(ciphers_collections::collection_uuid.eq(collection_uuid))) - .execute(&**conn) - .map_res("Error removing ciphers from collection") + db_run! { conn: { + diesel::delete(ciphers_collections::table.filter(ciphers_collections::collection_uuid.eq(collection_uuid))) + .execute(conn) + .map_res("Error removing ciphers from collection") + }} } pub fn update_users_revision(collection_uuid: &str, conn: &DbConn) { diff --git a/src/db/models/device.rs b/src/db/models/device.rs index 8b6bf2c3..6d6743ff 100644 --- a/src/db/models/device.rs +++ b/src/db/models/device.rs @@ -3,26 +3,28 @@ use chrono::{NaiveDateTime, Utc}; use super::User; use crate::CONFIG; -#[derive(Debug, Identifiable, Queryable, Insertable, Associations, AsChangeset)] -#[table_name = "devices"] -#[changeset_options(treat_none_as_null="true")] -#[belongs_to(User, foreign_key = "user_uuid")] -#[primary_key(uuid)] -pub struct Device { - pub uuid: String, - pub created_at: NaiveDateTime, - pub updated_at: NaiveDateTime, - - pub user_uuid: String, - - pub name: String, - /// https://github.com/bitwarden/core/tree/master/src/Core/Enums - pub atype: i32, - pub push_token: Option, - - pub refresh_token: String, - - pub twofactor_remember: Option, +db_object! { + #[derive(Debug, Identifiable, Queryable, Insertable, Associations, AsChangeset)] + #[table_name = "devices"] + #[changeset_options(treat_none_as_null="true")] + #[belongs_to(User, foreign_key = "user_uuid")] + #[primary_key(uuid)] + pub struct Device { + pub uuid: String, + pub created_at: NaiveDateTime, + pub updated_at: NaiveDateTime, + + pub user_uuid: String, + + pub name: String, + // https://github.com/bitwarden/core/tree/master/src/Core/Enums + pub atype: i32, + pub push_token: Option, + + pub refresh_token: String, + + pub twofactor_remember: Option, + } } /// Local methods @@ -105,41 +107,39 @@ impl Device { } } -use crate::db::schema::devices; use crate::db::DbConn; -use diesel::prelude::*; use crate::api::EmptyResult; use crate::error::MapResult; /// Database methods impl Device { - #[cfg(feature = "postgresql")] pub fn save(&mut self, conn: &DbConn) -> EmptyResult { self.updated_at = Utc::now().naive_utc(); - crate::util::retry( - || diesel::insert_into(devices::table).values(&*self).on_conflict(devices::uuid).do_update().set(&*self).execute(&**conn), - 10, - ) - .map_res("Error saving device") - } - - #[cfg(not(feature = "postgresql"))] - pub fn save(&mut self, conn: &DbConn) -> EmptyResult { - self.updated_at = Utc::now().naive_utc(); - - crate::util::retry( - || diesel::replace_into(devices::table).values(&*self).execute(&**conn), - 10, - ) - .map_res("Error saving device") + db_run! { conn: + sqlite, mysql { + crate::util::retry( + || diesel::replace_into(devices::table).values(DeviceDb::to_db(self)).execute(conn), + 10, + ).map_res("Error saving device") + } + postgresql { + let value = DeviceDb::to_db(self); + crate::util::retry( + || diesel::insert_into(devices::table).values(&value).on_conflict(devices::uuid).do_update().set(&value).execute(conn), + 10, + ).map_res("Error saving device") + } + } } pub fn delete(self, conn: &DbConn) -> EmptyResult { - diesel::delete(devices::table.filter(devices::uuid.eq(self.uuid))) - .execute(&**conn) - .map_res("Error removing device") + db_run! { conn: { + diesel::delete(devices::table.filter(devices::uuid.eq(self.uuid))) + .execute(conn) + .map_res("Error removing device") + }} } pub fn delete_all_by_user(user_uuid: &str, conn: &DbConn) -> EmptyResult { @@ -150,23 +150,32 @@ impl Device { } pub fn find_by_uuid(uuid: &str, conn: &DbConn) -> Option { - devices::table - .filter(devices::uuid.eq(uuid)) - .first::(&**conn) - .ok() + db_run! { conn: { + devices::table + .filter(devices::uuid.eq(uuid)) + .first::(conn) + .ok() + .from_db() + }} } pub fn find_by_refresh_token(refresh_token: &str, conn: &DbConn) -> Option { - devices::table - .filter(devices::refresh_token.eq(refresh_token)) - .first::(&**conn) - .ok() + db_run! { conn: { + devices::table + .filter(devices::refresh_token.eq(refresh_token)) + .first::(conn) + .ok() + .from_db() + }} } pub fn find_by_user(user_uuid: &str, conn: &DbConn) -> Vec { - devices::table - .filter(devices::user_uuid.eq(user_uuid)) - .load::(&**conn) - .expect("Error loading devices") + db_run! { conn: { + devices::table + .filter(devices::user_uuid.eq(user_uuid)) + .load::(conn) + .expect("Error loading devices") + .from_db() + }} } } diff --git a/src/db/models/folder.rs b/src/db/models/folder.rs index bea54473..5ff72b75 100644 --- a/src/db/models/folder.rs +++ b/src/db/models/folder.rs @@ -3,26 +3,28 @@ use serde_json::Value; use super::{Cipher, User}; -#[derive(Debug, Identifiable, Queryable, Insertable, Associations, AsChangeset)] -#[table_name = "folders"] -#[belongs_to(User, foreign_key = "user_uuid")] -#[primary_key(uuid)] -pub struct Folder { - pub uuid: String, - pub created_at: NaiveDateTime, - pub updated_at: NaiveDateTime, - pub user_uuid: String, - pub name: String, -} - -#[derive(Debug, Identifiable, Queryable, Insertable, Associations)] -#[table_name = "folders_ciphers"] -#[belongs_to(Cipher, foreign_key = "cipher_uuid")] -#[belongs_to(Folder, foreign_key = "folder_uuid")] -#[primary_key(cipher_uuid, folder_uuid)] -pub struct FolderCipher { - pub cipher_uuid: String, - pub folder_uuid: String, +db_object! { + #[derive(Debug, Identifiable, Queryable, Insertable, Associations, AsChangeset)] + #[table_name = "folders"] + #[belongs_to(User, foreign_key = "user_uuid")] + #[primary_key(uuid)] + pub struct Folder { + pub uuid: String, + pub created_at: NaiveDateTime, + pub updated_at: NaiveDateTime, + pub user_uuid: String, + pub name: String, + } + + #[derive(Debug, Identifiable, Queryable, Insertable, Associations)] + #[table_name = "folders_ciphers"] + #[belongs_to(Cipher, foreign_key = "cipher_uuid")] + #[belongs_to(Folder, foreign_key = "folder_uuid")] + #[primary_key(cipher_uuid, folder_uuid)] + pub struct FolderCipher { + pub cipher_uuid: String, + pub folder_uuid: String, + } } /// Local methods @@ -61,47 +63,47 @@ impl FolderCipher { } } -use crate::db::schema::{folders, folders_ciphers}; use crate::db::DbConn; -use diesel::prelude::*; use crate::api::EmptyResult; use crate::error::MapResult; /// Database methods impl Folder { - #[cfg(feature = "postgresql")] - pub fn save(&mut self, conn: &DbConn) -> EmptyResult { - User::update_uuid_revision(&self.user_uuid, conn); - self.updated_at = Utc::now().naive_utc(); - - diesel::insert_into(folders::table) - .values(&*self) - .on_conflict(folders::uuid) - .do_update() - .set(&*self) - .execute(&**conn) - .map_res("Error saving folder") - } - - #[cfg(not(feature = "postgresql"))] pub fn save(&mut self, conn: &DbConn) -> EmptyResult { User::update_uuid_revision(&self.user_uuid, conn); self.updated_at = Utc::now().naive_utc(); - diesel::replace_into(folders::table) - .values(&*self) - .execute(&**conn) - .map_res("Error saving folder") + db_run! { conn: + sqlite, mysql { + diesel::replace_into(folders::table) + .values(FolderDb::to_db(self)) + .execute(conn) + .map_res("Error saving folder") + } + postgresql { + let value = FolderDb::to_db(self); + diesel::insert_into(folders::table) + .values(&value) + .on_conflict(folders::uuid) + .do_update() + .set(&value) + .execute(conn) + .map_res("Error saving folder") + } + } } pub fn delete(&self, conn: &DbConn) -> EmptyResult { User::update_uuid_revision(&self.user_uuid, conn); FolderCipher::delete_all_by_folder(&self.uuid, &conn)?; - diesel::delete(folders::table.filter(folders::uuid.eq(&self.uuid))) - .execute(&**conn) - .map_res("Error deleting folder") + + db_run! { conn: { + diesel::delete(folders::table.filter(folders::uuid.eq(&self.uuid))) + .execute(conn) + .map_res("Error deleting folder") + }} } pub fn delete_all_by_user(user_uuid: &str, conn: &DbConn) -> EmptyResult { @@ -112,73 +114,92 @@ impl Folder { } pub fn find_by_uuid(uuid: &str, conn: &DbConn) -> Option { - folders::table - .filter(folders::uuid.eq(uuid)) - .first::(&**conn) - .ok() + db_run! { conn: { + folders::table + .filter(folders::uuid.eq(uuid)) + .first::(conn) + .ok() + .from_db() + }} } pub fn find_by_user(user_uuid: &str, conn: &DbConn) -> Vec { - folders::table - .filter(folders::user_uuid.eq(user_uuid)) - .load::(&**conn) - .expect("Error loading folders") + db_run! { conn: { + folders::table + .filter(folders::user_uuid.eq(user_uuid)) + .load::(conn) + .expect("Error loading folders") + .from_db() + }} } } impl FolderCipher { - #[cfg(feature = "postgresql")] - pub fn save(&self, conn: &DbConn) -> EmptyResult { - diesel::insert_into(folders_ciphers::table) - .values(&*self) - .on_conflict((folders_ciphers::cipher_uuid, folders_ciphers::folder_uuid)) - .do_nothing() - .execute(&**conn) - .map_res("Error adding cipher to folder") - } - - #[cfg(not(feature = "postgresql"))] pub fn save(&self, conn: &DbConn) -> EmptyResult { - diesel::replace_into(folders_ciphers::table) - .values(&*self) - .execute(&**conn) - .map_res("Error adding cipher to folder") + db_run! { conn: + sqlite, mysql { + diesel::replace_into(folders_ciphers::table) + .values(FolderCipherDb::to_db(self)) + .execute(conn) + .map_res("Error adding cipher to folder") + } + postgresql { + diesel::insert_into(folders_ciphers::table) + .values(FolderCipherDb::to_db(self)) + .on_conflict((folders_ciphers::cipher_uuid, folders_ciphers::folder_uuid)) + .do_nothing() + .execute(conn) + .map_res("Error adding cipher to folder") + } + } } pub fn delete(self, conn: &DbConn) -> EmptyResult { - diesel::delete( - folders_ciphers::table - .filter(folders_ciphers::cipher_uuid.eq(self.cipher_uuid)) - .filter(folders_ciphers::folder_uuid.eq(self.folder_uuid)), - ) - .execute(&**conn) - .map_res("Error removing cipher from folder") + db_run! { conn: { + diesel::delete( + folders_ciphers::table + .filter(folders_ciphers::cipher_uuid.eq(self.cipher_uuid)) + .filter(folders_ciphers::folder_uuid.eq(self.folder_uuid)), + ) + .execute(conn) + .map_res("Error removing cipher from folder") + }} } pub fn delete_all_by_cipher(cipher_uuid: &str, conn: &DbConn) -> EmptyResult { - diesel::delete(folders_ciphers::table.filter(folders_ciphers::cipher_uuid.eq(cipher_uuid))) - .execute(&**conn) - .map_res("Error removing cipher from folders") + db_run! { conn: { + diesel::delete(folders_ciphers::table.filter(folders_ciphers::cipher_uuid.eq(cipher_uuid))) + .execute(conn) + .map_res("Error removing cipher from folders") + }} } pub fn delete_all_by_folder(folder_uuid: &str, conn: &DbConn) -> EmptyResult { - diesel::delete(folders_ciphers::table.filter(folders_ciphers::folder_uuid.eq(folder_uuid))) - .execute(&**conn) - .map_res("Error removing ciphers from folder") + db_run! { conn: { + diesel::delete(folders_ciphers::table.filter(folders_ciphers::folder_uuid.eq(folder_uuid))) + .execute(conn) + .map_res("Error removing ciphers from folder") + }} } pub fn find_by_folder_and_cipher(folder_uuid: &str, cipher_uuid: &str, conn: &DbConn) -> Option { - folders_ciphers::table - .filter(folders_ciphers::folder_uuid.eq(folder_uuid)) - .filter(folders_ciphers::cipher_uuid.eq(cipher_uuid)) - .first::(&**conn) - .ok() + db_run! { conn: { + folders_ciphers::table + .filter(folders_ciphers::folder_uuid.eq(folder_uuid)) + .filter(folders_ciphers::cipher_uuid.eq(cipher_uuid)) + .first::(conn) + .ok() + .from_db() + }} } pub fn find_by_folder(folder_uuid: &str, conn: &DbConn) -> Vec { - folders_ciphers::table - .filter(folders_ciphers::folder_uuid.eq(folder_uuid)) - .load::(&**conn) - .expect("Error loading folders") + db_run! { conn: { + folders_ciphers::table + .filter(folders_ciphers::folder_uuid.eq(folder_uuid)) + .load::(conn) + .expect("Error loading folders") + .from_db() + }} } } diff --git a/src/db/models/org_policy.rs b/src/db/models/org_policy.rs index 7963a0f8..a58ca53f 100644 --- a/src/db/models/org_policy.rs +++ b/src/db/models/org_policy.rs @@ -1,23 +1,23 @@ -use diesel::prelude::*; use serde_json::Value; use crate::api::EmptyResult; -use crate::db::schema::org_policies; use crate::db::DbConn; use crate::error::MapResult; use super::Organization; -#[derive(Debug, Identifiable, Queryable, Insertable, Associations, AsChangeset)] -#[table_name = "org_policies"] -#[belongs_to(Organization, foreign_key = "org_uuid")] -#[primary_key(uuid)] -pub struct OrgPolicy { - pub uuid: String, - pub org_uuid: String, - pub atype: i32, - pub enabled: bool, - pub data: String, +db_object! { + #[derive(Debug, Identifiable, Queryable, Insertable, Associations, AsChangeset)] + #[table_name = "org_policies"] + #[belongs_to(Organization, foreign_key = "org_uuid")] + #[primary_key(uuid)] + pub struct OrgPolicy { + pub uuid: String, + pub org_uuid: String, + pub atype: i32, + pub enabled: bool, + pub data: String, + } } #[allow(dead_code)] @@ -55,87 +55,105 @@ impl OrgPolicy { /// Database methods impl OrgPolicy { - #[cfg(feature = "postgresql")] - pub fn save(&self, conn: &DbConn) -> EmptyResult { - // We need to make sure we're not going to violate the unique constraint on org_uuid and atype. - // This happens automatically on other DBMS backends due to replace_into(). PostgreSQL does - // not support multiple constraints on ON CONFLICT clauses. - diesel::delete( - org_policies::table - .filter(org_policies::org_uuid.eq(&self.org_uuid)) - .filter(org_policies::atype.eq(&self.atype)), - ) - .execute(&**conn) - .map_res("Error deleting org_policy for insert")?; - - diesel::insert_into(org_policies::table) - .values(self) - .on_conflict(org_policies::uuid) - .do_update() - .set(self) - .execute(&**conn) - .map_res("Error saving org_policy") - } - - #[cfg(not(feature = "postgresql"))] pub fn save(&self, conn: &DbConn) -> EmptyResult { - diesel::replace_into(org_policies::table) - .values(&*self) - .execute(&**conn) - .map_res("Error saving org_policy") + db_run! { conn: + sqlite, mysql { + diesel::replace_into(org_policies::table) + .values(OrgPolicyDb::to_db(self)) + .execute(conn) + .map_res("Error saving org_policy") + } + postgresql { + let value = OrgPolicyDb::to_db(self); + // We need to make sure we're not going to violate the unique constraint on org_uuid and atype. + // This happens automatically on other DBMS backends due to replace_into(). PostgreSQL does + // not support multiple constraints on ON CONFLICT clauses. + diesel::delete( + org_policies::table + .filter(org_policies::org_uuid.eq(&self.org_uuid)) + .filter(org_policies::atype.eq(&self.atype)), + ) + .execute(conn) + .map_res("Error deleting org_policy for insert")?; + + diesel::insert_into(org_policies::table) + .values(&value) + .on_conflict(org_policies::uuid) + .do_update() + .set(&value) + .execute(conn) + .map_res("Error saving org_policy") + } + } } pub fn delete(self, conn: &DbConn) -> EmptyResult { - diesel::delete(org_policies::table.filter(org_policies::uuid.eq(self.uuid))) - .execute(&**conn) - .map_res("Error deleting org_policy") + db_run! { conn: { + diesel::delete(org_policies::table.filter(org_policies::uuid.eq(self.uuid))) + .execute(conn) + .map_res("Error deleting org_policy") + }} } pub fn find_by_uuid(uuid: &str, conn: &DbConn) -> Option { - org_policies::table - .filter(org_policies::uuid.eq(uuid)) - .first::(&**conn) - .ok() + db_run! { conn: { + org_policies::table + .filter(org_policies::uuid.eq(uuid)) + .first::(conn) + .ok() + .from_db() + }} } pub fn find_by_org(org_uuid: &str, conn: &DbConn) -> Vec { - org_policies::table - .filter(org_policies::org_uuid.eq(org_uuid)) - .load::(&**conn) - .expect("Error loading org_policy") + db_run! { conn: { + org_policies::table + .filter(org_policies::org_uuid.eq(org_uuid)) + .load::(conn) + .expect("Error loading org_policy") + .from_db() + }} } pub fn find_by_user(user_uuid: &str, conn: &DbConn) -> Vec { - use crate::db::schema::users_organizations; - - org_policies::table - .left_join( - users_organizations::table.on( - users_organizations::org_uuid.eq(org_policies::org_uuid) - .and(users_organizations::user_uuid.eq(user_uuid))) - ) - .select(org_policies::all_columns) - .load::(&**conn) - .expect("Error loading org_policy") + db_run! { conn: { + org_policies::table + .left_join( + users_organizations::table.on( + users_organizations::org_uuid.eq(org_policies::org_uuid) + .and(users_organizations::user_uuid.eq(user_uuid))) + ) + .select(org_policies::all_columns) + .load::(conn) + .expect("Error loading org_policy") + .from_db() + }} } pub fn find_by_org_and_type(org_uuid: &str, atype: i32, conn: &DbConn) -> Option { - org_policies::table - .filter(org_policies::org_uuid.eq(org_uuid)) - .filter(org_policies::atype.eq(atype)) - .first::(&**conn) - .ok() + db_run! { conn: { + org_policies::table + .filter(org_policies::org_uuid.eq(org_uuid)) + .filter(org_policies::atype.eq(atype)) + .first::(conn) + .ok() + .from_db() + }} } pub fn delete_all_by_organization(org_uuid: &str, conn: &DbConn) -> EmptyResult { - diesel::delete(org_policies::table.filter(org_policies::org_uuid.eq(org_uuid))) - .execute(&**conn) - .map_res("Error deleting org_policy") + db_run! { conn: { + diesel::delete(org_policies::table.filter(org_policies::org_uuid.eq(org_uuid))) + .execute(conn) + .map_res("Error deleting org_policy") + }} } /*pub fn delete_all_by_user(user_uuid: &str, conn: &DbConn) -> EmptyResult { - diesel::delete(twofactor::table.filter(twofactor::user_uuid.eq(user_uuid))) - .execute(&**conn) - .map_res("Error deleting twofactors") + db_run! { conn: { + diesel::delete(twofactor::table.filter(twofactor::user_uuid.eq(user_uuid))) + .execute(conn) + .map_res("Error deleting twofactors") + }} }*/ } diff --git a/src/db/models/organization.rs b/src/db/models/organization.rs index 42c4cd9b..aa9d9494 100644 --- a/src/db/models/organization.rs +++ b/src/db/models/organization.rs @@ -4,27 +4,29 @@ use num_traits::FromPrimitive; use super::{CollectionUser, User, OrgPolicy}; -#[derive(Debug, Identifiable, Queryable, Insertable, AsChangeset)] -#[table_name = "organizations"] -#[primary_key(uuid)] -pub struct Organization { - pub uuid: String, - pub name: String, - pub billing_email: String, -} - -#[derive(Debug, Identifiable, Queryable, Insertable, AsChangeset)] -#[table_name = "users_organizations"] -#[primary_key(uuid)] -pub struct UserOrganization { - pub uuid: String, - pub user_uuid: String, - pub org_uuid: String, - - pub access_all: bool, - pub akey: String, - pub status: i32, - pub atype: i32, +db_object! { + #[derive(Debug, Identifiable, Queryable, Insertable, AsChangeset)] + #[table_name = "organizations"] + #[primary_key(uuid)] + pub struct Organization { + pub uuid: String, + pub name: String, + pub billing_email: String, + } + + #[derive(Debug, Identifiable, Queryable, Insertable, AsChangeset)] + #[table_name = "users_organizations"] + #[primary_key(uuid)] + pub struct UserOrganization { + pub uuid: String, + pub user_uuid: String, + pub org_uuid: String, + + pub access_all: bool, + pub akey: String, + pub status: i32, + pub atype: i32, + } } pub enum UserOrgStatus { @@ -196,16 +198,13 @@ impl UserOrganization { } } -use crate::db::schema::{ciphers_collections, organizations, users_collections, users_organizations}; use crate::db::DbConn; -use diesel::prelude::*; use crate::api::EmptyResult; use crate::error::MapResult; /// Database methods impl Organization { - #[cfg(feature = "postgresql")] pub fn save(&self, conn: &DbConn) -> EmptyResult { UserOrganization::find_by_org(&self.uuid, conn) .iter() @@ -213,27 +212,24 @@ impl Organization { User::update_uuid_revision(&user_org.user_uuid, conn); }); - diesel::insert_into(organizations::table) - .values(self) - .on_conflict(organizations::uuid) - .do_update() - .set(self) - .execute(&**conn) - .map_res("Error saving organization") - } - - #[cfg(not(feature = "postgresql"))] - pub fn save(&self, conn: &DbConn) -> EmptyResult { - UserOrganization::find_by_org(&self.uuid, conn) - .iter() - .for_each(|user_org| { - User::update_uuid_revision(&user_org.user_uuid, conn); - }); - - diesel::replace_into(organizations::table) - .values(self) - .execute(&**conn) - .map_res("Error saving organization") + db_run! { conn: + sqlite, mysql { + diesel::replace_into(organizations::table) + .values(OrganizationDb::to_db(self)) + .execute(conn) + .map_res("Error saving organization") + } + postgresql { + let value = OrganizationDb::to_db(self); + diesel::insert_into(organizations::table) + .values(&value) + .on_conflict(organizations::uuid) + .do_update() + .set(&value) + .execute(conn) + .map_res("Error saving organization") + } + } } pub fn delete(self, conn: &DbConn) -> EmptyResult { @@ -244,20 +240,27 @@ impl Organization { UserOrganization::delete_all_by_organization(&self.uuid, &conn)?; OrgPolicy::delete_all_by_organization(&self.uuid, &conn)?; - diesel::delete(organizations::table.filter(organizations::uuid.eq(self.uuid))) - .execute(&**conn) - .map_res("Error saving organization") + + db_run! { conn: { + diesel::delete(organizations::table.filter(organizations::uuid.eq(self.uuid))) + .execute(conn) + .map_res("Error saving organization") + }} } pub fn find_by_uuid(uuid: &str, conn: &DbConn) -> Option { - organizations::table - .filter(organizations::uuid.eq(uuid)) - .first::(&**conn) - .ok() + db_run! { conn: { + organizations::table + .filter(organizations::uuid.eq(uuid)) + .first::(conn) + .ok().from_db() + }} } pub fn get_all(conn: &DbConn) -> Vec { - organizations::table.load::(&**conn).expect("Error loading organizations") + db_run! { conn: { + organizations::table.load::(conn).expect("Error loading organizations").from_db() + }} } } @@ -345,28 +348,27 @@ impl UserOrganization { "Object": "organizationUserDetails", }) } - - #[cfg(feature = "postgresql")] pub fn save(&self, conn: &DbConn) -> EmptyResult { User::update_uuid_revision(&self.user_uuid, conn); - diesel::insert_into(users_organizations::table) - .values(self) - .on_conflict(users_organizations::uuid) - .do_update() - .set(self) - .execute(&**conn) - .map_res("Error adding user to organization") - } - - #[cfg(not(feature = "postgresql"))] - pub fn save(&self, conn: &DbConn) -> EmptyResult { - User::update_uuid_revision(&self.user_uuid, conn); - - diesel::replace_into(users_organizations::table) - .values(self) - .execute(&**conn) - .map_res("Error adding user to organization") + db_run! { conn: + sqlite, mysql { + diesel::replace_into(users_organizations::table) + .values(UserOrganizationDb::to_db(self)) + .execute(conn) + .map_res("Error adding user to organization") + } + postgresql { + let value = UserOrganizationDb::to_db(self); + diesel::insert_into(users_organizations::table) + .values(&value) + .on_conflict(users_organizations::uuid) + .do_update() + .set(&value) + .execute(conn) + .map_res("Error adding user to organization") + } + } } pub fn delete(self, conn: &DbConn) -> EmptyResult { @@ -374,9 +376,11 @@ impl UserOrganization { CollectionUser::delete_all_by_user(&self.user_uuid, &conn)?; - diesel::delete(users_organizations::table.filter(users_organizations::uuid.eq(self.uuid))) - .execute(&**conn) - .map_res("Error removing user from organization") + db_run! { conn: { + diesel::delete(users_organizations::table.filter(users_organizations::uuid.eq(self.uuid))) + .execute(conn) + .map_res("Error removing user from organization") + }} } pub fn delete_all_by_organization(org_uuid: &str, conn: &DbConn) -> EmptyResult { @@ -403,107 +407,129 @@ impl UserOrganization { } pub fn find_by_uuid(uuid: &str, conn: &DbConn) -> Option { - users_organizations::table - .filter(users_organizations::uuid.eq(uuid)) - .first::(&**conn) - .ok() + db_run! { conn: { + users_organizations::table + .filter(users_organizations::uuid.eq(uuid)) + .first::(conn) + .ok().from_db() + }} } pub fn find_by_uuid_and_org(uuid: &str, org_uuid: &str, conn: &DbConn) -> Option { - users_organizations::table - .filter(users_organizations::uuid.eq(uuid)) - .filter(users_organizations::org_uuid.eq(org_uuid)) - .first::(&**conn) - .ok() + db_run! { conn: { + users_organizations::table + .filter(users_organizations::uuid.eq(uuid)) + .filter(users_organizations::org_uuid.eq(org_uuid)) + .first::(conn) + .ok().from_db() + }} } pub fn find_by_user(user_uuid: &str, conn: &DbConn) -> Vec { - users_organizations::table - .filter(users_organizations::user_uuid.eq(user_uuid)) - .filter(users_organizations::status.eq(UserOrgStatus::Confirmed as i32)) - .load::(&**conn) - .unwrap_or_default() + db_run! { conn: { + users_organizations::table + .filter(users_organizations::user_uuid.eq(user_uuid)) + .filter(users_organizations::status.eq(UserOrgStatus::Confirmed as i32)) + .load::(conn) + .unwrap_or_default().from_db() + }} } pub fn find_invited_by_user(user_uuid: &str, conn: &DbConn) -> Vec { - users_organizations::table - .filter(users_organizations::user_uuid.eq(user_uuid)) - .filter(users_organizations::status.eq(UserOrgStatus::Invited as i32)) - .load::(&**conn) - .unwrap_or_default() + db_run! { conn: { + users_organizations::table + .filter(users_organizations::user_uuid.eq(user_uuid)) + .filter(users_organizations::status.eq(UserOrgStatus::Invited as i32)) + .load::(conn) + .unwrap_or_default().from_db() + }} } pub fn find_any_state_by_user(user_uuid: &str, conn: &DbConn) -> Vec { - users_organizations::table - .filter(users_organizations::user_uuid.eq(user_uuid)) - .load::(&**conn) - .unwrap_or_default() + db_run! { conn: { + users_organizations::table + .filter(users_organizations::user_uuid.eq(user_uuid)) + .load::(conn) + .unwrap_or_default().from_db() + }} } pub fn find_by_org(org_uuid: &str, conn: &DbConn) -> Vec { - users_organizations::table - .filter(users_organizations::org_uuid.eq(org_uuid)) - .load::(&**conn) - .expect("Error loading user organizations") + db_run! { conn: { + users_organizations::table + .filter(users_organizations::org_uuid.eq(org_uuid)) + .load::(conn) + .expect("Error loading user organizations").from_db() + }} } pub fn count_by_org(org_uuid: &str, conn: &DbConn) -> i64 { - users_organizations::table - .filter(users_organizations::org_uuid.eq(org_uuid)) - .count() - .first::(&**conn) - .ok() - .unwrap_or(0) + db_run! { conn: { + users_organizations::table + .filter(users_organizations::org_uuid.eq(org_uuid)) + .count() + .first::(conn) + .ok() + .unwrap_or(0) + }} } pub fn find_by_org_and_type(org_uuid: &str, atype: i32, conn: &DbConn) -> Vec { - users_organizations::table - .filter(users_organizations::org_uuid.eq(org_uuid)) - .filter(users_organizations::atype.eq(atype)) - .load::(&**conn) - .expect("Error loading user organizations") + db_run! { conn: { + users_organizations::table + .filter(users_organizations::org_uuid.eq(org_uuid)) + .filter(users_organizations::atype.eq(atype)) + .load::(conn) + .expect("Error loading user organizations").from_db() + }} } pub fn find_by_user_and_org(user_uuid: &str, org_uuid: &str, conn: &DbConn) -> Option { - users_organizations::table - .filter(users_organizations::user_uuid.eq(user_uuid)) - .filter(users_organizations::org_uuid.eq(org_uuid)) - .first::(&**conn) - .ok() + db_run! { conn: { + users_organizations::table + .filter(users_organizations::user_uuid.eq(user_uuid)) + .filter(users_organizations::org_uuid.eq(org_uuid)) + .first::(conn) + .ok().from_db() + }} } pub fn find_by_cipher_and_org(cipher_uuid: &str, org_uuid: &str, conn: &DbConn) -> Vec { - users_organizations::table - .filter(users_organizations::org_uuid.eq(org_uuid)) - .left_join(users_collections::table.on( - users_collections::user_uuid.eq(users_organizations::user_uuid) - )) - .left_join(ciphers_collections::table.on( - ciphers_collections::collection_uuid.eq(users_collections::collection_uuid).and( - ciphers_collections::cipher_uuid.eq(&cipher_uuid) - ) - )) - .filter( - users_organizations::access_all.eq(true).or( // AccessAll.. - ciphers_collections::cipher_uuid.eq(&cipher_uuid) // ..or access to collection with cipher + db_run! { conn: { + users_organizations::table + .filter(users_organizations::org_uuid.eq(org_uuid)) + .left_join(users_collections::table.on( + users_collections::user_uuid.eq(users_organizations::user_uuid) + )) + .left_join(ciphers_collections::table.on( + ciphers_collections::collection_uuid.eq(users_collections::collection_uuid).and( + ciphers_collections::cipher_uuid.eq(&cipher_uuid) + ) + )) + .filter( + users_organizations::access_all.eq(true).or( // AccessAll.. + ciphers_collections::cipher_uuid.eq(&cipher_uuid) // ..or access to collection with cipher + ) ) - ) - .select(users_organizations::all_columns) - .load::(&**conn).expect("Error loading user organizations") + .select(users_organizations::all_columns) + .load::(conn).expect("Error loading user organizations").from_db() + }} } pub fn find_by_collection_and_org(collection_uuid: &str, org_uuid: &str, conn: &DbConn) -> Vec { - users_organizations::table - .filter(users_organizations::org_uuid.eq(org_uuid)) - .left_join(users_collections::table.on( - users_collections::user_uuid.eq(users_organizations::user_uuid) - )) - .filter( - users_organizations::access_all.eq(true).or( // AccessAll.. - users_collections::collection_uuid.eq(&collection_uuid) // ..or access to collection with cipher + db_run! { conn: { + users_organizations::table + .filter(users_organizations::org_uuid.eq(org_uuid)) + .left_join(users_collections::table.on( + users_collections::user_uuid.eq(users_organizations::user_uuid) + )) + .filter( + users_organizations::access_all.eq(true).or( // AccessAll.. + users_collections::collection_uuid.eq(&collection_uuid) // ..or access to collection with cipher + ) ) - ) - .select(users_organizations::all_columns) - .load::(&**conn).expect("Error loading user organizations") + .select(users_organizations::all_columns) + .load::(conn).expect("Error loading user organizations").from_db() + }} } } diff --git a/src/db/models/two_factor.rs b/src/db/models/two_factor.rs index 2504baab..a1b925cd 100644 --- a/src/db/models/two_factor.rs +++ b/src/db/models/two_factor.rs @@ -1,24 +1,24 @@ -use diesel::prelude::*; use serde_json::Value; use crate::api::EmptyResult; -use crate::db::schema::twofactor; use crate::db::DbConn; use crate::error::MapResult; use super::User; -#[derive(Debug, Identifiable, Queryable, Insertable, Associations, AsChangeset)] -#[table_name = "twofactor"] -#[belongs_to(User, foreign_key = "user_uuid")] -#[primary_key(uuid)] -pub struct TwoFactor { - pub uuid: String, - pub user_uuid: String, - pub atype: i32, - pub enabled: bool, - pub data: String, - pub last_used: i32, +db_object! { + #[derive(Debug, Identifiable, Queryable, Insertable, Associations, AsChangeset)] + #[table_name = "twofactor"] + #[belongs_to(User, foreign_key = "user_uuid")] + #[primary_key(uuid)] + pub struct TwoFactor { + pub uuid: String, + pub user_uuid: String, + pub atype: i32, + pub enabled: bool, + pub data: String, + pub last_used: i32, + } } #[allow(dead_code)] @@ -70,57 +70,69 @@ impl TwoFactor { /// Database methods impl TwoFactor { - #[cfg(feature = "postgresql")] pub fn save(&self, conn: &DbConn) -> EmptyResult { - // We need to make sure we're not going to violate the unique constraint on user_uuid and atype. - // This happens automatically on other DBMS backends due to replace_into(). PostgreSQL does - // not support multiple constraints on ON CONFLICT clauses. - diesel::delete(twofactor::table.filter(twofactor::user_uuid.eq(&self.user_uuid)).filter(twofactor::atype.eq(&self.atype))) - .execute(&**conn) - .map_res("Error deleting twofactor for insert")?; - - diesel::insert_into(twofactor::table) - .values(self) - .on_conflict(twofactor::uuid) - .do_update() - .set(self) - .execute(&**conn) - .map_res("Error saving twofactor") - } + db_run! { conn: + sqlite, mysql { + diesel::replace_into(twofactor::table) + .values(TwoFactorDb::to_db(self)) + .execute(conn) + .map_res("Error saving twofactor") + } + postgresql { + let value = TwoFactorDb::to_db(self); + // We need to make sure we're not going to violate the unique constraint on user_uuid and atype. + // This happens automatically on other DBMS backends due to replace_into(). PostgreSQL does + // not support multiple constraints on ON CONFLICT clauses. + diesel::delete(twofactor::table.filter(twofactor::user_uuid.eq(&self.user_uuid)).filter(twofactor::atype.eq(&self.atype))) + .execute(conn) + .map_res("Error deleting twofactor for insert")?; - #[cfg(not(feature = "postgresql"))] - pub fn save(&self, conn: &DbConn) -> EmptyResult { - diesel::replace_into(twofactor::table) - .values(self) - .execute(&**conn) - .map_res("Error saving twofactor") + diesel::insert_into(twofactor::table) + .values(&value) + .on_conflict(twofactor::uuid) + .do_update() + .set(&value) + .execute(conn) + .map_res("Error saving twofactor") + } + } } pub fn delete(self, conn: &DbConn) -> EmptyResult { - diesel::delete(twofactor::table.filter(twofactor::uuid.eq(self.uuid))) - .execute(&**conn) - .map_res("Error deleting twofactor") + db_run! { conn: { + diesel::delete(twofactor::table.filter(twofactor::uuid.eq(self.uuid))) + .execute(conn) + .map_res("Error deleting twofactor") + }} } pub fn find_by_user(user_uuid: &str, conn: &DbConn) -> Vec { - twofactor::table - .filter(twofactor::user_uuid.eq(user_uuid)) - .filter(twofactor::atype.lt(1000)) // Filter implementation types - .load::(&**conn) - .expect("Error loading twofactor") + db_run! { conn: { + twofactor::table + .filter(twofactor::user_uuid.eq(user_uuid)) + .filter(twofactor::atype.lt(1000)) // Filter implementation types + .load::(conn) + .expect("Error loading twofactor") + .from_db() + }} } pub fn find_by_user_and_type(user_uuid: &str, atype: i32, conn: &DbConn) -> Option { - twofactor::table - .filter(twofactor::user_uuid.eq(user_uuid)) - .filter(twofactor::atype.eq(atype)) - .first::(&**conn) - .ok() + db_run! { conn: { + twofactor::table + .filter(twofactor::user_uuid.eq(user_uuid)) + .filter(twofactor::atype.eq(atype)) + .first::(conn) + .ok() + .from_db() + }} } pub fn delete_all_by_user(user_uuid: &str, conn: &DbConn) -> EmptyResult { - diesel::delete(twofactor::table.filter(twofactor::user_uuid.eq(user_uuid))) - .execute(&**conn) - .map_res("Error deleting twofactors") + db_run! { conn: { + diesel::delete(twofactor::table.filter(twofactor::user_uuid.eq(user_uuid))) + .execute(conn) + .map_res("Error deleting twofactors") + }} } } diff --git a/src/db/models/user.rs b/src/db/models/user.rs index 556a3fdb..1965bf2b 100644 --- a/src/db/models/user.rs +++ b/src/db/models/user.rs @@ -4,43 +4,53 @@ use serde_json::Value; use crate::crypto; use crate::CONFIG; -#[derive(Debug, Identifiable, Queryable, Insertable, AsChangeset)] -#[table_name = "users"] -#[changeset_options(treat_none_as_null="true")] -#[primary_key(uuid)] -pub struct User { - pub uuid: String, - pub created_at: NaiveDateTime, - pub updated_at: NaiveDateTime, - pub verified_at: Option, - pub last_verifying_at: Option, - pub login_verify_count: i32, - - pub email: String, - pub email_new: Option, - pub email_new_token: Option, - pub name: String, - - pub password_hash: Vec, - pub salt: Vec, - pub password_iterations: i32, - pub password_hint: Option, - - pub akey: String, - pub private_key: Option, - pub public_key: Option, - - #[column_name = "totp_secret"] - _totp_secret: Option, - pub totp_recover: Option, - - pub security_stamp: String, - - pub equivalent_domains: String, - pub excluded_globals: String, - - pub client_kdf_type: i32, - pub client_kdf_iter: i32, +db_object! { + #[derive(Debug, Identifiable, Queryable, Insertable, AsChangeset)] + #[table_name = "users"] + #[changeset_options(treat_none_as_null="true")] + #[primary_key(uuid)] + pub struct User { + pub uuid: String, + pub created_at: NaiveDateTime, + pub updated_at: NaiveDateTime, + pub verified_at: Option, + pub last_verifying_at: Option, + pub login_verify_count: i32, + + pub email: String, + pub email_new: Option, + pub email_new_token: Option, + pub name: String, + + pub password_hash: Vec, + pub salt: Vec, + pub password_iterations: i32, + pub password_hint: Option, + + pub akey: String, + pub private_key: Option, + pub public_key: Option, + + #[column_name = "totp_secret"] // Note, this is only added to the UserDb structs, not to User + _totp_secret: Option, + pub totp_recover: Option, + + pub security_stamp: String, + + pub equivalent_domains: String, + pub excluded_globals: String, + + pub client_kdf_type: i32, + pub client_kdf_iter: i32, + } + + + #[derive(Debug, Identifiable, Queryable, Insertable)] + #[table_name = "invitations"] + #[primary_key(email)] + pub struct Invitation { + pub email: String, + } } enum UserStatus { @@ -119,9 +129,7 @@ impl User { } use super::{Cipher, Device, Folder, TwoFactor, UserOrgType, UserOrganization}; -use crate::db::schema::{invitations, users}; use crate::db::DbConn; -use diesel::prelude::*; use crate::api::EmptyResult; use crate::error::MapResult; @@ -158,7 +166,6 @@ impl User { }) } - #[cfg(feature = "postgresql")] pub fn save(&mut self, conn: &DbConn) -> EmptyResult { if self.email.trim().is_empty() { err!("User email can't be empty") @@ -166,49 +173,48 @@ impl User { self.updated_at = Utc::now().naive_utc(); - diesel::insert_into(users::table) // Insert or update - .values(&*self) - .on_conflict(users::uuid) - .do_update() - .set(&*self) - .execute(&**conn) - .map_res("Error saving user") - } - - #[cfg(not(feature = "postgresql"))] - pub fn save(&mut self, conn: &DbConn) -> EmptyResult { - if self.email.trim().is_empty() { - err!("User email can't be empty") + db_run! {conn: + sqlite, mysql { + diesel::replace_into(users::table) // Insert or update + .values(&UserDb::to_db(self)) + .execute(conn) + .map_res("Error saving user") + } + postgresql { + let value = UserDb::to_db(self); + diesel::insert_into(users::table) // Insert or update + .values(&value) + .on_conflict(users::uuid) + .do_update() + .set(&value) + .execute(conn) + .map_res("Error saving user") + } } - - self.updated_at = Utc::now().naive_utc(); - - diesel::replace_into(users::table) // Insert or update - .values(&*self) - .execute(&**conn) - .map_res("Error saving user") } pub fn delete(self, conn: &DbConn) -> EmptyResult { - for user_org in UserOrganization::find_by_user(&self.uuid, &*conn) { + for user_org in UserOrganization::find_by_user(&self.uuid, conn) { if user_org.atype == UserOrgType::Owner { let owner_type = UserOrgType::Owner as i32; - if UserOrganization::find_by_org_and_type(&user_org.org_uuid, owner_type, &conn).len() <= 1 { + if UserOrganization::find_by_org_and_type(&user_org.org_uuid, owner_type, conn).len() <= 1 { err!("Can't delete last owner") } } } - UserOrganization::delete_all_by_user(&self.uuid, &*conn)?; - Cipher::delete_all_by_user(&self.uuid, &*conn)?; - Folder::delete_all_by_user(&self.uuid, &*conn)?; - Device::delete_all_by_user(&self.uuid, &*conn)?; - TwoFactor::delete_all_by_user(&self.uuid, &*conn)?; - Invitation::take(&self.email, &*conn); // Delete invitation if any - - diesel::delete(users::table.filter(users::uuid.eq(self.uuid))) - .execute(&**conn) - .map_res("Error deleting user") + UserOrganization::delete_all_by_user(&self.uuid, conn)?; + Cipher::delete_all_by_user(&self.uuid, conn)?; + Folder::delete_all_by_user(&self.uuid, conn)?; + Device::delete_all_by_user(&self.uuid, conn)?; + TwoFactor::delete_all_by_user(&self.uuid, conn)?; + Invitation::take(&self.email, conn); // Delete invitation if any + + db_run! {conn: { + diesel::delete(users::table.filter(users::uuid.eq(self.uuid))) + .execute(conn) + .map_res("Error deleting user") + }} } pub fn update_uuid_revision(uuid: &str, conn: &DbConn) { @@ -220,15 +226,14 @@ impl User { pub fn update_all_revisions(conn: &DbConn) -> EmptyResult { let updated_at = Utc::now().naive_utc(); - crate::util::retry( - || { + db_run! {conn: { + crate::util::retry(|| { diesel::update(users::table) .set(users::updated_at.eq(updated_at)) - .execute(&**conn) - }, - 10, - ) - .map_res("Error updating revision date for all users") + .execute(conn) + }, 10) + .map_res("Error updating revision date for all users") + }} } pub fn update_revision(&mut self, conn: &DbConn) -> EmptyResult { @@ -238,84 +243,85 @@ impl User { } fn _update_revision(uuid: &str, date: &NaiveDateTime, conn: &DbConn) -> EmptyResult { - crate::util::retry( - || { + db_run! {conn: { + crate::util::retry(|| { diesel::update(users::table.filter(users::uuid.eq(uuid))) .set(users::updated_at.eq(date)) - .execute(&**conn) - }, - 10, - ) - .map_res("Error updating user revision") + .execute(conn) + }, 10) + .map_res("Error updating user revision") + }} } pub fn find_by_mail(mail: &str, conn: &DbConn) -> Option { let lower_mail = mail.to_lowercase(); - users::table - .filter(users::email.eq(lower_mail)) - .first::(&**conn) - .ok() + db_run! {conn: { + users::table + .filter(users::email.eq(lower_mail)) + .first::(conn) + .ok() + .from_db() + }} } pub fn find_by_uuid(uuid: &str, conn: &DbConn) -> Option { - users::table.filter(users::uuid.eq(uuid)).first::(&**conn).ok() + db_run! {conn: { + users::table.filter(users::uuid.eq(uuid)).first::(conn).ok().from_db() + }} } pub fn get_all(conn: &DbConn) -> Vec { - users::table.load::(&**conn).expect("Error loading users") + db_run! {conn: { + users::table.load::(conn).expect("Error loading users").from_db() + }} } } -#[derive(Debug, Identifiable, Queryable, Insertable)] -#[table_name = "invitations"] -#[primary_key(email)] -pub struct Invitation { - pub email: String, -} - impl Invitation { pub const fn new(email: String) -> Self { Self { email } } - #[cfg(feature = "postgresql")] pub fn save(&self, conn: &DbConn) -> EmptyResult { if self.email.trim().is_empty() { err!("Invitation email can't be empty") } - diesel::insert_into(invitations::table) - .values(self) - .on_conflict(invitations::email) - .do_nothing() - .execute(&**conn) - .map_res("Error saving invitation") - } - - #[cfg(not(feature = "postgresql"))] - pub fn save(&self, conn: &DbConn) -> EmptyResult { - if self.email.trim().is_empty() { - err!("Invitation email can't be empty") + db_run! {conn: + sqlite, mysql { + diesel::replace_into(invitations::table) + .values(InvitationDb::to_db(self)) + .execute(conn) + .map_res("Error saving invitation") + } + postgresql { + diesel::insert_into(invitations::table) + .values(InvitationDb::to_db(self)) + .on_conflict(invitations::email) + .do_nothing() + .execute(conn) + .map_res("Error saving invitation") + } } - - diesel::replace_into(invitations::table) - .values(self) - .execute(&**conn) - .map_res("Error saving invitation") } pub fn delete(self, conn: &DbConn) -> EmptyResult { - diesel::delete(invitations::table.filter(invitations::email.eq(self.email))) - .execute(&**conn) - .map_res("Error deleting invitation") + db_run! {conn: { + diesel::delete(invitations::table.filter(invitations::email.eq(self.email))) + .execute(conn) + .map_res("Error deleting invitation") + }} } pub fn find_by_mail(mail: &str, conn: &DbConn) -> Option { let lower_mail = mail.to_lowercase(); - invitations::table - .filter(invitations::email.eq(lower_mail)) - .first::(&**conn) - .ok() + db_run! {conn: { + invitations::table + .filter(invitations::email.eq(lower_mail)) + .first::(conn) + .ok() + .from_db() + }} } pub fn take(mail: &str, conn: &DbConn) -> bool { diff --git a/src/error.rs b/src/error.rs index 226b6843..23c5ae70 100644 --- a/src/error.rs +++ b/src/error.rs @@ -34,6 +34,7 @@ macro_rules! make_error { } use diesel::result::Error as DieselErr; +use diesel::r2d2::PoolError as R2d2Err; use handlebars::RenderError as HbErr; use jsonwebtoken::errors::Error as JWTErr; use regex::Error as RegexErr; @@ -66,6 +67,7 @@ make_error! { // Used for special return values, like 2FA errors JsonError(Value): _no_source, _serialize, DbError(DieselErr): _has_source, _api_error, + R2d2Error(R2d2Err): _has_source, _api_error, U2fError(U2fErr): _has_source, _api_error, SerdeError(SerdeErr): _has_source, _api_error, JWTError(JWTErr): _has_source, _api_error, diff --git a/src/main.rs b/src/main.rs index b9c8bcf3..a269c622 100644 --- a/src/main.rs +++ b/src/main.rs @@ -33,6 +33,7 @@ mod api; mod auth; mod config; mod crypto; +#[macro_use] mod db; mod mail; mod util; @@ -61,10 +62,8 @@ fn main() { _ => false, }; - check_db(); check_rsa_keys(); check_web_vault(); - migrations::run_migrations(); create_icon_cache_folder(); @@ -200,30 +199,6 @@ fn chain_syslog(logger: fern::Dispatch) -> fern::Dispatch { } } -fn check_db() { - if cfg!(feature = "sqlite") { - let url = CONFIG.database_url(); - let path = Path::new(&url); - - if let Some(parent) = path.parent() { - if create_dir_all(parent).is_err() { - error!("Error creating database directory"); - exit(1); - } - } - - // Turn on WAL in SQLite - if CONFIG.enable_db_wal() { - use diesel::RunQueryDsl; - let connection = db::get_connection().expect("Can't connect to DB"); - diesel::sql_query("PRAGMA journal_mode=wal") - .execute(&connection) - .expect("Failed to turn on WAL"); - } - } - db::get_connection().expect("Can't connect to DB"); -} - fn create_icon_cache_folder() { // Try to create the icon cache folder, and generate an error if it could not. create_dir_all(&CONFIG.icon_cache_folder()).expect("Error creating icon cache directory"); @@ -285,57 +260,22 @@ fn check_web_vault() { let index_path = Path::new(&CONFIG.web_vault_folder()).join("index.html"); if !index_path.exists() { - error!("Web vault is not found. To install it, please follow the steps in: "); + error!("Web vault is not found at '{}'. To install it, please follow the steps in: ", CONFIG.web_vault_folder()); error!("https://github.com/dani-garcia/bitwarden_rs/wiki/Building-binary#install-the-web-vault"); error!("You can also set the environment variable 'WEB_VAULT_ENABLED=false' to disable it"); exit(1); } } -// Embed the migrations from the migrations folder into the application -// This way, the program automatically migrates the database to the latest version -// https://docs.rs/diesel_migrations/*/diesel_migrations/macro.embed_migrations.html -#[allow(unused_imports)] -mod migrations { - - #[cfg(feature = "sqlite")] - embed_migrations!("migrations/sqlite"); - #[cfg(feature = "mysql")] - embed_migrations!("migrations/mysql"); - #[cfg(feature = "postgresql")] - embed_migrations!("migrations/postgresql"); - - pub fn run_migrations() { - // Make sure the database is up to date (create if it doesn't exist, or run the migrations) - let connection = crate::db::get_connection().expect("Can't connect to DB"); - - use std::io::stdout; - - // Disable Foreign Key Checks during migration - use diesel::RunQueryDsl; - - // FIXME: Per https://www.postgresql.org/docs/12/sql-set-constraints.html, - // "SET CONSTRAINTS sets the behavior of constraint checking within the - // current transaction", so this setting probably won't take effect for - // any of the migrations since it's being run outside of a transaction. - // Migrations that need to disable foreign key checks should run this - // from within the migration script itself. - #[cfg(feature = "postgres")] - diesel::sql_query("SET CONSTRAINTS ALL DEFERRED").execute(&connection).expect("Failed to disable Foreign Key Checks during migrations"); - - // Scoped to a connection/session. - #[cfg(feature = "mysql")] - diesel::sql_query("SET FOREIGN_KEY_CHECKS = 0").execute(&connection).expect("Failed to disable Foreign Key Checks during migrations"); - - // Scoped to a connection. - #[cfg(feature = "sqlite")] - diesel::sql_query("PRAGMA foreign_keys = OFF").execute(&connection).expect("Failed to disable Foreign Key Checks during migrations"); - - embedded_migrations::run_with_output(&connection, &mut stdout()).expect("Can't run migrations"); - } -} - fn launch_rocket(extra_debug: bool) { + let pool = match db::DbPool::from_config() { + Ok(p) => p, + Err(e) => { + error!("Error creating database pool: {:?}", e); + exit(1); + } + }; + let basepath = &CONFIG.domain_path(); // If adding more paths here, consider also adding them to @@ -347,7 +287,7 @@ fn launch_rocket(extra_debug: bool) { .mount(&[basepath, "/identity"].concat(), api::identity_routes()) .mount(&[basepath, "/icons"].concat(), api::icons_routes()) .mount(&[basepath, "/notifications"].concat(), api::notifications_routes()) - .manage(db::init_pool()) + .manage(pool) .manage(api::start_notification_server()) .attach(util::AppHeaders()) .attach(util::CORS())