1pub mod error;
2
3use chrono::{Duration, Utc};
4use jsonwebtoken::{DecodingKey, EncodingKey, Header, TokenData, Validation, decode, encode};
5use serde::{Deserialize, Serialize};
6
7use crate::error::Error;
8pub type Result<T> = std::result::Result<T, Error>;
9
10#[derive(Debug, Deserialize)]
12pub struct JwtCfg {
13 pub access_secret: String,
14 pub refresh_secret: String,
15 pub audience: String,
16 pub access_token_duration: usize,
17 pub refresh_token_duration: usize,
18 pub access_key_validate_exp: bool,
19 pub refresh_key_validate_exp: bool,
20}
21
22#[derive(Debug, Serialize, Deserialize)]
24pub struct Claims {
25 pub aud: String,
26 pub sub: String,
27 pub exp: usize,
28 pub iat: usize,
29}
30
31impl Claims {
32 pub fn new(aud: String, sub: String, exp: usize, iat: usize) -> Self {
34 Self { aud, sub, exp, iat }
35 }
36}
37
38enum TokenKind {
40 Access,
41 Refesh,
42}
43
44#[derive(Clone)]
46pub struct Jwt {
47 header: Header,
48 encoding_access_key: EncodingKey,
49 encoding_refresh_key: EncodingKey,
50 decoding_access_key: DecodingKey,
51 decoding_refresh_key: DecodingKey,
52 validation_access_key: Validation,
53 validation_refresh_key: Validation,
54 aud: String,
55 access_token_duration: usize,
56 refresh_token_duration: usize,
57}
58
59impl Jwt {
60 pub fn new(cfg: JwtCfg) -> Self {
70 let header = Header::default();
71 let encoding_access_key = EncodingKey::from_secret(cfg.access_secret.as_bytes());
72 let encoding_refresh_key = EncodingKey::from_secret(cfg.refresh_secret.as_bytes());
73 let decoding_access_key = DecodingKey::from_secret(cfg.access_secret.as_bytes());
74 let decoding_refresh_key = DecodingKey::from_secret(cfg.refresh_secret.as_bytes());
75 let mut validation_access_key = Validation::default();
76 validation_access_key.set_audience(std::slice::from_ref(&cfg.audience));
77 let mut validation_refresh_key = validation_access_key.clone();
78 validation_access_key.validate_exp = cfg.access_key_validate_exp;
79 validation_refresh_key.validate_exp = cfg.refresh_key_validate_exp;
80 validation_refresh_key.required_spec_claims.clear();
81 Self {
82 header,
83 encoding_access_key,
84 encoding_refresh_key,
85 decoding_access_key,
86 decoding_refresh_key,
87 validation_access_key,
88 validation_refresh_key,
89 aud: cfg.audience,
90 access_token_duration: cfg.access_token_duration,
91 refresh_token_duration: cfg.refresh_token_duration,
92 }
93 }
94
95 pub fn generate_token_pair(&self, sub: String) -> Result<(String, String)> {
105 let access_token = self.generate_token(&TokenKind::Access, &sub)?;
106 let refresh_token = self.generate_token(&TokenKind::Refesh, &sub)?;
107 Ok((access_token, refresh_token))
108 }
109
110 pub fn generate_access_token(&self, sub: String) -> Result<String> {
120 self.generate_token(&TokenKind::Access, &sub)
121 }
122
123 pub fn refresh_access_token(&self, refresh_token: &str) -> Result<String> {
133 let claims = self.validate_refresh_token(refresh_token)?;
134 self.generate_access_token(claims.sub)
135 }
136
137 pub fn validate_access_token(&self, token: &str) -> Result<Claims> {
147 self.validate_token(&TokenKind::Access, token)
148 .map(|data| data.claims)
149 }
150
151 pub fn validate_refresh_token(&self, token: &str) -> Result<Claims> {
161 self.validate_token(&TokenKind::Refesh, token)
162 .map(|data| data.claims)
163 }
164
165 fn generate_token(&self, kind: &TokenKind, sub: &str) -> Result<String> {
176 let duration = self.get_token_duration(kind);
177 let (iat, exp) = self.generate_timestamps(duration);
178 let key = self.select_encoding_key(kind);
179 let claims = self.create_claims(sub, iat, exp);
180 encode(&self.header, &claims, key).map_err(|e| Error::AuthError(e.to_string().into()))
181 }
182
183 fn validate_token(&self, kind: &TokenKind, token: &str) -> Result<TokenData<Claims>> {
194 let (key, validation) = self.select_decoding_key_and_validation(kind);
195 decode::<Claims>(token, key, validation).map_err(|e| Error::AuthError(e.to_string().into()))
196 }
197
198 fn get_token_duration(&self, kind: &TokenKind) -> usize {
208 match kind {
209 TokenKind::Access => self.access_token_duration,
210 TokenKind::Refesh => self.refresh_token_duration,
211 }
212 }
213
214 fn generate_timestamps(&self, duration: usize) -> (usize, usize) {
224 generate_expired_time(duration)
225 }
226
227 fn select_encoding_key(&self, kind: &TokenKind) -> &EncodingKey {
237 match kind {
238 TokenKind::Access => &self.encoding_access_key,
239 TokenKind::Refesh => &self.encoding_refresh_key,
240 }
241 }
242
243 fn create_claims(&self, sub: &str, iat: usize, exp: usize) -> Claims {
255 Claims::new(self.aud.clone(), sub.to_string(), exp, iat)
256 }
257
258 fn select_decoding_key_and_validation(&self, kind: &TokenKind) -> (&DecodingKey, &Validation) {
268 match kind {
269 TokenKind::Access => (&self.decoding_access_key, &self.validation_access_key),
270 TokenKind::Refesh => (&self.decoding_refresh_key, &self.validation_refresh_key),
271 }
272 }
273}
274
275fn generate_expired_time(duration: usize) -> (usize, usize) {
285 let now = Utc::now();
286 let iat = now.timestamp() as usize;
287 let exp = (now + Duration::seconds(duration as i64)).timestamp() as usize;
288 (iat, exp)
289}
290
291#[cfg(test)]
292mod tests {
293 use super::*;
294
295 fn setup_jwt() -> Jwt {
301 Jwt::new(JwtCfg {
302 access_secret: "access_secret".to_string(),
303 refresh_secret: "refresh_secret".to_string(),
304 audience: "test_audience".to_string(),
305 access_token_duration: 3600, refresh_token_duration: 86400,
307 access_key_validate_exp: true,
308 refresh_key_validate_exp: true,
309 })
310 }
311
312 #[test]
313 fn test_generate_token_pair() {
314 let jwt = setup_jwt();
315 let (access_token, refresh_token) =
316 jwt.generate_token_pair("test_sub".to_string()).unwrap();
317
318 assert!(!access_token.is_empty());
319 assert!(!refresh_token.is_empty());
320 }
321
322 #[test]
323 fn test_generate_access_token() {
324 let jwt = setup_jwt();
325 let access_token = jwt.generate_access_token("test_sub".to_string()).unwrap();
326
327 assert!(!access_token.is_empty());
328 }
329
330 #[test]
331 fn test_validate_access_token() {
332 let jwt = setup_jwt();
333 let access_token = jwt.generate_access_token("test_sub".to_string()).unwrap();
334 let validation_result = jwt.validate_access_token(&access_token);
335
336 assert!(validation_result.is_ok());
337 let claims = validation_result.unwrap();
338 assert_eq!(claims.aud, "test_audience");
339 assert_eq!(claims.sub, "test_sub");
340 }
341
342 #[test]
343 fn test_validate_refresh_token() {
344 let jwt = setup_jwt();
345 let (_, refresh_token) = jwt.generate_token_pair("test_sub".to_string()).unwrap();
346 let validation_result = jwt.validate_refresh_token(&refresh_token);
347
348 assert!(validation_result.is_ok());
349 let claims = validation_result.unwrap();
350 assert_eq!(claims.aud, "test_audience");
351 assert_eq!(claims.sub, "test_sub");
352 }
353
354 #[test]
355 fn test_expired_access_token() {
356 use std::time::{Duration as StdDuration, SystemTime, UNIX_EPOCH};
357
358 let jwt = setup_jwt();
359 let iat = (SystemTime::now() - StdDuration::from_secs(7200))
361 .duration_since(UNIX_EPOCH)
362 .unwrap()
363 .as_secs() as usize;
364 let exp = (SystemTime::now() - StdDuration::from_secs(3600))
365 .duration_since(UNIX_EPOCH)
366 .unwrap()
367 .as_secs() as usize;
368 let claims = Claims::new(
369 "test_audience".to_string(),
370 "test_sub".to_string(),
371 exp,
372 iat,
373 );
374 let access_token = encode(
375 &Header::default(),
376 &claims,
377 &EncodingKey::from_secret("access_secret".as_ref()),
378 )
379 .unwrap();
380
381 let validation_result = jwt.validate_access_token(&access_token);
382
383 assert!(validation_result.is_err());
384 match validation_result.unwrap_err() {
385 Error::AuthError(_) => (),
386 _ => panic!("Expected AuthError"),
387 }
388 }
389
390 #[test]
391 fn test_invalid_access_token() {
392 let jwt = setup_jwt();
393 let invalid_token = "invalid_token";
394
395 let validation_result = jwt.validate_access_token(invalid_token);
396
397 assert!(validation_result.is_err());
398 match validation_result.unwrap_err() {
399 Error::AuthError(_) => (),
400 _ => panic!("Expected AuthError"),
401 }
402 }
403
404 #[test]
405 fn test_refresh_access_token() {
406 let jwt = setup_jwt();
407 let (_, refresh_token) = jwt.generate_token_pair("test_sub".to_string()).unwrap();
408
409 let new_access_token = jwt.refresh_access_token(&refresh_token).unwrap();
410
411 assert!(!new_access_token.is_empty());
412 }
413}