Skip to main content

sh_layer1/streaming/
websocket.rs

1//! WebSocket 适配器
2//!
3//! 提供基于 tokio-tungstenite 的 WebSocket 连接和流式消息处理。
4//!
5//! ## 特性
6//! - 异步连接建立
7//! - 自动重连机制
8//! - 心跳保活
9//! - 消息序列化/反序列化
10//! - 流式消息接收
11
12use anyhow::{anyhow, Result};
13use async_trait::async_trait;
14use futures::{SinkExt, Stream, StreamExt};
15use std::collections::VecDeque;
16use std::pin::Pin;
17use std::sync::atomic::{AtomicBool, Ordering};
18use std::sync::Arc;
19use std::task::{Context, Poll};
20use std::time::Duration;
21use tokio::sync::{mpsc, Mutex};
22use tokio_tungstenite::{
23    connect_async, tungstenite::Message as WsMessage, MaybeTlsStream, WebSocketStream,
24};
25
26/// WebSocket 配置
27#[derive(Debug, Clone)]
28pub struct WebSocketConfig {
29    /// 连接超时(毫秒)
30    pub connect_timeout_ms: u64,
31    /// 心跳间隔(毫秒)
32    pub heartbeat_interval_ms: u64,
33    /// 重连最大尝试次数
34    pub max_reconnect_attempts: u32,
35    /// 重连间隔(毫秒)
36    pub reconnect_interval_ms: u64,
37    /// 接收缓冲区大小
38    pub receive_buffer_size: usize,
39}
40
41impl Default for WebSocketConfig {
42    fn default() -> Self {
43        Self {
44            connect_timeout_ms: 10000,
45            heartbeat_interval_ms: 30000,
46            max_reconnect_attempts: 3,
47            reconnect_interval_ms: 1000,
48            receive_buffer_size: 100,
49        }
50    }
51}
52
53/// WebSocket 消息
54#[derive(Debug, Clone)]
55pub enum WebSocketMessage {
56    /// 文本消息
57    Text(String),
58    /// 二进制消息
59    Binary(Vec<u8>),
60    /// Ping 消息
61    Ping(Vec<u8>),
62    /// Pong 消息
63    Pong(Vec<u8>),
64    /// 关闭消息
65    Close(Option<String>),
66}
67
68impl From<WsMessage> for WebSocketMessage {
69    fn from(msg: WsMessage) -> Self {
70        match msg {
71            WsMessage::Text(t) => WebSocketMessage::Text(t.to_string()),
72            WsMessage::Binary(b) => WebSocketMessage::Binary(b.to_vec()),
73            WsMessage::Ping(p) => WebSocketMessage::Ping(p.to_vec()),
74            WsMessage::Pong(p) => WebSocketMessage::Pong(p.to_vec()),
75            WsMessage::Close(_) => WebSocketMessage::Close(None),
76            WsMessage::Frame(_) => WebSocketMessage::Text(String::new()),
77        }
78    }
79}
80
81impl From<WebSocketMessage> for WsMessage {
82    fn from(msg: WebSocketMessage) -> Self {
83        match msg {
84            WebSocketMessage::Text(t) => WsMessage::Text(t.into()),
85            WebSocketMessage::Binary(b) => WsMessage::Binary(b.into()),
86            WebSocketMessage::Ping(p) => WsMessage::Ping(p.into()),
87            WebSocketMessage::Pong(p) => WsMessage::Pong(p.into()),
88            WebSocketMessage::Close(_) => WsMessage::Close(None),
89        }
90    }
91}
92
93/// WebSocket 连接状态
94#[derive(Debug, Clone, Copy, PartialEq)]
95pub enum ConnectionState {
96    /// 已断开
97    Disconnected,
98    /// 连接中
99    Connecting,
100    /// 已连接
101    Connected,
102    /// 重连中
103    Reconnecting,
104    /// 已关闭
105    Closed,
106}
107
108/// WebSocket 适配器
109///
110/// 提供高级 WebSocket 连接管理,包括:
111/// - 自动连接和重连
112/// - 消息流式接收
113/// - 线程安全的发送
114pub struct WebSocketAdapter {
115    /// WebSocket 配置
116    config: WebSocketConfig,
117    /// 连接 URL
118    url: String,
119    /// 连接状态
120    state: Arc<Mutex<ConnectionState>>,
121    /// 发送通道
122    sender: mpsc::Sender<WebSocketMessage>,
123    /// 中断标志
124    abort_flag: Arc<AtomicBool>,
125}
126
127impl WebSocketAdapter {
128    /// 创建新的 WebSocket 适配器
129    ///
130    /// 注意:需要调用 `connect()` 建立连接
131    pub fn new(url: impl Into<String>) -> Self {
132        Self::with_config(url, WebSocketConfig::default())
133    }
134
135    /// 创建带配置的 WebSocket 适配器
136    pub fn with_config(url: impl Into<String>, config: WebSocketConfig) -> Self {
137        let (sender, _) = mpsc::channel(config.receive_buffer_size);
138        Self {
139            config,
140            url: url.into(),
141            state: Arc::new(Mutex::new(ConnectionState::Disconnected)),
142            sender,
143            abort_flag: Arc::new(AtomicBool::new(false)),
144        }
145    }
146
147    /// 获取连接状态
148    pub async fn state(&self) -> ConnectionState {
149        *self.state.lock().await
150    }
151
152    /// 获取中断标志
153    pub fn abort_flag(&self) -> Arc<AtomicBool> {
154        Arc::clone(&self.abort_flag)
155    }
156
157    /// 请求中断
158    pub fn abort(&self) {
159        self.abort_flag.store(true, Ordering::Relaxed);
160    }
161
162    /// 建立 WebSocket 连接
163    ///
164    /// 返回消息接收流
165    pub async fn connect(&self) -> Result<WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>> {
166        {
167            let mut state = self.state.lock().await;
168            if *state == ConnectionState::Connected {
169                return Err(anyhow!("Already connected"));
170            }
171            *state = ConnectionState::Connecting;
172        }
173
174        let url = self.url.clone();
175        let timeout = Duration::from_millis(self.config.connect_timeout_ms);
176
177        let connect_future = async { connect_async(&url).await };
178
179        let result = tokio::time::timeout(timeout, connect_future).await;
180
181        match result {
182            Ok(Ok((stream, _))) => {
183                let mut state = self.state.lock().await;
184                *state = ConnectionState::Connected;
185                tracing::info!("WebSocket connected to {}", self.url);
186                Ok(stream)
187            }
188            Ok(Err(e)) => {
189                let mut state = self.state.lock().await;
190                *state = ConnectionState::Disconnected;
191                Err(anyhow!("WebSocket connection failed: {}", e))
192            }
193            Err(_) => {
194                let mut state = self.state.lock().await;
195                *state = ConnectionState::Disconnected;
196                Err(anyhow!("WebSocket connection timeout"))
197            }
198        }
199    }
200
201    /// 发送消息
202    pub async fn send(&self, message: WebSocketMessage) -> Result<()> {
203        self.sender.send(message).await?;
204        Ok(())
205    }
206
207    /// 创建消息流
208    ///
209    /// 连接 WebSocket 并返回消息接收流
210    pub async fn create_stream(&self) -> Result<WebSocketMessageStream> {
211        let stream = self.connect().await?;
212        Ok(WebSocketMessageStream::new(stream, self.abort_flag.clone()))
213    }
214
215    /// 关闭连接
216    pub async fn close(&self) -> Result<()> {
217        let mut state = self.state.lock().await;
218        *state = ConnectionState::Closed;
219        tracing::info!("WebSocket closed");
220        Ok(())
221    }
222}
223
224/// WebSocket 消息流
225///
226/// 包装 WebSocketStream,提供便捷的消息接收接口
227pub struct WebSocketMessageStream {
228    inner: WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>,
229    abort_flag: Arc<AtomicBool>,
230    pending: VecDeque<WebSocketMessage>,
231}
232
233impl WebSocketMessageStream {
234    fn new(
235        inner: WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>,
236        abort_flag: Arc<AtomicBool>,
237    ) -> Self {
238        Self {
239            inner,
240            abort_flag,
241            pending: VecDeque::new(),
242        }
243    }
244
245    /// 获取下一个消息
246    pub async fn next_message(&mut self) -> Result<Option<WebSocketMessage>> {
247        if self.abort_flag.load(Ordering::Relaxed) {
248            return Ok(None);
249        }
250
251        loop {
252            if let Some(msg) = self.pending.pop_front() {
253                return Ok(Some(msg));
254            }
255
256            match self.inner.next().await {
257                Some(Ok(ws_msg)) => {
258                    let msg: WebSocketMessage = ws_msg.into();
259                    match msg {
260                        WebSocketMessage::Ping(p) => {
261                            // 自动响应 Pong
262                            let _ = self.inner.send(WsMessage::Pong(p.into())).await;
263                        }
264                        WebSocketMessage::Close(_) => {
265                            return Ok(None);
266                        }
267                        other => {
268                            self.pending.push_back(other);
269                        }
270                    }
271                }
272                Some(Err(e)) => {
273                    tracing::error!("WebSocket error: {}", e);
274                    return Err(anyhow!("WebSocket error: {}", e));
275                }
276                None => return Ok(None),
277            }
278        }
279    }
280
281    /// 发送消息
282    pub async fn send(&mut self, message: WebSocketMessage) -> Result<()> {
283        let ws_msg: WsMessage = message.into();
284        self.inner.send(ws_msg).await?;
285        Ok(())
286    }
287
288    /// 收集所有文本消息
289    pub async fn collect_text(&mut self) -> Result<String> {
290        let mut result = String::new();
291        while let Some(msg) = self.next_message().await? {
292            if let WebSocketMessage::Text(t) = msg {
293                result.push_str(&t);
294            }
295        }
296        Ok(result)
297    }
298}
299
300/// 流式 WebSocket 接收器
301///
302/// 实现 Stream trait,可以与 async 迭代器一起使用
303pub struct WebSocketReceiver {
304    stream: WebSocketMessageStream,
305}
306
307impl WebSocketReceiver {
308    /// 创建接收器
309    pub fn new(stream: WebSocketMessageStream) -> Self {
310        Self { stream }
311    }
312}
313
314impl Stream for WebSocketReceiver {
315    type Item = Result<WebSocketMessage>;
316
317    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
318        // 使用 futures 库的 StreamExt 来轮询
319        let abort_flag = self.stream.abort_flag.clone();
320        if abort_flag.load(Ordering::Relaxed) {
321            return Poll::Ready(None);
322        }
323
324        // 委托给 inner stream
325        Pin::new(&mut self.stream.inner).poll_next(cx).map(|opt| {
326            opt.map(|result| {
327                result
328                    .map(WebSocketMessage::from)
329                    .map_err(|e| anyhow::anyhow!("WebSocket error: {}", e))
330            })
331        })
332    }
333}
334
335/// WebSocket 适配器 trait
336///
337/// 定义 WebSocket 连接的标准接口
338#[async_trait]
339pub trait WebSocketAdapterTrait: Send + Sync {
340    /// 建立 WebSocket 连接
341    async fn connect(&self) -> Result<()>;
342
343    /// 发送消息
344    async fn send(&self, message: &str) -> Result<()>;
345
346    /// 接收消息
347    async fn receive(&self) -> Result<Option<String>>;
348
349    /// 关闭连接
350    async fn close(&self) -> Result<()>;
351
352    /// 检查是否已连接
353    async fn is_connected(&self) -> bool;
354}
355
356#[cfg(test)]
357mod tests {
358    use super::*;
359
360    #[test]
361    fn test_websocket_config_default() {
362        let config = WebSocketConfig::default();
363        assert_eq!(config.connect_timeout_ms, 10000);
364        assert_eq!(config.heartbeat_interval_ms, 30000);
365        assert_eq!(config.max_reconnect_attempts, 3);
366    }
367
368    #[test]
369    fn test_websocket_message_conversion() {
370        let ws_msg = WsMessage::Text("hello".into());
371        let msg: WebSocketMessage = ws_msg.into();
372        assert!(matches!(msg, WebSocketMessage::Text(t) if t == "hello"));
373    }
374
375    #[test]
376    fn test_websocket_message_to_ws_message() {
377        let msg = WebSocketMessage::Binary(vec![1, 2, 3]);
378        let ws_msg: WsMessage = msg.into();
379        assert!(matches!(ws_msg, WsMessage::Binary(b) if b == vec![1, 2, 3]));
380    }
381
382    #[tokio::test]
383    async fn test_websocket_adapter_creation() {
384        let adapter = WebSocketAdapter::new("ws://localhost:8080");
385        assert_eq!(adapter.state().await, ConnectionState::Disconnected);
386    }
387
388    #[tokio::test]
389    async fn test_websocket_adapter_abort() {
390        let adapter = WebSocketAdapter::new("ws://localhost:8080");
391        assert!(!adapter.abort_flag().load(Ordering::Relaxed));
392        adapter.abort();
393        assert!(adapter.abort_flag().load(Ordering::Relaxed));
394    }
395
396    #[test]
397    fn test_connection_state() {
398        assert_eq!(ConnectionState::Disconnected, ConnectionState::Disconnected);
399        assert_ne!(ConnectionState::Disconnected, ConnectionState::Connected);
400    }
401}