sa_token_core/oauth2.rs
1//! OAuth2 Authorization Code Flow Implementation | OAuth2 授权码模式实现
2//!
3//! ## Code Flow Logic | 代码流程逻辑
4//!
5//! ### Overall Architecture | 整体架构
6//!
7//! ```text
8//! ┌─────────────────────────────────────────────────────────────────┐
9//! │ OAuth2Manager │
10//! │ OAuth2 管理器核心 │
11//! ├─────────────────────────────────────────────────────────────────┤
12//! │ │
13//! │ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
14//! │ │ Client Mgmt │ │ Auth Code │ │ Token Mgmt │ │
15//! │ │ 客户端管理 │ │ 授权码管理 │ │ 令牌管理 │ │
16//! │ └──────┬───────┘ └──────┬───────┘ └──────┬───────┘ │
17//! │ │ │ │ │
18//! │ ▼ ▼ ▼ │
19//! │ ┌────────────────────────────────────────────────┐ │
20//! │ │ Storage Backend (SaStorage) │ │
21//! │ │ 存储后端(Memory/Redis/Database) │ │
22//! │ └────────────────────────────────────────────────┘ │
23//! └─────────────────────────────────────────────────────────────────┘
24//! ```
25//!
26//! ### Core Processes | 核心流程
27//!
28//! #### 1. Authorization Code Flow | 授权码流程
29//!
30//! ```text
31//! User/Client OAuth2Manager Storage
32//! 用户/客户端 OAuth2管理器 存储
33//! │ │ │
34//! │ register_client() │ │
35//! │─────────────────────▶│ │
36//! │ │ store client │
37//! │ │─────────────────────▶│
38//! │ │ │
39//! │ authorize request │ │
40//! │─────────────────────▶│ │
41//! │ │ │
42//! │ generate_auth_code()│ │
43//! │ │ validate redirect │
44//! │ │ validate scope │
45//! │ │ create code │
46//! │ │ │
47//! │ │ store code (TTL) │
48//! │ │─────────────────────▶│
49//! │ │ │
50//! │ return code │ │
51//! │◀─────────────────────│ │
52//! │ │ │
53//! │ exchange_code() │ │
54//! │─────────────────────▶│ │
55//! │ │ verify client │
56//! │ │ consume code │
57//! │ │─────────────────────▶│
58//! │ │ delete code │
59//! │ │ │
60//! │ │ generate tokens │
61//! │ │ - access_token │
62//! │ │ - refresh_token │
63//! │ │ │
64//! │ │ store tokens (TTL) │
65//! │ │─────────────────────▶│
66//! │ │ │
67//! │ return tokens │ │
68//! │◀─────────────────────│ │
69//! │ │ │
70//! ```
71//!
72//! #### 2. Token Refresh Flow | 令牌刷新流程
73//!
74//! ```text
75//! Client OAuth2Manager Storage
76//! 客户端 OAuth2管理器 存储
77//! │ │ │
78//! │ refresh_token() │ │
79//! │─────────────────────▶│ │
80//! │ │ verify client │
81//! │ │ credentials │
82//! │ │ │
83//! │ │ get refresh_token │
84//! │ │─────────────────────▶│
85//! │ │ return data │
86//! │ │◀─────────────────────│
87//! │ │ │
88//! │ │ validate client_id │
89//! │ │ validate not expired│
90//! │ │ │
91//! │ │ generate new tokens │
92//! │ │ - new access_token │
93//! │ │ - new refresh_token │
94//! │ │ │
95//! │ │ store new tokens │
96//! │ │─────────────────────▶│
97//! │ │ │
98//! │ return new tokens │ │
99//! │◀─────────────────────│ │
100//! │ │ │
101//! ```
102//!
103//! ### Storage Keys | 存储键格式
104//!
105//! ```text
106//! oauth2:client:{client_id} - Client information | 客户端信息
107//! oauth2:code:{authorization_code} - Authorization code | 授权码 (TTL: 10 min)
108//! oauth2:token:{access_token} - Token info | 令牌信息 (TTL: 1 hour)
109//! oauth2:refresh:{refresh_token} - Refresh token | 刷新令牌 (TTL: 30 days)
110//! ```
111//!
112//! ### Security Validations | 安全验证
113//!
114//! ```text
115//! ┌────────────────────────────────────────────────────────┐
116//! │ 1. Client Verification | 客户端验证 │
117//! │ - client_id + client_secret match | 凭据匹配 │
118//! │ - client exists in registry | 客户端已注册 │
119//! ├────────────────────────────────────────────────────────┤
120//! │ 2. Redirect URI Validation | 回调URI验证 │
121//! │ - URI in whitelist | URI在白名单中 │
122//! │ - Exact match (no wildcards) | 精确匹配 │
123//! ├────────────────────────────────────────────────────────┤
124//! │ 3. Scope Validation | 权限范围验证 │
125//! │ - Requested scopes ⊆ client scopes | 请求范围子集 │
126//! │ - All scopes valid | 所有范围合法 │
127//! ├────────────────────────────────────────────────────────┤
128//! │ 4. Code Validation | 授权码验证 │
129//! │ - Code exists | 授权码存在 │
130//! │ - Not expired | 未过期 │
131//! │ - Single use (consumed after exchange) | 单次使用 │
132//! │ - Client ID match | 客户端ID匹配 │
133//! │ - Redirect URI match | 回调URI匹配 │
134//! ├────────────────────────────────────────────────────────┤
135//! │ 5. Token Validation | 令牌验证 │
136//! │ - Token exists | 令牌存在 │
137//! │ - Not expired | 未过期 │
138//! │ - Signature valid | 签名有效 │
139//! └────────────────────────────────────────────────────────┘
140//! ```
141//!
142//! ### Performance Considerations | 性能考虑
143//!
144//! 1. **Async Operations | 异步操作**
145//! - All storage operations are async | 所有存储操作异步
146//! - Non-blocking I/O | 非阻塞IO
147//!
148//! 2. **TTL Management | TTL管理**
149//! - Storage-level expiration | 存储层级过期
150//! - Automatic cleanup | 自动清理
151//! - No manual garbage collection | 无需手动垃圾回收
152//!
153//! 3. **Code Consumption | 授权码消费**
154//! - Read + Delete in one flow | 读取和删除一次性
155//! - Prevents replay attacks | 防止重放攻击
156//!
157//! ### Error Handling | 错误处理
158//!
159//! ```text
160//! Error Type When It Occurs | 发生时机
161//! ────────────────────────────────────────────────────────
162//! InvalidToken - Invalid client credentials | 无效客户端凭据
163//! - Code not found | 授权码不存在
164//! - Client ID mismatch | 客户端ID不匹配
165//! - Redirect URI mismatch | 回调URI不匹配
166//!
167//! TokenExpired - Authorization code expired | 授权码过期
168//! - Access token expired | 访问令牌过期
169//! - Refresh token expired | 刷新令牌过期
170//!
171//! StorageError - Storage operation failed | 存储操作失败
172//! - Network error | 网络错误
173//!
174//! SerializationError - JSON encode/decode failed | JSON序列化失败
175//! ```
176
177use std::sync::Arc;
178use chrono::{DateTime, Utc, Duration};
179use serde::{Deserialize, Serialize};
180use uuid::Uuid;
181use sa_token_adapter::storage::SaStorage;
182use crate::error::{SaTokenError, SaTokenResult};
183
184/// OAuth2 Client Information | OAuth2 客户端信息
185///
186/// Represents a registered OAuth2 client application with its credentials and configuration.
187/// 表示一个已注册的 OAuth2 客户端应用程序及其凭据和配置。
188#[derive(Debug, Clone, Serialize, Deserialize)]
189pub struct OAuth2Client {
190 /// Unique identifier for the client | 客户端的唯一标识符
191 pub client_id: String,
192
193 /// Secret key for client authentication | 客户端认证的密钥
194 pub client_secret: String,
195
196 /// Allowed redirect URIs (whitelist) | 允许的回调 URI(白名单)
197 pub redirect_uris: Vec<String>,
198
199 /// Supported grant types (e.g., "authorization_code", "refresh_token")
200 /// 支持的授权类型(例如:"authorization_code"、"refresh_token")
201 pub grant_types: Vec<String>,
202
203 /// Permitted scopes for this client | 此客户端允许的权限范围
204 pub scope: Vec<String>,
205}
206
207/// Authorization Code | 授权码
208///
209/// Temporary code issued after user authorization, exchanged for access token.
210/// 用户授权后颁发的临时代码,用于交换访问令牌。
211#[derive(Debug, Clone, Serialize, Deserialize)]
212pub struct AuthorizationCode {
213 /// The authorization code value | 授权码的值
214 pub code: String,
215
216 /// Client ID that requested the code | 请求授权码的客户端 ID
217 pub client_id: String,
218
219 /// User ID who authorized | 授权的用户 ID
220 pub user_id: String,
221
222 /// Redirect URI used in authorization request | 授权请求中使用的回调 URI
223 pub redirect_uri: String,
224
225 /// Granted scopes | 授予的权限范围
226 pub scope: Vec<String>,
227
228 /// Code creation timestamp | 授权码创建时间戳
229 pub created_at: DateTime<Utc>,
230
231 /// Code expiration timestamp | 授权码过期时间戳
232 pub expires_at: DateTime<Utc>,
233}
234
235/// Access Token Response | 访问令牌响应
236///
237/// Token response returned to the client after successful authorization.
238/// 成功授权后返回给客户端的令牌响应。
239#[derive(Debug, Clone, Serialize, Deserialize)]
240pub struct AccessToken {
241 /// The access token value | 访问令牌的值
242 pub access_token: String,
243
244 /// Token type (typically "Bearer") | 令牌类型(通常为 "Bearer")
245 pub token_type: String,
246
247 /// Token lifetime in seconds | 令牌有效期(秒)
248 pub expires_in: i64,
249
250 /// Optional refresh token for token renewal | 可选的刷新令牌,用于令牌续期
251 pub refresh_token: Option<String>,
252
253 /// Granted scopes | 授予的权限范围
254 pub scope: Vec<String>,
255}
256
257/// OAuth2 Token Information (for storage) | OAuth2 令牌信息(用于存储)
258///
259/// Internal structure for storing token details in the backend.
260/// 用于在后端存储令牌详细信息的内部结构。
261#[derive(Debug, Clone, Serialize, Deserialize)]
262pub struct OAuth2TokenInfo {
263 /// Access token value | 访问令牌值
264 pub access_token: String,
265
266 /// Client ID that owns this token | 拥有此令牌的客户端 ID
267 pub client_id: String,
268
269 /// User ID associated with this token | 与此令牌关联的用户 ID
270 pub user_id: String,
271
272 /// Granted scopes | 授予的权限范围
273 pub scope: Vec<String>,
274
275 /// Token creation timestamp | 令牌创建时间戳
276 pub created_at: DateTime<Utc>,
277
278 /// Token expiration timestamp | 令牌过期时间戳
279 pub expires_at: DateTime<Utc>,
280
281 /// Optional refresh token | 可选的刷新令牌
282 pub refresh_token: Option<String>,
283}
284
285/// OAuth2 Manager | OAuth2 管理器
286///
287/// Core manager for OAuth2 authorization code flow operations.
288/// OAuth2 授权码模式的核心管理器。
289///
290/// # Responsibilities | 职责
291/// - Client registration and verification | 客户端注册和验证
292/// - Authorization code generation and validation | 授权码生成和验证
293/// - Access token issuance and verification | 访问令牌颁发和验证
294/// - Refresh token management | 刷新令牌管理
295/// - Security validation (redirect URI, scope, etc.) | 安全验证(回调 URI、权限范围等)
296pub struct OAuth2Manager {
297 /// Storage backend for tokens and clients | 令牌和客户端的存储后端
298 storage: Arc<dyn SaStorage>,
299
300 /// Authorization code TTL in seconds (default: 600 = 10 minutes)
301 /// 授权码有效期(秒)(默认:600 = 10 分钟)
302 code_ttl: i64,
303
304 /// Access token TTL in seconds (default: 3600 = 1 hour)
305 /// 访问令牌有效期(秒)(默认:3600 = 1 小时)
306 token_ttl: i64,
307
308 /// Refresh token TTL in seconds (default: 2592000 = 30 days)
309 /// 刷新令牌有效期(秒)(默认:2592000 = 30 天)
310 refresh_token_ttl: i64,
311}
312
313impl OAuth2Manager {
314 /// Create a new OAuth2Manager with default TTL values
315 /// 使用默认 TTL 值创建新的 OAuth2Manager
316 ///
317 /// # Default TTL | 默认 TTL
318 /// - Authorization code: 600 seconds (10 minutes) | 授权码:600 秒(10 分钟)
319 /// - Access token: 3600 seconds (1 hour) | 访问令牌:3600 秒(1 小时)
320 /// - Refresh token: 2592000 seconds (30 days) | 刷新令牌:2592000 秒(30 天)
321 ///
322 /// # Arguments | 参数
323 /// * `storage` - Storage backend for persistence | 用于持久化的存储后端
324 pub fn new(storage: Arc<dyn SaStorage>) -> Self {
325 Self {
326 storage,
327 code_ttl: 600, // 10 minutes
328 token_ttl: 3600, // 1 hour
329 refresh_token_ttl: 2592000, // 30 days
330 }
331 }
332
333 /// Set custom TTL values for codes and tokens
334 /// 设置授权码和令牌的自定义 TTL 值
335 ///
336 /// # Arguments | 参数
337 /// * `code_ttl` - Authorization code TTL in seconds | 授权码 TTL(秒)
338 /// * `token_ttl` - Access token TTL in seconds | 访问令牌 TTL(秒)
339 /// * `refresh_token_ttl` - Refresh token TTL in seconds | 刷新令牌 TTL(秒)
340 ///
341 /// # Example | 示例
342 /// ```ignore
343 /// let oauth2 = OAuth2Manager::new(storage)
344 /// .with_ttl(300, 1800, 604800); // 5min, 30min, 7days
345 /// ```
346 pub fn with_ttl(mut self, code_ttl: i64, token_ttl: i64, refresh_token_ttl: i64) -> Self {
347 self.code_ttl = code_ttl;
348 self.token_ttl = token_ttl;
349 self.refresh_token_ttl = refresh_token_ttl;
350 self
351 }
352
353 /// Register a new OAuth2 client | 注册新的 OAuth2 客户端
354 ///
355 /// Stores client information in the backend for future authentication.
356 /// 将客户端信息存储在后端,用于未来的认证。
357 ///
358 /// # Arguments | 参数
359 /// * `client` - Client information to register | 要注册的客户端信息
360 ///
361 /// # Returns | 返回
362 /// * `Ok(())` on success | 成功时返回 `Ok(())`
363 /// * `Err(SaTokenError)` on storage or serialization error | 存储或序列化错误时返回错误
364 ///
365 /// # Example | 示例
366 /// ```ignore
367 /// let client = OAuth2Client {
368 /// client_id: "app_001".to_string(),
369 /// client_secret: "secret".to_string(),
370 /// redirect_uris: vec!["http://localhost/callback".to_string()],
371 /// grant_types: vec!["authorization_code".to_string()],
372 /// scope: vec!["read".to_string(), "write".to_string()],
373 /// };
374 /// oauth2.register_client(&client).await?;
375 /// ```
376 pub async fn register_client(&self, client: &OAuth2Client) -> SaTokenResult<()> {
377 let key = format!("oauth2:client:{}", client.client_id);
378 let value = serde_json::to_string(client)
379 .map_err(|e| SaTokenError::SerializationError(e))?;
380
381 self.storage.set(&key, &value, None).await
382 .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
383
384 Ok(())
385 }
386
387 /// Retrieve client information by client ID | 通过客户端 ID 检索客户端信息
388 ///
389 /// # Arguments | 参数
390 /// * `client_id` - Client identifier | 客户端标识符
391 ///
392 /// # Returns | 返回
393 /// * `Ok(OAuth2Client)` if found | 找到时返回客户端信息
394 /// * `Err(OAuth2ClientNotFound)` if client doesn't exist | 客户端不存在时返回错误
395 pub async fn get_client(&self, client_id: &str) -> SaTokenResult<OAuth2Client> {
396 let key = format!("oauth2:client:{}", client_id);
397 let value = self.storage.get(&key).await
398 .map_err(|e| SaTokenError::StorageError(e.to_string()))?
399 .ok_or_else(|| SaTokenError::OAuth2ClientNotFound)?;
400
401 serde_json::from_str(&value)
402 .map_err(|e| SaTokenError::SerializationError(e))
403 }
404
405 /// Verify client credentials | 验证客户端凭据
406 ///
407 /// Checks if the provided client_id and client_secret match.
408 /// 检查提供的 client_id 和 client_secret 是否匹配。
409 ///
410 /// # Arguments | 参数
411 /// * `client_id` - Client identifier | 客户端标识符
412 /// * `client_secret` - Client secret key | 客户端密钥
413 ///
414 /// # Returns | 返回
415 /// * `Ok(true)` if credentials are valid | 凭据有效时返回 `true`
416 /// * `Ok(false)` if credentials are invalid | 凭据无效时返回 `false`
417 /// * `Err(OAuth2ClientNotFound)` if client doesn't exist | 客户端不存在时返回错误
418 pub async fn verify_client(&self, client_id: &str, client_secret: &str) -> SaTokenResult<bool> {
419 let client = self.get_client(client_id).await?;
420 Ok(client.client_secret == client_secret)
421 }
422
423 /// Generate a new authorization code | 生成新的授权码
424 ///
425 /// Creates a temporary authorization code after user consent.
426 /// 在用户同意后创建临时授权码。
427 ///
428 /// # Arguments | 参数
429 /// * `client_id` - Client requesting authorization | 请求授权的客户端
430 /// * `user_id` - User granting authorization | 授予授权的用户
431 /// * `redirect_uri` - Callback URI for this authorization | 此授权的回调 URI
432 /// * `scope` - Granted permissions | 授予的权限
433 ///
434 /// # Returns | 返回
435 /// * `AuthorizationCode` with unique code and expiration | 带有唯一代码和过期时间的授权码
436 ///
437 /// # Note | 注意
438 /// This code must be stored using `store_authorization_code()` before returning to client.
439 /// 此代码必须在返回给客户端之前使用 `store_authorization_code()` 存储。
440 pub fn generate_authorization_code(
441 &self,
442 client_id: String,
443 user_id: String,
444 redirect_uri: String,
445 scope: Vec<String>,
446 ) -> AuthorizationCode {
447 let now = Utc::now();
448 let code = format!("code_{}", Uuid::new_v4().simple());
449
450 AuthorizationCode {
451 code,
452 client_id,
453 user_id,
454 redirect_uri,
455 scope,
456 created_at: now,
457 expires_at: now + Duration::seconds(self.code_ttl),
458 }
459 }
460
461 /// Store authorization code in backend | 在后端存储授权码
462 ///
463 /// Persists the authorization code with TTL for later exchange.
464 /// 使用 TTL 持久化授权码,以便稍后交换。
465 ///
466 /// # Arguments | 参数
467 /// * `auth_code` - Authorization code to store | 要存储的授权码
468 ///
469 /// # Storage Key Format | 存储键格式
470 /// `oauth2:code:{authorization_code}`
471 pub async fn store_authorization_code(&self, auth_code: &AuthorizationCode) -> SaTokenResult<()> {
472 let key = format!("oauth2:code:{}", auth_code.code);
473 let value = serde_json::to_string(auth_code)
474 .map_err(|e| SaTokenError::SerializationError(e))?;
475
476 let ttl = Some(std::time::Duration::from_secs(self.code_ttl as u64));
477 self.storage.set(&key, &value, ttl).await
478 .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
479
480 Ok(())
481 }
482
483 pub async fn get_authorization_code(&self, code: &str) -> SaTokenResult<AuthorizationCode> {
484 let key = format!("oauth2:code:{}", code);
485 let value = self.storage.get(&key).await
486 .map_err(|e| SaTokenError::StorageError(e.to_string()))?
487 .ok_or_else(|| SaTokenError::OAuth2CodeNotFound)?;
488
489 let auth_code: AuthorizationCode = serde_json::from_str(&value)
490 .map_err(|e| SaTokenError::SerializationError(e))?;
491
492 if Utc::now() > auth_code.expires_at {
493 self.storage.delete(&key).await.ok();
494 return Err(SaTokenError::TokenExpired);
495 }
496
497 Ok(auth_code)
498 }
499
500 pub async fn consume_authorization_code(&self, code: &str) -> SaTokenResult<AuthorizationCode> {
501 let auth_code = self.get_authorization_code(code).await?;
502 let key = format!("oauth2:code:{}", code);
503 self.storage.delete(&key).await
504 .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
505 Ok(auth_code)
506 }
507
508 /// Exchange authorization code for access token | 用授权码换取访问令牌
509 ///
510 /// Core of the authorization code flow. Validates the code and issues tokens.
511 /// 授权码流程的核心。验证授权码并颁发令牌。
512 ///
513 /// # Validations | 验证
514 /// 1. Client credentials (client_id + client_secret) | 客户端凭据
515 /// 2. Authorization code exists and not expired | 授权码存在且未过期
516 /// 3. Client ID matches the code | 客户端 ID 与授权码匹配
517 /// 4. Redirect URI matches the code | 回调 URI 与授权码匹配
518 ///
519 /// # Arguments | 参数
520 /// * `code` - Authorization code | 授权码
521 /// * `client_id` - Client identifier | 客户端标识符
522 /// * `client_secret` - Client secret | 客户端密钥
523 /// * `redirect_uri` - Redirect URI used in authorization | 授权时使用的回调 URI
524 ///
525 /// # Returns | 返回
526 /// * `Ok(AccessToken)` with access_token and optional refresh_token | 带有访问令牌和可选刷新令牌
527 /// * `Err(OAuth2InvalidCredentials)` if client credentials invalid | 客户端凭据无效时
528 /// * `Err(OAuth2CodeNotFound)` if code not found or expired | 授权码未找到或已过期时
529 /// * `Err(OAuth2ClientIdMismatch)` if client ID doesn't match | 客户端 ID 不匹配时
530 /// * `Err(OAuth2RedirectUriMismatch)` if redirect URI doesn't match | 回调 URI 不匹配时
531 ///
532 /// # Security | 安全性
533 /// The authorization code is consumed (deleted) after use to prevent replay attacks.
534 /// 授权码在使用后被消费(删除),以防止重放攻击。
535 pub async fn exchange_code_for_token(
536 &self,
537 code: &str,
538 client_id: &str,
539 client_secret: &str,
540 redirect_uri: &str,
541 ) -> SaTokenResult<AccessToken> {
542 if !self.verify_client(client_id, client_secret).await? {
543 return Err(SaTokenError::OAuth2InvalidCredentials);
544 }
545
546 let auth_code = self.consume_authorization_code(code).await?;
547
548 if auth_code.client_id != client_id {
549 return Err(SaTokenError::OAuth2ClientIdMismatch);
550 }
551
552 if auth_code.redirect_uri != redirect_uri {
553 return Err(SaTokenError::OAuth2RedirectUriMismatch);
554 }
555
556 self.generate_access_token(&auth_code.client_id, &auth_code.user_id, auth_code.scope).await
557 }
558
559 pub async fn generate_access_token(
560 &self,
561 client_id: &str,
562 user_id: &str,
563 scope: Vec<String>,
564 ) -> SaTokenResult<AccessToken> {
565 let now = Utc::now();
566 let access_token = format!("at_{}", Uuid::new_v4().simple());
567 let refresh_token = format!("rt_{}", Uuid::new_v4().simple());
568
569 let token_info = OAuth2TokenInfo {
570 access_token: access_token.clone(),
571 client_id: client_id.to_string(),
572 user_id: user_id.to_string(),
573 scope: scope.clone(),
574 created_at: now,
575 expires_at: now + Duration::seconds(self.token_ttl),
576 refresh_token: Some(refresh_token.clone()),
577 };
578
579 let key = format!("oauth2:token:{}", access_token);
580 let value = serde_json::to_string(&token_info)
581 .map_err(|e| SaTokenError::SerializationError(e))?;
582
583 let ttl = Some(std::time::Duration::from_secs(self.token_ttl as u64));
584 self.storage.set(&key, &value, ttl).await
585 .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
586
587 let refresh_key = format!("oauth2:refresh:{}", refresh_token);
588 let refresh_value = serde_json::json!({
589 "user_id": user_id,
590 "client_id": client_id,
591 "scope": scope,
592 }).to_string();
593
594 let refresh_ttl = Some(std::time::Duration::from_secs(self.refresh_token_ttl as u64));
595 self.storage.set(&refresh_key, &refresh_value, refresh_ttl).await
596 .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
597
598 Ok(AccessToken {
599 access_token,
600 token_type: "Bearer".to_string(),
601 expires_in: self.token_ttl,
602 refresh_token: Some(refresh_token),
603 scope,
604 })
605 }
606
607 /// Verify access token and retrieve token information | 验证访问令牌并检索令牌信息
608 ///
609 /// Checks if the access token is valid and not expired.
610 /// 检查访问令牌是否有效且未过期。
611 ///
612 /// # Arguments | 参数
613 /// * `access_token` - Access token to verify | 要验证的访问令牌
614 ///
615 /// # Returns | 返回
616 /// * `Ok(OAuth2TokenInfo)` if token is valid | 令牌有效时返回令牌信息
617 /// * `Err(OAuth2AccessTokenNotFound)` if token not found | 令牌未找到时
618 /// * `Err(TokenExpired)` if token has expired | 令牌已过期时
619 ///
620 /// # Note | 注意
621 /// Expired tokens are automatically cleaned up from storage.
622 /// 过期的令牌会自动从存储中清理。
623 pub async fn verify_access_token(&self, access_token: &str) -> SaTokenResult<OAuth2TokenInfo> {
624 let key = format!("oauth2:token:{}", access_token);
625 let value = self.storage.get(&key).await
626 .map_err(|e| SaTokenError::StorageError(e.to_string()))?
627 .ok_or_else(|| SaTokenError::OAuth2AccessTokenNotFound)?;
628
629 let token_info: OAuth2TokenInfo = serde_json::from_str(&value)
630 .map_err(|e| SaTokenError::SerializationError(e))?;
631
632 if Utc::now() > token_info.expires_at {
633 self.storage.delete(&key).await.ok();
634 return Err(SaTokenError::TokenExpired);
635 }
636
637 Ok(token_info)
638 }
639
640 /// Refresh access token using refresh token | 使用刷新令牌刷新访问令牌
641 ///
642 /// Issues a new access token (and optionally a new refresh token) when the old one expires.
643 /// 当旧令牌过期时颁发新的访问令牌(以及可选的新刷新令牌)。
644 ///
645 /// # Validations | 验证
646 /// 1. Client credentials are valid | 客户端凭据有效
647 /// 2. Refresh token exists and belongs to the client | 刷新令牌存在且属于该客户端
648 /// 3. Client ID matches the refresh token | 客户端 ID 与刷新令牌匹配
649 ///
650 /// # Arguments | 参数
651 /// * `refresh_token` - Refresh token | 刷新令牌
652 /// * `client_id` - Client identifier | 客户端标识符
653 /// * `client_secret` - Client secret | 客户端密钥
654 ///
655 /// # Returns | 返回
656 /// * `Ok(AccessToken)` with new access_token and refresh_token | 新的访问令牌和刷新令牌
657 /// * `Err(OAuth2InvalidCredentials)` if credentials invalid | 凭据无效时
658 /// * `Err(OAuth2RefreshTokenNotFound)` if refresh token not found | 刷新令牌未找到时
659 /// * `Err(OAuth2ClientIdMismatch)` if client ID doesn't match | 客户端 ID 不匹配时
660 pub async fn refresh_access_token(
661 &self,
662 refresh_token: &str,
663 client_id: &str,
664 client_secret: &str,
665 ) -> SaTokenResult<AccessToken> {
666 if !self.verify_client(client_id, client_secret).await? {
667 return Err(SaTokenError::OAuth2InvalidCredentials);
668 }
669
670 let key = format!("oauth2:refresh:{}", refresh_token);
671 let value = self.storage.get(&key).await
672 .map_err(|e| SaTokenError::StorageError(e.to_string()))?
673 .ok_or_else(|| SaTokenError::OAuth2RefreshTokenNotFound)?;
674
675 let data: serde_json::Value = serde_json::from_str(&value)
676 .map_err(|e| SaTokenError::SerializationError(e))?;
677
678 let stored_client_id = data["client_id"].as_str()
679 .ok_or_else(|| SaTokenError::OAuth2InvalidRefreshToken)?;
680
681 if stored_client_id != client_id {
682 return Err(SaTokenError::OAuth2ClientIdMismatch);
683 }
684
685 let user_id = data["user_id"].as_str()
686 .ok_or_else(|| SaTokenError::OAuth2InvalidRefreshToken)?;
687
688 let scope: Vec<String> = data["scope"].as_array()
689 .ok_or_else(|| SaTokenError::OAuth2InvalidScope)?
690 .iter()
691 .filter_map(|v| v.as_str().map(|s| s.to_string()))
692 .collect();
693
694 self.generate_access_token(client_id, user_id, scope).await
695 }
696
697 /// Revoke an access token or refresh token | 撤销访问令牌或刷新令牌
698 ///
699 /// Deletes the token from storage, making it immediately invalid.
700 /// 从存储中删除令牌,使其立即失效。
701 ///
702 /// # Arguments | 参数
703 /// * `token` - Token to revoke (access or refresh) | 要撤销的令牌(访问或刷新)
704 ///
705 /// # Use Cases | 使用场景
706 /// - User logout | 用户登出
707 /// - Security breach | 安全漏洞
708 /// - Client revocation | 客户端撤销
709 pub async fn revoke_token(&self, token: &str) -> SaTokenResult<()> {
710 let access_key = format!("oauth2:token:{}", token);
711 let refresh_key = format!("oauth2:refresh:{}", token);
712
713 self.storage.delete(&access_key).await.ok();
714 self.storage.delete(&refresh_key).await.ok();
715
716 Ok(())
717 }
718
719 /// Validate redirect URI against client's whitelist | 根据客户端白名单验证回调 URI
720 ///
721 /// Security check to prevent redirect URI hijacking.
722 /// 安全检查以防止回调 URI 劫持。
723 ///
724 /// # Arguments | 参数
725 /// * `client` - Client information with registered URIs | 带有注册 URI 的客户端信息
726 /// * `redirect_uri` - URI to validate | 要验证的 URI
727 ///
728 /// # Returns | 返回
729 /// * `true` if URI is in the whitelist | URI 在白名单中时返回 `true`
730 /// * `false` if URI is not allowed | URI 不被允许时返回 `false`
731 pub fn validate_redirect_uri(&self, client: &OAuth2Client, redirect_uri: &str) -> bool {
732 client.redirect_uris.iter().any(|uri| uri == redirect_uri)
733 }
734
735 /// Validate requested scopes against client's permitted scopes | 根据客户端允许的范围验证请求的权限范围
736 ///
737 /// Ensures requested scopes are a subset of client's permitted scopes.
738 /// 确保请求的权限范围是客户端允许范围的子集。
739 ///
740 /// # Arguments | 参数
741 /// * `client` - Client information with permitted scopes | 带有允许权限范围的客户端信息
742 /// * `requested_scope` - Scopes being requested | 正在请求的权限范围
743 ///
744 /// # Returns | 返回
745 /// * `true` if all requested scopes are permitted | 所有请求的权限范围都被允许时返回 `true`
746 /// * `false` if any requested scope is not permitted | 任何请求的权限范围不被允许时返回 `false`
747 pub fn validate_scope(&self, client: &OAuth2Client, requested_scope: &[String]) -> bool {
748 requested_scope.iter().all(|s| client.scope.contains(s))
749 }
750}
751
752#[cfg(test)]
753mod tests {
754 use super::*;
755 use sa_token_storage_memory::MemoryStorage;
756
757 #[tokio::test]
758 async fn test_oauth2_authorization_code_flow() {
759 let storage = Arc::new(MemoryStorage::new());
760 let oauth2 = OAuth2Manager::new(storage);
761
762 let client = OAuth2Client {
763 client_id: "test_client".to_string(),
764 client_secret: "test_secret".to_string(),
765 redirect_uris: vec!["http://localhost:3000/callback".to_string()],
766 grant_types: vec!["authorization_code".to_string()],
767 scope: vec!["read".to_string(), "write".to_string()],
768 };
769
770 oauth2.register_client(&client).await.unwrap();
771
772 let auth_code = oauth2.generate_authorization_code(
773 "test_client".to_string(),
774 "user_123".to_string(),
775 "http://localhost:3000/callback".to_string(),
776 vec!["read".to_string()],
777 );
778
779 oauth2.store_authorization_code(&auth_code).await.unwrap();
780
781 let token = oauth2.exchange_code_for_token(
782 &auth_code.code,
783 "test_client",
784 "test_secret",
785 "http://localhost:3000/callback",
786 ).await.unwrap();
787
788 assert_eq!(token.token_type, "Bearer");
789 assert!(token.refresh_token.is_some());
790
791 let token_info = oauth2.verify_access_token(&token.access_token).await.unwrap();
792 assert_eq!(token_info.user_id, "user_123");
793 assert_eq!(token_info.client_id, "test_client");
794 }
795
796 #[tokio::test]
797 async fn test_refresh_token() {
798 let storage = Arc::new(MemoryStorage::new());
799 let oauth2 = OAuth2Manager::new(storage);
800
801 let client = OAuth2Client {
802 client_id: "test_client".to_string(),
803 client_secret: "test_secret".to_string(),
804 redirect_uris: vec!["http://localhost:3000/callback".to_string()],
805 grant_types: vec!["authorization_code".to_string(), "refresh_token".to_string()],
806 scope: vec!["read".to_string()],
807 };
808
809 oauth2.register_client(&client).await.unwrap();
810
811 let token = oauth2.generate_access_token(
812 "test_client",
813 "user_123",
814 vec!["read".to_string()],
815 ).await.unwrap();
816
817 let refresh_token = token.refresh_token.as_ref().unwrap();
818 let new_token = oauth2.refresh_access_token(
819 refresh_token,
820 "test_client",
821 "test_secret",
822 ).await.unwrap();
823
824 assert_ne!(new_token.access_token, token.access_token);
825 }
826}
827