1use chrono::{DateTime, Duration, Utc};
42use jsonwebtoken::{
43 decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation,
44};
45use serde::{Deserialize, Serialize};
46use serde_json::Value;
47use std::collections::HashMap;
48
49use crate::error::{SaTokenError, SaTokenResult};
50
51#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
53#[derive(Default)]
54pub enum JwtAlgorithm {
55 #[default]
57 HS256,
58 HS384,
60 HS512,
62 RS256,
64 RS384,
66 RS512,
68 ES256,
70 ES384,
72}
73
74
75impl From<JwtAlgorithm> for Algorithm {
76 fn from(alg: JwtAlgorithm) -> Self {
77 match alg {
78 JwtAlgorithm::HS256 => Algorithm::HS256,
79 JwtAlgorithm::HS384 => Algorithm::HS384,
80 JwtAlgorithm::HS512 => Algorithm::HS512,
81 JwtAlgorithm::RS256 => Algorithm::RS256,
82 JwtAlgorithm::RS384 => Algorithm::RS384,
83 JwtAlgorithm::RS512 => Algorithm::RS512,
84 JwtAlgorithm::ES256 => Algorithm::ES256,
85 JwtAlgorithm::ES384 => Algorithm::ES384,
86 }
87 }
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct JwtClaims {
96 #[serde(rename = "sub")]
98 pub login_id: String,
99
100 #[serde(skip_serializing_if = "Option::is_none")]
102 pub iss: Option<String>,
103
104 #[serde(skip_serializing_if = "Option::is_none")]
106 pub aud: Option<String>,
107
108 #[serde(skip_serializing_if = "Option::is_none")]
110 pub exp: Option<i64>,
111
112 #[serde(skip_serializing_if = "Option::is_none")]
114 pub nbf: Option<i64>,
115
116 #[serde(skip_serializing_if = "Option::is_none")]
118 pub iat: Option<i64>,
119
120 #[serde(skip_serializing_if = "Option::is_none")]
122 pub jti: Option<String>,
123
124 #[serde(skip_serializing_if = "Option::is_none")]
128 pub login_type: Option<String>,
129
130 #[serde(skip_serializing_if = "Option::is_none")]
132 pub device: Option<String>,
133
134 #[serde(default)]
136 #[serde(skip_serializing_if = "HashMap::is_empty")]
137 pub extra: HashMap<String, Value>,
138}
139
140impl JwtClaims {
141 pub fn new(login_id: impl Into<String>) -> Self {
147 let now = Utc::now().timestamp();
148 Self {
149 login_id: login_id.into(),
150 iss: None,
151 aud: None,
152 exp: None,
153 nbf: None,
154 iat: Some(now),
155 jti: None,
156 login_type: Some("default".to_string()),
157 device: None,
158 extra: HashMap::new(),
159 }
160 }
161
162 pub fn set_expiration(&mut self, seconds: i64) -> &mut Self {
168 let exp_time = Utc::now() + Duration::seconds(seconds);
169 self.exp = Some(exp_time.timestamp());
170 self
171 }
172
173 pub fn set_expiration_at(&mut self, datetime: DateTime<Utc>) -> &mut Self {
175 self.exp = Some(datetime.timestamp());
176 self
177 }
178
179 pub fn set_issuer(&mut self, issuer: impl Into<String>) -> &mut Self {
181 self.iss = Some(issuer.into());
182 self
183 }
184
185 pub fn set_audience(&mut self, audience: impl Into<String>) -> &mut Self {
187 self.aud = Some(audience.into());
188 self
189 }
190
191 pub fn set_jti(&mut self, jti: impl Into<String>) -> &mut Self {
193 self.jti = Some(jti.into());
194 self
195 }
196
197 pub fn set_login_type(&mut self, login_type: impl Into<String>) -> &mut Self {
199 self.login_type = Some(login_type.into());
200 self
201 }
202
203 pub fn set_device(&mut self, device: impl Into<String>) -> &mut Self {
205 self.device = Some(device.into());
206 self
207 }
208
209 pub fn add_claim(&mut self, key: impl Into<String>, value: Value) -> &mut Self {
211 self.extra.insert(key.into(), value);
212 self
213 }
214
215 pub fn get_claim(&self, key: &str) -> Option<&Value> {
217 self.extra.get(key)
218 }
219
220 pub fn set_claims(&mut self, claims: HashMap<String, Value>) -> &mut Self {
222 self.extra = claims;
223 self
224 }
225
226 pub fn get_claims(&self) -> &HashMap<String, Value> {
228 &self.extra
229 }
230
231 pub fn is_expired(&self) -> bool {
233 if let Some(exp) = self.exp {
234 let now = Utc::now().timestamp();
235 now >= exp
236 } else {
237 false
238 }
239 }
240
241 pub fn remaining_time(&self) -> Option<i64> {
243 self.exp.map(|exp| {
244 let now = Utc::now().timestamp();
245 (exp - now).max(0)
246 })
247 }
248}
249
250#[derive(Clone)]
255pub struct JwtManager {
256 secret: String,
258
259 algorithm: JwtAlgorithm,
261
262 issuer: Option<String>,
264
265 audience: Option<String>,
267}
268
269impl JwtManager {
270 pub fn new(secret: impl Into<String>) -> Self {
276 Self {
277 secret: secret.into(),
278 algorithm: JwtAlgorithm::HS256,
279 issuer: None,
280 audience: None,
281 }
282 }
283
284 pub fn with_algorithm(secret: impl Into<String>, algorithm: JwtAlgorithm) -> Self {
286 Self {
287 secret: secret.into(),
288 algorithm,
289 issuer: None,
290 audience: None,
291 }
292 }
293
294 pub fn set_issuer(mut self, issuer: impl Into<String>) -> Self {
296 self.issuer = Some(issuer.into());
297 self
298 }
299
300 pub fn set_audience(mut self, audience: impl Into<String>) -> Self {
302 self.audience = Some(audience.into());
303 self
304 }
305
306 pub fn generate(&self, claims: &JwtClaims) -> SaTokenResult<String> {
316 let mut final_claims = claims.clone();
317
318 if self.issuer.is_some() && final_claims.iss.is_none() {
321 final_claims.iss = self.issuer.clone();
322 }
323 if self.audience.is_some() && final_claims.aud.is_none() {
324 final_claims.aud = self.audience.clone();
325 }
326
327 let header = Header::new(self.algorithm.into());
328 let encoding_key = EncodingKey::from_secret(self.secret.as_bytes());
329
330 encode(&header, &final_claims, &encoding_key).map_err(|e| {
331 SaTokenError::InvalidToken(format!("Failed to generate JWT: {}", e))
332 })
333 }
334
335 pub fn validate(&self, token: &str) -> SaTokenResult<JwtClaims> {
345 let mut validation = Validation::new(self.algorithm.into());
346
347 validation.validate_exp = true;
349
350 validation.leeway = 0;
352
353 if let Some(ref iss) = self.issuer {
355 validation.set_issuer(&[iss]);
356 }
357 if let Some(ref aud) = self.audience {
358 validation.set_audience(&[aud]);
359 }
360
361 let decoding_key = DecodingKey::from_secret(self.secret.as_bytes());
362
363 let token_data = decode::<JwtClaims>(token, &decoding_key, &validation).map_err(|e| {
364 match e.kind() {
365 jsonwebtoken::errors::ErrorKind::ExpiredSignature => {
366 SaTokenError::TokenExpired
367 }
368 _ => SaTokenError::InvalidToken(format!("JWT validation failed: {}", e)),
369 }
370 })?;
371
372 Ok(token_data.claims)
373 }
374
375 pub fn decode_without_validation(&self, token: &str) -> SaTokenResult<JwtClaims> {
380 let token_data = jsonwebtoken::dangerous::insecure_decode::<JwtClaims>(token)
383 .map_err(|e| SaTokenError::InvalidToken(format!("Failed to decode JWT: {}", e)))?;
384
385 Ok(token_data.claims)
386 }
387
388 pub fn refresh(&self, token: &str, extend_seconds: i64) -> SaTokenResult<String> {
398 let mut claims = self.validate(token)?;
399
400 claims.set_expiration(extend_seconds);
402
403 claims.iat = Some(Utc::now().timestamp());
405
406 self.generate(&claims)
407 }
408
409 pub fn extract_login_id(&self, token: &str) -> SaTokenResult<String> {
414 let claims = self.decode_without_validation(token)?;
415 Ok(claims.login_id)
416 }
417}
418
419#[cfg(test)]
420mod tests {
421 use super::*;
422
423 #[test]
424 fn test_jwt_claims_creation() {
425 let mut claims = JwtClaims::new("user_123");
426 claims.set_expiration(3600);
427 claims.set_issuer("sa-token");
428 claims.add_claim("role", serde_json::json!("admin"));
429
430 assert_eq!(claims.login_id, "user_123");
431 assert!(claims.exp.is_some());
432 assert_eq!(claims.iss, Some("sa-token".to_string()));
433 assert_eq!(
434 claims.get_claim("role"),
435 Some(&serde_json::json!("admin"))
436 );
437 }
438
439 #[test]
440 fn test_jwt_generate_and_validate() {
441 let jwt_manager = JwtManager::new("test-secret-key");
442
443 let mut claims = JwtClaims::new("user_123");
444 claims.set_expiration(3600);
445
446 let token = jwt_manager.generate(&claims).unwrap();
448 assert!(!token.is_empty());
449
450 let decoded = jwt_manager.validate(&token).unwrap();
452 assert_eq!(decoded.login_id, "user_123");
453 assert!(!decoded.is_expired());
454 }
455
456 #[test]
457 fn test_jwt_expired() {
458 let jwt_manager = JwtManager::new("test-secret-key");
459
460 let mut claims = JwtClaims::new("user_123");
461 let exp_time = Utc::now() - Duration::seconds(10);
464 claims.set_expiration_at(exp_time);
465
466 let token = jwt_manager.generate(&claims).unwrap();
467
468 let result = jwt_manager.validate(&token);
470 assert!(result.is_err());
471
472 match result {
474 Err(SaTokenError::TokenExpired) => {}, _ => panic!("Expected TokenExpired error"),
476 }
477 }
478
479 #[test]
480 fn test_jwt_refresh() {
481 let jwt_manager = JwtManager::new("test-secret-key");
482
483 let mut claims = JwtClaims::new("user_123");
484 claims.set_expiration(3600);
485
486 let original_token = jwt_manager.generate(&claims).unwrap();
487
488 let new_token = jwt_manager.refresh(&original_token, 7200).unwrap();
490 assert_ne!(original_token, new_token);
491
492 let decoded = jwt_manager.validate(&new_token).unwrap();
494 assert_eq!(decoded.login_id, "user_123");
495 }
496
497 #[test]
498 fn test_jwt_custom_claims() {
499 let jwt_manager = JwtManager::new("test-secret-key");
500
501 let mut claims = JwtClaims::new("user_123");
502 claims.set_expiration(3600);
503 claims.add_claim("role", serde_json::json!("admin"));
504 claims.add_claim("permissions", serde_json::json!(["read", "write"]));
505
506 let token = jwt_manager.generate(&claims).unwrap();
507 let decoded = jwt_manager.validate(&token).unwrap();
508
509 assert_eq!(decoded.get_claim("role"), Some(&serde_json::json!("admin")));
510 assert_eq!(
511 decoded.get_claim("permissions"),
512 Some(&serde_json::json!(["read", "write"]))
513 );
514 }
515
516 #[test]
517 fn test_extract_login_id() {
518 let jwt_manager = JwtManager::new("test-secret-key");
519
520 let mut claims = JwtClaims::new("user_123");
521 claims.set_expiration(3600);
522
523 let token = jwt_manager.generate(&claims).unwrap();
524 let login_id = jwt_manager.extract_login_id(&token).unwrap();
525
526 assert_eq!(login_id, "user_123");
527 }
528}
529