1use chrono::{DateTime, Duration, Utc};
35use jsonwebtoken::{
36 decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation,
37};
38use serde::{Deserialize, Serialize};
39use serde_json::Value;
40use std::collections::HashMap;
41
42use crate::error::{SaTokenError, SaTokenResult};
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
46pub enum JwtAlgorithm {
47 HS256,
49 HS384,
51 HS512,
53 RS256,
55 RS384,
57 RS512,
59 ES256,
61 ES384,
63}
64
65impl Default for JwtAlgorithm {
66 fn default() -> Self {
67 Self::HS256
68 }
69}
70
71impl From<JwtAlgorithm> for Algorithm {
72 fn from(alg: JwtAlgorithm) -> Self {
73 match alg {
74 JwtAlgorithm::HS256 => Algorithm::HS256,
75 JwtAlgorithm::HS384 => Algorithm::HS384,
76 JwtAlgorithm::HS512 => Algorithm::HS512,
77 JwtAlgorithm::RS256 => Algorithm::RS256,
78 JwtAlgorithm::RS384 => Algorithm::RS384,
79 JwtAlgorithm::RS512 => Algorithm::RS512,
80 JwtAlgorithm::ES256 => Algorithm::ES256,
81 JwtAlgorithm::ES384 => Algorithm::ES384,
82 }
83 }
84}
85
86#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct JwtClaims {
92 #[serde(rename = "sub")]
94 pub login_id: String,
95
96 #[serde(skip_serializing_if = "Option::is_none")]
98 pub iss: Option<String>,
99
100 #[serde(skip_serializing_if = "Option::is_none")]
102 pub aud: Option<String>,
103
104 #[serde(skip_serializing_if = "Option::is_none")]
106 pub exp: Option<i64>,
107
108 #[serde(skip_serializing_if = "Option::is_none")]
110 pub nbf: Option<i64>,
111
112 #[serde(skip_serializing_if = "Option::is_none")]
114 pub iat: Option<i64>,
115
116 #[serde(skip_serializing_if = "Option::is_none")]
118 pub jti: Option<String>,
119
120 #[serde(skip_serializing_if = "Option::is_none")]
124 pub login_type: Option<String>,
125
126 #[serde(skip_serializing_if = "Option::is_none")]
128 pub device: Option<String>,
129
130 #[serde(skip_serializing_if = "HashMap::is_empty")]
132 pub extra: HashMap<String, Value>,
133}
134
135impl JwtClaims {
136 pub fn new(login_id: impl Into<String>) -> Self {
142 let now = Utc::now().timestamp();
143 Self {
144 login_id: login_id.into(),
145 iss: None,
146 aud: None,
147 exp: None,
148 nbf: None,
149 iat: Some(now),
150 jti: None,
151 login_type: Some("default".to_string()),
152 device: None,
153 extra: HashMap::new(),
154 }
155 }
156
157 pub fn set_expiration(&mut self, seconds: i64) -> &mut Self {
163 let exp_time = Utc::now() + Duration::seconds(seconds);
164 self.exp = Some(exp_time.timestamp());
165 self
166 }
167
168 pub fn set_expiration_at(&mut self, datetime: DateTime<Utc>) -> &mut Self {
170 self.exp = Some(datetime.timestamp());
171 self
172 }
173
174 pub fn set_issuer(&mut self, issuer: impl Into<String>) -> &mut Self {
176 self.iss = Some(issuer.into());
177 self
178 }
179
180 pub fn set_audience(&mut self, audience: impl Into<String>) -> &mut Self {
182 self.aud = Some(audience.into());
183 self
184 }
185
186 pub fn set_jti(&mut self, jti: impl Into<String>) -> &mut Self {
188 self.jti = Some(jti.into());
189 self
190 }
191
192 pub fn set_login_type(&mut self, login_type: impl Into<String>) -> &mut Self {
194 self.login_type = Some(login_type.into());
195 self
196 }
197
198 pub fn set_device(&mut self, device: impl Into<String>) -> &mut Self {
200 self.device = Some(device.into());
201 self
202 }
203
204 pub fn add_claim(&mut self, key: impl Into<String>, value: Value) -> &mut Self {
206 self.extra.insert(key.into(), value);
207 self
208 }
209
210 pub fn get_claim(&self, key: &str) -> Option<&Value> {
212 self.extra.get(key)
213 }
214
215 pub fn set_claims(&mut self, claims: HashMap<String, Value>) -> &mut Self {
217 self.extra = claims;
218 self
219 }
220
221 pub fn get_claims(&self) -> &HashMap<String, Value> {
223 &self.extra
224 }
225
226 pub fn is_expired(&self) -> bool {
228 if let Some(exp) = self.exp {
229 let now = Utc::now().timestamp();
230 now >= exp
231 } else {
232 false
233 }
234 }
235
236 pub fn remaining_time(&self) -> Option<i64> {
238 self.exp.map(|exp| {
239 let now = Utc::now().timestamp();
240 (exp - now).max(0)
241 })
242 }
243}
244
245#[derive(Clone)]
250pub struct JwtManager {
251 secret: String,
253
254 algorithm: JwtAlgorithm,
256
257 issuer: Option<String>,
259
260 audience: Option<String>,
262}
263
264impl JwtManager {
265 pub fn new(secret: impl Into<String>) -> Self {
271 Self {
272 secret: secret.into(),
273 algorithm: JwtAlgorithm::HS256,
274 issuer: None,
275 audience: None,
276 }
277 }
278
279 pub fn with_algorithm(secret: impl Into<String>, algorithm: JwtAlgorithm) -> Self {
281 Self {
282 secret: secret.into(),
283 algorithm,
284 issuer: None,
285 audience: None,
286 }
287 }
288
289 pub fn set_issuer(mut self, issuer: impl Into<String>) -> Self {
291 self.issuer = Some(issuer.into());
292 self
293 }
294
295 pub fn set_audience(mut self, audience: impl Into<String>) -> Self {
297 self.audience = Some(audience.into());
298 self
299 }
300
301 pub fn generate(&self, claims: &JwtClaims) -> SaTokenResult<String> {
311 let mut final_claims = claims.clone();
312
313 if self.issuer.is_some() && final_claims.iss.is_none() {
316 final_claims.iss = self.issuer.clone();
317 }
318 if self.audience.is_some() && final_claims.aud.is_none() {
319 final_claims.aud = self.audience.clone();
320 }
321
322 let header = Header::new(self.algorithm.into());
323 let encoding_key = EncodingKey::from_secret(self.secret.as_bytes());
324
325 encode(&header, &final_claims, &encoding_key).map_err(|e| {
326 SaTokenError::InvalidToken(format!("Failed to generate JWT: {}", e))
327 })
328 }
329
330 pub fn validate(&self, token: &str) -> SaTokenResult<JwtClaims> {
340 let mut validation = Validation::new(self.algorithm.into());
341
342 validation.validate_exp = true;
344
345 validation.leeway = 0;
347
348 if let Some(ref iss) = self.issuer {
350 validation.set_issuer(&[iss]);
351 }
352 if let Some(ref aud) = self.audience {
353 validation.set_audience(&[aud]);
354 }
355
356 let decoding_key = DecodingKey::from_secret(self.secret.as_bytes());
357
358 let token_data = decode::<JwtClaims>(token, &decoding_key, &validation).map_err(|e| {
359 match e.kind() {
360 jsonwebtoken::errors::ErrorKind::ExpiredSignature => {
361 SaTokenError::TokenExpired
362 }
363 _ => SaTokenError::InvalidToken(format!("JWT validation failed: {}", e)),
364 }
365 })?;
366
367 Ok(token_data.claims)
368 }
369
370 pub fn decode_without_validation(&self, token: &str) -> SaTokenResult<JwtClaims> {
375 let mut validation = Validation::new(self.algorithm.into());
376 validation.insecure_disable_signature_validation();
377 validation.validate_exp = false;
378
379 let decoding_key = DecodingKey::from_secret(self.secret.as_bytes());
380
381 let token_data = decode::<JwtClaims>(token, &decoding_key, &validation).map_err(|e| {
382 SaTokenError::InvalidToken(format!("Failed to decode JWT: {}", e))
383 })?;
384
385 Ok(token_data.claims)
386 }
387
388 pub fn refresh(&self, token: &str, extend_seconds: i64) -> SaTokenResult<String> {
398 let mut claims = self.validate(token)?;
399
400 claims.set_expiration(extend_seconds);
402
403 claims.iat = Some(Utc::now().timestamp());
405
406 self.generate(&claims)
407 }
408
409 pub fn extract_login_id(&self, token: &str) -> SaTokenResult<String> {
414 let claims = self.decode_without_validation(token)?;
415 Ok(claims.login_id)
416 }
417}
418
419#[cfg(test)]
420mod tests {
421 use super::*;
422
423 #[test]
424 fn test_jwt_claims_creation() {
425 let mut claims = JwtClaims::new("user_123");
426 claims.set_expiration(3600);
427 claims.set_issuer("sa-token");
428 claims.add_claim("role", serde_json::json!("admin"));
429
430 assert_eq!(claims.login_id, "user_123");
431 assert!(claims.exp.is_some());
432 assert_eq!(claims.iss, Some("sa-token".to_string()));
433 assert_eq!(
434 claims.get_claim("role"),
435 Some(&serde_json::json!("admin"))
436 );
437 }
438
439 #[test]
440 fn test_jwt_generate_and_validate() {
441 let jwt_manager = JwtManager::new("test-secret-key");
442
443 let mut claims = JwtClaims::new("user_123");
444 claims.set_expiration(3600);
445
446 let token = jwt_manager.generate(&claims).unwrap();
448 assert!(!token.is_empty());
449
450 let decoded = jwt_manager.validate(&token).unwrap();
452 assert_eq!(decoded.login_id, "user_123");
453 assert!(!decoded.is_expired());
454 }
455
456 #[test]
457 fn test_jwt_expired() {
458 let jwt_manager = JwtManager::new("test-secret-key");
459
460 let mut claims = JwtClaims::new("user_123");
461 let exp_time = Utc::now() - Duration::seconds(10);
464 claims.set_expiration_at(exp_time);
465
466 let token = jwt_manager.generate(&claims).unwrap();
467
468 let result = jwt_manager.validate(&token);
470 assert!(result.is_err());
471
472 match result {
474 Err(SaTokenError::TokenExpired) => {}, _ => panic!("Expected TokenExpired error"),
476 }
477 }
478
479 #[test]
480 fn test_jwt_refresh() {
481 let jwt_manager = JwtManager::new("test-secret-key");
482
483 let mut claims = JwtClaims::new("user_123");
484 claims.set_expiration(3600);
485
486 let original_token = jwt_manager.generate(&claims).unwrap();
487
488 let new_token = jwt_manager.refresh(&original_token, 7200).unwrap();
490 assert_ne!(original_token, new_token);
491
492 let decoded = jwt_manager.validate(&new_token).unwrap();
494 assert_eq!(decoded.login_id, "user_123");
495 }
496
497 #[test]
498 fn test_jwt_custom_claims() {
499 let jwt_manager = JwtManager::new("test-secret-key");
500
501 let mut claims = JwtClaims::new("user_123");
502 claims.set_expiration(3600);
503 claims.add_claim("role", serde_json::json!("admin"));
504 claims.add_claim("permissions", serde_json::json!(["read", "write"]));
505
506 let token = jwt_manager.generate(&claims).unwrap();
507 let decoded = jwt_manager.validate(&token).unwrap();
508
509 assert_eq!(decoded.get_claim("role"), Some(&serde_json::json!("admin")));
510 assert_eq!(
511 decoded.get_claim("permissions"),
512 Some(&serde_json::json!(["read", "write"]))
513 );
514 }
515
516 #[test]
517 fn test_extract_login_id() {
518 let jwt_manager = JwtManager::new("test-secret-key");
519
520 let mut claims = JwtClaims::new("user_123");
521 claims.set_expiration(3600);
522
523 let token = jwt_manager.generate(&claims).unwrap();
524 let login_id = jwt_manager.extract_login_id(&token).unwrap();
525
526 assert_eq!(login_id, "user_123");
527 }
528}
529