Skip to main content

sh_layer4/channel_gateway/adapter/
websocket.rs

1//! # WebSocket Channel Adapter
2//!
3//! WebSocket 实时通信渠道适配器。
4
5use async_trait::async_trait;
6use parking_lot::RwLock;
7use std::collections::HashMap;
8use std::collections::VecDeque;
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::sync::Mutex as AsyncMutex;
12
13use crate::channel_gateway::{Channel, ChannelType, InboundMessage, OutboundMessage};
14use crate::types::Layer4Result;
15
16use futures::{SinkExt, StreamExt};
17use tokio::net::TcpStream;
18use tokio_tungstenite::{
19    connect_async, tungstenite::Message as WsMessage, MaybeTlsStream, WebSocketStream,
20};
21
22/// WebSocket 渠道配置
23#[derive(Debug, Clone)]
24pub struct WebSocketChannelConfig {
25    pub url: String,
26    pub reconnect_attempts: u32,
27    pub reconnect_interval_ms: u64,
28    pub ping_interval_ms: u64,
29    pub connect_timeout_ms: u64,
30}
31
32impl Default for WebSocketChannelConfig {
33    fn default() -> Self {
34        Self {
35            url: "ws://localhost:8080/ws".to_string(),
36            reconnect_attempts: 3,
37            reconnect_interval_ms: 1000,
38            ping_interval_ms: 30000,
39            connect_timeout_ms: 10000,
40        }
41    }
42}
43
44/// WebSocket 连接类型
45type WsConnection = WebSocketStream<MaybeTlsStream<TcpStream>>;
46
47/// WebSocket 渠道适配器
48pub struct WebSocketChannel {
49    channel_id: String,
50    config: WebSocketChannelConfig,
51    connected: RwLock<bool>,
52    message_queue: RwLock<VecDeque<InboundMessage>>,
53    sessions: RwLock<HashMap<String, String>>, // session_id -> user_id
54    /// WebSocket 连接(用于发送)
55    ws_sender: Arc<AsyncMutex<Option<futures::stream::SplitSink<WsConnection, WsMessage>>>>,
56}
57
58impl WebSocketChannel {
59    /// 创建新的 WebSocket 渠道
60    pub fn new(channel_id: impl Into<String>, config: WebSocketChannelConfig) -> Self {
61        Self {
62            channel_id: channel_id.into(),
63            config,
64            connected: RwLock::new(false),
65            message_queue: RwLock::new(VecDeque::new()),
66            sessions: RwLock::new(HashMap::new()),
67            ws_sender: Arc::new(AsyncMutex::new(None)),
68        }
69    }
70
71    /// 创建默认 WebSocket 渠道
72    pub fn default_channel() -> Self {
73        Self::new("ws-default", WebSocketChannelConfig::default())
74    }
75
76    /// 建立 WebSocket 连接
77    pub async fn connect(&self) -> Layer4Result<()> {
78        let url = self.config.url.clone();
79        let timeout = Duration::from_millis(self.config.connect_timeout_ms);
80
81        let connect_future = async { connect_async(&url).await };
82
83        let result = tokio::time::timeout(timeout, connect_future).await;
84
85        match result {
86            Ok(Ok((stream, _))) => {
87                // 分离读写
88                let (sink, _stream) = stream.split();
89                *self.ws_sender.lock().await = Some(sink);
90                *self.connected.write() = true;
91                tracing::info!("WebSocket connected to {}", url);
92                Ok(())
93            }
94            Ok(Err(e)) => {
95                tracing::error!("WebSocket connection failed: {}", e);
96                Err(anyhow::anyhow!("WebSocket connection failed: {}", e))
97            }
98            Err(_) => {
99                tracing::error!("WebSocket connection timeout");
100                Err(anyhow::anyhow!("WebSocket connection timeout"))
101            }
102        }
103    }
104
105    /// 带重连的连接
106    pub async fn connect_with_retry(&self) -> Layer4Result<()> {
107        let mut attempts = 0;
108        let max_attempts = self.config.reconnect_attempts;
109        let interval = Duration::from_millis(self.config.reconnect_interval_ms);
110
111        loop {
112            match self.connect().await {
113                Ok(_) => return Ok(()),
114                Err(e) => {
115                    attempts += 1;
116                    if attempts >= max_attempts {
117                        return Err(e);
118                    }
119                    tracing::warn!(
120                        "WebSocket connection attempt {}/{} failed, retrying...",
121                        attempts,
122                        max_attempts
123                    );
124                    tokio::time::sleep(interval).await;
125                }
126            }
127        }
128    }
129
130    /// 发送原始 WebSocket 消息
131    pub async fn send_raw(&self, message: WsMessage) -> Layer4Result<()> {
132        let mut sender = self.ws_sender.lock().await;
133        if let Some(ref mut sink) = *sender {
134            sink.send(message).await?;
135            Ok(())
136        } else {
137            Err(anyhow::anyhow!("WebSocket not connected"))
138        }
139    }
140
141    /// 发送文本消息
142    pub async fn send_text(&self, text: &str) -> Layer4Result<()> {
143        self.send_raw(WsMessage::Text(text.into())).await
144    }
145
146    /// 发送二进制消息
147    pub async fn send_binary(&self, data: Vec<u8>) -> Layer4Result<()> {
148        self.send_raw(WsMessage::Binary(data.into())).await
149    }
150
151    /// 注册会话
152    pub fn register_session(&self, session_id: &str, user_id: &str) {
153        self.sessions
154            .write()
155            .insert(session_id.to_string(), user_id.to_string());
156    }
157
158    /// 注销会话
159    pub fn unregister_session(&self, session_id: &str) {
160        self.sessions.write().remove(session_id);
161    }
162
163    /// 接收 WebSocket 消息(模拟)
164    pub fn receive_message(&self, session_id: &str, content: &str) {
165        let user_id = self
166            .sessions
167            .read()
168            .get(session_id)
169            .cloned()
170            .unwrap_or_default();
171        let message = InboundMessage::new(&self.channel_id, &user_id, content)
172            .with_session(session_id)
173            .with_metadata(serde_json::json!({
174                "source": "websocket",
175                "session_id": session_id
176            }));
177        self.message_queue.write().push_back(message);
178    }
179
180    /// 获取活跃会话数量
181    pub fn active_sessions(&self) -> usize {
182        self.sessions.read().len()
183    }
184}
185
186#[async_trait]
187impl Channel for WebSocketChannel {
188    fn id(&self) -> &str {
189        &self.channel_id
190    }
191
192    fn channel_type(&self) -> ChannelType {
193        ChannelType::WebSocket
194    }
195
196    async fn send(&self, message: &OutboundMessage) -> Layer4Result<()> {
197        if !*self.connected.read() {
198            return Err(anyhow::anyhow!("Channel not connected"));
199        }
200
201        // 序列化消息为 JSON
202        let payload = serde_json::json!({
203            "message_id": message.message_id,
204            "content": message.content,
205            "message_type": message.message_type,
206            "target": message.target,
207            "metadata": message.metadata,
208            "timestamp": message.timestamp.to_rfc3339(),
209        });
210
211        // 发送 WebSocket 文本消息
212        self.send_text(&payload.to_string()).await?;
213
214        tracing::debug!("WebSocket channel sent message {}", message.message_id);
215        Ok(())
216    }
217
218    async fn try_receive(&self) -> Layer4Result<Option<InboundMessage>> {
219        if !*self.connected.read() {
220            return Err(anyhow::anyhow!("Channel not connected"));
221        }
222
223        Ok(self.message_queue.write().pop_front())
224    }
225
226    fn is_connected(&self) -> bool {
227        *self.connected.read()
228    }
229
230    async fn close(&self) -> Layer4Result<()> {
231        // 发送 Close 帧
232        let mut sender = self.ws_sender.lock().await;
233        if let Some(ref mut sink) = *sender {
234            sink.close().await?;
235        }
236        *sender = None;
237
238        *self.connected.write() = false;
239        self.message_queue.write().clear();
240        self.sessions.write().clear();
241        tracing::info!("WebSocket channel closed");
242        Ok(())
243    }
244}
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249
250    #[test]
251    fn test_websocket_channel_creation() {
252        let channel = WebSocketChannel::default_channel();
253        assert_eq!(channel.id(), "ws-default");
254        // 初始状态未连接
255        assert!(!channel.is_connected());
256    }
257
258    #[test]
259    fn test_websocket_config_default() {
260        let config = WebSocketChannelConfig::default();
261        assert_eq!(config.reconnect_attempts, 3);
262        assert_eq!(config.ping_interval_ms, 30000);
263        assert_eq!(config.connect_timeout_ms, 10000);
264    }
265
266    #[test]
267    fn test_websocket_session_management() {
268        let channel = WebSocketChannel::default_channel();
269        channel.register_session("session-1", "user-1");
270
271        assert_eq!(channel.active_sessions(), 1);
272
273        channel.unregister_session("session-1");
274        assert_eq!(channel.active_sessions(), 0);
275    }
276
277    #[test]
278    fn test_websocket_receive_message() {
279        let channel = WebSocketChannel::default_channel();
280        // 手动设置连接状态以便测试消息接收
281        *channel.connected.write() = true;
282        channel.register_session("session-1", "user-1");
283        channel.receive_message("session-1", "Hello");
284
285        let count = channel.message_queue.read().len();
286        assert_eq!(count, 1);
287    }
288
289    #[tokio::test]
290    async fn test_websocket_channel_close() {
291        let channel = WebSocketChannel::default_channel();
292        // 手动设置连接状态
293        *channel.connected.write() = true;
294        channel.register_session("session-1", "user-1");
295        channel.close().await.unwrap();
296
297        assert!(!channel.is_connected());
298        assert_eq!(channel.active_sessions(), 0);
299    }
300
301    #[tokio::test]
302    async fn test_send_without_connection() {
303        let channel = WebSocketChannel::default_channel();
304        // 未连接时发送应该失败
305        let msg = OutboundMessage::to_user("test-user", "hello");
306        let result = channel.send(&msg).await;
307        assert!(result.is_err());
308    }
309}