sa_token_core/ws.rs
1//! WebSocket Authentication Module | WebSocket 认证模块
2//!
3//! # Code Flow Logic | 代码流程逻辑
4//!
5//! ## English
6//!
7//! ### Overview
8//! This module provides WebSocket authentication capabilities for sa-token-rust.
9//! It handles token extraction from various sources (headers, query parameters)
10//! and validates them against the token manager.
11//!
12//! ### Authentication Flow
13//! ```text
14//! 1. WebSocket Connection Request
15//! ↓
16//! 2. WsAuthManager.authenticate(headers, query)
17//! ↓
18//! 3. WsTokenExtractor.extract_token()
19//! ├─→ Check Authorization Header (Bearer Token)
20//! ├─→ Check Sec-WebSocket-Protocol Header
21//! └─→ Check Query Parameter (?token=xxx)
22//! ↓
23//! 4. Found Token → Create TokenValue
24//! ↓
25//! 5. SaTokenManager.get_token_info(token)
26//! ↓
27//! 6. Validate Token Expiration
28//! ├─→ Expired → Return TokenExpired Error
29//! └─→ Valid → Continue
30//! ↓
31//! 7. Generate WebSocket Session ID
32//! Format: ws:{login_id}:{uuid}
33//! ↓
34//! 8. Create WsAuthInfo
35//! - login_id: User identifier
36//! - token: Original token string
37//! - session_id: Unique WebSocket session ID
38//! - connect_time: Connection timestamp
39//! - metadata: Custom key-value data
40//! ↓
41//! 9. Publish Login Event
42//! SaTokenEvent::login(login_id, token)
43//! └─→ Mark as "websocket" login type
44//! └─→ Trigger all registered event listeners
45//! ↓
46//! 10. Return WsAuthInfo
47//! ```
48//!
49//! ### Token Extraction Priority
50//! 1. Authorization Header: `Bearer {token}`
51//! 2. Sec-WebSocket-Protocol Header: `{token}`
52//! 3. Query Parameter: `?token={token}`
53//!
54//! ### Extension Points
55//! - Custom WsTokenExtractor: Implement your own token extraction logic
56//! - WsAuthInfo.metadata: Store custom connection data
57//!
58//! ## 中文
59//!
60//! ### 概述
61//! 本模块为 sa-token-rust 提供 WebSocket 认证功能。
62//! 它负责从多种来源(请求头、查询参数)提取 Token 并通过 Token 管理器进行验证。
63//!
64//! ### 认证流程
65//! ```text
66//! 1. WebSocket 连接请求
67//! ↓
68//! 2. WsAuthManager.authenticate(headers, query)
69//! ↓
70//! 3. WsTokenExtractor.extract_token()
71//! ├─→ 检查 Authorization 请求头 (Bearer Token)
72//! ├─→ 检查 Sec-WebSocket-Protocol 请求头
73//! └─→ 检查查询参数 (?token=xxx)
74//! ↓
75//! 4. 找到 Token → 创建 TokenValue
76//! ↓
77//! 5. SaTokenManager.get_token_info(token)
78//! ↓
79//! 6. 验证 Token 过期时间
80//! ├─→ 已过期 → 返回 TokenExpired 错误
81//! └─→ 有效 → 继续
82//! ↓
83//! 7. 生成 WebSocket 会话 ID
84//! 格式: ws:{login_id}:{uuid}
85//! ↓
86//! 8. 创建 WsAuthInfo
87//! - login_id: 用户标识
88//! - token: 原始 Token 字符串
89//! - session_id: 唯一的 WebSocket 会话 ID
90//! - connect_time: 连接时间戳
91//! - metadata: 自定义键值数据
92//! ↓
93//! 9. 发布 Login 事件
94//! SaTokenEvent::login(login_id, token)
95//! └─→ 标记为 "websocket" 登录类型
96//! └─→ 触发所有已注册的事件监听器
97//! ↓
98//! 10. 返回 WsAuthInfo
99//! ```
100//!
101//! ### Token 提取优先级
102//! 1. Authorization 请求头: `Bearer {token}`
103//! 2. Sec-WebSocket-Protocol 请求头: `{token}`
104//! 3. 查询参数: `?token={token}`
105//!
106//! ### 扩展点
107//! - 自定义 WsTokenExtractor: 实现自己的 Token 提取逻辑
108//! - WsAuthInfo.metadata: 存储自定义连接数据
109
110use crate::error::SaTokenError;
111use crate::manager::SaTokenManager;
112use crate::token::TokenValue;
113use crate::event::SaTokenEvent;
114use async_trait::async_trait;
115use std::collections::HashMap;
116use std::sync::Arc;
117
118/// WebSocket authentication information
119/// WebSocket 认证信息
120///
121/// Contains all the information about an authenticated WebSocket connection
122/// 包含已认证的 WebSocket 连接的所有信息
123#[derive(Debug, Clone)]
124pub struct WsAuthInfo {
125 /// User login ID | 用户登录 ID
126 pub login_id: String,
127
128 /// Authentication token | 认证 Token
129 pub token: String,
130
131 /// Unique WebSocket session ID | 唯一的 WebSocket 会话 ID
132 /// Format: ws:{login_id}:{uuid}
133 pub session_id: String,
134
135 /// Connection timestamp | 连接时间戳
136 pub connect_time: chrono::DateTime<chrono::Utc>,
137
138 /// Custom metadata for this connection | 该连接的自定义元数据
139 pub metadata: HashMap<String, String>,
140}
141
142/// Token extractor trait for WebSocket connections
143/// WebSocket 连接的 Token 提取器 trait
144///
145/// Implement this trait to customize token extraction logic
146/// 实现此 trait 以自定义 Token 提取逻辑
147#[async_trait]
148pub trait WsTokenExtractor: Send + Sync {
149 /// Extract token from headers and query parameters
150 /// 从请求头和查询参数中提取 Token
151 ///
152 /// # Arguments | 参数
153 /// * `headers` - HTTP headers | HTTP 请求头
154 /// * `query` - Query parameters | 查询参数
155 ///
156 /// # Returns | 返回值
157 /// * `Some(token)` - Token found | 找到 Token
158 /// * `None` - No token found | 未找到 Token
159 async fn extract_token(&self, headers: &HashMap<String, String>, query: &HashMap<String, String>) -> Option<String>;
160}
161
162/// Default token extractor implementation
163/// 默认的 Token 提取器实现
164///
165/// Extracts tokens from:
166/// 从以下位置提取 Token:
167/// 1. Authorization header (Bearer token)
168/// 2. Sec-WebSocket-Protocol header
169/// 3. Query parameter "token"
170pub struct DefaultWsTokenExtractor;
171
172#[async_trait]
173impl WsTokenExtractor for DefaultWsTokenExtractor {
174 async fn extract_token(&self, headers: &HashMap<String, String>, query: &HashMap<String, String>) -> Option<String> {
175 // Priority 1: Authorization header with Bearer scheme
176 // 优先级 1: Authorization 请求头(Bearer 方式)
177 if let Some(token) = headers.get("Authorization") {
178 return Some(token.trim_start_matches("Bearer ").to_string());
179 }
180
181 // Priority 2: WebSocket Protocol header
182 // 优先级 2: WebSocket Protocol 请求头
183 if let Some(token) = headers.get("Sec-WebSocket-Protocol") {
184 return Some(token.to_string());
185 }
186
187 // Priority 3: Query parameter
188 // 优先级 3: 查询参数
189 if let Some(token) = query.get("token") {
190 return Some(token.to_string());
191 }
192
193 None
194 }
195}
196
197/// WebSocket authentication manager
198/// WebSocket 认证管理器
199///
200/// Provides authentication and verification for WebSocket connections
201/// 为 WebSocket 连接提供认证和验证功能
202pub struct WsAuthManager {
203 /// Reference to the token manager | Token 管理器引用
204 manager: Arc<SaTokenManager>,
205
206 /// Token extractor implementation | Token 提取器实现
207 extractor: Arc<dyn WsTokenExtractor>,
208}
209
210impl WsAuthManager {
211 /// Create a new WebSocket authentication manager with default extractor
212 /// 使用默认提取器创建新的 WebSocket 认证管理器
213 ///
214 /// # Arguments | 参数
215 /// * `manager` - SaToken manager instance | SaToken 管理器实例
216 ///
217 /// # Example | 示例
218 /// ```rust,ignore
219 /// let ws_auth = WsAuthManager::new(manager);
220 /// ```
221 pub fn new(manager: Arc<SaTokenManager>) -> Self {
222 Self {
223 manager,
224 extractor: Arc::new(DefaultWsTokenExtractor),
225 }
226 }
227
228 /// Create a new WebSocket authentication manager with custom extractor
229 /// 使用自定义提取器创建新的 WebSocket 认证管理器
230 ///
231 /// # Arguments | 参数
232 /// * `manager` - SaToken manager instance | SaToken 管理器实例
233 /// * `extractor` - Custom token extractor | 自定义 Token 提取器
234 ///
235 /// # Example | 示例
236 /// ```rust,ignore
237 /// let custom_extractor = Arc::new(MyCustomExtractor);
238 /// let ws_auth = WsAuthManager::with_extractor(manager, custom_extractor);
239 /// ```
240 pub fn with_extractor(manager: Arc<SaTokenManager>, extractor: Arc<dyn WsTokenExtractor>) -> Self {
241 Self {
242 manager,
243 extractor,
244 }
245 }
246
247 /// Authenticate a WebSocket connection
248 /// 认证 WebSocket 连接
249 ///
250 /// This method will trigger a Login event after successful authentication
251 /// 此方法在认证成功后会触发 Login 事件
252 ///
253 /// # Arguments | 参数
254 /// * `headers` - HTTP headers from the WebSocket handshake | WebSocket 握手的 HTTP 请求头
255 /// * `query` - Query parameters from the connection URL | 连接 URL 的查询参数
256 ///
257 /// # Returns | 返回值
258 /// * `Ok(WsAuthInfo)` - Authentication successful | 认证成功
259 /// * `Err(SaTokenError)` - Authentication failed | 认证失败
260 ///
261 /// # Errors | 错误
262 /// * `NotLogin` - No token found | 未找到 Token
263 /// * `TokenNotFound` - Token not found in storage | 存储中未找到 Token
264 /// * `TokenExpired` - Token has expired | Token 已过期
265 ///
266 /// # Events | 事件
267 /// Publishes `SaTokenEvent::Login` with login_type = "websocket"
268 /// 发布 `SaTokenEvent::Login` 事件,login_type = "websocket"
269 ///
270 /// # Example | 示例
271 /// ```rust,ignore
272 /// let mut headers = HashMap::new();
273 /// headers.insert("Authorization".to_string(), "Bearer token123".to_string());
274 ///
275 /// let auth_info = ws_auth.authenticate(&headers, &HashMap::new()).await?;
276 /// println!("User {} connected", auth_info.login_id);
277 ///
278 /// // Event listeners will be notified of WebSocket authentication
279 /// // 事件监听器将收到 WebSocket 认证通知
280 /// ```
281 pub async fn authenticate(
282 &self,
283 headers: &HashMap<String, String>,
284 query: &HashMap<String, String>,
285 ) -> Result<WsAuthInfo, SaTokenError> {
286 // Step 1: Extract token from request
287 // 步骤 1: 从请求中提取 Token
288 let token_str = self.extractor.extract_token(headers, query).await
289 .ok_or(SaTokenError::NotLogin)?;
290
291 // Step 2: Convert to TokenValue and get token info
292 // 步骤 2: 转换为 TokenValue 并获取 Token 信息
293 let token = TokenValue::new(token_str.clone());
294 let token_info = self.manager.get_token_info(&token).await?;
295
296 // Step 3: Validate token expiration
297 // 步骤 3: 验证 Token 过期时间
298 if let Some(expire_time) = token_info.expire_time {
299 if chrono::Utc::now() > expire_time {
300 return Err(SaTokenError::TokenExpired);
301 }
302 }
303
304 // Step 4: Generate unique WebSocket session ID
305 // 步骤 4: 生成唯一的 WebSocket 会话 ID
306 let login_id = token_info.login_id.clone();
307 let session_id = format!("ws:{}:{}", login_id, uuid::Uuid::new_v4());
308
309 // Step 5: Create authentication info
310 // 步骤 5: 创建认证信息
311 let auth_info = WsAuthInfo {
312 login_id: login_id.clone(),
313 token: token_str.clone(),
314 session_id,
315 connect_time: chrono::Utc::now(),
316 metadata: HashMap::new(),
317 };
318
319 // Step 6: Publish WebSocket authentication event (Login event with websocket type)
320 // 步骤 6: 发布 WebSocket 认证事件(标记为 websocket 类型的 Login 事件)
321 let event = SaTokenEvent::login(login_id, &token_str)
322 .with_login_type("websocket");
323 self.manager.event_bus().publish(event).await;
324
325 // Step 7: Return authentication info
326 // 步骤 7: 返回认证信息
327 Ok(auth_info)
328 }
329
330 /// Verify a token and return the login ID
331 /// 验证 Token 并返回登录 ID
332 ///
333 /// # Arguments | 参数
334 /// * `token` - Token string to verify | 要验证的 Token 字符串
335 ///
336 /// # Returns | 返回值
337 /// * `Ok(login_id)` - Token is valid | Token 有效
338 /// * `Err(SaTokenError)` - Token is invalid or expired | Token 无效或已过期
339 ///
340 /// # Example | 示例
341 /// ```rust,ignore
342 /// let login_id = ws_auth.verify_token("token123").await?;
343 /// println!("Token belongs to user: {}", login_id);
344 /// ```
345 pub async fn verify_token(&self, token: &str) -> Result<String, SaTokenError> {
346 let token_value = TokenValue::new(token.to_string());
347 let token_info = self.manager.get_token_info(&token_value).await?;
348
349 // Validate expiration | 验证过期时间
350 if let Some(expire_time) = token_info.expire_time {
351 if chrono::Utc::now() > expire_time {
352 return Err(SaTokenError::TokenExpired);
353 }
354 }
355
356 Ok(token_info.login_id)
357 }
358
359 /// Refresh a WebSocket session by verifying its token
360 /// 通过验证 Token 刷新 WebSocket 会话
361 ///
362 /// # Arguments | 参数
363 /// * `auth_info` - WebSocket authentication info | WebSocket 认证信息
364 ///
365 /// # Returns | 返回值
366 /// * `Ok(())` - Session refreshed successfully | 会话刷新成功
367 /// * `Err(SaTokenError)` - Token is invalid or expired | Token 无效或已过期
368 ///
369 /// # Example | 示例
370 /// ```rust,ignore
371 /// ws_auth.refresh_ws_session(&auth_info).await?;
372 /// ```
373 pub async fn refresh_ws_session(&self, auth_info: &WsAuthInfo) -> Result<(), SaTokenError> {
374 self.verify_token(&auth_info.token).await?;
375 Ok(())
376 }
377}
378
379#[cfg(test)]
380mod tests {
381 use super::*;
382 use crate::config::SaTokenConfig;
383 use sa_token_storage_memory::MemoryStorage;
384
385 #[tokio::test]
386 async fn test_ws_auth_manager() {
387 let config = SaTokenConfig::default();
388 let storage = Arc::new(MemoryStorage::new());
389 let manager = Arc::new(SaTokenManager::new(storage, config));
390
391 let ws_manager = WsAuthManager::new(manager.clone());
392
393 let token = manager.login("user123").await.unwrap();
394
395 let mut headers = HashMap::new();
396 headers.insert("Authorization".to_string(), format!("Bearer {}", token.as_str()));
397
398 let auth_info = ws_manager.authenticate(&headers, &HashMap::new()).await.unwrap();
399 assert_eq!(auth_info.login_id, "user123");
400 }
401
402 #[tokio::test]
403 async fn test_token_extraction_from_query() {
404 let config = SaTokenConfig::default();
405 let storage = Arc::new(MemoryStorage::new());
406 let manager = Arc::new(SaTokenManager::new(storage, config));
407
408 let ws_manager = WsAuthManager::new(manager.clone());
409
410 let token = manager.login("user456").await.unwrap();
411
412 let mut query = HashMap::new();
413 query.insert("token".to_string(), token.as_str().to_string());
414
415 let auth_info = ws_manager.authenticate(&HashMap::new(), &query).await.unwrap();
416 assert_eq!(auth_info.login_id, "user456");
417 }
418
419 #[tokio::test]
420 async fn test_verify_token() {
421 let config = SaTokenConfig::default();
422 let storage = Arc::new(MemoryStorage::new());
423 let manager = Arc::new(SaTokenManager::new(storage, config));
424
425 let ws_manager = WsAuthManager::new(manager.clone());
426
427 let token = manager.login("user789").await.unwrap();
428
429 let login_id = ws_manager.verify_token(token.as_str()).await.unwrap();
430 assert_eq!(login_id, "user789");
431 }
432}