sa_token_core/oauth2.rs
1//! OAuth2 Authorization Code Flow Implementation | OAuth2 授权码模式实现
2//!
3//! ## Code Flow Logic | 代码流程逻辑
4//!
5//! ### Overall Architecture | 整体架构
6//!
7//! ```text
8//! ┌─────────────────────────────────────────────────────────────────┐
9//! │ OAuth2Manager │
10//! │ OAuth2 管理器核心 │
11//! ├─────────────────────────────────────────────────────────────────┤
12//! │ │
13//! │ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
14//! │ │ Client Mgmt │ │ Auth Code │ │ Token Mgmt │ │
15//! │ │ 客户端管理 │ │ 授权码管理 │ │ 令牌管理 │ │
16//! │ └──────┬───────┘ └──────┬───────┘ └──────┬───────┘ │
17//! │ │ │ │ │
18//! │ ▼ ▼ ▼ │
19//! │ ┌────────────────────────────────────────────────┐ │
20//! │ │ Storage Backend (SaStorage) │ │
21//! │ │ 存储后端(Memory/Redis/Database) │ │
22//! │ └────────────────────────────────────────────────┘ │
23//! └─────────────────────────────────────────────────────────────────┘
24//! ```
25//!
26//! ### Core Processes | 核心流程
27//!
28//! #### 1. Authorization Code Flow | 授权码流程
29//!
30//! ```text
31//! User/Client OAuth2Manager Storage
32//! 用户/客户端 OAuth2管理器 存储
33//! │ │ │
34//! │ register_client() │ │
35//! │─────────────────────▶│ │
36//! │ │ store client │
37//! │ │─────────────────────▶│
38//! │ │ │
39//! │ authorize request │ │
40//! │─────────────────────▶│ │
41//! │ │ │
42//! │ generate_auth_code()│ │
43//! │ │ validate redirect │
44//! │ │ validate scope │
45//! │ │ create code │
46//! │ │ │
47//! │ │ store code (TTL) │
48//! │ │─────────────────────▶│
49//! │ │ │
50//! │ return code │ │
51//! │◀─────────────────────│ │
52//! │ │ │
53//! │ exchange_code() │ │
54//! │─────────────────────▶│ │
55//! │ │ verify client │
56//! │ │ consume code │
57//! │ │─────────────────────▶│
58//! │ │ delete code │
59//! │ │ │
60//! │ │ generate tokens │
61//! │ │ - access_token │
62//! │ │ - refresh_token │
63//! │ │ │
64//! │ │ store tokens (TTL) │
65//! │ │─────────────────────▶│
66//! │ │ │
67//! │ return tokens │ │
68//! │◀─────────────────────│ │
69//! │ │ │
70//! ```
71//!
72//! #### 2. Token Refresh Flow | 令牌刷新流程
73//!
74//! ```text
75//! Client OAuth2Manager Storage
76//! 客户端 OAuth2管理器 存储
77//! │ │ │
78//! │ refresh_token() │ │
79//! │─────────────────────▶│ │
80//! │ │ verify client │
81//! │ │ credentials │
82//! │ │ │
83//! │ │ get refresh_token │
84//! │ │─────────────────────▶│
85//! │ │ return data │
86//! │ │◀─────────────────────│
87//! │ │ │
88//! │ │ validate client_id │
89//! │ │ validate not expired│
90//! │ │ │
91//! │ │ generate new tokens │
92//! │ │ - new access_token │
93//! │ │ - new refresh_token │
94//! │ │ │
95//! │ │ store new tokens │
96//! │ │─────────────────────▶│
97//! │ │ │
98//! │ return new tokens │ │
99//! │◀─────────────────────│ │
100//! │ │ │
101//! ```
102//!
103//! ### Storage Keys | 存储键格式
104//!
105//! ```text
106//! oauth2:client:{client_id} - Client information | 客户端信息
107//! oauth2:code:{authorization_code} - Authorization code | 授权码 (TTL: 10 min)
108//! oauth2:token:{access_token} - Token info | 令牌信息 (TTL: 1 hour)
109//! oauth2:refresh:{refresh_token} - Refresh token | 刷新令牌 (TTL: 30 days)
110//! ```
111//!
112//! ### Security Validations | 安全验证
113//!
114//! ```text
115//! ┌────────────────────────────────────────────────────────┐
116//! │ 1. Client Verification | 客户端验证 │
117//! │ - client_id + client_secret match | 凭据匹配 │
118//! │ - client exists in registry | 客户端已注册 │
119//! ├────────────────────────────────────────────────────────┤
120//! │ 2. Redirect URI Validation | 回调URI验证 │
121//! │ - URI in whitelist | URI在白名单中 │
122//! │ - Exact match (no wildcards) | 精确匹配 │
123//! ├────────────────────────────────────────────────────────┤
124//! │ 3. Scope Validation | 权限范围验证 │
125//! │ - Requested scopes ⊆ client scopes | 请求范围子集 │
126//! │ - All scopes valid | 所有范围合法 │
127//! ├────────────────────────────────────────────────────────┤
128//! │ 4. Code Validation | 授权码验证 │
129//! │ - Code exists | 授权码存在 │
130//! │ - Not expired | 未过期 │
131//! │ - Single use (consumed after exchange) | 单次使用 │
132//! │ - Client ID match | 客户端ID匹配 │
133//! │ - Redirect URI match | 回调URI匹配 │
134//! ├────────────────────────────────────────────────────────┤
135//! │ 5. Token Validation | 令牌验证 │
136//! │ - Token exists | 令牌存在 │
137//! │ - Not expired | 未过期 │
138//! │ - Signature valid | 签名有效 │
139//! └────────────────────────────────────────────────────────┘
140//! ```
141//!
142//! ### Performance Considerations | 性能考虑
143//!
144//! 1. **Async Operations | 异步操作**
145//! - All storage operations are async | 所有存储操作异步
146//! - Non-blocking I/O | 非阻塞IO
147//!
148//! 2. **TTL Management | TTL管理**
149//! - Storage-level expiration | 存储层级过期
150//! - Automatic cleanup | 自动清理
151//! - No manual garbage collection | 无需手动垃圾回收
152//!
153//! 3. **Code Consumption | 授权码消费**
154//! - Read + Delete in one flow | 读取和删除一次性
155//! - Prevents replay attacks | 防止重放攻击
156//!
157//! ### Error Handling | 错误处理
158//!
159//! ```text
160//! Error Type When It Occurs | 发生时机
161//! ────────────────────────────────────────────────────────
162//! InvalidToken - Invalid client credentials | 无效客户端凭据
163//! - Code not found | 授权码不存在
164//! - Client ID mismatch | 客户端ID不匹配
165//! - Redirect URI mismatch | 回调URI不匹配
166//!
167//! TokenExpired - Authorization code expired | 授权码过期
168//! - Access token expired | 访问令牌过期
169//! - Refresh token expired | 刷新令牌过期
170//!
171//! StorageError - Storage operation failed | 存储操作失败
172//! - Network error | 网络错误
173//!
174//! SerializationError - JSON encode/decode failed | JSON序列化失败
175//! ```
176
177use std::sync::Arc;
178use chrono::{DateTime, Utc, Duration};
179use serde::{Deserialize, Serialize};
180use uuid::Uuid;
181use sa_token_adapter::storage::SaStorage;
182use crate::error::{SaTokenError, SaTokenResult};
183
184/// OAuth2 Client Information | OAuth2 客户端信息
185///
186/// Represents a registered OAuth2 client application with its credentials and configuration.
187/// 表示一个已注册的 OAuth2 客户端应用程序及其凭据和配置。
188#[derive(Debug, Clone, Serialize, Deserialize)]
189pub struct OAuth2Client {
190 /// Unique identifier for the client | 客户端的唯一标识符
191 pub client_id: String,
192
193 /// Secret key for client authentication | 客户端认证的密钥
194 pub client_secret: String,
195
196 /// Allowed redirect URIs (whitelist) | 允许的回调 URI(白名单)
197 pub redirect_uris: Vec<String>,
198
199 /// Supported grant types (e.g., "authorization_code", "refresh_token")
200 /// 支持的授权类型(例如:"authorization_code"、"refresh_token")
201 pub grant_types: Vec<String>,
202
203 /// Permitted scopes for this client | 此客户端允许的权限范围
204 pub scope: Vec<String>,
205}
206
207/// Authorization Code | 授权码
208///
209/// Temporary code issued after user authorization, exchanged for access token.
210/// 用户授权后颁发的临时代码,用于交换访问令牌。
211#[derive(Debug, Clone, Serialize, Deserialize)]
212pub struct AuthorizationCode {
213 /// The authorization code value | 授权码的值
214 pub code: String,
215
216 /// Client ID that requested the code | 请求授权码的客户端 ID
217 pub client_id: String,
218
219 /// User ID who authorized | 授权的用户 ID
220 pub user_id: String,
221
222 /// Redirect URI used in authorization request | 授权请求中使用的回调 URI
223 pub redirect_uri: String,
224
225 /// Granted scopes | 授予的权限范围
226 pub scope: Vec<String>,
227
228 /// Code creation timestamp | 授权码创建时间戳
229 pub created_at: DateTime<Utc>,
230
231 /// Code expiration timestamp | 授权码过期时间戳
232 pub expires_at: DateTime<Utc>,
233}
234
235/// Access Token Response | 访问令牌响应
236///
237/// Token response returned to the client after successful authorization.
238/// 成功授权后返回给客户端的令牌响应。
239#[derive(Debug, Clone, Serialize, Deserialize)]
240pub struct AccessToken {
241 /// The access token value | 访问令牌的值
242 pub access_token: String,
243
244 /// Token type (typically "Bearer") | 令牌类型(通常为 "Bearer")
245 pub token_type: String,
246
247 /// Token lifetime in seconds | 令牌有效期(秒)
248 pub expires_in: i64,
249
250 /// Optional refresh token for token renewal | 可选的刷新令牌,用于令牌续期
251 pub refresh_token: Option<String>,
252
253 /// Granted scopes | 授予的权限范围
254 pub scope: Vec<String>,
255}
256
257/// OAuth2 Token Information (for storage) | OAuth2 令牌信息(用于存储)
258///
259/// Internal structure for storing token details in the backend.
260/// 用于在后端存储令牌详细信息的内部结构。
261#[derive(Debug, Clone, Serialize, Deserialize)]
262pub struct OAuth2TokenInfo {
263 /// Access token value | 访问令牌值
264 pub access_token: String,
265
266 /// Client ID that owns this token | 拥有此令牌的客户端 ID
267 pub client_id: String,
268
269 /// User ID associated with this token | 与此令牌关联的用户 ID
270 pub user_id: String,
271
272 /// Granted scopes | 授予的权限范围
273 pub scope: Vec<String>,
274
275 /// Token creation timestamp | 令牌创建时间戳
276 pub created_at: DateTime<Utc>,
277
278 /// Token expiration timestamp | 令牌过期时间戳
279 pub expires_at: DateTime<Utc>,
280
281 /// Optional refresh token | 可选的刷新令牌
282 pub refresh_token: Option<String>,
283}
284
285/// OAuth2 Manager | OAuth2 管理器
286///
287/// Core manager for OAuth2 authorization code flow operations.
288/// OAuth2 授权码模式的核心管理器。
289///
290/// # Responsibilities | 职责
291/// - Client registration and verification | 客户端注册和验证
292/// - Authorization code generation and validation | 授权码生成和验证
293/// - Access token issuance and verification | 访问令牌颁发和验证
294/// - Refresh token management | 刷新令牌管理
295/// - Security validation (redirect URI, scope, etc.) | 安全验证(回调 URI、权限范围等)
296pub struct OAuth2Manager {
297 /// Storage backend for tokens and clients | 令牌和客户端的存储后端
298 storage: Arc<dyn SaStorage>,
299
300 /// Authorization code TTL in seconds (default: 600 = 10 minutes)
301 /// 授权码有效期(秒)(默认:600 = 10 分钟)
302 code_ttl: i64,
303
304 /// Access token TTL in seconds (default: 3600 = 1 hour)
305 /// 访问令牌有效期(秒)(默认:3600 = 1 小时)
306 token_ttl: i64,
307
308 /// Refresh token TTL in seconds (default: 2592000 = 30 days)
309 /// 刷新令牌有效期(秒)(默认:2592000 = 30 天)
310 refresh_token_ttl: i64,
311}
312
313impl OAuth2Manager {
314 /// Create a new OAuth2Manager with default TTL values
315 /// 使用默认 TTL 值创建新的 OAuth2Manager
316 ///
317 /// # Default TTL | 默认 TTL
318 /// - Authorization code: 600 seconds (10 minutes) | 授权码:600 秒(10 分钟)
319 /// - Access token: 3600 seconds (1 hour) | 访问令牌:3600 秒(1 小时)
320 /// - Refresh token: 2592000 seconds (30 days) | 刷新令牌:2592000 秒(30 天)
321 ///
322 /// # Arguments | 参数
323 /// * `storage` - Storage backend for persistence | 用于持久化的存储后端
324 pub fn new(storage: Arc<dyn SaStorage>) -> Self {
325 Self {
326 storage,
327 code_ttl: 600, // 10 minutes
328 token_ttl: 3600, // 1 hour
329 refresh_token_ttl: 2592000, // 30 days
330 }
331 }
332
333 /// Set custom TTL values for codes and tokens
334 /// 设置授权码和令牌的自定义 TTL 值
335 ///
336 /// # Arguments | 参数
337 /// * `code_ttl` - Authorization code TTL in seconds | 授权码 TTL(秒)
338 /// * `token_ttl` - Access token TTL in seconds | 访问令牌 TTL(秒)
339 /// * `refresh_token_ttl` - Refresh token TTL in seconds | 刷新令牌 TTL(秒)
340 ///
341 /// # Example | 示例
342 /// ```ignore
343 /// let oauth2 = OAuth2Manager::new(storage)
344 /// .with_ttl(300, 1800, 604800); // 5min, 30min, 7days
345 /// ```
346 pub fn with_ttl(mut self, code_ttl: i64, token_ttl: i64, refresh_token_ttl: i64) -> Self {
347 self.code_ttl = code_ttl;
348 self.token_ttl = token_ttl;
349 self.refresh_token_ttl = refresh_token_ttl;
350 self
351 }
352
353 /// Register a new OAuth2 client | 注册新的 OAuth2 客户端
354 ///
355 /// Stores client information in the backend for future authentication.
356 /// 将客户端信息存储在后端,用于未来的认证。
357 ///
358 /// # Arguments | 参数
359 /// * `client` - Client information to register | 要注册的客户端信息
360 ///
361 /// # Returns | 返回
362 /// * `Ok(())` on success | 成功时返回 `Ok(())`
363 /// * `Err(SaTokenError)` on storage or serialization error | 存储或序列化错误时返回错误
364 ///
365 /// # Example | 示例
366 /// ```ignore
367 /// let client = OAuth2Client {
368 /// client_id: "app_001".to_string(),
369 /// client_secret: "secret".to_string(),
370 /// redirect_uris: vec!["http://localhost/callback".to_string()],
371 /// grant_types: vec!["authorization_code".to_string()],
372 /// scope: vec!["read".to_string(), "write".to_string()],
373 /// };
374 /// oauth2.register_client(&client).await?;
375 /// ```
376 pub async fn register_client(&self, client: &OAuth2Client) -> SaTokenResult<()> {
377 let key = format!("oauth2:client:{}", client.client_id);
378 let value = serde_json::to_string(client)
379 .map_err(|e| SaTokenError::SerializationError(e))?;
380
381 self.storage.set(&key, &value, None).await
382 .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
383
384 Ok(())
385 }
386
387 /// Retrieve client information by client ID | 通过客户端 ID 检索客户端信息
388 ///
389 /// # Arguments | 参数
390 /// * `client_id` - Client identifier | 客户端标识符
391 ///
392 /// # Returns | 返回
393 /// * `Ok(OAuth2Client)` if found | 找到时返回客户端信息
394 /// * `Err(OAuth2ClientNotFound)` if client doesn't exist | 客户端不存在时返回错误
395 pub async fn get_client(&self, client_id: &str) -> SaTokenResult<OAuth2Client> {
396 let key = format!("oauth2:client:{}", client_id);
397 let value = self.storage.get(&key).await
398 .map_err(|e| SaTokenError::StorageError(e.to_string()))?
399 .ok_or_else(|| SaTokenError::OAuth2ClientNotFound)?;
400
401 serde_json::from_str(&value)
402 .map_err(|e| SaTokenError::SerializationError(e))
403 }
404
405 /// Verify client credentials | 验证客户端凭据
406 ///
407 /// Checks if the provided client_id and client_secret match.
408 /// 检查提供的 client_id 和 client_secret 是否匹配。
409 ///
410 /// # Arguments | 参数
411 /// * `client_id` - Client identifier | 客户端标识符
412 /// * `client_secret` - Client secret key | 客户端密钥
413 ///
414 /// # Returns | 返回
415 /// * `Ok(true)` if credentials are valid | 凭据有效时返回 `true`
416 /// * `Ok(false)` if credentials are invalid | 凭据无效时返回 `false`
417 /// * `Err(OAuth2ClientNotFound)` if client doesn't exist | 客户端不存在时返回错误
418 pub async fn verify_client(&self, client_id: &str, client_secret: &str) -> SaTokenResult<bool> {
419 let client = self.get_client(client_id).await?;
420 Ok(client.client_secret == client_secret)
421 }
422
423 /// Generate a new authorization code | 生成新的授权码
424 ///
425 /// Creates a temporary authorization code after user consent.
426 /// 在用户同意后创建临时授权码。
427 ///
428 /// # Arguments | 参数
429 /// * `client_id` - Client requesting authorization | 请求授权的客户端
430 /// * `user_id` - User granting authorization | 授予授权的用户
431 /// * `redirect_uri` - Callback URI for this authorization | 此授权的回调 URI
432 /// * `scope` - Granted permissions | 授予的权限
433 ///
434 /// # Returns | 返回
435 /// * `AuthorizationCode` with unique code and expiration | 带有唯一代码和过期时间的授权码
436 ///
437 /// # Note | 注意
438 /// This code must be stored using `store_authorization_code()` before returning to client.
439 /// 此代码必须在返回给客户端之前使用 `store_authorization_code()` 存储。
440 pub fn generate_authorization_code(
441 &self,
442 client_id: String,
443 user_id: String,
444 redirect_uri: String,
445 scope: Vec<String>,
446 ) -> AuthorizationCode {
447 let now = Utc::now();
448 let code = format!("code_{}", Uuid::new_v4().simple());
449
450 AuthorizationCode {
451 code,
452 client_id,
453 user_id,
454 redirect_uri,
455 scope,
456 created_at: now,
457 expires_at: now + Duration::seconds(self.code_ttl),
458 }
459 }
460
461 /// Store authorization code in backend | 在后端存储授权码
462 ///
463 /// Persists the authorization code with TTL for later exchange.
464 /// 使用 TTL 持久化授权码,以便稍后交换。
465 ///
466 /// # Arguments | 参数
467 /// * `auth_code` - Authorization code to store | 要存储的授权码
468 ///
469 /// # Storage Key Format | 存储键格式
470 /// `oauth2:code:{authorization_code}`
471 pub async fn store_authorization_code(&self, auth_code: &AuthorizationCode) -> SaTokenResult<()> {
472 let key = format!("oauth2:code:{}", auth_code.code);
473 let value = serde_json::to_string(auth_code)
474 .map_err(|e| SaTokenError::SerializationError(e))?;
475
476 let ttl = Some(std::time::Duration::from_secs(self.code_ttl as u64));
477 self.storage.set(&key, &value, ttl).await
478 .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
479
480 Ok(())
481 }
482
483 /// Retrieve authorization code information | 获取授权码信息
484 ///
485 /// Fetches the stored authorization code and validates its expiration.
486 /// 获取存储的授权码并验证其过期状态。
487 ///
488 /// # Arguments | 参数
489 /// * `code` - Authorization code to retrieve | 要获取的授权码
490 ///
491 /// # Returns | 返回
492 /// * `Ok(AuthorizationCode)` if code exists and is valid | 授权码存在且有效时返回
493 /// * `Err(OAuth2CodeNotFound)` if code not found | 授权码未找到时
494 /// * `Err(TokenExpired)` if code has expired | 授权码已过期时
495 ///
496 /// # Note | 注意
497 /// Expired codes are automatically cleaned up from storage.
498 /// 过期的授权码会自动从存储中清理。
499 pub async fn get_authorization_code(&self, code: &str) -> SaTokenResult<AuthorizationCode> {
500 let key = format!("oauth2:code:{}", code);
501 let value = self.storage.get(&key).await
502 .map_err(|e| SaTokenError::StorageError(e.to_string()))?
503 .ok_or_else(|| SaTokenError::OAuth2CodeNotFound)?;
504
505 let auth_code: AuthorizationCode = serde_json::from_str(&value)
506 .map_err(|e| SaTokenError::SerializationError(e))?;
507
508 // Check expiration and auto-cleanup if expired
509 if Utc::now() > auth_code.expires_at {
510 self.storage.delete(&key).await.ok();
511 return Err(SaTokenError::TokenExpired);
512 }
513
514 Ok(auth_code)
515 }
516
517 /// Consume authorization code (one-time use) | 消费授权码(一次性使用)
518 ///
519 /// Retrieves and deletes the authorization code in one operation.
520 /// 在一次操作中检索并删除授权码。
521 ///
522 /// # Arguments | 参数
523 /// * `code` - Authorization code to consume | 要消费的授权码
524 ///
525 /// # Returns | 返回
526 /// * `Ok(AuthorizationCode)` if code is valid and consumed | 授权码有效且已消费时返回
527 /// * `Err(OAuth2CodeNotFound)` if code not found | 授权码未找到时
528 /// * `Err(TokenExpired)` if code has expired | 授权码已过期时
529 ///
530 /// # Security | 安全性
531 /// This ensures the code can only be used once, preventing replay attacks.
532 /// 这确保授权码只能使用一次,防止重放攻击。
533 pub async fn consume_authorization_code(&self, code: &str) -> SaTokenResult<AuthorizationCode> {
534 // First, get and validate the code
535 let auth_code = self.get_authorization_code(code).await?;
536
537 // Then delete it (consume it)
538 let key = format!("oauth2:code:{}", code);
539 self.storage.delete(&key).await
540 .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
541
542 Ok(auth_code)
543 }
544
545 /// Exchange authorization code for access token | 用授权码换取访问令牌
546 ///
547 /// Core of the authorization code flow. Validates the code and issues tokens.
548 /// 授权码流程的核心。验证授权码并颁发令牌。
549 ///
550 /// # Validations | 验证
551 /// 1. Client credentials (client_id + client_secret) | 客户端凭据
552 /// 2. Authorization code exists and not expired | 授权码存在且未过期
553 /// 3. Client ID matches the code | 客户端 ID 与授权码匹配
554 /// 4. Redirect URI matches the code | 回调 URI 与授权码匹配
555 ///
556 /// # Arguments | 参数
557 /// * `code` - Authorization code | 授权码
558 /// * `client_id` - Client identifier | 客户端标识符
559 /// * `client_secret` - Client secret | 客户端密钥
560 /// * `redirect_uri` - Redirect URI used in authorization | 授权时使用的回调 URI
561 ///
562 /// # Returns | 返回
563 /// * `Ok(AccessToken)` with access_token and optional refresh_token | 带有访问令牌和可选刷新令牌
564 /// * `Err(OAuth2InvalidCredentials)` if client credentials invalid | 客户端凭据无效时
565 /// * `Err(OAuth2CodeNotFound)` if code not found or expired | 授权码未找到或已过期时
566 /// * `Err(OAuth2ClientIdMismatch)` if client ID doesn't match | 客户端 ID 不匹配时
567 /// * `Err(OAuth2RedirectUriMismatch)` if redirect URI doesn't match | 回调 URI 不匹配时
568 ///
569 /// # Security | 安全性
570 /// The authorization code is consumed (deleted) after use to prevent replay attacks.
571 /// 授权码在使用后被消费(删除),以防止重放攻击。
572 pub async fn exchange_code_for_token(
573 &self,
574 code: &str,
575 client_id: &str,
576 client_secret: &str,
577 redirect_uri: &str,
578 ) -> SaTokenResult<AccessToken> {
579 // 1. Verify client credentials
580 if !self.verify_client(client_id, client_secret).await? {
581 return Err(SaTokenError::OAuth2InvalidCredentials);
582 }
583
584 // 2. Consume the authorization code (one-time use)
585 let auth_code = self.consume_authorization_code(code).await?;
586
587 // 3. Validate client ID matches
588 if auth_code.client_id != client_id {
589 return Err(SaTokenError::OAuth2ClientIdMismatch);
590 }
591
592 // 4. Validate redirect URI matches
593 if auth_code.redirect_uri != redirect_uri {
594 return Err(SaTokenError::OAuth2RedirectUriMismatch);
595 }
596
597 // 5. Generate and return access token
598 self.generate_access_token(&auth_code.client_id, &auth_code.user_id, auth_code.scope).await
599 }
600
601 /// Generate access token and refresh token | 生成访问令牌和刷新令牌
602 ///
603 /// Creates a new access token with an optional refresh token for the user.
604 /// 为用户创建新的访问令牌和可选的刷新令牌。
605 ///
606 /// # Arguments | 参数
607 /// * `client_id` - Client identifier | 客户端标识符
608 /// * `user_id` - User identifier | 用户标识符
609 /// * `scope` - Granted permissions | 授予的权限范围
610 ///
611 /// # Returns | 返回
612 /// * `Ok(AccessToken)` with access_token and refresh_token | 带有访问令牌和刷新令牌
613 ///
614 /// # Storage | 存储
615 /// - Access token: `oauth2:token:{access_token}` (TTL: token_ttl)
616 /// - Refresh token: `oauth2:refresh:{refresh_token}` (TTL: refresh_token_ttl)
617 ///
618 /// # Note | 注意
619 /// Both tokens are stored with TTL for automatic expiration cleanup.
620 /// 两个令牌都使用 TTL 存储,以便自动清理过期令牌。
621 pub async fn generate_access_token(
622 &self,
623 client_id: &str,
624 user_id: &str,
625 scope: Vec<String>,
626 ) -> SaTokenResult<AccessToken> {
627 let now = Utc::now();
628 let access_token = format!("at_{}", Uuid::new_v4().simple());
629 let refresh_token = format!("rt_{}", Uuid::new_v4().simple());
630
631 // Create token info for storage
632 let token_info = OAuth2TokenInfo {
633 access_token: access_token.clone(),
634 client_id: client_id.to_string(),
635 user_id: user_id.to_string(),
636 scope: scope.clone(),
637 created_at: now,
638 expires_at: now + Duration::seconds(self.token_ttl),
639 refresh_token: Some(refresh_token.clone()),
640 };
641
642 // Store access token with TTL
643 let key = format!("oauth2:token:{}", access_token);
644 let value = serde_json::to_string(&token_info)
645 .map_err(|e| SaTokenError::SerializationError(e))?;
646
647 let ttl = Some(std::time::Duration::from_secs(self.token_ttl as u64));
648 self.storage.set(&key, &value, ttl).await
649 .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
650
651 // Store refresh token with longer TTL
652 let refresh_key = format!("oauth2:refresh:{}", refresh_token);
653 let refresh_value = serde_json::json!({
654 "user_id": user_id,
655 "client_id": client_id,
656 "scope": scope,
657 }).to_string();
658
659 let refresh_ttl = Some(std::time::Duration::from_secs(self.refresh_token_ttl as u64));
660 self.storage.set(&refresh_key, &refresh_value, refresh_ttl).await
661 .map_err(|e| SaTokenError::StorageError(e.to_string()))?;
662
663 // Return the access token response
664 Ok(AccessToken {
665 access_token,
666 token_type: "Bearer".to_string(),
667 expires_in: self.token_ttl,
668 refresh_token: Some(refresh_token),
669 scope,
670 })
671 }
672
673 /// Verify access token and retrieve token information | 验证访问令牌并检索令牌信息
674 ///
675 /// Checks if the access token is valid and not expired.
676 /// 检查访问令牌是否有效且未过期。
677 ///
678 /// # Arguments | 参数
679 /// * `access_token` - Access token to verify | 要验证的访问令牌
680 ///
681 /// # Returns | 返回
682 /// * `Ok(OAuth2TokenInfo)` if token is valid | 令牌有效时返回令牌信息
683 /// * `Err(OAuth2AccessTokenNotFound)` if token not found | 令牌未找到时
684 /// * `Err(TokenExpired)` if token has expired | 令牌已过期时
685 ///
686 /// # Note | 注意
687 /// Expired tokens are automatically cleaned up from storage.
688 /// 过期的令牌会自动从存储中清理。
689 pub async fn verify_access_token(&self, access_token: &str) -> SaTokenResult<OAuth2TokenInfo> {
690 let key = format!("oauth2:token:{}", access_token);
691 let value = self.storage.get(&key).await
692 .map_err(|e| SaTokenError::StorageError(e.to_string()))?
693 .ok_or_else(|| SaTokenError::OAuth2AccessTokenNotFound)?;
694
695 let token_info: OAuth2TokenInfo = serde_json::from_str(&value)
696 .map_err(|e| SaTokenError::SerializationError(e))?;
697
698 // Check expiration and auto-cleanup if expired
699 if Utc::now() > token_info.expires_at {
700 self.storage.delete(&key).await.ok();
701 return Err(SaTokenError::TokenExpired);
702 }
703
704 Ok(token_info)
705 }
706
707 /// Refresh access token using refresh token | 使用刷新令牌刷新访问令牌
708 ///
709 /// Issues a new access token (and optionally a new refresh token) when the old one expires.
710 /// 当旧令牌过期时颁发新的访问令牌(以及可选的新刷新令牌)。
711 ///
712 /// # Validations | 验证
713 /// 1. Client credentials are valid | 客户端凭据有效
714 /// 2. Refresh token exists and belongs to the client | 刷新令牌存在且属于该客户端
715 /// 3. Client ID matches the refresh token | 客户端 ID 与刷新令牌匹配
716 ///
717 /// # Arguments | 参数
718 /// * `refresh_token` - Refresh token | 刷新令牌
719 /// * `client_id` - Client identifier | 客户端标识符
720 /// * `client_secret` - Client secret | 客户端密钥
721 ///
722 /// # Returns | 返回
723 /// * `Ok(AccessToken)` with new access_token and refresh_token | 新的访问令牌和刷新令牌
724 /// * `Err(OAuth2InvalidCredentials)` if credentials invalid | 凭据无效时
725 /// * `Err(OAuth2RefreshTokenNotFound)` if refresh token not found | 刷新令牌未找到时
726 /// * `Err(OAuth2ClientIdMismatch)` if client ID doesn't match | 客户端 ID 不匹配时
727 pub async fn refresh_access_token(
728 &self,
729 refresh_token: &str,
730 client_id: &str,
731 client_secret: &str,
732 ) -> SaTokenResult<AccessToken> {
733 // 1. Verify client credentials
734 if !self.verify_client(client_id, client_secret).await? {
735 return Err(SaTokenError::OAuth2InvalidCredentials);
736 }
737
738 // 2. Get refresh token data from storage
739 let key = format!("oauth2:refresh:{}", refresh_token);
740 let value = self.storage.get(&key).await
741 .map_err(|e| SaTokenError::StorageError(e.to_string()))?
742 .ok_or_else(|| SaTokenError::OAuth2RefreshTokenNotFound)?;
743
744 let data: serde_json::Value = serde_json::from_str(&value)
745 .map_err(|e| SaTokenError::SerializationError(e))?;
746
747 // 3. Validate client ID matches
748 let stored_client_id = data["client_id"].as_str()
749 .ok_or_else(|| SaTokenError::OAuth2InvalidRefreshToken)?;
750
751 if stored_client_id != client_id {
752 return Err(SaTokenError::OAuth2ClientIdMismatch);
753 }
754
755 // 4. Extract user ID and scope
756 let user_id = data["user_id"].as_str()
757 .ok_or_else(|| SaTokenError::OAuth2InvalidRefreshToken)?;
758
759 let scope: Vec<String> = data["scope"].as_array()
760 .ok_or_else(|| SaTokenError::OAuth2InvalidScope)?
761 .iter()
762 .filter_map(|v| v.as_str().map(|s| s.to_string()))
763 .collect();
764
765 // 5. Generate new access token with same scope
766 self.generate_access_token(client_id, user_id, scope).await
767 }
768
769 /// Revoke an access token or refresh token | 撤销访问令牌或刷新令牌
770 ///
771 /// Deletes the token from storage, making it immediately invalid.
772 /// 从存储中删除令牌,使其立即失效。
773 ///
774 /// # Arguments | 参数
775 /// * `token` - Token to revoke (access or refresh) | 要撤销的令牌(访问或刷新)
776 ///
777 /// # Use Cases | 使用场景
778 /// - User logout | 用户登出
779 /// - Security breach | 安全漏洞
780 /// - Client revocation | 客户端撤销
781 pub async fn revoke_token(&self, token: &str) -> SaTokenResult<()> {
782 let access_key = format!("oauth2:token:{}", token);
783 let refresh_key = format!("oauth2:refresh:{}", token);
784
785 self.storage.delete(&access_key).await.ok();
786 self.storage.delete(&refresh_key).await.ok();
787
788 Ok(())
789 }
790
791 /// Validate redirect URI against client's whitelist | 根据客户端白名单验证回调 URI
792 ///
793 /// Security check to prevent redirect URI hijacking.
794 /// 安全检查以防止回调 URI 劫持。
795 ///
796 /// # Arguments | 参数
797 /// * `client` - Client information with registered URIs | 带有注册 URI 的客户端信息
798 /// * `redirect_uri` - URI to validate | 要验证的 URI
799 ///
800 /// # Returns | 返回
801 /// * `true` if URI is in the whitelist | URI 在白名单中时返回 `true`
802 /// * `false` if URI is not allowed | URI 不被允许时返回 `false`
803 pub fn validate_redirect_uri(&self, client: &OAuth2Client, redirect_uri: &str) -> bool {
804 client.redirect_uris.iter().any(|uri| uri == redirect_uri)
805 }
806
807 /// Validate requested scopes against client's permitted scopes | 根据客户端允许的范围验证请求的权限范围
808 ///
809 /// Ensures requested scopes are a subset of client's permitted scopes.
810 /// 确保请求的权限范围是客户端允许范围的子集。
811 ///
812 /// # Arguments | 参数
813 /// * `client` - Client information with permitted scopes | 带有允许权限范围的客户端信息
814 /// * `requested_scope` - Scopes being requested | 正在请求的权限范围
815 ///
816 /// # Returns | 返回
817 /// * `true` if all requested scopes are permitted | 所有请求的权限范围都被允许时返回 `true`
818 /// * `false` if any requested scope is not permitted | 任何请求的权限范围不被允许时返回 `false`
819 pub fn validate_scope(&self, client: &OAuth2Client, requested_scope: &[String]) -> bool {
820 requested_scope.iter().all(|s| client.scope.contains(s))
821 }
822}
823
824#[cfg(test)]
825mod tests {
826 use super::*;
827 use sa_token_storage_memory::MemoryStorage;
828
829 #[tokio::test]
830 async fn test_oauth2_authorization_code_flow() {
831 let storage = Arc::new(MemoryStorage::new());
832 let oauth2 = OAuth2Manager::new(storage);
833
834 let client = OAuth2Client {
835 client_id: "test_client".to_string(),
836 client_secret: "test_secret".to_string(),
837 redirect_uris: vec!["http://localhost:3000/callback".to_string()],
838 grant_types: vec!["authorization_code".to_string()],
839 scope: vec!["read".to_string(), "write".to_string()],
840 };
841
842 oauth2.register_client(&client).await.unwrap();
843
844 let auth_code = oauth2.generate_authorization_code(
845 "test_client".to_string(),
846 "user_123".to_string(),
847 "http://localhost:3000/callback".to_string(),
848 vec!["read".to_string()],
849 );
850
851 oauth2.store_authorization_code(&auth_code).await.unwrap();
852
853 let token = oauth2.exchange_code_for_token(
854 &auth_code.code,
855 "test_client",
856 "test_secret",
857 "http://localhost:3000/callback",
858 ).await.unwrap();
859
860 assert_eq!(token.token_type, "Bearer");
861 assert!(token.refresh_token.is_some());
862
863 let token_info = oauth2.verify_access_token(&token.access_token).await.unwrap();
864 assert_eq!(token_info.user_id, "user_123");
865 assert_eq!(token_info.client_id, "test_client");
866 }
867
868 #[tokio::test]
869 async fn test_refresh_token() {
870 let storage = Arc::new(MemoryStorage::new());
871 let oauth2 = OAuth2Manager::new(storage);
872
873 let client = OAuth2Client {
874 client_id: "test_client".to_string(),
875 client_secret: "test_secret".to_string(),
876 redirect_uris: vec!["http://localhost:3000/callback".to_string()],
877 grant_types: vec!["authorization_code".to_string(), "refresh_token".to_string()],
878 scope: vec!["read".to_string()],
879 };
880
881 oauth2.register_client(&client).await.unwrap();
882
883 let token = oauth2.generate_access_token(
884 "test_client",
885 "user_123",
886 vec!["read".to_string()],
887 ).await.unwrap();
888
889 let refresh_token = token.refresh_token.as_ref().unwrap();
890 let new_token = oauth2.refresh_access_token(
891 refresh_token,
892 "test_client",
893 "test_secret",
894 ).await.unwrap();
895
896 assert_ne!(new_token.access_token, token.access_token);
897 }
898}
899