diff --git a/src/api/core/accounts.rs b/src/api/core/accounts.rs index 2efd6b88..514eb875 100644 --- a/src/api/core/accounts.rs +++ b/src/api/core/accounts.rs @@ -76,10 +76,8 @@ fn register(data: JsonUpcase, conn: DbConn) -> EmptyResult { Some(token) => token, None => err!("No valid invite token") }; - let claims: InviteJWTClaims = match decode_invite_jwt(&token) { - Ok(claims) => claims, - Err(msg) => err!("Invalid claim: {:#?}", msg), - }; + + let claims: InviteJWTClaims = decode_invite_jwt(&token)?; if &claims.email == &data.Email { user } else { diff --git a/src/api/core/organizations.rs b/src/api/core/organizations.rs index 0e15a880..9794cc91 100644 --- a/src/api/core/organizations.rs +++ b/src/api/core/organizations.rs @@ -522,10 +522,7 @@ fn accept_invite(_org_id: String, _org_user_id: String, data: JsonUpcase claims, - Err(msg) => err!("Invalid claim: {:#?}", msg), - }; + let claims: InviteJWTClaims = decode_invite_jwt(&token)?; match User::find_by_mail(&claims.email, &conn) { Some(_) => { diff --git a/src/auth.rs b/src/auth.rs index 03490f94..df85697f 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -7,6 +7,7 @@ use chrono::Duration; use jsonwebtoken::{self, Algorithm, Header}; use serde::ser::Serialize; +use crate::error::{Error, MapResult}; use crate::CONFIG; const JWT_ALGORITHM: Algorithm = Algorithm::RS256; @@ -31,11 +32,11 @@ lazy_static! { pub fn encode_jwt(claims: &T) -> String { match jsonwebtoken::encode(&JWT_HEADER, claims, &PRIVATE_RSA_KEY) { Ok(token) => token, - Err(e) => panic!("Error encoding jwt {}", e) + Err(e) => panic!("Error encoding jwt {}", e), } } -pub fn decode_jwt(token: &str) -> Result { +pub fn decode_jwt(token: &str) -> Result { let validation = jsonwebtoken::Validation { leeway: 30, // 30 seconds validate_exp: true, @@ -47,16 +48,12 @@ pub fn decode_jwt(token: &str) -> Result { algorithms: vec![JWT_ALGORITHM], }; - match jsonwebtoken::decode(token, &PUBLIC_RSA_KEY, &validation) { - Ok(decoded) => Ok(decoded.claims), - Err(msg) => { - error!("Error validating jwt - {:#?}", msg); - Err(msg.to_string()) - } - } + jsonwebtoken::decode(token, &PUBLIC_RSA_KEY, &validation) + .map(|d| d.claims) + .map_res("Error decoding login JWT") } -pub fn decode_invite_jwt(token: &str) -> Result { +pub fn decode_invite_jwt(token: &str) -> Result { let validation = jsonwebtoken::Validation { leeway: 30, // 30 seconds validate_exp: true, @@ -68,13 +65,9 @@ pub fn decode_invite_jwt(token: &str) -> Result { algorithms: vec![JWT_ALGORITHM], }; - match jsonwebtoken::decode(token, &PUBLIC_RSA_KEY, &validation) { - Ok(decoded) => Ok(decoded.claims), - Err(msg) => { - error!("Error validating jwt - {:#?}", msg); - Err(msg.to_string()) - } - } + jsonwebtoken::decode(token, &PUBLIC_RSA_KEY, &validation) + .map(|d| d.claims) + .map_res("Error decoding invite JWT") } #[derive(Debug, Serialize, Deserialize)] @@ -150,7 +143,7 @@ impl<'a, 'r> FromRequest<'a, 'r> for Headers { CONFIG.domain.clone() } else if let Some(referer) = headers.get_one("Referer") { referer.to_string() - } else { + } else { // Try to guess from the headers use std::env; @@ -185,7 +178,7 @@ impl<'a, 'r> FromRequest<'a, 'r> for Headers { // Check JWT token is valid and get device and user from it let claims: JWTClaims = match decode_jwt(access_token) { Ok(claims) => claims, - Err(_) => err_handler!("Invalid claim") + Err(_) => err_handler!("Invalid claim"), }; let device_uuid = claims.device; @@ -193,17 +186,17 @@ impl<'a, 'r> FromRequest<'a, 'r> for Headers { let conn = match request.guard::() { Outcome::Success(conn) => conn, - _ => err_handler!("Error getting DB") + _ => err_handler!("Error getting DB"), }; let device = match Device::find_by_uuid(&device_uuid, &conn) { Some(device) => device, - None => err_handler!("Invalid device id") + None => err_handler!("Invalid device id"), }; let user = match User::find_by_uuid(&user_uuid, &conn) { Some(user) => user, - None => err_handler!("Device has no user associated") + None => err_handler!("Device has no user associated"), }; if user.security_stamp != claims.sstamp { @@ -248,11 +241,11 @@ impl<'a, 'r> FromRequest<'a, 'r> for OrgHeaders { None => err_handler!("The current user isn't member of the organization") }; - Outcome::Success(Self{ + Outcome::Success(Self { host: headers.host, device: headers.device, user: headers.user, - org_user_type: { + org_user_type: { if let Some(org_usr_type) = UserOrgType::from_i32(org_user.type_) { org_usr_type } else { // This should only happen if the DB is corrupted @@ -260,7 +253,7 @@ impl<'a, 'r> FromRequest<'a, 'r> for OrgHeaders { } }, }) - }, + } _ => err_handler!("Error getting the organization id"), } } diff --git a/src/error.rs b/src/error.rs index 905bf63c..a42cccbf 100644 --- a/src/error.rs +++ b/src/error.rs @@ -44,14 +44,15 @@ macro_rules! make_error { }; } -use diesel::result::{Error as DieselError, QueryResult}; -use serde_json::{Value, Error as SerError}; +use diesel::result::Error as DieselError; +use jsonwebtoken::errors::Error as JwtError; +use serde_json::{Error as SerError, Value}; use u2f::u2ferror::U2fError as U2fErr; // Error struct // Each variant has two elements, the first is an error of different types, used for logging purposes // The second is a String, and it's contents are displayed to the user when the error occurs. Inside the macro, this is represented as _ -// +// // After the variant itself, there are two expressions. The first one is a bool to indicate whether the error cause will be printed to the log. // The second one contains the function used to obtain the response sent to the client make_error! { @@ -63,6 +64,7 @@ make_error! { DbError(DieselError, _): true, _api_error, U2fError(U2fErr, _): true, _api_error, SerdeError(SerError, _): true, _api_error, + JWTError(JwtError, _): true, _api_error, //WsError(ws::Error, _): true, _api_error, } @@ -73,19 +75,25 @@ impl Error { } pub trait MapResult { - fn map_res(self, msg: &str) -> Result<(), E>; + fn map_res(self, msg: &str) -> Result; +} + +impl> MapResult for Result { + fn map_res(self, msg: &str) -> Result { + self.map_err(Into::into).map_err(|e| e.with_msg(msg)) + } } -impl MapResult<(), Error> for QueryResult { +impl> MapResult<(), Error> for Result { fn map_res(self, msg: &str) -> Result<(), Error> { - self.and(Ok(())).map_err(Error::from).map_err(|e| e.with_msg(msg)) + self.and(Ok(())).map_res(msg) } } use serde::Serialize; use std::any::Any; -fn _serialize(e: &impl Serialize, _: &impl Any) -> String { +fn _serialize(e: &impl Serialize, _msg: &str) -> String { serde_json::to_string(e).unwrap() } @@ -102,7 +110,7 @@ fn _api_error(_: &impl Any, msg: &str) -> String { "Object": "error" }); - _serialize(&json, &false) + _serialize(&json, "") } //