1use std::fmt;
2
3use chrono::{Duration, TimeDelta, Utc};
4use jsonwebtoken::{decode, encode, errors::ErrorKind, Algorithm, DecodingKey, EncodingKey, Header, TokenData, Validation};
5use serde::{Deserialize, Serialize};
6
7use crate::error::AuthError;
8
9#[derive(Serialize)]
11pub struct TokenPair {
12 pub access: String,
13 pub refresh: String
14}
15
16impl fmt::Debug for TokenPair {
17 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
18 f.debug_struct("TokenPair")
19 .field("access", &"***")
20 .field("refresh", &"***")
21 .finish()
22 }
23}
24
25pub struct JwtTokenSettings {
27 pub access_tokens_secret: String,
28 pub access_tokens_lifetime: TimeDelta,
29 pub refresh_tokens_secret: String,
30 pub refresh_tokens_lifetime: TimeDelta
31}
32
33#[derive(Debug, Serialize, Deserialize)]
35pub(crate) struct Claims {
36 pub(crate) sub: String,
37 pub(crate) exp: usize,
38 pub(crate) roles: Vec<String>
39}
40
41pub(crate) fn generate_token(user_id: i32, roles: &Vec<String>, alg: Algorithm, expiration: Duration, key: &[u8]) -> Result<String, AuthError> {
42 let exp = Utc::now()
43 .checked_add_signed(expiration)
44 .unwrap()
45 .timestamp() as usize;
46
47 let claims = Claims {
48 sub: user_id.to_string(),
49 exp,
50 roles: roles.iter().map(|s| s.to_string()).collect()
51 };
52
53 encode(&Header::new(alg), &claims, &EncodingKey::from_secret(key))
54 .map_err(|err| AuthError::Internal(format!("couldn't generate jwt: {err}")))
55}
56
57pub(crate) fn decode_token(token: &str, alg: Algorithm, key: &[u8]) -> Result<TokenData<Claims>, AuthError> {
58 const BEARER_START: &str = "Bearer ";
59 let token = token.strip_prefix(BEARER_START).unwrap_or(token);
60
61 decode::<Claims>(&token, &DecodingKey::from_secret(&key), &Validation::new(alg))
62 .map_err(|err| {
63 match err.kind() {
64 ErrorKind::ExpiredSignature => AuthError::Unathorized,
65 _ => AuthError::InvalidCredentials,
66 }
67 })
68}
69
70#[cfg(test)]
71mod tests {
72 use super::*;
73
74 #[test]
75 fn generate_token_test() {
76 let key = "m4HsuPraSekretp455W00rd";
78 let user_id = 1;
79 let user_roles = vec!["admin".to_string(), "adm".to_string()];
80
81 let generate_token_res = generate_token(user_id, &user_roles, Algorithm::HS256, TimeDelta::seconds(10), key.as_bytes());
83
84 assert!(generate_token_res.is_ok());
86 assert_ne!("", generate_token_res.unwrap())
87 }
88
89 #[test]
90 fn decode_token_test() {
91 let key = "m4HsuPraSekretp455W00rd".as_bytes();
93 let user_id = 1;
94 let user_roles = vec!["admin".to_string(), "adm".to_string()];
95 let token = generate_token(user_id, &user_roles, Algorithm::HS256, TimeDelta::seconds(10), key).unwrap();
96
97 let decoded_token = decode_token(&token, Algorithm::HS256, key);
99
100 assert!(decoded_token.is_ok());
102 let decoded_token = decoded_token.unwrap();
103 assert_eq!("1", decoded_token.claims.sub);
104 assert!(decoded_token.claims.roles.iter().any(|r| r == "admin"));
105 assert!(decoded_token.claims.roles.iter().any(|r| r == "adm"));
106 }
107
108 #[test]
109 fn decode_token_with_bearer_test() {
110 let key = "m4HsuPraSekretp455W00rd".as_bytes();
112 let user_id = 1;
113 let user_roles = vec!["admin".to_string(), "adm".to_string()];
114 let token = format!("Bearer {}", generate_token(user_id, &user_roles, Algorithm::HS256, TimeDelta::seconds(10), key).unwrap());
115
116 let decoded_token = decode_token(&token, Algorithm::HS256, key);
118
119 assert!(decoded_token.is_ok());
121 let decoded_token = decoded_token.unwrap();
122 assert_eq!("1", decoded_token.claims.sub);
123 assert!(decoded_token.claims.roles.iter().any(|r| r == "admin"));
124 assert!(decoded_token.claims.roles.iter().any(|r| r == "adm"));
125 }
126
127 #[test]
128 fn decode_token_0_expired_token_0_invalid() {
129 let key = "m4HsuPraSekretp455W00rd".as_bytes();
131 let user_id = 1;
132 let token = generate_token(user_id, &vec![], Algorithm::HS256, TimeDelta::minutes(-2), key).unwrap();
133
134 let decoded_token = decode_token(&token, Algorithm::HS256, key);
136
137 assert!(decoded_token.is_err());
139 assert!(decoded_token.unwrap_err().to_string().contains("Unathorized"))
140 }
141
142 #[test]
143 fn decode_token_0_spoofed_token_0_invalid() {
144 let key = "m4HsuPraSekretp455W00rd".as_bytes();
146 let user_id = 1;
147 let token = generate_token(user_id, &vec![], Algorithm::HS256, TimeDelta::seconds(10), key).unwrap();
148 let spoofed_part = "eyJzdWIiOiIyIiwiaWF0IjoxNzE4OTU1NjAxfQ";
150
151 let token_parts: Vec<_> = token.split('.').collect();
153 let spoofed_token = format!("{}.{}.{}", token_parts[0], spoofed_part, token_parts[2]);
154 let decoded_token = decode_token(&spoofed_token, Algorithm::HS256, key);
155
156 assert!(decoded_token.is_err());
158 assert!(decoded_token.unwrap_err().to_string().contains("Invalid credentials"))
159 }
160}