@ -1,6 +1,10 @@
use std ::{ sync ::Arc , time ::Duration } ;
use diesel ::r2d2 ::{ ConnectionManager , Pool , PooledConnection } ;
use diesel ::{
connection ::SimpleConnection ,
r2d2 ::{ ConnectionManager , CustomizeConnection , Pool , PooledConnection } ,
} ;
use rocket ::{
http ::Status ,
outcome ::IntoOutcome ,
@ -62,6 +66,23 @@ macro_rules! generate_connections {
#[ allow(non_camel_case_types) ]
pub enum DbConnInner { $( #[ cfg($name) ] $name ( PooledConnection < ConnectionManager < $ty > > ) , ) + }
#[ derive(Debug) ]
pub struct DbConnOptions {
pub init_stmts : String ,
}
$( // Based on <https://stackoverflow.com/a/57717533>.
#[ cfg($name) ]
impl CustomizeConnection < $ty , diesel ::r2d2 ::Error > for DbConnOptions {
fn on_acquire ( & self , conn : & mut $ty ) -> Result < ( ) , diesel ::r2d2 ::Error > {
( | | {
if ! self . init_stmts . is_empty ( ) {
conn . batch_execute ( & self . init_stmts ) ? ;
}
Ok ( ( ) )
} ) ( ) . map_err ( diesel ::r2d2 ::Error ::QueryError )
}
} ) +
#[ derive(Clone) ]
pub struct DbPool {
@ -103,7 +124,8 @@ macro_rules! generate_connections {
}
impl DbPool {
// For the given database URL, guess it's type, run migrations create pool and return it
// For the given database URL, guess its type, run migrations, create pool, and return it
#[ allow(clippy::diverging_sub_expression) ]
pub fn from_config ( ) -> Result < Self , Error > {
let url = CONFIG . database_url ( ) ;
let conn_type = DbConnType ::from_url ( & url ) ? ;
@ -117,6 +139,9 @@ macro_rules! generate_connections {
let pool = Pool ::builder ( )
. max_size ( CONFIG . database_max_conns ( ) )
. connection_timeout ( Duration ::from_secs ( CONFIG . database_timeout ( ) ) )
. connection_customizer ( Box ::new ( DbConnOptions {
init_stmts : conn_type . get_init_stmts ( )
} ) )
. build ( manager )
. map_res ( "Failed to create pool" ) ? ;
return Ok ( DbPool {
@ -190,6 +215,23 @@ impl DbConnType {
err ! ( "`DATABASE_URL` looks like a SQLite URL, but 'sqlite' feature is not enabled" )
}
}
pub fn get_init_stmts ( & self ) -> String {
let init_stmts = CONFIG . database_conn_init ( ) ;
if ! init_stmts . is_empty ( ) {
init_stmts
} else {
self . default_init_stmts ( )
}
}
pub fn default_init_stmts ( & self ) -> String {
match self {
Self ::sqlite = > "PRAGMA busy_timeout = 5000; PRAGMA synchronous = NORMAL;" . to_string ( ) ,
Self ::mysql = > "" . to_string ( ) ,
Self ::postgresql = > "" . to_string ( ) ,
}
}
}
#[ macro_export ]