tide_jwt/
lib.rs

1use async_trait::async_trait;
2use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
3use serde::{de::DeserializeOwned, Serialize};
4use std::marker::PhantomData;
5use tide::{Middleware, Next, Request, Response, StatusCode};
6
7pub fn jwtsign<Claims: Serialize + DeserializeOwned + Send + Sync + 'static>(
8    claims: &Claims,
9    key: &EncodingKey,
10) -> Result<String, jsonwebtoken::errors::Error> {
11    encode(&Header::default(), claims, key)
12}
13
14pub fn jwtsign_secret<Claims: Serialize + DeserializeOwned + Send + Sync + 'static>(
15    claims: &Claims,
16    key: &str,
17) -> Result<String, jsonwebtoken::errors::Error> {
18    encode(
19        &Header::default(),
20        claims,
21        &EncodingKey::from_base64_secret(key)?,
22    )
23}
24
25pub fn jwtsign_with<Claims: Serialize + DeserializeOwned + Send + Sync + 'static>(
26    header: &Header,
27    claims: &Claims,
28    key: &EncodingKey,
29) -> Result<String, jsonwebtoken::errors::Error> {
30    encode(header, claims, key)
31}
32
33pub struct JwtAuthenticationDecoder<Claims: DeserializeOwned + Send + Sync + 'static> {
34    validation: Validation,
35    key: DecodingKey,
36    _claims: PhantomData<Claims>,
37}
38
39impl<Claims: DeserializeOwned + Send + Sync + 'static> JwtAuthenticationDecoder<Claims> {
40    pub fn default(key: DecodingKey) -> Self {
41        Self::new(Validation::default(), key)
42    }
43
44    pub fn new(validation: Validation, key: DecodingKey) -> Self {
45        Self {
46            validation,
47            key,
48            _claims: PhantomData::default(),
49        }
50    }
51}
52
53#[async_trait]
54impl<State, Claims> Middleware<State> for JwtAuthenticationDecoder<Claims>
55where
56    State: Clone + Send + Sync + 'static,
57    Claims: DeserializeOwned + Send + Sync + 'static,
58{
59    async fn handle(&self, mut req: Request<State>, next: Next<'_, State>) -> tide::Result {
60        let header = req.header("Authorization");
61        if header.is_none() {
62            return Ok(next.run(req).await);
63        }
64
65        let values: Vec<_> = header.unwrap().into_iter().collect();
66
67        if values.is_empty() {
68            return Ok(next.run(req).await);
69        }
70
71        if values.len() > 1 {
72            return Ok(Response::new(StatusCode::Unauthorized));
73        }
74
75        for value in values {
76            let value = value.as_str();
77            if !value.starts_with("Bearer") {
78                continue;
79            }
80
81            let token = &value["Bearer ".len()..];
82            println!("found authorization token: {token}");
83            let data = match decode::<Claims>(token, &self.key, &self.validation) {
84                Ok(c) => c,
85                Err(_) => {
86                    return Ok(Response::new(StatusCode::Unauthorized));
87                }
88            };
89
90            req.set_ext(data.claims);
91            break;
92        }
93
94        Ok(next.run(req).await)
95    }
96}