some_auth/
jwt.rs

1use std::fmt;
2
3use chrono::{Duration, TimeDelta, Utc};
4use jsonwebtoken::{decode, encode, errors::ErrorKind, Algorithm, DecodingKey, EncodingKey, Header, TokenData, Validation};
5use serde::{Deserialize, Serialize};
6
7use crate::error::AuthError;
8
9/// Access-Refresh token pair
10#[derive(Serialize)]
11pub struct TokenPair {
12    pub access: String,
13    pub refresh: String
14}
15
16impl fmt::Debug for TokenPair {
17    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
18        f.debug_struct("TokenPair")
19            .field("access", &"***")
20            .field("refresh", &"***")
21            .finish()
22    }
23}
24
25/// Jwt settings for [`UserService`] configuration
26pub struct JwtTokenSettings {
27    pub access_tokens_secret: String,
28    pub access_tokens_lifetime: TimeDelta,
29    pub refresh_tokens_secret: String,
30    pub refresh_tokens_lifetime: TimeDelta
31}
32
33/// Token's claims
34#[derive(Debug, Serialize, Deserialize)]
35pub(crate) struct Claims {
36    pub(crate) sub: String,
37    pub(crate) exp: usize,
38    pub(crate) roles: Vec<String>
39}
40
41pub(crate) fn generate_token(user_id: i32, roles: &Vec<String>, alg: Algorithm, expiration: Duration, key: &[u8]) -> Result<String, AuthError> {
42    let exp = Utc::now()
43        .checked_add_signed(expiration)
44        .unwrap()
45        .timestamp() as usize;
46
47    let claims = Claims {
48        sub: user_id.to_string(),
49        exp,
50        roles: roles.iter().map(|s| s.to_string()).collect()
51    };
52
53    encode(&Header::new(alg), &claims, &EncodingKey::from_secret(key))
54        .map_err(|err| AuthError::Internal(format!("couldn't generate jwt: {err}")))
55}
56
57pub(crate) fn decode_token(token: &str, alg: Algorithm, key: &[u8]) -> Result<TokenData<Claims>, AuthError> {
58    const BEARER_START: &str = "Bearer ";
59    let token = token.strip_prefix(BEARER_START).unwrap_or(token);
60
61    decode::<Claims>(&token, &DecodingKey::from_secret(&key), &Validation::new(alg))
62        .map_err(|err| {
63            match err.kind() {                
64                ErrorKind::ExpiredSignature => AuthError::Unathorized,
65                _ => AuthError::InvalidCredentials,
66            }
67        })
68}
69
70#[cfg(test)]
71mod tests {
72    use super::*;
73
74    #[test]
75    fn generate_token_test() {
76        // Arrange
77        let key = "m4HsuPraSekretp455W00rd";
78        let user_id = 1;
79        let user_roles = vec!["admin".to_string(), "adm".to_string()];
80
81        // Act
82        let generate_token_res = generate_token(user_id, &user_roles, Algorithm::HS256, TimeDelta::seconds(10), key.as_bytes());
83
84        // Arrange
85        assert!(generate_token_res.is_ok());
86        assert_ne!("", generate_token_res.unwrap())
87    }
88
89    #[test]
90    fn decode_token_test() {
91        // Arrange
92        let key = "m4HsuPraSekretp455W00rd".as_bytes();
93        let user_id = 1;
94        let user_roles = vec!["admin".to_string(), "adm".to_string()];
95        let token = generate_token(user_id, &user_roles, Algorithm::HS256, TimeDelta::seconds(10), key).unwrap();
96
97        // Act
98        let decoded_token = decode_token(&token, Algorithm::HS256, key);
99        
100        // Arrange
101        assert!(decoded_token.is_ok());
102        let decoded_token = decoded_token.unwrap();
103        assert_eq!("1", decoded_token.claims.sub);
104        assert!(decoded_token.claims.roles.iter().any(|r| r == "admin"));
105        assert!(decoded_token.claims.roles.iter().any(|r| r == "adm"));
106    }
107
108    #[test]
109    fn decode_token_with_bearer_test() {
110        // Arrange
111        let key = "m4HsuPraSekretp455W00rd".as_bytes();
112        let user_id = 1;
113        let user_roles = vec!["admin".to_string(), "adm".to_string()];
114        let token = format!("Bearer {}", generate_token(user_id, &user_roles, Algorithm::HS256, TimeDelta::seconds(10), key).unwrap());
115
116        // Act
117        let decoded_token = decode_token(&token, Algorithm::HS256, key);
118        
119        // Arrange
120        assert!(decoded_token.is_ok());
121        let decoded_token = decoded_token.unwrap();
122        assert_eq!("1", decoded_token.claims.sub);
123        assert!(decoded_token.claims.roles.iter().any(|r| r == "admin"));
124        assert!(decoded_token.claims.roles.iter().any(|r| r == "adm"));
125    }
126
127    #[test]
128    fn decode_token_0_expired_token_0_invalid() {
129        // Arrange
130        let key = "m4HsuPraSekretp455W00rd".as_bytes();
131        let user_id = 1;
132        let token = generate_token(user_id, &vec![], Algorithm::HS256, TimeDelta::minutes(-2), key).unwrap();
133
134        // Act
135        let decoded_token = decode_token(&token, Algorithm::HS256, key);
136        
137        // Arrange
138        assert!(decoded_token.is_err());
139        assert!(decoded_token.unwrap_err().to_string().contains("Unathorized"))
140    }
141
142    #[test]
143    fn decode_token_0_spoofed_token_0_invalid() {
144        // Arrange
145        let key = "m4HsuPraSekretp455W00rd".as_bytes();
146        let user_id = 1;
147        let token = generate_token(user_id, &vec![], Algorithm::HS256, TimeDelta::seconds(10), key).unwrap();
148        // {"sub":"2","iat":1718955601}
149        let spoofed_part = "eyJzdWIiOiIyIiwiaWF0IjoxNzE4OTU1NjAxfQ";
150
151        // Act
152        let token_parts: Vec<_> = token.split('.').collect();
153        let spoofed_token = format!("{}.{}.{}", token_parts[0], spoofed_part, token_parts[2]);
154        let decoded_token = decode_token(&spoofed_token, Algorithm::HS256, key);
155
156        // Arrange
157        assert!(decoded_token.is_err());
158        assert!(decoded_token.unwrap_err().to_string().contains("Invalid credentials"))
159    }
160}