1use chrono::{DateTime, Duration, Utc};
35use jsonwebtoken::{
36 decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation,
37};
38use serde::{Deserialize, Serialize};
39use serde_json::Value;
40use std::collections::HashMap;
41
42use crate::error::{SaTokenError, SaTokenResult};
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
46#[derive(Default)]
47pub enum JwtAlgorithm {
48 #[default]
50 HS256,
51 HS384,
53 HS512,
55 RS256,
57 RS384,
59 RS512,
61 ES256,
63 ES384,
65}
66
67
68impl From<JwtAlgorithm> for Algorithm {
69 fn from(alg: JwtAlgorithm) -> Self {
70 match alg {
71 JwtAlgorithm::HS256 => Algorithm::HS256,
72 JwtAlgorithm::HS384 => Algorithm::HS384,
73 JwtAlgorithm::HS512 => Algorithm::HS512,
74 JwtAlgorithm::RS256 => Algorithm::RS256,
75 JwtAlgorithm::RS384 => Algorithm::RS384,
76 JwtAlgorithm::RS512 => Algorithm::RS512,
77 JwtAlgorithm::ES256 => Algorithm::ES256,
78 JwtAlgorithm::ES384 => Algorithm::ES384,
79 }
80 }
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct JwtClaims {
89 #[serde(rename = "sub")]
91 pub login_id: String,
92
93 #[serde(skip_serializing_if = "Option::is_none")]
95 pub iss: Option<String>,
96
97 #[serde(skip_serializing_if = "Option::is_none")]
99 pub aud: Option<String>,
100
101 #[serde(skip_serializing_if = "Option::is_none")]
103 pub exp: Option<i64>,
104
105 #[serde(skip_serializing_if = "Option::is_none")]
107 pub nbf: Option<i64>,
108
109 #[serde(skip_serializing_if = "Option::is_none")]
111 pub iat: Option<i64>,
112
113 #[serde(skip_serializing_if = "Option::is_none")]
115 pub jti: Option<String>,
116
117 #[serde(skip_serializing_if = "Option::is_none")]
121 pub login_type: Option<String>,
122
123 #[serde(skip_serializing_if = "Option::is_none")]
125 pub device: Option<String>,
126
127 #[serde(default)]
129 #[serde(skip_serializing_if = "HashMap::is_empty")]
130 pub extra: HashMap<String, Value>,
131}
132
133impl JwtClaims {
134 pub fn new(login_id: impl Into<String>) -> Self {
140 let now = Utc::now().timestamp();
141 Self {
142 login_id: login_id.into(),
143 iss: None,
144 aud: None,
145 exp: None,
146 nbf: None,
147 iat: Some(now),
148 jti: None,
149 login_type: Some("default".to_string()),
150 device: None,
151 extra: HashMap::new(),
152 }
153 }
154
155 pub fn set_expiration(&mut self, seconds: i64) -> &mut Self {
161 let exp_time = Utc::now() + Duration::seconds(seconds);
162 self.exp = Some(exp_time.timestamp());
163 self
164 }
165
166 pub fn set_expiration_at(&mut self, datetime: DateTime<Utc>) -> &mut Self {
168 self.exp = Some(datetime.timestamp());
169 self
170 }
171
172 pub fn set_issuer(&mut self, issuer: impl Into<String>) -> &mut Self {
174 self.iss = Some(issuer.into());
175 self
176 }
177
178 pub fn set_audience(&mut self, audience: impl Into<String>) -> &mut Self {
180 self.aud = Some(audience.into());
181 self
182 }
183
184 pub fn set_jti(&mut self, jti: impl Into<String>) -> &mut Self {
186 self.jti = Some(jti.into());
187 self
188 }
189
190 pub fn set_login_type(&mut self, login_type: impl Into<String>) -> &mut Self {
192 self.login_type = Some(login_type.into());
193 self
194 }
195
196 pub fn set_device(&mut self, device: impl Into<String>) -> &mut Self {
198 self.device = Some(device.into());
199 self
200 }
201
202 pub fn add_claim(&mut self, key: impl Into<String>, value: Value) -> &mut Self {
204 self.extra.insert(key.into(), value);
205 self
206 }
207
208 pub fn get_claim(&self, key: &str) -> Option<&Value> {
210 self.extra.get(key)
211 }
212
213 pub fn set_claims(&mut self, claims: HashMap<String, Value>) -> &mut Self {
215 self.extra = claims;
216 self
217 }
218
219 pub fn get_claims(&self) -> &HashMap<String, Value> {
221 &self.extra
222 }
223
224 pub fn is_expired(&self) -> bool {
226 if let Some(exp) = self.exp {
227 let now = Utc::now().timestamp();
228 now >= exp
229 } else {
230 false
231 }
232 }
233
234 pub fn remaining_time(&self) -> Option<i64> {
236 self.exp.map(|exp| {
237 let now = Utc::now().timestamp();
238 (exp - now).max(0)
239 })
240 }
241}
242
243#[derive(Clone)]
248pub struct JwtManager {
249 secret: String,
251
252 algorithm: JwtAlgorithm,
254
255 issuer: Option<String>,
257
258 audience: Option<String>,
260}
261
262impl JwtManager {
263 pub fn new(secret: impl Into<String>) -> Self {
269 Self {
270 secret: secret.into(),
271 algorithm: JwtAlgorithm::HS256,
272 issuer: None,
273 audience: None,
274 }
275 }
276
277 pub fn with_algorithm(secret: impl Into<String>, algorithm: JwtAlgorithm) -> Self {
279 Self {
280 secret: secret.into(),
281 algorithm,
282 issuer: None,
283 audience: None,
284 }
285 }
286
287 pub fn set_issuer(mut self, issuer: impl Into<String>) -> Self {
289 self.issuer = Some(issuer.into());
290 self
291 }
292
293 pub fn set_audience(mut self, audience: impl Into<String>) -> Self {
295 self.audience = Some(audience.into());
296 self
297 }
298
299 pub fn generate(&self, claims: &JwtClaims) -> SaTokenResult<String> {
309 let mut final_claims = claims.clone();
310
311 if self.issuer.is_some() && final_claims.iss.is_none() {
314 final_claims.iss = self.issuer.clone();
315 }
316 if self.audience.is_some() && final_claims.aud.is_none() {
317 final_claims.aud = self.audience.clone();
318 }
319
320 let header = Header::new(self.algorithm.into());
321 let encoding_key = EncodingKey::from_secret(self.secret.as_bytes());
322
323 encode(&header, &final_claims, &encoding_key).map_err(|e| {
324 SaTokenError::InvalidToken(format!("Failed to generate JWT: {}", e))
325 })
326 }
327
328 pub fn validate(&self, token: &str) -> SaTokenResult<JwtClaims> {
338 let mut validation = Validation::new(self.algorithm.into());
339
340 validation.validate_exp = true;
342
343 validation.leeway = 0;
345
346 if let Some(ref iss) = self.issuer {
348 validation.set_issuer(&[iss]);
349 }
350 if let Some(ref aud) = self.audience {
351 validation.set_audience(&[aud]);
352 }
353
354 let decoding_key = DecodingKey::from_secret(self.secret.as_bytes());
355
356 let token_data = decode::<JwtClaims>(token, &decoding_key, &validation).map_err(|e| {
357 match e.kind() {
358 jsonwebtoken::errors::ErrorKind::ExpiredSignature => {
359 SaTokenError::TokenExpired
360 }
361 _ => SaTokenError::InvalidToken(format!("JWT validation failed: {}", e)),
362 }
363 })?;
364
365 Ok(token_data.claims)
366 }
367
368 pub fn decode_without_validation(&self, token: &str) -> SaTokenResult<JwtClaims> {
373 let mut validation = Validation::new(self.algorithm.into());
374 validation.insecure_disable_signature_validation();
375 validation.validate_exp = false;
376
377 let decoding_key = DecodingKey::from_secret(self.secret.as_bytes());
378
379 let token_data = decode::<JwtClaims>(token, &decoding_key, &validation).map_err(|e| {
380 SaTokenError::InvalidToken(format!("Failed to decode JWT: {}", e))
381 })?;
382
383 Ok(token_data.claims)
384 }
385
386 pub fn refresh(&self, token: &str, extend_seconds: i64) -> SaTokenResult<String> {
396 let mut claims = self.validate(token)?;
397
398 claims.set_expiration(extend_seconds);
400
401 claims.iat = Some(Utc::now().timestamp());
403
404 self.generate(&claims)
405 }
406
407 pub fn extract_login_id(&self, token: &str) -> SaTokenResult<String> {
412 let claims = self.decode_without_validation(token)?;
413 Ok(claims.login_id)
414 }
415}
416
417#[cfg(test)]
418mod tests {
419 use super::*;
420
421 #[test]
422 fn test_jwt_claims_creation() {
423 let mut claims = JwtClaims::new("user_123");
424 claims.set_expiration(3600);
425 claims.set_issuer("sa-token");
426 claims.add_claim("role", serde_json::json!("admin"));
427
428 assert_eq!(claims.login_id, "user_123");
429 assert!(claims.exp.is_some());
430 assert_eq!(claims.iss, Some("sa-token".to_string()));
431 assert_eq!(
432 claims.get_claim("role"),
433 Some(&serde_json::json!("admin"))
434 );
435 }
436
437 #[test]
438 fn test_jwt_generate_and_validate() {
439 let jwt_manager = JwtManager::new("test-secret-key");
440
441 let mut claims = JwtClaims::new("user_123");
442 claims.set_expiration(3600);
443
444 let token = jwt_manager.generate(&claims).unwrap();
446 assert!(!token.is_empty());
447
448 let decoded = jwt_manager.validate(&token).unwrap();
450 assert_eq!(decoded.login_id, "user_123");
451 assert!(!decoded.is_expired());
452 }
453
454 #[test]
455 fn test_jwt_expired() {
456 let jwt_manager = JwtManager::new("test-secret-key");
457
458 let mut claims = JwtClaims::new("user_123");
459 let exp_time = Utc::now() - Duration::seconds(10);
462 claims.set_expiration_at(exp_time);
463
464 let token = jwt_manager.generate(&claims).unwrap();
465
466 let result = jwt_manager.validate(&token);
468 assert!(result.is_err());
469
470 match result {
472 Err(SaTokenError::TokenExpired) => {}, _ => panic!("Expected TokenExpired error"),
474 }
475 }
476
477 #[test]
478 fn test_jwt_refresh() {
479 let jwt_manager = JwtManager::new("test-secret-key");
480
481 let mut claims = JwtClaims::new("user_123");
482 claims.set_expiration(3600);
483
484 let original_token = jwt_manager.generate(&claims).unwrap();
485
486 let new_token = jwt_manager.refresh(&original_token, 7200).unwrap();
488 assert_ne!(original_token, new_token);
489
490 let decoded = jwt_manager.validate(&new_token).unwrap();
492 assert_eq!(decoded.login_id, "user_123");
493 }
494
495 #[test]
496 fn test_jwt_custom_claims() {
497 let jwt_manager = JwtManager::new("test-secret-key");
498
499 let mut claims = JwtClaims::new("user_123");
500 claims.set_expiration(3600);
501 claims.add_claim("role", serde_json::json!("admin"));
502 claims.add_claim("permissions", serde_json::json!(["read", "write"]));
503
504 let token = jwt_manager.generate(&claims).unwrap();
505 let decoded = jwt_manager.validate(&token).unwrap();
506
507 assert_eq!(decoded.get_claim("role"), Some(&serde_json::json!("admin")));
508 assert_eq!(
509 decoded.get_claim("permissions"),
510 Some(&serde_json::json!(["read", "write"]))
511 );
512 }
513
514 #[test]
515 fn test_extract_login_id() {
516 let jwt_manager = JwtManager::new("test-secret-key");
517
518 let mut claims = JwtClaims::new("user_123");
519 claims.set_expiration(3600);
520
521 let token = jwt_manager.generate(&claims).unwrap();
522 let login_id = jwt_manager.extract_login_id(&token).unwrap();
523
524 assert_eq!(login_id, "user_123");
525 }
526}
527