1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
//! JSON Web Token Middleware

use std::{future::Future, marker::PhantomData, pin::Pin, fmt::Debug};

use viz_core::{
    http::{
        header::{HeaderValue, WWW_AUTHENTICATE},
        StatusCode,
    },
    Context, Middleware, Response, Result,
};

use viz_utils::tracing;

#[cfg(feature = "jwt-header")]
use viz_core::http::headers::{
    authorization::{Authorization, Bearer},
    HeaderMapExt,
};

#[cfg(feature = "jwt-query")]
#[cfg(not(all(feature = "jwt-header", feature = "jwt-param", feature = "jwt-cookie")))]
use std::collections::HashMap;
#[cfg(feature = "jwt-query")]
#[cfg(not(all(feature = "jwt-header", feature = "jwt-param", feature = "jwt-cookie")))]
use viz_core::types::QueryContextExt;

#[cfg(feature = "jwt-param")]
#[cfg(not(all(feature = "jwt-header", feature = "jwt-query", feature = "jwt-cookie")))]
use viz_core::types::ParamsContextExt;

#[cfg(feature = "jwt-cookie")]
#[cfg(not(all(feature = "jwt-header", feature = "jwt-query", feature = "jwt-param")))]
use viz_core::types::CookieContextExt;

use jsonwebtoken::{decode, DecodingKey, Validation};
use serde::de::DeserializeOwned;

pub use jsonwebtoken;

/// JWT Middleware
#[derive(Debug)]
pub struct JWTMiddleware<T>
where
    T: Debug
{
    #[cfg(not(feature = "jwt-header"))]
    #[cfg(any(feature = "jwt-query", feature = "jwt-param", feature = "jwt-cookie"))]
    n: String,
    s: String,
    v: Validation,
    t: PhantomData<T>,
}

impl<T> JWTMiddleware<T>
where
    T: DeserializeOwned + Sync + Send + 'static + Debug,
{
    /// Creates JWT
    pub fn new() -> Self {
        Self {
            #[cfg(not(feature = "jwt-header"))]
            #[cfg(any(feature = "jwt-query", feature = "jwt-param", feature = "jwt-cookie"))]
            n: "token".to_owned(),
            s: "secret".to_owned(),
            v: Validation::default(),
            t: PhantomData::default(),
        }
    }

    /// Creates JWT Middleware with a secret
    pub fn secret(mut self, secret: &str) -> Self {
        self.s = secret.to_owned();
        self
    }

    /// Creates JWT Middleware with an validation
    pub fn validation(mut self, validation: Validation) -> Self {
        self.v = validation;
        self
    }

    /// Creates JWT Middleware with a name
    #[cfg(not(feature = "jwt-header"))]
    #[cfg(any(feature = "jwt-query", feature = "jwt-param", feature = "jwt-cookie"))]
    pub fn name(mut self, name: &str) -> Self {
        self.n = name.to_owned();
        self
    }

    #[tracing::instrument(skip(cx))]
    async fn run(&self, cx: &mut Context) -> Result<Response> {
        let (status, error) = if let Some(val) = self.get(cx) {
            match decode::<T>(&val, &DecodingKey::from_secret(self.s.as_ref()), &self.v) {
                Ok(token) => {
                    cx.extensions_mut().insert(token);
                    return cx.next().await;
                }
                Err(e) => {
                    tracing::error!("JWT error: {}", e);
                    (StatusCode::UNAUTHORIZED, "Invalid or expired JWT")
                }
            }
        } else {
            (StatusCode::BAD_REQUEST, "Missing or malformed JWT")
        };

        let mut res: Response = status.into();
        res.headers_mut().insert(WWW_AUTHENTICATE, HeaderValue::from_str(error)?);
        Ok(res)
    }

    #[allow(unused_variables)]
    fn get(&self, cx: &mut Context) -> Option<String> {
        cfg_if::cfg_if! {
            if #[cfg(feature = "jwt-header")] {
                cx.headers()
                    .typed_get::<Authorization<Bearer>>()
                    .map(|auth| auth.0.token().to_owned())
            } else if #[cfg(feature = "jwt-query")] {
                cx.query::<HashMap<String, String>>()
                    .ok()?
                    .get(&self.n)
                    .cloned()
            } else if #[cfg(feature = "jwt-param")] {
                cx.param(&self.n).ok()
            }  else if #[cfg(feature = "jwt-cookie")] {
                cx.cookie(&self.n).map(|c| c.to_string())
            } else {
                None
            }
        }
    }
}

impl<'a, T> Middleware<'a, Context> for JWTMiddleware<T>
where
    T: DeserializeOwned + Sync + Send + 'static + Debug,
{
    type Output = Result<Response>;

    #[must_use]
    fn call(
        &'a self,
        cx: &'a mut Context,
    ) -> Pin<Box<dyn Future<Output = Self::Output> + Send + 'a>> {
        Box::pin(self.run(cx))
    }
}