Skip to main content

wae_authentication/jwt/
service.rs

1//! JWT 服务实现
2
3use crate::jwt::{AccessTokenClaims, JwtAlgorithm, JwtClaims, JwtConfig, JwtError, JwtResult, RefreshTokenClaims};
4use chrono::Utc;
5use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation, decode, encode};
6use serde::{Serialize, de::DeserializeOwned};
7
8/// JWT 服务
9#[derive(Debug, Clone)]
10pub struct JwtService {
11    config: JwtConfig,
12    encoding_key: EncodingKey,
13    decoding_key: DecodingKey,
14}
15
16impl JwtService {
17    /// 创建新的 JWT 服务
18    ///
19    /// # Arguments
20    /// * `config` - JWT 配置
21    pub fn new(config: JwtConfig) -> JwtResult<Self> {
22        let (encoding_key, decoding_key) = Self::create_keys(&config)?;
23        Ok(Self { config, encoding_key, decoding_key })
24    }
25
26    fn create_keys(config: &JwtConfig) -> JwtResult<(EncodingKey, DecodingKey)> {
27        match config.algorithm {
28            JwtAlgorithm::HS256 | JwtAlgorithm::HS384 | JwtAlgorithm::HS512 => {
29                let encoding_key = EncodingKey::from_secret(config.secret.as_bytes());
30                let decoding_key = DecodingKey::from_secret(config.secret.as_bytes());
31                Ok((encoding_key, decoding_key))
32            }
33            JwtAlgorithm::RS256 | JwtAlgorithm::RS384 | JwtAlgorithm::RS512 | JwtAlgorithm::ES256 | JwtAlgorithm::ES384 => {
34                let encoding_key =
35                    EncodingKey::from_rsa_pem(config.secret.as_bytes()).map_err(|e| JwtError::KeyError(e.to_string()))?;
36
37                let public_key = config
38                    .public_key
39                    .as_ref()
40                    .ok_or_else(|| JwtError::KeyError("public key is required for asymmetric algorithms".into()))?;
41
42                let decoding_key =
43                    DecodingKey::from_rsa_pem(public_key.as_bytes()).map_err(|e| JwtError::KeyError(e.to_string()))?;
44
45                Ok((encoding_key, decoding_key))
46            }
47        }
48    }
49
50    /// 生成令牌
51    ///
52    /// # Arguments
53    /// * `claims` - JWT Claims
54    pub fn generate_token<T: Serialize>(&self, claims: &T) -> JwtResult<String> {
55        let header = Header::new(self.config.algorithm.into());
56        encode(&header, claims, &self.encoding_key).map_err(Into::into)
57    }
58
59    /// 验证令牌
60    ///
61    /// # Arguments
62    /// * `token` - JWT 令牌
63    pub fn verify_token<T: DeserializeOwned>(&self, token: &str) -> JwtResult<T> {
64        let mut validation = Validation::new(self.config.algorithm.into());
65
66        if self.config.validate_issuer {
67            if let Some(ref issuer) = self.config.issuer {
68                validation.set_issuer(&[issuer]);
69            }
70        }
71
72        if self.config.validate_audience {
73            if let Some(ref audience) = self.config.audience {
74                validation.set_audience(&[audience]);
75            }
76        }
77
78        validation.leeway = self.config.leeway_seconds as u64;
79
80        let token_data = decode::<T>(token, &self.decoding_key, &validation)?;
81        Ok(token_data.claims)
82    }
83
84    /// 生成访问令牌
85    ///
86    /// # Arguments
87    /// * `user_id` - 用户 ID
88    /// * `claims` - 访问令牌 Claims
89    pub fn generate_access_token(&self, user_id: &str, claims: AccessTokenClaims) -> JwtResult<String> {
90        let mut jwt_claims = JwtClaims::new(user_id, self.config.access_token_expires_in());
91
92        if let Some(ref issuer) = self.config.issuer {
93            jwt_claims = jwt_claims.with_issuer(issuer.clone());
94        }
95
96        if let Some(ref audience) = self.config.audience {
97            jwt_claims = jwt_claims.with_audience(audience.clone());
98        }
99
100        jwt_claims.custom.insert("user_id".to_string(), serde_json::to_value(&claims.user_id).unwrap());
101
102        if let Some(ref username) = claims.username {
103            jwt_claims.custom.insert("username".to_string(), serde_json::to_value(username).unwrap());
104        }
105
106        if !claims.roles.is_empty() {
107            jwt_claims.custom.insert("roles".to_string(), serde_json::to_value(&claims.roles).unwrap());
108        }
109
110        if !claims.permissions.is_empty() {
111            jwt_claims.custom.insert("permissions".to_string(), serde_json::to_value(&claims.permissions).unwrap());
112        }
113
114        if let Some(ref session_id) = claims.session_id {
115            jwt_claims.custom.insert("session_id".to_string(), serde_json::to_value(session_id).unwrap());
116        }
117
118        self.generate_token(&jwt_claims)
119    }
120
121    /// 验证访问令牌
122    ///
123    /// # Arguments
124    /// * `token` - 访问令牌
125    pub fn verify_access_token(&self, token: &str) -> JwtResult<JwtClaims> {
126        self.verify_token(token)
127    }
128
129    /// 生成刷新令牌
130    ///
131    /// # Arguments
132    /// * `user_id` - 用户 ID
133    /// * `session_id` - 会话 ID
134    /// * `version` - 令牌版本
135    pub fn generate_refresh_token(&self, user_id: &str, session_id: &str, version: u32) -> JwtResult<String> {
136        let mut jwt_claims = JwtClaims::new(user_id, self.config.refresh_token_expires_in());
137
138        if let Some(ref issuer) = self.config.issuer {
139            jwt_claims = jwt_claims.with_issuer(issuer.clone());
140        }
141
142        jwt_claims.custom.insert("session_id".to_string(), serde_json::to_value(session_id).unwrap());
143        jwt_claims.custom.insert("version".to_string(), serde_json::to_value(version).unwrap());
144        jwt_claims.custom.insert("token_type".to_string(), serde_json::to_value("refresh").unwrap());
145
146        self.generate_token(&jwt_claims)
147    }
148
149    /// 验证刷新令牌
150    ///
151    /// # Arguments
152    /// * `token` - 刷新令牌
153    pub fn verify_refresh_token(&self, token: &str) -> JwtResult<RefreshTokenClaims> {
154        let claims: JwtClaims = self.verify_token(token)?;
155
156        let token_type: Option<String> = claims.custom.get("token_type").and_then(|v| serde_json::from_value(v.clone()).ok());
157
158        if token_type.as_deref() != Some("refresh") {
159            return Err(JwtError::InvalidToken("not a refresh token".into()));
160        }
161
162        let session_id: String = claims
163            .custom
164            .get("session_id")
165            .and_then(|v| serde_json::from_value(v.clone()).ok())
166            .ok_or_else(|| JwtError::MissingClaim("session_id".into()))?;
167
168        let version: u32 = claims.custom.get("version").and_then(|v| serde_json::from_value(v.clone()).ok()).unwrap_or(0);
169
170        Ok(RefreshTokenClaims { user_id: claims.sub, session_id, version })
171    }
172
173    /// 解码令牌(不验证签名)
174    ///
175    /// # Arguments
176    /// * `token` - JWT 令牌
177    pub fn decode_unchecked<T: DeserializeOwned>(&self, token: &str) -> JwtResult<T> {
178        let mut validation = Validation::new(self.config.algorithm.into());
179        validation.insecure_disable_signature_validation();
180
181        let token_data = decode::<T>(token, &self.decoding_key, &validation)?;
182        Ok(token_data.claims)
183    }
184
185    /// 获取令牌剩余有效期
186    ///
187    /// # Arguments
188    /// * `token` - JWT 令牌
189    pub fn get_remaining_ttl(&self, token: &str) -> JwtResult<i64> {
190        let claims: JwtClaims = self.verify_token(token)?;
191        let now = Utc::now().timestamp();
192        let remaining = claims.exp - now;
193        Ok(remaining.max(0))
194    }
195
196    /// 检查令牌是否即将过期
197    ///
198    /// # Arguments
199    /// * `token` - JWT 令牌
200    /// * `threshold_seconds` - 阈值(秒)
201    pub fn is_token_expiring_soon(&self, token: &str, threshold_seconds: i64) -> JwtResult<bool> {
202        let remaining = self.get_remaining_ttl(token)?;
203        Ok(remaining < threshold_seconds)
204    }
205
206    /// 获取配置
207    pub fn config(&self) -> &JwtConfig {
208        &self.config
209    }
210}
211
212/// 令牌对(访问令牌 + 刷新令牌)
213#[derive(Debug, Clone)]
214pub struct TokenPair {
215    /// 访问令牌
216    pub access_token: String,
217
218    /// 刷新令牌
219    pub refresh_token: String,
220
221    /// 令牌类型
222    pub token_type: String,
223
224    /// 访问令牌过期时间(秒)
225    pub expires_in: i64,
226
227    /// 刷新令牌过期时间(秒)
228    pub refresh_expires_in: i64,
229}
230
231impl JwtService {
232    /// 生成令牌对
233    ///
234    /// # Arguments
235    /// * `user_id` - 用户 ID
236    /// * `access_claims` - 访问令牌 Claims
237    /// * `session_id` - 会话 ID
238    pub fn generate_token_pair(
239        &self,
240        user_id: &str,
241        access_claims: AccessTokenClaims,
242        session_id: &str,
243    ) -> JwtResult<TokenPair> {
244        let access_token = self.generate_access_token(user_id, access_claims)?;
245        let refresh_token = self.generate_refresh_token(user_id, session_id, 0)?;
246
247        Ok(TokenPair {
248            access_token,
249            refresh_token,
250            token_type: "Bearer".to_string(),
251            expires_in: self.config.access_token_expires_in(),
252            refresh_expires_in: self.config.refresh_token_expires_in(),
253        })
254    }
255}
256
257/// 便捷函数:创建默认 JWT 服务
258pub fn default_jwt_service() -> JwtResult<JwtService> {
259    JwtService::new(JwtConfig::default())
260}
261
262/// 便捷函数:使用密钥创建 JWT 服务
263pub fn jwt_service_with_secret(secret: impl Into<String>) -> JwtResult<JwtService> {
264    JwtService::new(JwtConfig::new(secret))
265}