From 61515160a7688fa4ba2b7a7c5ddea3df2a1c1611 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Garc=C3=ADa?= Date: Thu, 14 Mar 2019 00:17:36 +0100 Subject: [PATCH] Allow changing error codes and create an empty error. Return 404 instead of 400 when no accounts breached. --- src/api/core/mod.rs | 17 +++++++++-------- src/error.rs | 24 +++++++++++++++++++++--- 2 files changed, 30 insertions(+), 11 deletions(-) diff --git a/src/api/core/mod.rs b/src/api/core/mod.rs index 2aab77b5..4abe31a4 100644 --- a/src/api/core/mod.rs +++ b/src/api/core/mod.rs @@ -33,10 +33,10 @@ use rocket::Route; use rocket_contrib::json::Json; use serde_json::Value; -use crate::db::DbConn; - use crate::api::{EmptyResult, JsonResult, JsonUpcase}; use crate::auth::Headers; +use crate::db::DbConn; +use crate::error::Error; #[put("/devices/identifier//clear-token")] fn clear_device_token(uuid: String) -> EmptyResult { @@ -137,12 +137,13 @@ fn hibp_breach(username: String) -> JsonResult { use reqwest::{header::USER_AGENT, Client}; - let value: Value = Client::new() - .get(&url) - .header(USER_AGENT, user_agent) - .send()? - .error_for_status()? - .json()?; + let res = Client::new().get(&url).header(USER_AGENT, user_agent).send()?; + + // If we get a 404, return a 404, it means no breached accounts + if res.status() == 404 { + return Err(Error::empty().with_code(404)); + } + let value: Value = res.error_for_status()?.json()?; Ok(Json(value)) } diff --git a/src/error.rs b/src/error.rs index 15014f1e..be8907e0 100644 --- a/src/error.rs +++ b/src/error.rs @@ -5,16 +5,18 @@ use std::error::Error as StdError; macro_rules! make_error { ( $( $name:ident ( $ty:ty ): $src_fn:expr, $usr_msg_fun:expr ),+ $(,)? ) => { + const BAD_REQUEST: u16 = 400; + #[derive(Display)] pub enum ErrorKind { $($name( $ty )),+ } - pub struct Error { message: String, error: ErrorKind } + pub struct Error { message: String, error: ErrorKind, error_code: u16 } $(impl From<$ty> for Error { fn from(err: $ty) -> Self { Error::from((stringify!($name), err)) } })+ $(impl> From<(S, $ty)> for Error { fn from(val: (S, $ty)) -> Self { - Error { message: val.0.into(), error: ErrorKind::$name(val.1) } + Error { message: val.0.into(), error: ErrorKind::$name(val.1), error_code: BAD_REQUEST } } })+ impl StdError for Error { @@ -43,12 +45,17 @@ use std::time::SystemTimeError as TimeErr; use u2f::u2ferror::U2fError as U2fErr; use yubico::yubicoerror::YubicoError as YubiErr; +#[derive(Display, Serialize)] +pub struct Empty {} + // Error struct // Contains a String error message, meant for the user and an enum variant, with an error of different types. // // After the variant itself, there are two expressions. The first one indicates whether the error contains a source error (that we pretty print). // The second one contains the function used to obtain the response sent to the client make_error! { + // Just an empty error + EmptyError(Empty): _no_source, _serialize, // Used to represent err! calls SimpleError(String): _no_source, _api_error, // Used for special return values, like 2FA errors @@ -80,10 +87,19 @@ impl Error { (usr_msg, log_msg.into()).into() } + pub fn empty() -> Self { + Empty {}.into() + } + pub fn with_msg>(mut self, msg: M) -> Self { self.message = msg.into(); self } + + pub fn with_code(mut self, code: u16) -> Self { + self.error_code = code; + self + } } pub trait MapResult { @@ -142,8 +158,10 @@ impl<'r> Responder<'r> for Error { let usr_msg = format!("{}", self); error!("{:#?}", self); + let code = Status::from_code(self.error_code).unwrap_or(Status::BadRequest); + Response::build() - .status(Status::BadRequest) + .status(code) .header(ContentType::JSON) .sized_body(Cursor::new(usr_msg)) .ok()