Skip to main content

wae_authentication/jwt/
service.rs

1//! JWT 服务实现
2
3use crate::jwt::{
4    AccessTokenClaims, JwtAlgorithm, JwtClaims, JwtConfig, JwtHeader, RefreshTokenClaims, decode_jwt, encode_jwt,
5};
6use chrono::Utc;
7use serde::{Serialize, de::DeserializeOwned};
8use wae_types::WaeError;
9
10/// JWT 结果类型
11pub type JwtResult<T> = Result<T, WaeError>;
12
13/// JWT 服务
14#[derive(Debug, Clone)]
15pub struct JwtService {
16    config: JwtConfig,
17}
18
19impl JwtService {
20    /// 创建新的 JWT 服务
21    ///
22    /// # Arguments
23    /// * `config` - JWT 配置
24    pub fn new(config: JwtConfig) -> JwtResult<Self> {
25        Ok(Self { config })
26    }
27
28    /// 获取算法名称
29    fn algorithm_name(&self) -> &'static str {
30        match self.config.algorithm {
31            JwtAlgorithm::HS256 => "HS256",
32            JwtAlgorithm::HS384 => "HS384",
33            JwtAlgorithm::HS512 => "HS512",
34            _ => unimplemented!("Only HS256, HS384, HS512 are supported in custom implementation"),
35        }
36    }
37
38    /// 生成令牌
39    ///
40    /// # Arguments
41    /// * `claims` - JWT Claims
42    pub fn generate_token<T: Serialize>(&self, claims: &T) -> JwtResult<String> {
43        let header = JwtHeader::new(self.algorithm_name());
44        let secret = self.config.secret.as_bytes();
45        encode_jwt(&header, claims, secret).map_err(|e| e.into())
46    }
47
48    /// 验证令牌
49    ///
50    /// # Arguments
51    /// * `token` - JWT 令牌
52    pub fn verify_token<T: DeserializeOwned + 'static>(&self, token: &str) -> JwtResult<T> {
53        let secret = self.config.secret.as_bytes();
54        let claims: T = decode_jwt(token, secret, true)?;
55
56        let now = Utc::now().timestamp();
57        let leeway = self.config.leeway_seconds;
58
59        if let Some(validated_claims) = (&claims as &dyn std::any::Any).downcast_ref::<JwtClaims>() {
60            if validated_claims.exp + leeway < now {
61                return Err(WaeError::token_expired());
62            }
63            if let Some(nbf) = validated_claims.nbf {
64                if nbf - leeway > now {
65                    return Err(WaeError::token_not_valid_yet());
66                }
67            }
68            if self.config.validate_issuer {
69                if let Some(ref issuer) = self.config.issuer {
70                    if validated_claims.iss.as_ref() != Some(issuer) {
71                        return Err(WaeError::invalid_claim("issuer"));
72                    }
73                }
74            }
75            if self.config.validate_audience {
76                if let Some(ref audience) = self.config.audience {
77                    if validated_claims.aud.as_ref() != Some(audience) {
78                        return Err(WaeError::invalid_audience());
79                    }
80                }
81            }
82        }
83
84        Ok(claims)
85    }
86
87    /// 生成访问令牌
88    ///
89    /// # Arguments
90    /// * `user_id` - 用户 ID
91    /// * `claims` - 访问令牌 Claims
92    pub fn generate_access_token(&self, user_id: &str, claims: AccessTokenClaims) -> JwtResult<String> {
93        let mut jwt_claims = JwtClaims::new(user_id, self.config.access_token_expires_in());
94
95        if let Some(ref issuer) = self.config.issuer {
96            jwt_claims = jwt_claims.with_issuer(issuer.clone());
97        }
98
99        if let Some(ref audience) = self.config.audience {
100            jwt_claims = jwt_claims.with_audience(audience.clone());
101        }
102
103        jwt_claims.custom.insert("user_id".to_string(), serde_json::to_value(&claims.user_id).unwrap());
104
105        if let Some(ref username) = claims.username {
106            jwt_claims.custom.insert("username".to_string(), serde_json::to_value(username).unwrap());
107        }
108
109        if !claims.roles.is_empty() {
110            jwt_claims.custom.insert("roles".to_string(), serde_json::to_value(&claims.roles).unwrap());
111        }
112
113        if !claims.permissions.is_empty() {
114            jwt_claims.custom.insert("permissions".to_string(), serde_json::to_value(&claims.permissions).unwrap());
115        }
116
117        if let Some(ref session_id) = claims.session_id {
118            jwt_claims.custom.insert("session_id".to_string(), serde_json::to_value(session_id).unwrap());
119        }
120
121        self.generate_token(&jwt_claims)
122    }
123
124    /// 验证访问令牌
125    ///
126    /// # Arguments
127    /// * `token` - 访问令牌
128    pub fn verify_access_token(&self, token: &str) -> JwtResult<JwtClaims> {
129        self.verify_token(token)
130    }
131
132    /// 生成刷新令牌
133    ///
134    /// # Arguments
135    /// * `user_id` - 用户 ID
136    /// * `session_id` - 会话 ID
137    /// * `version` - 令牌版本
138    pub fn generate_refresh_token(&self, user_id: &str, session_id: &str, version: u32) -> JwtResult<String> {
139        let mut jwt_claims = JwtClaims::new(user_id, self.config.refresh_token_expires_in());
140
141        if let Some(ref issuer) = self.config.issuer {
142            jwt_claims = jwt_claims.with_issuer(issuer.clone());
143        }
144
145        jwt_claims.custom.insert("session_id".to_string(), serde_json::to_value(session_id).unwrap());
146        jwt_claims.custom.insert("version".to_string(), serde_json::to_value(version).unwrap());
147        jwt_claims.custom.insert("token_type".to_string(), serde_json::to_value("refresh").unwrap());
148
149        self.generate_token(&jwt_claims)
150    }
151
152    /// 验证刷新令牌
153    ///
154    /// # Arguments
155    /// * `token` - 刷新令牌
156    pub fn verify_refresh_token(&self, token: &str) -> JwtResult<RefreshTokenClaims> {
157        let claims: JwtClaims = self.verify_token(token)?;
158
159        let token_type: Option<String> = claims.custom.get("token_type").and_then(|v| serde_json::from_value(v.clone()).ok());
160
161        if token_type.as_deref() != Some("refresh") {
162            return Err(WaeError::invalid_token("not a refresh token"));
163        }
164
165        let session_id: String = claims
166            .custom
167            .get("session_id")
168            .and_then(|v| serde_json::from_value(v.clone()).ok())
169            .ok_or_else(|| WaeError::missing_claim("session_id"))?;
170
171        let version: u32 = claims.custom.get("version").and_then(|v| serde_json::from_value(v.clone()).ok()).unwrap_or(0);
172
173        Ok(RefreshTokenClaims { user_id: claims.sub, session_id, version })
174    }
175
176    /// 解码令牌(不验证签名)
177    ///
178    /// # Arguments
179    /// * `token` - JWT 令牌
180    pub fn decode_unchecked<T: DeserializeOwned>(&self, token: &str) -> JwtResult<T> {
181        let secret = self.config.secret.as_bytes();
182        decode_jwt(token, secret, false).map_err(|e| e.into())
183    }
184
185    /// 获取令牌剩余有效期
186    ///
187    /// # Arguments
188    /// * `token` - JWT 令牌
189    pub fn get_remaining_ttl(&self, token: &str) -> JwtResult<i64> {
190        let claims: JwtClaims = self.verify_token(token)?;
191        let now = Utc::now().timestamp();
192        let remaining = claims.exp - now;
193        Ok(remaining.max(0))
194    }
195
196    /// 检查令牌是否即将过期
197    ///
198    /// # Arguments
199    /// * `token` - JWT 令牌
200    /// * `threshold_seconds` - 阈值(秒)
201    pub fn is_token_expiring_soon(&self, token: &str, threshold_seconds: i64) -> JwtResult<bool> {
202        let remaining = self.get_remaining_ttl(token)?;
203        Ok(remaining < threshold_seconds)
204    }
205
206    /// 获取配置
207    pub fn config(&self) -> &JwtConfig {
208        &self.config
209    }
210}
211
212/// 令牌对(访问令牌 + 刷新令牌)
213#[derive(Debug, Clone)]
214pub struct TokenPair {
215    /// 访问令牌
216    pub access_token: String,
217
218    /// 刷新令牌
219    pub refresh_token: String,
220
221    /// 令牌类型
222    pub token_type: String,
223
224    /// 访问令牌过期时间(秒)
225    pub expires_in: i64,
226
227    /// 刷新令牌过期时间(秒)
228    pub refresh_expires_in: i64,
229}
230
231impl JwtService {
232    /// 生成令牌对
233    ///
234    /// # Arguments
235    /// * `user_id` - 用户 ID
236    /// * `access_claims` - 访问令牌 Claims
237    /// * `session_id` - 会话 ID
238    pub fn generate_token_pair(
239        &self,
240        user_id: &str,
241        access_claims: AccessTokenClaims,
242        session_id: &str,
243    ) -> JwtResult<TokenPair> {
244        let access_token = self.generate_access_token(user_id, access_claims)?;
245        let refresh_token = self.generate_refresh_token(user_id, session_id, 0)?;
246
247        Ok(TokenPair {
248            access_token,
249            refresh_token,
250            token_type: "Bearer".to_string(),
251            expires_in: self.config.access_token_expires_in(),
252            refresh_expires_in: self.config.refresh_token_expires_in(),
253        })
254    }
255
256    /// 轮换刷新令牌
257    ///
258    /// 使用旧的刷新令牌生成新的令牌对,并递增版本号
259    ///
260    /// # Arguments
261    /// * `refresh_token` - 旧的刷新令牌
262    /// * `new_access_claims` - 新的访问令牌 Claims(可选)
263    pub fn rotate_token_pair(&self, refresh_token: &str, new_access_claims: Option<AccessTokenClaims>) -> JwtResult<TokenPair> {
264        let refresh_claims = self.verify_refresh_token(refresh_token)?;
265
266        let new_version = refresh_claims.version + 1;
267        let access_claims = new_access_claims.unwrap_or_else(|| {
268            AccessTokenClaims::new(refresh_claims.user_id.clone()).with_session_id(refresh_claims.session_id.clone())
269        });
270
271        let access_token = self.generate_access_token(&refresh_claims.user_id, access_claims)?;
272        let new_refresh_token =
273            self.generate_refresh_token(&refresh_claims.user_id, &refresh_claims.session_id, new_version)?;
274
275        Ok(TokenPair {
276            access_token,
277            refresh_token: new_refresh_token,
278            token_type: "Bearer".to_string(),
279            expires_in: self.config.access_token_expires_in(),
280            refresh_expires_in: self.config.refresh_token_expires_in(),
281        })
282    }
283}
284
285/// 便捷函数:创建默认 JWT 服务
286pub fn default_jwt_service() -> JwtResult<JwtService> {
287    JwtService::new(JwtConfig::default())
288}
289
290/// 便捷函数:使用密钥创建 JWT 服务
291pub fn jwt_service_with_secret(secret: impl Into<String>) -> JwtResult<JwtService> {
292    JwtService::new(JwtConfig::new(secret))
293}