1use chrono::{DateTime, Duration, Utc};
35use jsonwebtoken::{
36 decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation,
37};
38use serde::{Deserialize, Serialize};
39use serde_json::Value;
40use std::collections::HashMap;
41
42use crate::error::{SaTokenError, SaTokenResult};
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
46#[derive(Default)]
47pub enum JwtAlgorithm {
48 #[default]
50 HS256,
51 HS384,
53 HS512,
55 RS256,
57 RS384,
59 RS512,
61 ES256,
63 ES384,
65}
66
67
68impl From<JwtAlgorithm> for Algorithm {
69 fn from(alg: JwtAlgorithm) -> Self {
70 match alg {
71 JwtAlgorithm::HS256 => Algorithm::HS256,
72 JwtAlgorithm::HS384 => Algorithm::HS384,
73 JwtAlgorithm::HS512 => Algorithm::HS512,
74 JwtAlgorithm::RS256 => Algorithm::RS256,
75 JwtAlgorithm::RS384 => Algorithm::RS384,
76 JwtAlgorithm::RS512 => Algorithm::RS512,
77 JwtAlgorithm::ES256 => Algorithm::ES256,
78 JwtAlgorithm::ES384 => Algorithm::ES384,
79 }
80 }
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct JwtClaims {
89 #[serde(rename = "sub")]
91 pub login_id: String,
92
93 #[serde(skip_serializing_if = "Option::is_none")]
95 pub iss: Option<String>,
96
97 #[serde(skip_serializing_if = "Option::is_none")]
99 pub aud: Option<String>,
100
101 #[serde(skip_serializing_if = "Option::is_none")]
103 pub exp: Option<i64>,
104
105 #[serde(skip_serializing_if = "Option::is_none")]
107 pub nbf: Option<i64>,
108
109 #[serde(skip_serializing_if = "Option::is_none")]
111 pub iat: Option<i64>,
112
113 #[serde(skip_serializing_if = "Option::is_none")]
115 pub jti: Option<String>,
116
117 #[serde(skip_serializing_if = "Option::is_none")]
121 pub login_type: Option<String>,
122
123 #[serde(skip_serializing_if = "Option::is_none")]
125 pub device: Option<String>,
126
127 #[serde(default)]
129 #[serde(skip_serializing_if = "HashMap::is_empty")]
130 pub extra: HashMap<String, Value>,
131}
132
133impl JwtClaims {
134 pub fn new(login_id: impl Into<String>) -> Self {
140 let now = Utc::now().timestamp();
141 Self {
142 login_id: login_id.into(),
143 iss: None,
144 aud: None,
145 exp: None,
146 nbf: None,
147 iat: Some(now),
148 jti: None,
149 login_type: Some("default".to_string()),
150 device: None,
151 extra: HashMap::new(),
152 }
153 }
154
155 pub fn set_expiration(&mut self, seconds: i64) -> &mut Self {
161 let exp_time = Utc::now() + Duration::seconds(seconds);
162 self.exp = Some(exp_time.timestamp());
163 self
164 }
165
166 pub fn set_expiration_at(&mut self, datetime: DateTime<Utc>) -> &mut Self {
168 self.exp = Some(datetime.timestamp());
169 self
170 }
171
172 pub fn set_issuer(&mut self, issuer: impl Into<String>) -> &mut Self {
174 self.iss = Some(issuer.into());
175 self
176 }
177
178 pub fn set_audience(&mut self, audience: impl Into<String>) -> &mut Self {
180 self.aud = Some(audience.into());
181 self
182 }
183
184 pub fn set_jti(&mut self, jti: impl Into<String>) -> &mut Self {
186 self.jti = Some(jti.into());
187 self
188 }
189
190 pub fn set_login_type(&mut self, login_type: impl Into<String>) -> &mut Self {
192 self.login_type = Some(login_type.into());
193 self
194 }
195
196 pub fn set_device(&mut self, device: impl Into<String>) -> &mut Self {
198 self.device = Some(device.into());
199 self
200 }
201
202 pub fn add_claim(&mut self, key: impl Into<String>, value: Value) -> &mut Self {
204 self.extra.insert(key.into(), value);
205 self
206 }
207
208 pub fn get_claim(&self, key: &str) -> Option<&Value> {
210 self.extra.get(key)
211 }
212
213 pub fn set_claims(&mut self, claims: HashMap<String, Value>) -> &mut Self {
215 self.extra = claims;
216 self
217 }
218
219 pub fn get_claims(&self) -> &HashMap<String, Value> {
221 &self.extra
222 }
223
224 pub fn is_expired(&self) -> bool {
226 if let Some(exp) = self.exp {
227 let now = Utc::now().timestamp();
228 now >= exp
229 } else {
230 false
231 }
232 }
233
234 pub fn remaining_time(&self) -> Option<i64> {
236 self.exp.map(|exp| {
237 let now = Utc::now().timestamp();
238 (exp - now).max(0)
239 })
240 }
241}
242
243#[derive(Clone)]
248pub struct JwtManager {
249 secret: String,
251
252 algorithm: JwtAlgorithm,
254
255 issuer: Option<String>,
257
258 audience: Option<String>,
260}
261
262impl JwtManager {
263 pub fn new(secret: impl Into<String>) -> Self {
269 Self {
270 secret: secret.into(),
271 algorithm: JwtAlgorithm::HS256,
272 issuer: None,
273 audience: None,
274 }
275 }
276
277 pub fn with_algorithm(secret: impl Into<String>, algorithm: JwtAlgorithm) -> Self {
279 Self {
280 secret: secret.into(),
281 algorithm,
282 issuer: None,
283 audience: None,
284 }
285 }
286
287 pub fn set_issuer(mut self, issuer: impl Into<String>) -> Self {
289 self.issuer = Some(issuer.into());
290 self
291 }
292
293 pub fn set_audience(mut self, audience: impl Into<String>) -> Self {
295 self.audience = Some(audience.into());
296 self
297 }
298
299 pub fn generate(&self, claims: &JwtClaims) -> SaTokenResult<String> {
309 let mut final_claims = claims.clone();
310
311 if self.issuer.is_some() && final_claims.iss.is_none() {
314 final_claims.iss = self.issuer.clone();
315 }
316 if self.audience.is_some() && final_claims.aud.is_none() {
317 final_claims.aud = self.audience.clone();
318 }
319
320 let header = Header::new(self.algorithm.into());
321 let encoding_key = EncodingKey::from_secret(self.secret.as_bytes());
322
323 encode(&header, &final_claims, &encoding_key).map_err(|e| {
324 SaTokenError::InvalidToken(format!("Failed to generate JWT: {}", e))
325 })
326 }
327
328 pub fn validate(&self, token: &str) -> SaTokenResult<JwtClaims> {
338 let mut validation = Validation::new(self.algorithm.into());
339
340 validation.validate_exp = true;
342
343 validation.leeway = 0;
345
346 if let Some(ref iss) = self.issuer {
348 validation.set_issuer(&[iss]);
349 }
350 if let Some(ref aud) = self.audience {
351 validation.set_audience(&[aud]);
352 }
353
354 let decoding_key = DecodingKey::from_secret(self.secret.as_bytes());
355
356 let token_data = decode::<JwtClaims>(token, &decoding_key, &validation).map_err(|e| {
357 match e.kind() {
358 jsonwebtoken::errors::ErrorKind::ExpiredSignature => {
359 SaTokenError::TokenExpired
360 }
361 _ => SaTokenError::InvalidToken(format!("JWT validation failed: {}", e)),
362 }
363 })?;
364
365 Ok(token_data.claims)
366 }
367
368 pub fn decode_without_validation(&self, token: &str) -> SaTokenResult<JwtClaims> {
373 let token_data = jsonwebtoken::dangerous::insecure_decode::<JwtClaims>(token)
376 .map_err(|e| SaTokenError::InvalidToken(format!("Failed to decode JWT: {}", e)))?;
377
378 Ok(token_data.claims)
379 }
380
381 pub fn refresh(&self, token: &str, extend_seconds: i64) -> SaTokenResult<String> {
391 let mut claims = self.validate(token)?;
392
393 claims.set_expiration(extend_seconds);
395
396 claims.iat = Some(Utc::now().timestamp());
398
399 self.generate(&claims)
400 }
401
402 pub fn extract_login_id(&self, token: &str) -> SaTokenResult<String> {
407 let claims = self.decode_without_validation(token)?;
408 Ok(claims.login_id)
409 }
410}
411
412#[cfg(test)]
413mod tests {
414 use super::*;
415
416 #[test]
417 fn test_jwt_claims_creation() {
418 let mut claims = JwtClaims::new("user_123");
419 claims.set_expiration(3600);
420 claims.set_issuer("sa-token");
421 claims.add_claim("role", serde_json::json!("admin"));
422
423 assert_eq!(claims.login_id, "user_123");
424 assert!(claims.exp.is_some());
425 assert_eq!(claims.iss, Some("sa-token".to_string()));
426 assert_eq!(
427 claims.get_claim("role"),
428 Some(&serde_json::json!("admin"))
429 );
430 }
431
432 #[test]
433 fn test_jwt_generate_and_validate() {
434 let jwt_manager = JwtManager::new("test-secret-key");
435
436 let mut claims = JwtClaims::new("user_123");
437 claims.set_expiration(3600);
438
439 let token = jwt_manager.generate(&claims).unwrap();
441 assert!(!token.is_empty());
442
443 let decoded = jwt_manager.validate(&token).unwrap();
445 assert_eq!(decoded.login_id, "user_123");
446 assert!(!decoded.is_expired());
447 }
448
449 #[test]
450 fn test_jwt_expired() {
451 let jwt_manager = JwtManager::new("test-secret-key");
452
453 let mut claims = JwtClaims::new("user_123");
454 let exp_time = Utc::now() - Duration::seconds(10);
457 claims.set_expiration_at(exp_time);
458
459 let token = jwt_manager.generate(&claims).unwrap();
460
461 let result = jwt_manager.validate(&token);
463 assert!(result.is_err());
464
465 match result {
467 Err(SaTokenError::TokenExpired) => {}, _ => panic!("Expected TokenExpired error"),
469 }
470 }
471
472 #[test]
473 fn test_jwt_refresh() {
474 let jwt_manager = JwtManager::new("test-secret-key");
475
476 let mut claims = JwtClaims::new("user_123");
477 claims.set_expiration(3600);
478
479 let original_token = jwt_manager.generate(&claims).unwrap();
480
481 let new_token = jwt_manager.refresh(&original_token, 7200).unwrap();
483 assert_ne!(original_token, new_token);
484
485 let decoded = jwt_manager.validate(&new_token).unwrap();
487 assert_eq!(decoded.login_id, "user_123");
488 }
489
490 #[test]
491 fn test_jwt_custom_claims() {
492 let jwt_manager = JwtManager::new("test-secret-key");
493
494 let mut claims = JwtClaims::new("user_123");
495 claims.set_expiration(3600);
496 claims.add_claim("role", serde_json::json!("admin"));
497 claims.add_claim("permissions", serde_json::json!(["read", "write"]));
498
499 let token = jwt_manager.generate(&claims).unwrap();
500 let decoded = jwt_manager.validate(&token).unwrap();
501
502 assert_eq!(decoded.get_claim("role"), Some(&serde_json::json!("admin")));
503 assert_eq!(
504 decoded.get_claim("permissions"),
505 Some(&serde_json::json!(["read", "write"]))
506 );
507 }
508
509 #[test]
510 fn test_extract_login_id() {
511 let jwt_manager = JwtManager::new("test-secret-key");
512
513 let mut claims = JwtClaims::new("user_123");
514 claims.set_expiration(3600);
515
516 let token = jwt_manager.generate(&claims).unwrap();
517 let login_id = jwt_manager.extract_login_id(&token).unwrap();
518
519 assert_eq!(login_id, "user_123");
520 }
521}
522