@ -1,23 +1,11 @@
use std ::{
net ::{ IpAddr , SocketAddr } ,
sync ::Arc ,
time ::Duration ,
} ;
use std ::{ net ::IpAddr , sync ::Arc , time ::Duration } ;
use chrono ::{ NaiveDateTime , Utc } ;
use rmpv ::Value ;
use rocket ::{
futures ::{ SinkExt , StreamExt } ,
Route ,
} ;
use tokio ::{
net ::{ TcpListener , TcpStream } ,
sync ::mpsc ::Sender ,
} ;
use tokio_tungstenite ::{
accept_hdr_async ,
tungstenite ::{ handshake , Message } ,
} ;
use rocket ::{ futures ::StreamExt , Route } ;
use tokio ::sync ::mpsc ::Sender ;
use rocket_ws ::{ Message , WebSocket } ;
use crate ::{
auth ::{ ClientIp , WsAccessTokenHeader } ,
@ -30,7 +18,7 @@ use crate::{
use once_cell ::sync ::Lazy ;
static WS_USERS : Lazy < Arc < WebSocketUsers > > = Lazy ::new ( | | {
pub static WS_USERS : Lazy < Arc < WebSocketUsers > > = Lazy ::new ( | | {
Arc ::new ( WebSocketUsers {
map : Arc ::new ( dashmap ::DashMap ::new ( ) ) ,
} )
@ -47,8 +35,15 @@ use super::{
push_send_update , push_user_update ,
} ;
static NOTIFICATIONS_DISABLED : Lazy < bool > = Lazy ::new ( | | ! CONFIG . enable_websocket ( ) & & ! CONFIG . push_enabled ( ) ) ;
pub fn routes ( ) -> Vec < Route > {
if CONFIG . enable_websocket ( ) {
routes ! [ websockets_hub , anonymous_websockets_hub ]
} else {
info ! ( "WebSocket are disabled, realtime sync functionality will not work!" ) ;
routes ! [ ]
}
}
#[ derive(FromForm, Debug) ]
@ -108,7 +103,7 @@ impl Drop for WSAnonymousEntryMapGuard {
#[ get( " /hub?<data..> " ) ]
fn websockets_hub < ' r > (
ws : rocket_ws:: WebSocket,
ws : WebSocket,
data : WsAccessToken ,
ip : ClientIp ,
header_token : WsAccessTokenHeader ,
@ -192,11 +187,7 @@ fn websockets_hub<'r>(
}
#[ get( " /anonymous-hub?<token..> " ) ]
fn anonymous_websockets_hub < ' r > (
ws : rocket_ws ::WebSocket ,
token : String ,
ip : ClientIp ,
) -> Result < rocket_ws ::Stream ! [ ' r ] , Error > {
fn anonymous_websockets_hub < ' r > ( ws : WebSocket , token : String , ip : ClientIp ) -> Result < rocket_ws ::Stream ! [ ' r ] , Error > {
let addr = ip . ip ;
info ! ( "Accepting Anonymous Rocket WS connection from {addr}" ) ;
@ -349,13 +340,19 @@ impl WebSocketUsers {
// NOTE: The last modified date needs to be updated before calling these methods
pub async fn send_user_update ( & self , ut : UpdateType , user : & User ) {
// Skip any processing if both WebSockets and Push are not active
if * NOTIFICATIONS_DISABLED {
return ;
}
let data = create_update (
vec! [ ( "UserId" . into ( ) , user . uuid . clone ( ) . into ( ) ) , ( "Date" . into ( ) , serialize_date ( user . updated_at ) ) ] ,
ut ,
None ,
) ;
if CONFIG . enable_websocket ( ) {
self . send_update ( & user . uuid , & data ) . await ;
}
if CONFIG . push_enabled ( ) {
push_user_update ( ut , user ) ;
@ -363,13 +360,19 @@ impl WebSocketUsers {
}
pub async fn send_logout ( & self , user : & User , acting_device_uuid : Option < String > ) {
// Skip any processing if both WebSockets and Push are not active
if * NOTIFICATIONS_DISABLED {
return ;
}
let data = create_update (
vec! [ ( "UserId" . into ( ) , user . uuid . clone ( ) . into ( ) ) , ( "Date" . into ( ) , serialize_date ( user . updated_at ) ) ] ,
UpdateType ::LogOut ,
acting_device_uuid . clone ( ) ,
) ;
if CONFIG . enable_websocket ( ) {
self . send_update ( & user . uuid , & data ) . await ;
}
if CONFIG . push_enabled ( ) {
push_logout ( user , acting_device_uuid ) ;
@ -383,6 +386,10 @@ impl WebSocketUsers {
acting_device_uuid : & String ,
conn : & mut DbConn ,
) {
// Skip any processing if both WebSockets and Push are not active
if * NOTIFICATIONS_DISABLED {
return ;
}
let data = create_update (
vec! [
( "Id" . into ( ) , folder . uuid . clone ( ) . into ( ) ) ,
@ -393,7 +400,9 @@ impl WebSocketUsers {
Some ( acting_device_uuid . into ( ) ) ,
) ;
if CONFIG . enable_websocket ( ) {
self . send_update ( & folder . user_uuid , & data ) . await ;
}
if CONFIG . push_enabled ( ) {
push_folder_update ( ut , folder , acting_device_uuid , conn ) . await ;
@ -409,6 +418,10 @@ impl WebSocketUsers {
collection_uuids : Option < Vec < String > > ,
conn : & mut DbConn ,
) {
// Skip any processing if both WebSockets and Push are not active
if * NOTIFICATIONS_DISABLED {
return ;
}
let org_uuid = convert_option ( cipher . organization_uuid . clone ( ) ) ;
// Depending if there are collections provided or not, we need to have different values for the following variables.
// The user_uuid should be `null`, and the revision date should be set to now, else the clients won't sync the collection change.
@ -434,9 +447,11 @@ impl WebSocketUsers {
Some ( acting_device_uuid . into ( ) ) ,
) ;
if CONFIG . enable_websocket ( ) {
for uuid in user_uuids {
self . send_update ( uuid , & data ) . await ;
}
}
if CONFIG . push_enabled ( ) & & user_uuids . len ( ) = = 1 {
push_cipher_update ( ut , cipher , acting_device_uuid , conn ) . await ;
@ -451,6 +466,10 @@ impl WebSocketUsers {
acting_device_uuid : & String ,
conn : & mut DbConn ,
) {
// Skip any processing if both WebSockets and Push are not active
if * NOTIFICATIONS_DISABLED {
return ;
}
let user_uuid = convert_option ( send . user_uuid . clone ( ) ) ;
let data = create_update (
@ -463,9 +482,11 @@ impl WebSocketUsers {
None ,
) ;
if CONFIG . enable_websocket ( ) {
for uuid in user_uuids {
self . send_update ( uuid , & data ) . await ;
}
}
if CONFIG . push_enabled ( ) & & user_uuids . len ( ) = = 1 {
push_send_update ( ut , send , acting_device_uuid , conn ) . await ;
}
@ -478,12 +499,18 @@ impl WebSocketUsers {
acting_device_uuid : & String ,
conn : & mut DbConn ,
) {
// Skip any processing if both WebSockets and Push are not active
if * NOTIFICATIONS_DISABLED {
return ;
}
let data = create_update (
vec! [ ( "Id" . into ( ) , auth_request_uuid . clone ( ) . into ( ) ) , ( "UserId" . into ( ) , user_uuid . clone ( ) . into ( ) ) ] ,
UpdateType ::AuthRequest ,
Some ( acting_device_uuid . to_string ( ) ) ,
) ;
if CONFIG . enable_websocket ( ) {
self . send_update ( user_uuid , & data ) . await ;
}
if CONFIG . push_enabled ( ) {
push_auth_request ( user_uuid . to_string ( ) , auth_request_uuid . to_string ( ) , conn ) . await ;
@ -497,12 +524,18 @@ impl WebSocketUsers {
approving_device_uuid : String ,
conn : & mut DbConn ,
) {
// Skip any processing if both WebSockets and Push are not active
if * NOTIFICATIONS_DISABLED {
return ;
}
let data = create_update (
vec! [ ( "Id" . into ( ) , auth_response_uuid . to_owned ( ) . into ( ) ) , ( "UserId" . into ( ) , user_uuid . clone ( ) . into ( ) ) ] ,
UpdateType ::AuthRequestResponse ,
approving_device_uuid . clone ( ) . into ( ) ,
) ;
if CONFIG . enable_websocket ( ) {
self . send_update ( auth_response_uuid , & data ) . await ;
}
if CONFIG . push_enabled ( ) {
push_auth_response ( user_uuid . to_string ( ) , auth_response_uuid . to_string ( ) , approving_device_uuid , conn )
@ -526,6 +559,9 @@ impl AnonymousWebSocketSubscriptions {
}
pub async fn send_auth_response ( & self , user_uuid : & String , auth_response_uuid : & str ) {
if ! CONFIG . enable_websocket ( ) {
return ;
}
let data = create_anonymous_update (
vec! [ ( "Id" . into ( ) , auth_response_uuid . to_owned ( ) . into ( ) ) , ( "UserId" . into ( ) , user_uuid . clone ( ) . into ( ) ) ] ,
UpdateType ::AuthRequestResponse ,
@ -620,127 +656,3 @@ pub enum UpdateType {
pub type Notify < ' a > = & ' a rocket ::State < Arc < WebSocketUsers > > ;
pub type AnonymousNotify < ' a > = & ' a rocket ::State < Arc < AnonymousWebSocketSubscriptions > > ;
pub fn start_notification_server ( ) -> Arc < WebSocketUsers > {
let users = Arc ::clone ( & WS_USERS ) ;
if CONFIG . websocket_enabled ( ) {
let users2 = Arc ::< WebSocketUsers > ::clone ( & users ) ;
tokio ::spawn ( async move {
let addr = ( CONFIG . websocket_address ( ) , CONFIG . websocket_port ( ) ) ;
info ! ( "Starting WebSockets server on {}:{}" , addr . 0 , addr . 1 ) ;
let listener = TcpListener ::bind ( addr ) . await . expect ( "Can't listen on websocket port" ) ;
let ( shutdown_tx , mut shutdown_rx ) = tokio ::sync ::oneshot ::channel ::< ( ) > ( ) ;
CONFIG . set_ws_shutdown_handle ( shutdown_tx ) ;
loop {
tokio ::select ! {
Ok ( ( stream , addr ) ) = listener . accept ( ) = > {
tokio ::spawn ( handle_connection ( stream , Arc ::< WebSocketUsers > ::clone ( & users2 ) , addr ) ) ;
}
_ = & mut shutdown_rx = > {
break ;
}
}
}
info ! ( "Shutting down WebSockets server!" )
} ) ;
}
users
}
async fn handle_connection ( stream : TcpStream , users : Arc < WebSocketUsers > , addr : SocketAddr ) -> Result < ( ) , Error > {
let mut user_uuid : Option < String > = None ;
info ! ( "Accepting WS connection from {addr}" ) ;
// Accept connection, do initial handshake, validate auth token and get the user ID
use handshake ::server ::{ Request , Response } ;
let mut stream = accept_hdr_async ( stream , | req : & Request , res : Response | {
if let Some ( token ) = get_request_token ( req ) {
if let Ok ( claims ) = crate ::auth ::decode_login ( & token ) {
user_uuid = Some ( claims . sub ) ;
return Ok ( res ) ;
}
}
Err ( Response ::builder ( ) . status ( 401 ) . body ( None ) . unwrap ( ) )
} )
. await ? ;
let user_uuid = user_uuid . expect ( "User UUID should be set after the handshake" ) ;
let ( mut rx , guard ) = {
// Add a channel to send messages to this client to the map
let entry_uuid = uuid ::Uuid ::new_v4 ( ) ;
let ( tx , rx ) = tokio ::sync ::mpsc ::channel ::< Message > ( 100 ) ;
users . map . entry ( user_uuid . clone ( ) ) . or_default ( ) . push ( ( entry_uuid , tx ) ) ;
// Once the guard goes out of scope, the connection will have been closed and the entry will be deleted from the map
( rx , WSEntryMapGuard ::new ( users , user_uuid , entry_uuid , addr . ip ( ) ) )
} ;
let _guard = guard ;
let mut interval = tokio ::time ::interval ( Duration ::from_secs ( 15 ) ) ;
loop {
tokio ::select ! {
res = stream . next ( ) = > {
match res {
Some ( Ok ( message ) ) = > {
match message {
// Respond to any pings
Message ::Ping ( ping ) = > stream . send ( Message ::Pong ( ping ) ) . await ? ,
Message ::Pong ( _ ) = > { /* Ignored */ } ,
// We should receive an initial message with the protocol and version, and we will reply to it
Message ::Text ( ref message ) = > {
let msg = message . strip_suffix ( RECORD_SEPARATOR as char ) . unwrap_or ( message ) ;
if serde_json ::from_str ( msg ) . ok ( ) = = Some ( INITIAL_MESSAGE ) {
stream . send ( Message ::binary ( INITIAL_RESPONSE ) ) . await ? ;
continue ;
}
}
// Just echo anything else the client sends
_ = > stream . send ( message ) . await ? ,
}
}
_ = > break ,
}
}
res = rx . recv ( ) = > {
match res {
Some ( res ) = > stream . send ( res ) . await ? ,
None = > break ,
}
}
_ = interval . tick ( ) = > stream . send ( Message ::Ping ( create_ping ( ) ) ) . await ?
}
}
Ok ( ( ) )
}
fn get_request_token ( req : & handshake ::server ::Request ) -> Option < String > {
const ACCESS_TOKEN_KEY : & str = "access_token=" ;
if let Some ( Ok ( auth ) ) = req . headers ( ) . get ( "Authorization" ) . map ( | a | a . to_str ( ) ) {
if let Some ( token_part ) = auth . strip_prefix ( "Bearer " ) {
return Some ( token_part . to_owned ( ) ) ;
}
}
if let Some ( params ) = req . uri ( ) . query ( ) {
let params_iter = params . split ( '&' ) . take ( 1 ) ;
for val in params_iter {
if let Some ( stripped ) = val . strip_prefix ( ACCESS_TOKEN_KEY ) {
return Some ( stripped . to_owned ( ) ) ;
}
}
}
None
}