Jake Howard 4 years ago
parent 6b1daeba05
commit 49af9cf4f5
No known key found for this signature in database
GPG Key ID: 57AFB45680EDD477

@ -120,7 +120,7 @@ fn convert_option<T: Into<Value>>(option: Option<T>) -> Value {
} }
// Server WebSocket handler // Server WebSocket handler
pub struct WSHandler { pub struct WsHandler {
out: Sender, out: Sender,
user_uuid: Option<String>, user_uuid: Option<String>,
users: WebSocketUsers, users: WebSocketUsers,
@ -140,7 +140,7 @@ const PING: Token = Token(1);
const ACCESS_TOKEN_KEY: &str = "access_token="; const ACCESS_TOKEN_KEY: &str = "access_token=";
impl WSHandler { impl WsHandler {
fn err(&self, msg: &'static str) -> ws::Result<()> { fn err(&self, msg: &'static str) -> ws::Result<()> {
self.out.close(ws::CloseCode::Invalid)?; self.out.close(ws::CloseCode::Invalid)?;
@ -176,7 +176,7 @@ impl WSHandler {
} }
} }
impl Handler for WSHandler { impl Handler for WsHandler {
fn on_open(&mut self, hs: Handshake) -> ws::Result<()> { fn on_open(&mut self, hs: Handshake) -> ws::Result<()> {
// Path == "/notifications/hub?id=<id>==&access_token=<access_token>" // Path == "/notifications/hub?id=<id>==&access_token=<access_token>"
// //
@ -240,13 +240,13 @@ impl Handler for WSHandler {
} }
} }
struct WSFactory { struct WsFactory {
pub users: WebSocketUsers, pub users: WebSocketUsers,
} }
impl WSFactory { impl WsFactory {
pub fn init() -> Self { pub fn init() -> Self {
WSFactory { WsFactory {
users: WebSocketUsers { users: WebSocketUsers {
map: Arc::new(CHashMap::new()), map: Arc::new(CHashMap::new()),
}, },
@ -254,11 +254,11 @@ impl WSFactory {
} }
} }
impl Factory for WSFactory { impl Factory for WsFactory {
type Handler = WSHandler; type Handler = WsHandler;
fn connection_made(&mut self, out: Sender) -> Self::Handler { fn connection_made(&mut self, out: Sender) -> Self::Handler {
WSHandler { WsHandler {
out, out,
user_uuid: None, user_uuid: None,
users: self.users.clone(), users: self.users.clone(),
@ -405,7 +405,7 @@ use rocket::State;
pub type Notify<'a> = State<'a, WebSocketUsers>; pub type Notify<'a> = State<'a, WebSocketUsers>;
pub fn start_notification_server() -> WebSocketUsers { pub fn start_notification_server() -> WebSocketUsers {
let factory = WSFactory::init(); let factory = WsFactory::init();
let users = factory.users.clone(); let users = factory.users.clone();
if CONFIG.websocket_enabled() { if CONFIG.websocket_enabled() {

@ -58,28 +58,28 @@ fn decode_jwt<T: DeserializeOwned>(token: &str, issuer: String) -> Result<T, Err
.map_res("Error decoding JWT") .map_res("Error decoding JWT")
} }
pub fn decode_login(token: &str) -> Result<LoginJWTClaims, Error> { pub fn decode_login(token: &str) -> Result<LoginJwtClaims, Error> {
decode_jwt(token, JWT_LOGIN_ISSUER.to_string()) decode_jwt(token, JWT_LOGIN_ISSUER.to_string())
} }
pub fn decode_invite(token: &str) -> Result<InviteJWTClaims, Error> { pub fn decode_invite(token: &str) -> Result<InviteJwtClaims, Error> {
decode_jwt(token, JWT_INVITE_ISSUER.to_string()) decode_jwt(token, JWT_INVITE_ISSUER.to_string())
} }
pub fn decode_delete(token: &str) -> Result<DeleteJWTClaims, Error> { pub fn decode_delete(token: &str) -> Result<DeleteJwtClaims, Error> {
decode_jwt(token, JWT_DELETE_ISSUER.to_string()) decode_jwt(token, JWT_DELETE_ISSUER.to_string())
} }
pub fn decode_verify_email(token: &str) -> Result<VerifyEmailJWTClaims, Error> { pub fn decode_verify_email(token: &str) -> Result<VerifyEmailJwtClaims, Error> {
decode_jwt(token, JWT_VERIFYEMAIL_ISSUER.to_string()) decode_jwt(token, JWT_VERIFYEMAIL_ISSUER.to_string())
} }
pub fn decode_admin(token: &str) -> Result<AdminJWTClaims, Error> { pub fn decode_admin(token: &str) -> Result<AdminJwtClaims, Error> {
decode_jwt(token, JWT_ADMIN_ISSUER.to_string()) decode_jwt(token, JWT_ADMIN_ISSUER.to_string())
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub struct LoginJWTClaims { pub struct LoginJwtClaims {
// Not before // Not before
pub nbf: i64, pub nbf: i64,
// Expiration time // Expiration time
@ -110,7 +110,7 @@ pub struct LoginJWTClaims {
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub struct InviteJWTClaims { pub struct InviteJwtClaims {
// Not before // Not before
pub nbf: i64, pub nbf: i64,
// Expiration time // Expiration time
@ -132,9 +132,9 @@ pub fn generate_invite_claims(
org_id: Option<String>, org_id: Option<String>,
user_org_id: Option<String>, user_org_id: Option<String>,
invited_by_email: Option<String>, invited_by_email: Option<String>,
) -> InviteJWTClaims { ) -> InviteJwtClaims {
let time_now = Utc::now().naive_utc(); let time_now = Utc::now().naive_utc();
InviteJWTClaims { InviteJwtClaims {
nbf: time_now.timestamp(), nbf: time_now.timestamp(),
exp: (time_now + Duration::days(5)).timestamp(), exp: (time_now + Duration::days(5)).timestamp(),
iss: JWT_INVITE_ISSUER.to_string(), iss: JWT_INVITE_ISSUER.to_string(),
@ -147,7 +147,7 @@ pub fn generate_invite_claims(
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub struct DeleteJWTClaims { pub struct DeleteJwtClaims {
// Not before // Not before
pub nbf: i64, pub nbf: i64,
// Expiration time // Expiration time
@ -158,9 +158,9 @@ pub struct DeleteJWTClaims {
pub sub: String, pub sub: String,
} }
pub fn generate_delete_claims(uuid: String) -> DeleteJWTClaims { pub fn generate_delete_claims(uuid: String) -> DeleteJwtClaims {
let time_now = Utc::now().naive_utc(); let time_now = Utc::now().naive_utc();
DeleteJWTClaims { DeleteJwtClaims {
nbf: time_now.timestamp(), nbf: time_now.timestamp(),
exp: (time_now + Duration::days(5)).timestamp(), exp: (time_now + Duration::days(5)).timestamp(),
iss: JWT_DELETE_ISSUER.to_string(), iss: JWT_DELETE_ISSUER.to_string(),
@ -169,7 +169,7 @@ pub fn generate_delete_claims(uuid: String) -> DeleteJWTClaims {
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub struct VerifyEmailJWTClaims { pub struct VerifyEmailJwtClaims {
// Not before // Not before
pub nbf: i64, pub nbf: i64,
// Expiration time // Expiration time
@ -180,9 +180,9 @@ pub struct VerifyEmailJWTClaims {
pub sub: String, pub sub: String,
} }
pub fn generate_verify_email_claims(uuid: String) -> DeleteJWTClaims { pub fn generate_verify_email_claims(uuid: String) -> DeleteJwtClaims {
let time_now = Utc::now().naive_utc(); let time_now = Utc::now().naive_utc();
DeleteJWTClaims { DeleteJwtClaims {
nbf: time_now.timestamp(), nbf: time_now.timestamp(),
exp: (time_now + Duration::days(5)).timestamp(), exp: (time_now + Duration::days(5)).timestamp(),
iss: JWT_VERIFYEMAIL_ISSUER.to_string(), iss: JWT_VERIFYEMAIL_ISSUER.to_string(),
@ -191,7 +191,7 @@ pub fn generate_verify_email_claims(uuid: String) -> DeleteJWTClaims {
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub struct AdminJWTClaims { pub struct AdminJwtClaims {
// Not before // Not before
pub nbf: i64, pub nbf: i64,
// Expiration time // Expiration time
@ -202,9 +202,9 @@ pub struct AdminJWTClaims {
pub sub: String, pub sub: String,
} }
pub fn generate_admin_claims() -> AdminJWTClaims { pub fn generate_admin_claims() -> AdminJwtClaims {
let time_now = Utc::now().naive_utc(); let time_now = Utc::now().naive_utc();
AdminJWTClaims { AdminJwtClaims {
nbf: time_now.timestamp(), nbf: time_now.timestamp(),
exp: (time_now + Duration::minutes(20)).timestamp(), exp: (time_now + Duration::minutes(20)).timestamp(),
iss: JWT_ADMIN_ISSUER.to_string(), iss: JWT_ADMIN_ISSUER.to_string(),

@ -80,8 +80,8 @@ impl Device {
let orgmanager: Vec<_> = orgs.iter().filter(|o| o.atype == 3).map(|o| o.org_uuid.clone()).collect(); let orgmanager: Vec<_> = orgs.iter().filter(|o| o.atype == 3).map(|o| o.org_uuid.clone()).collect();
// Create the JWT claims struct, to send to the client // Create the JWT claims struct, to send to the client
use crate::auth::{encode_jwt, LoginJWTClaims, DEFAULT_VALIDITY, JWT_LOGIN_ISSUER}; use crate::auth::{encode_jwt, LoginJwtClaims, DEFAULT_VALIDITY, JWT_LOGIN_ISSUER};
let claims = LoginJWTClaims { let claims = LoginJwtClaims {
nbf: time_now.timestamp(), nbf: time_now.timestamp(),
exp: (time_now + *DEFAULT_VALIDITY).timestamp(), exp: (time_now + *DEFAULT_VALIDITY).timestamp(),
iss: JWT_LOGIN_ISSUER.to_string(), iss: JWT_LOGIN_ISSUER.to_string(),
@ -117,7 +117,7 @@ impl Device {
pub fn save(&mut self, conn: &DbConn) -> EmptyResult { pub fn save(&mut self, conn: &DbConn) -> EmptyResult {
self.updated_at = Utc::now().naive_utc(); self.updated_at = Utc::now().naive_utc();
db_run! { conn: db_run! { conn:
sqlite, mysql { sqlite, mysql {
crate::util::retry( crate::util::retry(
|| diesel::replace_into(devices::table).values(DeviceDb::to_db(self)).execute(conn), || diesel::replace_into(devices::table).values(DeviceDb::to_db(self)).execute(conn),

@ -38,11 +38,11 @@ use diesel::ConnectionError as DieselConErr;
use diesel_migrations::RunMigrationsError as DieselMigErr; use diesel_migrations::RunMigrationsError as DieselMigErr;
use diesel::r2d2::PoolError as R2d2Err; use diesel::r2d2::PoolError as R2d2Err;
use handlebars::RenderError as HbErr; use handlebars::RenderError as HbErr;
use jsonwebtoken::errors::Error as JWTErr; use jsonwebtoken::errors::Error as JwtErr;
use regex::Error as RegexErr; use regex::Error as RegexErr;
use reqwest::Error as ReqErr; use reqwest::Error as ReqErr;
use serde_json::{Error as SerdeErr, Value}; use serde_json::{Error as SerdeErr, Value};
use std::io::Error as IOErr; use std::io::Error as IoErr;
use std::time::SystemTimeError as TimeErr; use std::time::SystemTimeError as TimeErr;
use u2f::u2ferror::U2fError as U2fErr; use u2f::u2ferror::U2fError as U2fErr;
@ -72,10 +72,10 @@ make_error! {
R2d2Error(R2d2Err): _has_source, _api_error, R2d2Error(R2d2Err): _has_source, _api_error,
U2fError(U2fErr): _has_source, _api_error, U2fError(U2fErr): _has_source, _api_error,
SerdeError(SerdeErr): _has_source, _api_error, SerdeError(SerdeErr): _has_source, _api_error,
JWTError(JWTErr): _has_source, _api_error, JWtError(JwtErr): _has_source, _api_error,
TemplError(HbErr): _has_source, _api_error, TemplError(HbErr): _has_source, _api_error,
//WsError(ws::Error): _has_source, _api_error, //WsError(ws::Error): _has_source, _api_error,
IOError(IOErr): _has_source, _api_error, IoError(IoErr): _has_source, _api_error,
TimeError(TimeErr): _has_source, _api_error, TimeError(TimeErr): _has_source, _api_error,
ReqError(ReqErr): _has_source, _api_error, ReqError(ReqErr): _has_source, _api_error,
RegexError(RegexErr): _has_source, _api_error, RegexError(RegexErr): _has_source, _api_error,

@ -326,7 +326,7 @@ fn launch_rocket(extra_debug: bool) {
.manage(pool) .manage(pool)
.manage(api::start_notification_server()) .manage(api::start_notification_server())
.attach(util::AppHeaders()) .attach(util::AppHeaders())
.attach(util::CORS()) .attach(util::Cors())
.attach(util::BetterLogging(extra_debug)) .attach(util::BetterLogging(extra_debug))
.launch(); .launch();

@ -38,9 +38,9 @@ impl Fairing for AppHeaders {
} }
} }
pub struct CORS(); pub struct Cors();
impl CORS { impl Cors {
fn get_header(headers: &HeaderMap, name: &str) -> String { fn get_header(headers: &HeaderMap, name: &str) -> String {
match headers.get_one(name) { match headers.get_one(name) {
Some(h) => h.to_string(), Some(h) => h.to_string(),
@ -51,7 +51,7 @@ impl CORS {
// Check a request's `Origin` header against the list of allowed origins. // Check a request's `Origin` header against the list of allowed origins.
// If a match exists, return it. Otherwise, return None. // If a match exists, return it. Otherwise, return None.
fn get_allowed_origin(headers: &HeaderMap) -> Option<String> { fn get_allowed_origin(headers: &HeaderMap) -> Option<String> {
let origin = CORS::get_header(headers, "Origin"); let origin = Cors::get_header(headers, "Origin");
let domain_origin = CONFIG.domain_origin(); let domain_origin = CONFIG.domain_origin();
let safari_extension_origin = "file://"; let safari_extension_origin = "file://";
if origin == domain_origin || origin == safari_extension_origin { if origin == domain_origin || origin == safari_extension_origin {
@ -62,10 +62,10 @@ impl CORS {
} }
} }
impl Fairing for CORS { impl Fairing for Cors {
fn info(&self) -> Info { fn info(&self) -> Info {
Info { Info {
name: "CORS", name: "Cors",
kind: Kind::Response, kind: Kind::Response,
} }
} }
@ -73,14 +73,14 @@ impl Fairing for CORS {
fn on_response(&self, request: &Request, response: &mut Response) { fn on_response(&self, request: &Request, response: &mut Response) {
let req_headers = request.headers(); let req_headers = request.headers();
if let Some(origin) = CORS::get_allowed_origin(req_headers) { if let Some(origin) = Cors::get_allowed_origin(req_headers) {
response.set_header(Header::new("Access-Control-Allow-Origin", origin)); response.set_header(Header::new("Access-Control-Allow-Origin", origin));
} }
// Preflight request // Preflight request
if request.method() == Method::Options { if request.method() == Method::Options {
let req_allow_headers = CORS::get_header(req_headers, "Access-Control-Request-Headers"); let req_allow_headers = Cors::get_header(req_headers, "Access-Control-Request-Headers");
let req_allow_method = CORS::get_header(req_headers, "Access-Control-Request-Method"); let req_allow_method = Cors::get_header(req_headers, "Access-Control-Request-Method");
response.set_header(Header::new("Access-Control-Allow-Methods", req_allow_method)); response.set_header(Header::new("Access-Control-Allow-Methods", req_allow_method));
response.set_header(Header::new("Access-Control-Allow-Headers", req_allow_headers)); response.set_header(Header::new("Access-Control-Allow-Headers", req_allow_headers));

Loading…
Cancel
Save