sa_token_core/
ws.rs

1//! WebSocket Authentication Module | WebSocket 认证模块
2//!
3//! # Code Flow Logic | 代码流程逻辑
4//!
5//! ## English
6//! 
7//! ### Overview
8//! This module provides WebSocket authentication capabilities for sa-token-rust.
9//! It handles token extraction from various sources (headers, query parameters)
10//! and validates them against the token manager.
11//! 
12//! ### Authentication Flow
13//! ```text
14//! 1. WebSocket Connection Request
15//!    ↓
16//! 2. WsAuthManager.authenticate(headers, query)
17//!    ↓
18//! 3. WsTokenExtractor.extract_token()
19//!    ├─→ Check Authorization Header (Bearer Token)
20//!    ├─→ Check Sec-WebSocket-Protocol Header
21//!    └─→ Check Query Parameter (?token=xxx)
22//!    ↓
23//! 4. Found Token → Create TokenValue
24//!    ↓
25//! 5. SaTokenManager.get_token_info(token)
26//!    ↓
27//! 6. Validate Token Expiration
28//!    ├─→ Expired → Return TokenExpired Error
29//!    └─→ Valid → Continue
30//!    ↓
31//! 7. Generate WebSocket Session ID
32//!    Format: ws:{login_id}:{uuid}
33//!    ↓
34//! 8. Create WsAuthInfo
35//!    - login_id: User identifier
36//!    - token: Original token string
37//!    - session_id: Unique WebSocket session ID
38//!    - connect_time: Connection timestamp
39//!    - metadata: Custom key-value data
40//!    ↓
41//! 9. Publish Login Event
42//!    SaTokenEvent::login(login_id, token)
43//!    └─→ Mark as "websocket" login type
44//!    └─→ Trigger all registered event listeners
45//!    ↓
46//! 10. Return WsAuthInfo
47//! ```
48//! 
49//! ### Token Extraction Priority
50//! 1. Authorization Header: `Bearer {token}`
51//! 2. Sec-WebSocket-Protocol Header: `{token}`
52//! 3. Query Parameter: `?token={token}`
53//! 
54//! ### Extension Points
55//! - Custom WsTokenExtractor: Implement your own token extraction logic
56//! - WsAuthInfo.metadata: Store custom connection data
57//!
58//! ## 中文
59//! 
60//! ### 概述
61//! 本模块为 sa-token-rust 提供 WebSocket 认证功能。
62//! 它负责从多种来源(请求头、查询参数)提取 Token 并通过 Token 管理器进行验证。
63//! 
64//! ### 认证流程
65//! ```text
66//! 1. WebSocket 连接请求
67//!    ↓
68//! 2. WsAuthManager.authenticate(headers, query)
69//!    ↓
70//! 3. WsTokenExtractor.extract_token()
71//!    ├─→ 检查 Authorization 请求头 (Bearer Token)
72//!    ├─→ 检查 Sec-WebSocket-Protocol 请求头
73//!    └─→ 检查查询参数 (?token=xxx)
74//!    ↓
75//! 4. 找到 Token → 创建 TokenValue
76//!    ↓
77//! 5. SaTokenManager.get_token_info(token)
78//!    ↓
79//! 6. 验证 Token 过期时间
80//!    ├─→ 已过期 → 返回 TokenExpired 错误
81//!    └─→ 有效 → 继续
82//!    ↓
83//! 7. 生成 WebSocket 会话 ID
84//!    格式: ws:{login_id}:{uuid}
85//!    ↓
86//! 8. 创建 WsAuthInfo
87//!    - login_id: 用户标识
88//!    - token: 原始 Token 字符串
89//!    - session_id: 唯一的 WebSocket 会话 ID
90//!    - connect_time: 连接时间戳
91//!    - metadata: 自定义键值数据
92//!    ↓
93//! 9. 发布 Login 事件
94//!    SaTokenEvent::login(login_id, token)
95//!    └─→ 标记为 "websocket" 登录类型
96//!    └─→ 触发所有已注册的事件监听器
97//!    ↓
98//! 10. 返回 WsAuthInfo
99//! ```
100//! 
101//! ### Token 提取优先级
102//! 1. Authorization 请求头: `Bearer {token}`
103//! 2. Sec-WebSocket-Protocol 请求头: `{token}`
104//! 3. 查询参数: `?token={token}`
105//! 
106//! ### 扩展点
107//! - 自定义 WsTokenExtractor: 实现自己的 Token 提取逻辑
108//! - WsAuthInfo.metadata: 存储自定义连接数据
109
110use crate::error::SaTokenError;
111use crate::manager::SaTokenManager;
112use crate::token::TokenValue;
113use crate::event::SaTokenEvent;
114use async_trait::async_trait;
115use std::collections::HashMap;
116use std::sync::Arc;
117
118/// WebSocket authentication information
119/// WebSocket 认证信息
120///
121/// Contains all the information about an authenticated WebSocket connection
122/// 包含已认证的 WebSocket 连接的所有信息
123#[derive(Debug, Clone)]
124pub struct WsAuthInfo {
125    /// User login ID | 用户登录 ID
126    pub login_id: String,
127    
128    /// Authentication token | 认证 Token
129    pub token: String,
130    
131    /// Unique WebSocket session ID | 唯一的 WebSocket 会话 ID
132    /// Format: ws:{login_id}:{uuid}
133    pub session_id: String,
134    
135    /// Connection timestamp | 连接时间戳
136    pub connect_time: chrono::DateTime<chrono::Utc>,
137    
138    /// Custom metadata for this connection | 该连接的自定义元数据
139    pub metadata: HashMap<String, String>,
140}
141
142/// Token extractor trait for WebSocket connections
143/// WebSocket 连接的 Token 提取器 trait
144///
145/// Implement this trait to customize token extraction logic
146/// 实现此 trait 以自定义 Token 提取逻辑
147#[async_trait]
148pub trait WsTokenExtractor: Send + Sync {
149    /// Extract token from headers and query parameters
150    /// 从请求头和查询参数中提取 Token
151    ///
152    /// # Arguments | 参数
153    /// * `headers` - HTTP headers | HTTP 请求头
154    /// * `query` - Query parameters | 查询参数
155    ///
156    /// # Returns | 返回值
157    /// * `Some(token)` - Token found | 找到 Token
158    /// * `None` - No token found | 未找到 Token
159    async fn extract_token(&self, headers: &HashMap<String, String>, query: &HashMap<String, String>) -> Option<String>;
160}
161
162/// Default token extractor implementation
163/// 默认的 Token 提取器实现
164///
165/// Extracts tokens from:
166/// 从以下位置提取 Token:
167/// 1. Authorization header (Bearer token)
168/// 2. Sec-WebSocket-Protocol header
169/// 3. Query parameter "token"
170pub struct DefaultWsTokenExtractor;
171
172#[async_trait]
173impl WsTokenExtractor for DefaultWsTokenExtractor {
174    async fn extract_token(&self, headers: &HashMap<String, String>, query: &HashMap<String, String>) -> Option<String> {
175        // Priority 1: Authorization header with Bearer scheme
176        // 优先级 1: Authorization 请求头(Bearer 方式)
177        if let Some(token) = headers.get("Authorization") {
178            return Some(token.trim_start_matches("Bearer ").to_string());
179        }
180        
181        // Priority 2: WebSocket Protocol header
182        // 优先级 2: WebSocket Protocol 请求头
183        if let Some(token) = headers.get("Sec-WebSocket-Protocol") {
184            return Some(token.to_string());
185        }
186        
187        // Priority 3: Query parameter
188        // 优先级 3: 查询参数
189        if let Some(token) = query.get("token") {
190            return Some(token.to_string());
191        }
192        
193        None
194    }
195}
196
197/// WebSocket authentication manager
198/// WebSocket 认证管理器
199///
200/// Provides authentication and verification for WebSocket connections
201/// 为 WebSocket 连接提供认证和验证功能
202pub struct WsAuthManager {
203    /// Reference to the token manager | Token 管理器引用
204    manager: Arc<SaTokenManager>,
205    
206    /// Token extractor implementation | Token 提取器实现
207    extractor: Arc<dyn WsTokenExtractor>,
208}
209
210impl WsAuthManager {
211    /// Create a new WebSocket authentication manager with default extractor
212    /// 使用默认提取器创建新的 WebSocket 认证管理器
213    ///
214    /// # Arguments | 参数
215    /// * `manager` - SaToken manager instance | SaToken 管理器实例
216    ///
217    /// # Example | 示例
218    /// ```rust,ignore
219    /// let ws_auth = WsAuthManager::new(manager);
220    /// ```
221    pub fn new(manager: Arc<SaTokenManager>) -> Self {
222        Self {
223            manager,
224            extractor: Arc::new(DefaultWsTokenExtractor),
225        }
226    }
227
228    /// Create a new WebSocket authentication manager with custom extractor
229    /// 使用自定义提取器创建新的 WebSocket 认证管理器
230    ///
231    /// # Arguments | 参数
232    /// * `manager` - SaToken manager instance | SaToken 管理器实例
233    /// * `extractor` - Custom token extractor | 自定义 Token 提取器
234    ///
235    /// # Example | 示例
236    /// ```rust,ignore
237    /// let custom_extractor = Arc::new(MyCustomExtractor);
238    /// let ws_auth = WsAuthManager::with_extractor(manager, custom_extractor);
239    /// ```
240    pub fn with_extractor(manager: Arc<SaTokenManager>, extractor: Arc<dyn WsTokenExtractor>) -> Self {
241        Self {
242            manager,
243            extractor,
244        }
245    }
246
247    /// Authenticate a WebSocket connection
248    /// 认证 WebSocket 连接
249    ///
250    /// This method will trigger a Login event after successful authentication
251    /// 此方法在认证成功后会触发 Login 事件
252    ///
253    /// # Arguments | 参数
254    /// * `headers` - HTTP headers from the WebSocket handshake | WebSocket 握手的 HTTP 请求头
255    /// * `query` - Query parameters from the connection URL | 连接 URL 的查询参数
256    ///
257    /// # Returns | 返回值
258    /// * `Ok(WsAuthInfo)` - Authentication successful | 认证成功
259    /// * `Err(SaTokenError)` - Authentication failed | 认证失败
260    ///
261    /// # Errors | 错误
262    /// * `NotLogin` - No token found | 未找到 Token
263    /// * `TokenNotFound` - Token not found in storage | 存储中未找到 Token
264    /// * `TokenExpired` - Token has expired | Token 已过期
265    ///
266    /// # Events | 事件
267    /// Publishes `SaTokenEvent::Login` with login_type = "websocket"
268    /// 发布 `SaTokenEvent::Login` 事件,login_type = "websocket"
269    ///
270    /// # Example | 示例
271    /// ```rust,ignore
272    /// let mut headers = HashMap::new();
273    /// headers.insert("Authorization".to_string(), "Bearer token123".to_string());
274    /// 
275    /// let auth_info = ws_auth.authenticate(&headers, &HashMap::new()).await?;
276    /// println!("User {} connected", auth_info.login_id);
277    /// 
278    /// // Event listeners will be notified of WebSocket authentication
279    /// // 事件监听器将收到 WebSocket 认证通知
280    /// ```
281    pub async fn authenticate(
282        &self,
283        headers: &HashMap<String, String>,
284        query: &HashMap<String, String>,
285    ) -> Result<WsAuthInfo, SaTokenError> {
286        // Step 1: Extract token from request
287        // 步骤 1: 从请求中提取 Token
288        let token_str = self.extractor.extract_token(headers, query).await
289            .ok_or(SaTokenError::NotLogin)?;
290
291        // Step 2: Convert to TokenValue and get token info
292        // 步骤 2: 转换为 TokenValue 并获取 Token 信息
293        let token = TokenValue::new(token_str.clone());
294        let token_info = self.manager.get_token_info(&token).await?;
295        
296        // Step 3: Validate token expiration
297        // 步骤 3: 验证 Token 过期时间
298        if let Some(expire_time) = token_info.expire_time {
299            if chrono::Utc::now() > expire_time {
300                return Err(SaTokenError::TokenExpired);
301            }
302        }
303
304        // Step 4: Generate unique WebSocket session ID
305        // 步骤 4: 生成唯一的 WebSocket 会话 ID
306        let login_id = token_info.login_id.clone();
307        let session_id = format!("ws:{}:{}", login_id, uuid::Uuid::new_v4());
308
309        // Step 5: Create authentication info
310        // 步骤 5: 创建认证信息
311        let auth_info = WsAuthInfo {
312            login_id: login_id.clone(),
313            token: token_str.clone(),
314            session_id,
315            connect_time: chrono::Utc::now(),
316            metadata: HashMap::new(),
317        };
318
319        // Step 6: Publish WebSocket authentication event (Login event with websocket type)
320        // 步骤 6: 发布 WebSocket 认证事件(标记为 websocket 类型的 Login 事件)
321        let event = SaTokenEvent::login(login_id, &token_str)
322            .with_login_type("websocket");
323        self.manager.event_bus().publish(event).await;
324
325        // Step 7: Return authentication info
326        // 步骤 7: 返回认证信息
327        Ok(auth_info)
328    }
329
330    /// Verify a token and return the login ID
331    /// 验证 Token 并返回登录 ID
332    ///
333    /// # Arguments | 参数
334    /// * `token` - Token string to verify | 要验证的 Token 字符串
335    ///
336    /// # Returns | 返回值
337    /// * `Ok(login_id)` - Token is valid | Token 有效
338    /// * `Err(SaTokenError)` - Token is invalid or expired | Token 无效或已过期
339    ///
340    /// # Example | 示例
341    /// ```rust,ignore
342    /// let login_id = ws_auth.verify_token("token123").await?;
343    /// println!("Token belongs to user: {}", login_id);
344    /// ```
345    pub async fn verify_token(&self, token: &str) -> Result<String, SaTokenError> {
346        let token_value = TokenValue::new(token.to_string());
347        let token_info = self.manager.get_token_info(&token_value).await?;
348        
349        // Validate expiration | 验证过期时间
350        if let Some(expire_time) = token_info.expire_time {
351            if chrono::Utc::now() > expire_time {
352                return Err(SaTokenError::TokenExpired);
353            }
354        }
355
356        Ok(token_info.login_id)
357    }
358
359    /// Refresh a WebSocket session by verifying its token
360    /// 通过验证 Token 刷新 WebSocket 会话
361    ///
362    /// # Arguments | 参数
363    /// * `auth_info` - WebSocket authentication info | WebSocket 认证信息
364    ///
365    /// # Returns | 返回值
366    /// * `Ok(())` - Session refreshed successfully | 会话刷新成功
367    /// * `Err(SaTokenError)` - Token is invalid or expired | Token 无效或已过期
368    ///
369    /// # Example | 示例
370    /// ```rust,ignore
371    /// ws_auth.refresh_ws_session(&auth_info).await?;
372    /// ```
373    pub async fn refresh_ws_session(&self, auth_info: &WsAuthInfo) -> Result<(), SaTokenError> {
374        self.verify_token(&auth_info.token).await?;
375        Ok(())
376    }
377}
378
379#[cfg(test)]
380mod tests {
381    use super::*;
382    use crate::config::SaTokenConfig;
383    use sa_token_storage_memory::MemoryStorage;
384
385    #[tokio::test]
386    async fn test_ws_auth_manager() {
387        let config = SaTokenConfig::default();
388        let storage = Arc::new(MemoryStorage::new());
389        let manager = Arc::new(SaTokenManager::new(storage, config));
390        
391        let ws_manager = WsAuthManager::new(manager.clone());
392        
393        let token = manager.login("user123").await.unwrap();
394        
395        let mut headers = HashMap::new();
396        headers.insert("Authorization".to_string(), format!("Bearer {}", token.as_str()));
397        
398        let auth_info = ws_manager.authenticate(&headers, &HashMap::new()).await.unwrap();
399        assert_eq!(auth_info.login_id, "user123");
400    }
401
402    #[tokio::test]
403    async fn test_token_extraction_from_query() {
404        let config = SaTokenConfig::default();
405        let storage = Arc::new(MemoryStorage::new());
406        let manager = Arc::new(SaTokenManager::new(storage, config));
407        
408        let ws_manager = WsAuthManager::new(manager.clone());
409        
410        let token = manager.login("user456").await.unwrap();
411        
412        let mut query = HashMap::new();
413        query.insert("token".to_string(), token.as_str().to_string());
414        
415        let auth_info = ws_manager.authenticate(&HashMap::new(), &query).await.unwrap();
416        assert_eq!(auth_info.login_id, "user456");
417    }
418
419    #[tokio::test]
420    async fn test_verify_token() {
421        let config = SaTokenConfig::default();
422        let storage = Arc::new(MemoryStorage::new());
423        let manager = Arc::new(SaTokenManager::new(storage, config));
424        
425        let ws_manager = WsAuthManager::new(manager.clone());
426        
427        let token = manager.login("user789").await.unwrap();
428        
429        let login_id = ws_manager.verify_token(token.as_str()).await.unwrap();
430        assert_eq!(login_id, "user789");
431    }
432}