1use chrono::{DateTime, Duration, Utc};
35use jsonwebtoken::{
36 decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation,
37};
38use serde::{Deserialize, Serialize};
39use serde_json::Value;
40use std::collections::HashMap;
41
42use crate::error::{SaTokenError, SaTokenResult};
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
46pub enum JwtAlgorithm {
47 HS256,
49 HS384,
51 HS512,
53 RS256,
55 RS384,
57 RS512,
59 ES256,
61 ES384,
63}
64
65impl Default for JwtAlgorithm {
66 fn default() -> Self {
67 Self::HS256
68 }
69}
70
71impl From<JwtAlgorithm> for Algorithm {
72 fn from(alg: JwtAlgorithm) -> Self {
73 match alg {
74 JwtAlgorithm::HS256 => Algorithm::HS256,
75 JwtAlgorithm::HS384 => Algorithm::HS384,
76 JwtAlgorithm::HS512 => Algorithm::HS512,
77 JwtAlgorithm::RS256 => Algorithm::RS256,
78 JwtAlgorithm::RS384 => Algorithm::RS384,
79 JwtAlgorithm::RS512 => Algorithm::RS512,
80 JwtAlgorithm::ES256 => Algorithm::ES256,
81 JwtAlgorithm::ES384 => Algorithm::ES384,
82 }
83 }
84}
85
86#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct JwtClaims {
92 #[serde(rename = "sub")]
94 pub login_id: String,
95
96 #[serde(skip_serializing_if = "Option::is_none")]
98 pub iss: Option<String>,
99
100 #[serde(skip_serializing_if = "Option::is_none")]
102 pub aud: Option<String>,
103
104 #[serde(skip_serializing_if = "Option::is_none")]
106 pub exp: Option<i64>,
107
108 #[serde(skip_serializing_if = "Option::is_none")]
110 pub nbf: Option<i64>,
111
112 #[serde(skip_serializing_if = "Option::is_none")]
114 pub iat: Option<i64>,
115
116 #[serde(skip_serializing_if = "Option::is_none")]
118 pub jti: Option<String>,
119
120 #[serde(skip_serializing_if = "Option::is_none")]
124 pub login_type: Option<String>,
125
126 #[serde(skip_serializing_if = "Option::is_none")]
128 pub device: Option<String>,
129
130 #[serde(flatten)]
132 pub extra: HashMap<String, Value>,
133}
134
135impl JwtClaims {
136 pub fn new(login_id: impl Into<String>) -> Self {
142 let now = Utc::now().timestamp();
143 Self {
144 login_id: login_id.into(),
145 iss: None,
146 aud: None,
147 exp: None,
148 nbf: None,
149 iat: Some(now),
150 jti: None,
151 login_type: Some("default".to_string()),
152 device: None,
153 extra: HashMap::new(),
154 }
155 }
156
157 pub fn set_expiration(&mut self, seconds: i64) -> &mut Self {
163 let exp_time = Utc::now() + Duration::seconds(seconds);
164 self.exp = Some(exp_time.timestamp());
165 self
166 }
167
168 pub fn set_expiration_at(&mut self, datetime: DateTime<Utc>) -> &mut Self {
170 self.exp = Some(datetime.timestamp());
171 self
172 }
173
174 pub fn set_issuer(&mut self, issuer: impl Into<String>) -> &mut Self {
176 self.iss = Some(issuer.into());
177 self
178 }
179
180 pub fn set_audience(&mut self, audience: impl Into<String>) -> &mut Self {
182 self.aud = Some(audience.into());
183 self
184 }
185
186 pub fn set_jti(&mut self, jti: impl Into<String>) -> &mut Self {
188 self.jti = Some(jti.into());
189 self
190 }
191
192 pub fn set_login_type(&mut self, login_type: impl Into<String>) -> &mut Self {
194 self.login_type = Some(login_type.into());
195 self
196 }
197
198 pub fn set_device(&mut self, device: impl Into<String>) -> &mut Self {
200 self.device = Some(device.into());
201 self
202 }
203
204 pub fn add_claim(&mut self, key: impl Into<String>, value: Value) -> &mut Self {
206 self.extra.insert(key.into(), value);
207 self
208 }
209
210 pub fn get_claim(&self, key: &str) -> Option<&Value> {
212 self.extra.get(key)
213 }
214
215 pub fn is_expired(&self) -> bool {
217 if let Some(exp) = self.exp {
218 let now = Utc::now().timestamp();
219 now >= exp
220 } else {
221 false
222 }
223 }
224
225 pub fn remaining_time(&self) -> Option<i64> {
227 self.exp.map(|exp| {
228 let now = Utc::now().timestamp();
229 (exp - now).max(0)
230 })
231 }
232}
233
234#[derive(Clone)]
239pub struct JwtManager {
240 secret: String,
242
243 algorithm: JwtAlgorithm,
245
246 issuer: Option<String>,
248
249 audience: Option<String>,
251}
252
253impl JwtManager {
254 pub fn new(secret: impl Into<String>) -> Self {
260 Self {
261 secret: secret.into(),
262 algorithm: JwtAlgorithm::HS256,
263 issuer: None,
264 audience: None,
265 }
266 }
267
268 pub fn with_algorithm(secret: impl Into<String>, algorithm: JwtAlgorithm) -> Self {
270 Self {
271 secret: secret.into(),
272 algorithm,
273 issuer: None,
274 audience: None,
275 }
276 }
277
278 pub fn set_issuer(mut self, issuer: impl Into<String>) -> Self {
280 self.issuer = Some(issuer.into());
281 self
282 }
283
284 pub fn set_audience(mut self, audience: impl Into<String>) -> Self {
286 self.audience = Some(audience.into());
287 self
288 }
289
290 pub fn generate(&self, claims: &JwtClaims) -> SaTokenResult<String> {
300 let mut final_claims = claims.clone();
301
302 if self.issuer.is_some() && final_claims.iss.is_none() {
305 final_claims.iss = self.issuer.clone();
306 }
307 if self.audience.is_some() && final_claims.aud.is_none() {
308 final_claims.aud = self.audience.clone();
309 }
310
311 let header = Header::new(self.algorithm.into());
312 let encoding_key = EncodingKey::from_secret(self.secret.as_bytes());
313
314 encode(&header, &final_claims, &encoding_key).map_err(|e| {
315 SaTokenError::InvalidToken(format!("Failed to generate JWT: {}", e))
316 })
317 }
318
319 pub fn validate(&self, token: &str) -> SaTokenResult<JwtClaims> {
329 let mut validation = Validation::new(self.algorithm.into());
330
331 validation.validate_exp = true;
333
334 validation.leeway = 0;
336
337 if let Some(ref iss) = self.issuer {
339 validation.set_issuer(&[iss]);
340 }
341 if let Some(ref aud) = self.audience {
342 validation.set_audience(&[aud]);
343 }
344
345 let decoding_key = DecodingKey::from_secret(self.secret.as_bytes());
346
347 let token_data = decode::<JwtClaims>(token, &decoding_key, &validation).map_err(|e| {
348 match e.kind() {
349 jsonwebtoken::errors::ErrorKind::ExpiredSignature => {
350 SaTokenError::TokenExpired
351 }
352 _ => SaTokenError::InvalidToken(format!("JWT validation failed: {}", e)),
353 }
354 })?;
355
356 Ok(token_data.claims)
357 }
358
359 pub fn decode_without_validation(&self, token: &str) -> SaTokenResult<JwtClaims> {
364 let mut validation = Validation::new(self.algorithm.into());
365 validation.insecure_disable_signature_validation();
366 validation.validate_exp = false;
367
368 let decoding_key = DecodingKey::from_secret(self.secret.as_bytes());
369
370 let token_data = decode::<JwtClaims>(token, &decoding_key, &validation).map_err(|e| {
371 SaTokenError::InvalidToken(format!("Failed to decode JWT: {}", e))
372 })?;
373
374 Ok(token_data.claims)
375 }
376
377 pub fn refresh(&self, token: &str, extend_seconds: i64) -> SaTokenResult<String> {
387 let mut claims = self.validate(token)?;
388
389 claims.set_expiration(extend_seconds);
391
392 claims.iat = Some(Utc::now().timestamp());
394
395 self.generate(&claims)
396 }
397
398 pub fn extract_login_id(&self, token: &str) -> SaTokenResult<String> {
403 let claims = self.decode_without_validation(token)?;
404 Ok(claims.login_id)
405 }
406}
407
408#[cfg(test)]
409mod tests {
410 use super::*;
411
412 #[test]
413 fn test_jwt_claims_creation() {
414 let mut claims = JwtClaims::new("user_123");
415 claims.set_expiration(3600);
416 claims.set_issuer("sa-token");
417 claims.add_claim("role", serde_json::json!("admin"));
418
419 assert_eq!(claims.login_id, "user_123");
420 assert!(claims.exp.is_some());
421 assert_eq!(claims.iss, Some("sa-token".to_string()));
422 assert_eq!(
423 claims.get_claim("role"),
424 Some(&serde_json::json!("admin"))
425 );
426 }
427
428 #[test]
429 fn test_jwt_generate_and_validate() {
430 let jwt_manager = JwtManager::new("test-secret-key");
431
432 let mut claims = JwtClaims::new("user_123");
433 claims.set_expiration(3600);
434
435 let token = jwt_manager.generate(&claims).unwrap();
437 assert!(!token.is_empty());
438
439 let decoded = jwt_manager.validate(&token).unwrap();
441 assert_eq!(decoded.login_id, "user_123");
442 assert!(!decoded.is_expired());
443 }
444
445 #[test]
446 fn test_jwt_expired() {
447 let jwt_manager = JwtManager::new("test-secret-key");
448
449 let mut claims = JwtClaims::new("user_123");
450 let exp_time = Utc::now() - Duration::seconds(10);
453 claims.set_expiration_at(exp_time);
454
455 let token = jwt_manager.generate(&claims).unwrap();
456
457 let result = jwt_manager.validate(&token);
459 assert!(result.is_err());
460
461 match result {
463 Err(SaTokenError::TokenExpired) => {}, _ => panic!("Expected TokenExpired error"),
465 }
466 }
467
468 #[test]
469 fn test_jwt_refresh() {
470 let jwt_manager = JwtManager::new("test-secret-key");
471
472 let mut claims = JwtClaims::new("user_123");
473 claims.set_expiration(3600);
474
475 let original_token = jwt_manager.generate(&claims).unwrap();
476
477 let new_token = jwt_manager.refresh(&original_token, 7200).unwrap();
479 assert_ne!(original_token, new_token);
480
481 let decoded = jwt_manager.validate(&new_token).unwrap();
483 assert_eq!(decoded.login_id, "user_123");
484 }
485
486 #[test]
487 fn test_jwt_custom_claims() {
488 let jwt_manager = JwtManager::new("test-secret-key");
489
490 let mut claims = JwtClaims::new("user_123");
491 claims.set_expiration(3600);
492 claims.add_claim("role", serde_json::json!("admin"));
493 claims.add_claim("permissions", serde_json::json!(["read", "write"]));
494
495 let token = jwt_manager.generate(&claims).unwrap();
496 let decoded = jwt_manager.validate(&token).unwrap();
497
498 assert_eq!(decoded.get_claim("role"), Some(&serde_json::json!("admin")));
499 assert_eq!(
500 decoded.get_claim("permissions"),
501 Some(&serde_json::json!(["read", "write"]))
502 );
503 }
504
505 #[test]
506 fn test_extract_login_id() {
507 let jwt_manager = JwtManager::new("test-secret-key");
508
509 let mut claims = JwtClaims::new("user_123");
510 claims.set_expiration(3600);
511
512 let token = jwt_manager.generate(&claims).unwrap();
513 let login_id = jwt_manager.extract_login_id(&token).unwrap();
514
515 assert_eq!(login_id, "user_123");
516 }
517}
518